浏览代码

WireDataset_

xue50 5 月之前
父节点
当前提交
a666f6881b
共有 3 个文件被更改,包括 121 次插入36 次删除
  1. 2 2
      models/dataset_tool.py
  2. 23 0
      models/wirenet/test.py
  3. 96 34
      models/wirenet/wirepoint_dataset.py

+ 2 - 2
models/dataset_tool.py

@@ -173,7 +173,7 @@ def read_masks_from_txt(label_path, shape):
     return labels, masks
 
 
-def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
+def masks_to_boxes(masks: torch.Tensor, ) -> torch.Tensor:
     """
     Compute the bounding boxes around the provided masks.
 
@@ -278,4 +278,4 @@ def adjacency_matrix(n, link):  # 邻接矩阵
     if len(link) > 0:
         mat[link[:, 0], link[:, 1]] = 1
         mat[link[:, 1], link[:, 0]] = 1
-    return mat
+    return mat

+ 23 - 0
models/wirenet/test.py

@@ -0,0 +1,23 @@
+from models.wirenet.wirepoint_dataset import WirePointDataset
+from models.config.config_tool import read_yaml
+
+# image_file = "D:/python/PycharmProjects/data"
+#
+# label_file = "D:/python/PycharmProjects/data/labels/train"
+# dataset_test = WireDataset(image_file)
+# dataset_test.show(0)
+# for i in dataset_test:
+#     print(i)
+cfg = 'wirenet.yaml'
+cfg = read_yaml(cfg)
+print(f'cfg:{cfg}')
+print(cfg['model']['n_dyn_negl'])
+# net = WirepointPredictor()
+
+dataset = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
+# dataset.show(0)
+
+for i in range(len(dataset)):
+    dataset.show(i)
+
+

+ 96 - 34
models/wirenet/wirepoint_dataset.py

@@ -10,6 +10,10 @@ import random
 import cv2
 import PIL
 
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
 import numpy as np
 import numpy.linalg as LA
 import torch
@@ -51,8 +55,10 @@ class WirePointDataset(BaseDataset):
 
         # print(f'img:{img}')
         return img, target
+
     def __len__(self):
         return len(self.imgs)
+
     def read_target(self, item, lbl_path, shape, extra=None):
         # print(f'lbl_path:{lbl_path}')
         with open(lbl_path, 'r') as file:
@@ -103,49 +109,105 @@ class WirePointDataset(BaseDataset):
         elif self.target_type == 'pixel':
             labels, masks = read_masks_from_pixels_wire(lbl_path, shape)
 
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
         target = {}
-        target["boxes"] = masks_to_boxes(torch.stack(masks))
+        # target["boxes"] = masks_to_boxes(torch.stack(masks))
+        # print(target["boxes"])
         target["labels"] = torch.stack(labels)
         target["masks"] = torch.stack(masks)
         target["image_id"] = torch.tensor(item)
         # return wire_labels, target
         target["wires"] = wire_labels
+
+        boxs = []
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        lines = lpre
+        sline = np.ones(lpre.shape[0])
+
+        if len(lines) > 0 and not (lines[0] == 0).all():
+            for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                if i > 0 and (lines[i] == lines[0]).all():
+                    break
+                # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+                if a[1] > b[1]:
+                    ymax = a[1] + 1
+                    ymin = b[1] - 1
+                else:
+                    ymin = a[1] - 1
+                    ymax = b[1] + 1
+                if a[0] > b[0]:
+                    xmax = a[0] + 1
+                    xmin = b[0] - 1
+                else:
+                    xmin = a[0] - 1
+                    xmax = b[0] + 1
+                boxs.append([ymin, xmin, ymax, xmax])
+                # plt.Rectangle([a[1] - 1, b[1] + 1], [a[0] + 1, b[0] - 1], c="g", linewidth=1)
+        target["line_boxes"] = torch.tensor(boxs)
+
         return target
 
     def show(self, idx):
-        img_path = os.path.join(self.img_path, self.imgs[idx])
-        lbl_path = os.path.join(self.lbl_path, self.imgs[idx][:-3] + 'json')
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        def draw_vecl(lines, sline, juncs, junts, fn=None):
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            imshow(io.imread(img_path))
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            img = PIL.Image.open(img_path).convert('RGB')
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["line_boxes"],
+                                              colors="yellow", width=1)
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
+
+            plt.show()
+            if fn != None:
+                plt.savefig(fn)
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
+
+
+    def show_img(self, img_path):
+        pass
 
-        with open(lbl_path, 'r') as file:
-            lable_all = json.load(file)
 
-        # 可视化图像和标注
-        image = cv2.imread(img_path)  # [H,W,3]  # 默认为BGR格式
-        # print(image.shape)
-        # 绘制每个标注的多边形
-        # for ann in lable_all["segmentations"]:
-        #     segmentation = [[x * 512 for x in ann['data']]]
-        #     # segmentation = [ann['data']]
-        #     # for i in range(len(ann['data'])):
-        #     #     if i % 2 == 0:
-        #     #         segmentation[0][i] *= image.shape[0]
-        #     #     else:
-        #     #         segmentation[0][i] *= image.shape[0]
-        #
-        #     # if isinstance(segmentation, list):
-        #     #     for seg in segmentation:
-        #     #         poly = np.array(seg).reshape((-1, 2)).astype(int)
-        #     #         cv2.polylines(image, [poly], isClosed=True, color=(0, 255, 0), thickness=2)
-        #     #         cv2.fillPoly(image, [poly], color=(0, 255, 0))
-
-
-        #
-        # # 显示图像
-        # cv2.namedWindow('Image with Segmentations', cv2.WINDOW_NORMAL)
-        # cv2.imshow('Image with Segmentations', image)
-        # cv2.waitKey(0)
-        # cv2.destroyAllWindows()
-
-    def show_img(self,img_path):
-        pass