浏览代码

WireDataset

xue50 5 月之前
父节点
当前提交
5ba7c49f94
共有 4 个文件被更改,包括 199 次插入66 次删除
  1. 44 0
      models/utils.py
  2. 5 0
      models/wirenet/head.py
  3. 77 0
      models/wirenet/postprocess.py
  4. 73 66
      models/wirenet/wirepoint_rcnn.py

+ 44 - 0
models/utils.py

@@ -0,0 +1,44 @@
+# import torch
+#
+#
+# def evaluate(model, data_loader, device):
+#     n_threads = torch.get_num_threads()
+#     # FIXME remove this and make paste_masks_in_image run on the GPU
+#     torch.set_num_threads(1)
+#     cpu_device = torch.device("cpu")
+#     model.eval()
+#     metric_logger = utils.MetricLogger(delimiter="  ")
+#     header = "Test:"
+#
+#     coco = get_coco_api_from_dataset(data_loader.dataset)
+#     iou_types = _get_iou_types(model)
+#     coco_evaluator = CocoEvaluator(coco, iou_types)
+#
+#     print(f'start to evaluate!!!')
+#     for images, targets in metric_logger.log_every(data_loader, 10, header):
+#         images = list(img.to(device) for img in images)
+#
+#         if torch.cuda.is_available():
+#             torch.cuda.synchronize()
+#         model_time = time.time()
+#         outputs = model(images)
+#
+#         outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
+#         model_time = time.time() - model_time
+#
+#         res = {target["image_id"]: output for target, output in zip(targets, outputs)}
+#         evaluator_time = time.time()
+#         coco_evaluator.update(res)
+#         evaluator_time = time.time() - evaluator_time
+#         metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
+#
+#     # gather the stats from all processes
+#     metric_logger.synchronize_between_processes()
+#     print("Averaged stats:", metric_logger)
+#     coco_evaluator.synchronize_between_processes()
+#
+#     # accumulate predictions from all images
+#     coco_evaluator.accumulate()
+#     coco_evaluator.summarize()
+#     torch.set_num_threads(n_threads)
+#     return coco_evaluator

+ 5 - 0
models/wirenet/head.py

@@ -1147,8 +1147,13 @@ class RoIHeads(nn.Module):
             else:
                 pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
                 result.append(pred)
+
                 loss_wirepoint = {}
 
+                # loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
+                # rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
+                # loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
+
             # tmp = wirepoint_features[0][0]
             # plt.imshow(tmp.detach().numpy())
             # wirepoint_logits = self.wirepoint_predictor((outputs,wirepoint_features))

+ 77 - 0
models/wirenet/postprocess.py

