Bläddra i källkod

修改gitignore

RenLiqiang 5 månader sedan
förälder
incheckning
6ccc414be2
5 ändrade filer med 305 tillägg och 40 borttagningar
  1. 26 0
      .gitignore
  2. 44 0
      models/utils.py
  3. 5 0
      models/wirenet/head.py
  4. 77 0
      models/wirenet/postprocess.py
  5. 153 40
      models/wirenet/wirepoint_rcnn.py

+ 26 - 0
.gitignore

@@ -1,5 +1,31 @@
 .idea
 *.pt
+*.log
+*.onnx
 runs
+logs
+log
+
+/tensorboard/
+logs/
+tensorboard_logs/
+summaries/
+events.out.tfevents.*
+
+# If you have a specific directory for your runs, you can ignore it directly
+/runs/
+
+# Ignore checkpoint files if you don't want to track them
+checkpoint
+*.ckpt.data-*
+*.ckpt.index
+*.ckpt.meta
+
+# Ignore TensorFlow model files that are not necessary for version control
+*.pb
+/*.pbtxt
+# Ignore Jupyter Notebook checkpoints
+/.ipynb_checkpoints/
+
 __pycache__
 train_results

+ 44 - 0
models/utils.py

@@ -0,0 +1,44 @@
+# import torch
+#
+#
+# def evaluate(model, data_loader, device):
+#     n_threads = torch.get_num_threads()
+#     # FIXME remove this and make paste_masks_in_image run on the GPU
+#     torch.set_num_threads(1)
+#     cpu_device = torch.device("cpu")
+#     model.eval()
+#     metric_logger = utils.MetricLogger(delimiter="  ")
+#     header = "Test:"
+#
+#     coco = get_coco_api_from_dataset(data_loader.dataset)
+#     iou_types = _get_iou_types(model)
+#     coco_evaluator = CocoEvaluator(coco, iou_types)
+#
+#     print(f'start to evaluate!!!')
+#     for images, targets in metric_logger.log_every(data_loader, 10, header):
+#         images = list(img.to(device) for img in images)
+#
+#         if torch.cuda.is_available():
+#             torch.cuda.synchronize()
+#         model_time = time.time()
+#         outputs = model(images)
+#
+#         outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
+#         model_time = time.time() - model_time
+#
+#         res = {target["image_id"]: output for target, output in zip(targets, outputs)}
+#         evaluator_time = time.time()
+#         coco_evaluator.update(res)
+#         evaluator_time = time.time() - evaluator_time
+#         metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
+#
+#     # gather the stats from all processes
+#     metric_logger.synchronize_between_processes()
+#     print("Averaged stats:", metric_logger)
+#     coco_evaluator.synchronize_between_processes()
+#
+#     # accumulate predictions from all images
+#     coco_evaluator.accumulate()
+#     coco_evaluator.summarize()
+#     torch.set_num_threads(n_threads)
+#     return coco_evaluator

+ 5 - 0
models/wirenet/head.py

@@ -1147,8 +1147,13 @@ class RoIHeads(nn.Module):
             else:
                 pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
                 result.append(pred)
+
                 loss_wirepoint = {}
 
+                # loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
+                # rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
+                # loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
+
             # tmp = wirepoint_features[0][0]
             # plt.imshow(tmp.detach().numpy())
             # wirepoint_logits = self.wirepoint_predictor((outputs,wirepoint_features))

+ 77 - 0
models/wirenet/postprocess.py

@@ -0,0 +1,77 @@
+import numpy as np
+
+
+def pline(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    u = ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
+    dx = x1 + u * px - x
+    dy = y1 + u * py - y
+    return dx * dx + dy * dy
+
+
+def psegment(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    u = max(min(((x - x1) * px + (y - y1) * py) / float(dd), 1), 0)
+    dx = x1 + u * px - x
+    dy = y1 + u * py - y
+    return dx * dx + dy * dy
+
+
+def plambda(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    return ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
+
+
+def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
+    nlines, nscores = [], []
+    for (p, q), score in zip(lines, scores):
+        start, end = 0, 1
+        for a, b in nlines:
+            if (
+                min(
+                    max(pline(*p, *q, *a), pline(*p, *q, *b)),
+                    max(pline(*a, *b, *p), pline(*a, *b, *q)),
+                )
+                > threshold ** 2
+            ):
+                continue
+            lambda_a = plambda(*p, *q, *a)
+            lambda_b = plambda(*p, *q, *b)
+            if lambda_a > lambda_b:
+                lambda_a, lambda_b = lambda_b, lambda_a
+            lambda_a -= tol
+            lambda_b += tol
+
+            # case 1: skip (if not do_clip)
+            if start < lambda_a and lambda_b < end:
+                continue
+
+            # not intersect
+            if lambda_b < start or lambda_a > end:
+                continue
+
+            # cover
+            if lambda_a <= start and end <= lambda_b:
+                start = 10
+                break
+
+            # case 2 & 3:
+            if lambda_a <= start and start <= lambda_b:
+                start = lambda_b
+            if lambda_a <= end and end <= lambda_b:
+                end = lambda_a
+
+            if start >= end:
+                break
+
+        if start >= end:
+            continue
+        nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
+        nscores.append(score)
+    return np.array(nlines), np.array(nscores)

+ 153 - 40
models/wirenet/wirepoint_rcnn.py

@@ -33,6 +33,8 @@ import matplotlib as mpl
 from skimage import io
 import os.path as osp
 from torchvision.utils import draw_bounding_boxes
+from torchvision import transforms
+from models.wirenet.postprocess import postprocess
 
 FEATURE_DIM = 8
 
@@ -583,32 +585,91 @@ def imshow(im):
     plt.colorbar(sm, fraction=0.046)
     plt.xlim([0, im.shape[0]])
     plt.ylim([im.shape[0], 0])
-    plt.show()
+    # plt.show()
+
+
+# def _plot_samples(img, i, result, prefix, epoch):
+#     print(f"prefix:{prefix}")
+#     def draw_vecl(lines, sline, juncs, junts, fn):
+#         directory = os.path.dirname(fn)
+#         if not os.path.exists(directory):
+#             os.makedirs(directory)
+#         imshow(img.permute(1, 2, 0))
+#         if len(lines) > 0 and not (lines[0] == 0).all():
+#             for i, ((a, b), s) in enumerate(zip(lines, sline)):
+#                 if i > 0 and (lines[i] == lines[0]).all():
+#                     break
+#                 plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
+#         if not (juncs[0] == 0).all():
+#             for i, j in enumerate(juncs):
+#                 if i > 0 and (i == juncs[0]).all():
+#                     break
+#                 plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
+#         if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
+#             for i, j in enumerate(junts):
+#                 if i > 0 and (i == junts[0]).all():
+#                     break
+#                 plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
+#         plt.savefig(fn), plt.close()
+#
+#     rjuncs = result["juncs"][i].cpu().numpy() * 4
+#     rjunts = None
+#     if "junts" in result:
+#         rjunts = result["junts"][i].cpu().numpy() * 4
+#
+#     vecl_result = result["lines"][i].cpu().numpy() * 4
+#     score = result["score"][i].cpu().numpy()
+#
+#     draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
+#
+#     img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
+#     writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch)
 
+def _plot_samples(img, i, result, prefix, epoch, writer):
+    # print(f"prefix:{prefix}")
 
-def _plot_samples(img, i, result, prefix, epoch):
-    print(f"prefix:{prefix}")
     def draw_vecl(lines, sline, juncs, junts, fn):
-        if not os.path.exists(fn):
-            os.makedirs(fn)
-        imshow(img.permute(1, 2, 0))
+        # 确保目录存在
+        directory = os.path.dirname(fn)
+        if not os.path.exists(directory):
+            os.makedirs(directory)
+
+        # 绘制图像
+        plt.figure()
+        plt.imshow(img.permute(1, 2, 0).cpu().numpy())
+        plt.axis('off')  # 可选:关闭坐标轴
+
         if len(lines) > 0 and not (lines[0] == 0).all():
-            for i, ((a, b), s) in enumerate(zip(lines, sline)):
-                if i > 0 and (lines[i] == lines[0]).all():
+            for idx, ((a, b), s) in enumerate(zip(lines, sline)):
+                if idx > 0 and (lines[idx] == lines[0]).all():
                     break
-                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=1)
+
         if not (juncs[0] == 0).all():
-            for i, j in enumerate(juncs):
-                if i > 0 and (i == juncs[0]).all():
+            for idx, j in enumerate(juncs):
+                if idx > 0 and (j == juncs[0]).all():
                     break
-                plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
+                plt.scatter(j[1], j[0], c="red", s=20, zorder=100)
+
         if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
-            for i, j in enumerate(junts):
-                if i > 0 and (i == junts[0]).all():
+            for idx, j in enumerate(junts):
+                if idx > 0 and (j == junts[0]).all():
                     break
-                plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
-        plt.savefig(fn), plt.close()
+                plt.scatter(j[1], j[0], c="blue", s=20, zorder=100)
+
+        # plt.show()
+
+        # 将matplotlib图像转换为numpy数组
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
 
+        return image_from_plot
+
+    # 获取结果数据并转换为numpy数组
     rjuncs = result["juncs"][i].cpu().numpy() * 4
     rjunts = None
     if "junts" in result:
@@ -617,10 +678,62 @@ def _plot_samples(img, i, result, prefix, epoch):
     vecl_result = result["lines"][i].cpu().numpy() * 4
     score = result["score"][i].cpu().numpy()
 
-    draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
-
-    img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
-    writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch)
+    # 调用绘图函数并获取图像
+    image_path = f"{prefix}_vecl_b.jpg"
+    image_array = draw_vecl(vecl_result, score, rjuncs, rjunts, image_path)
+
+    # 将numpy数组转换为torch tensor,并写入TensorBoard
+    image_tensor = transforms.ToTensor()(image_array)
+    writer.add_image(f'output_epoch', image_tensor, global_step=epoch)
+    writer.add_image(f'ori_epoch', img, global_step=epoch)
+
+
+def show_line(img, pred, prefix, epoch, write):
+    fn = f"{prefix}_line.jpg"
+    directory = os.path.dirname(fn)
+    if not os.path.exists(directory):
+        os.makedirs(directory)
+    print(fn)
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    H = pred
+
+    im = img.permute(1, 2, 0)
+
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    for i, t in enumerate([0.5]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im)
+        plt.savefig(fn, bbox_inches="tight")
+        plt.show()
+        plt.close()
+
+
+        img2 = cv2.imread(fn)  # 预测图
+        # img1 = im.resize(img2.shape)  # 原图
+
+        # writer.add_images(f"{epoch}", torch.tensor([img1, img2]), dataformats='NHWC')
+        writer.add_image("output", img2, epoch)
 
 
 if __name__ == '__main__':
@@ -698,27 +811,27 @@ if __name__ == '__main__':
         print(f"epoch:{epoch}")
         model.train()
 
-        for imgs, targets in data_loader_train:
-            losses = model(move_to_device(imgs, device), move_to_device(targets, device))
-            loss = _loss(losses)
-            print(loss)
-            optimizer.zero_grad()
-            loss.backward()
-            optimizer.step()
-            writer_loss(writer, losses, epoch)
-
-            model.eval()
-            with torch.no_grad():
-                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
-                    pred = model(move_to_device(imgs, device))
-                    # print(f"pred:{pred}")
-
-                    if batch_idx == 0:
-                        result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores']
-                        print(imgs[0].shape)  # [3,512,512]
-                        # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3)
-                        _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch)
+        # for imgs, targets in data_loader_train:
+        # losses = model(move_to_device(imgs, device), move_to_device(targets, device))
+        # loss = _loss(losses)
+        # print(loss)
+        # optimizer.zero_grad()
+        # loss.backward()
+        # optimizer.step()
+        # writer_loss(writer, losses, epoch)
 
+        model.eval()
+        with torch.no_grad():
+            for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+                pred = model(move_to_device(imgs, device))
+                # print(f"pred:{pred}")
+
+                if batch_idx == 0:
+                    result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores']
+                    print(imgs[0].shape)  # [3,512,512]
+                    # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3)
+                    _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer)
+                    # show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer)
 
 # imgs, targets = next(iter(data_loader))
 #