@@ -0,0 +1,77 @@
+import numpy as np
+
+
+def pline(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    u = ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
+    dx = x1 + u * px - x
+    dy = y1 + u * py - y
+    return dx * dx + dy * dy
+
+
+def psegment(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    u = max(min(((x - x1) * px + (y - y1) * py) / float(dd), 1), 0)
+    dx = x1 + u * px - x
+    dy = y1 + u * py - y
+    return dx * dx + dy * dy
+
+
+def plambda(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    return ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
+
+
+def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
+    nlines, nscores = [], []
+    for (p, q), score in zip(lines, scores):
+        start, end = 0, 1
+        for a, b in nlines:
+            if (
+                min(
+                    max(pline(*p, *q, *a), pline(*p, *q, *b)),
+                    max(pline(*a, *b, *p), pline(*a, *b, *q)),
+                )
+                > threshold ** 2
+            ):
+                continue
+            lambda_a = plambda(*p, *q, *a)
+            lambda_b = plambda(*p, *q, *b)
+            if lambda_a > lambda_b:
+                lambda_a, lambda_b = lambda_b, lambda_a
+            lambda_a -= tol
+            lambda_b += tol
+
+            # case 1: skip (if not do_clip)
+            if start < lambda_a and lambda_b < end:
+                continue
+
+            # not intersect
+            if lambda_b < start or lambda_a > end:
+                continue
+
+            # cover
+            if lambda_a <= start and end <= lambda_b:
+                start = 10
+                break
+
+            # case 2 & 3:
+            if lambda_a <= start and start <= lambda_b:
+                start = lambda_b
+            if lambda_a <= end and end <= lambda_b:
+                end = lambda_a
+
+            if start >= end:
+                break
+
+        if start >= end:
+            continue
+        nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
+        nscores.append(score)
+    return np.array(nlines), np.array(nscores)

+ 73 - 66
models/wirenet/wirepoint_rcnn.py

@@ -33,6 +33,8 @@ import matplotlib as mpl
 from skimage import io
 import os.path as osp
 from torchvision.utils import draw_bounding_boxes
+from torchvision import transforms
+from models.wirenet.postprocess import postprocess
 
 FEATURE_DIM = 8
 
@@ -575,52 +577,61 @@ sm.set_array([])
 def c(x):
     return sm.to_rgba(x)
 
-
-def imshow(im):
-    plt.close()
-    plt.tight_layout()
-    plt.imshow(im)
-    plt.colorbar(sm, fraction=0.046)
-    plt.xlim([0, im.shape[0]])
-    plt.ylim([im.shape[0], 0])
-    plt.show()
-
-
-def _plot_samples(img, i, result, prefix, epoch):
-    print(f"prefix:{prefix}")
-    def draw_vecl(lines, sline, juncs, junts, fn):
-        if not os.path.exists(fn):
-            os.makedirs(fn)
-        imshow(img.permute(1, 2, 0))
-        if len(lines) > 0 and not (lines[0] == 0).all():
-            for i, ((a, b), s) in enumerate(zip(lines, sline)):
-                if i > 0 and (lines[i] == lines[0]).all():
-                    break
-                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
-        if not (juncs[0] == 0).all():
-            for i, j in enumerate(juncs):
-                if i > 0 and (i == juncs[0]).all():
-                    break
-                plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
-        if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
-            for i, j in enumerate(junts):
-                if i > 0 and (i == junts[0]).all():
-                    break
-                plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
-        plt.savefig(fn), plt.close()
-
-    rjuncs = result["juncs"][i].cpu().numpy() * 4
-    rjunts = None
-    if "junts" in result:
-        rjunts = result["junts"][i].cpu().numpy() * 4
-
-    vecl_result = result["lines"][i].cpu().numpy() * 4
-    score = result["score"][i].cpu().numpy()
-
-    draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
-
-    img1 = cv2.imread(f"{prefix}_vecl_b.jpg")
-    writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch)
+#
+# def imshow(im):
+#     plt.close()
+#     plt.tight_layout()
+#     plt.imshow(im)
+#     plt.colorbar(sm, fraction=0.046)
+#     plt.xlim([0, im.shape[0]])
+#     plt.ylim([im.shape[0], 0])
+#     # plt.show()
+
+
+def show_line(img, pred,  epoch, write):
+    im = img.permute(1, 2, 0)
+    writer.add_image("ori", im, epoch, dataformats="HWC")
+
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
+                                      colors="yellow", width=1)
+    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    H = pred[1]['wires']
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    for i, t in enumerate([0.5]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im)
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+        img2 = transforms.ToTensor()(image_from_plot)
+
+        writer.add_image("output", img2, epoch)
 
 
 if __name__ == '__main__':
@@ -698,27 +709,23 @@ if __name__ == '__main__':
         print(f"epoch:{epoch}")
         model.train()
 
-        for imgs, targets in data_loader_train:
-            losses = model(move_to_device(imgs, device), move_to_device(targets, device))
-            loss = _loss(losses)
-            print(loss)
-            optimizer.zero_grad()
-            loss.backward()
-            optimizer.step()
-            writer_loss(writer, losses, epoch)
-
-            model.eval()
-            with torch.no_grad():
-                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
-                    pred = model(move_to_device(imgs, device))
-                    # print(f"pred:{pred}")
-
-                    if batch_idx == 0:
-                        result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores']
-                        print(imgs[0].shape)  # [3,512,512]
-                        # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3)
-                        _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch)
+        # for imgs, targets in data_loader_train:
+        # losses = model(move_to_device(imgs, device), move_to_device(targets, device))
+        # loss = _loss(losses)
+        # print(loss)
+        # optimizer.zero_grad()
+        # loss.backward()
+        # optimizer.step()
+        # writer_loss(writer, losses, epoch)
+
+        model.eval()
+        with torch.no_grad():
+            for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+                pred = model(move_to_device(imgs, device))  # # pred[0].keys()   ['boxes', 'labels', 'scores']
+                # print(f"pred:{pred}")
 
+                if batch_idx == 0:
+                    show_line(imgs[0], pred,  epoch, writer)
 
 # imgs, targets = next(iter(data_loader))
 #