xue50 vor 5 Monaten
Ursprung
Commit
5453210b96
3 geänderte Dateien mit 44 neuen und 39 gelöschten Zeilen
  1. 31 0
      models/dataset_tool.py
  2. 3 35
      models/wirenet/wirepoint_dataset.py
  3. 10 4
      models/wirenet/wirepoint_rcnn.py

+ 31 - 0
models/dataset_tool.py

@@ -215,6 +215,37 @@ def masks_to_boxes(masks: torch.Tensor, ) -> torch.Tensor:
     return bounding_boxes
 
 
+def line_boxes(target):
+    boxs = []
+    lpre = target['wires']["lpre"].cpu().numpy() * 4
+    vecl_target = target['wires']["lpre_label"].cpu().numpy()
+    lpre = lpre[vecl_target == 1]
+
+    lines = lpre
+    sline = np.ones(lpre.shape[0])
+
+    if len(lines) > 0 and not (lines[0] == 0).all():
+        for i, ((a, b), s) in enumerate(zip(lines, sline)):
+            if i > 0 and (lines[i] == lines[0]).all():
+                break
+            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            if a[1] > b[1]:
+                ymax = a[1] + 1
+                ymin = b[1] - 1
+            else:
+                ymin = a[1] - 1
+                ymax = b[1] + 1
+            if a[0] > b[0]:
+                xmax = a[0] + 1
+                xmin = b[0] - 1
+            else:
+                xmin = a[0] - 1
+                xmax = b[0] + 1
+            boxs.append([ymin, xmin, ymax, xmax])
+
+    return torch.tensor(boxs)
+
+
 def read_polygon_points_wire(lbl_path, shape):
     """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
     polygon_points = []

+ 3 - 35
models/wirenet/wirepoint_dataset.py

@@ -22,7 +22,7 @@ from torch.utils.data import Dataset
 from torch.utils.data.dataloader import default_collate
 
 import matplotlib.pyplot as plt
-from models.dataset_tool import masks_to_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
 
 
 class WirePointDataset(BaseDataset):
@@ -111,44 +111,12 @@ class WirePointDataset(BaseDataset):
 
         # print(torch.stack(masks).shape)    # [线段数, 512, 512]
         target = {}
-        # target["boxes"] = masks_to_boxes(torch.stack(masks))
-        # print(target["boxes"])
         target["labels"] = torch.stack(labels)
         target["masks"] = torch.stack(masks)
         target["image_id"] = torch.tensor(item)
         # return wire_labels, target
         target["wires"] = wire_labels
-
-        boxs = []
-        junc = target['wires']['junc_coords'].cpu().numpy() * 4
-        lpre = target['wires']["lpre"].cpu().numpy() * 4
-        vecl_target = target['wires']["lpre_label"].cpu().numpy()
-        lpre = lpre[vecl_target == 1]
-
-        lines = lpre
-        sline = np.ones(lpre.shape[0])
-
-        if len(lines) > 0 and not (lines[0] == 0).all():
-            for i, ((a, b), s) in enumerate(zip(lines, sline)):
-                if i > 0 and (lines[i] == lines[0]).all():
-                    break
-                # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
-                if a[1] > b[1]:
-                    ymax = a[1] + 1
-                    ymin = b[1] - 1
-                else:
-                    ymin = a[1] - 1
-                    ymax = b[1] + 1
-                if a[0] > b[0]:
-                    xmax = a[0] + 1
-                    xmin = b[0] - 1
-                else:
-                    xmin = a[0] - 1
-                    xmax = b[0] + 1
-                boxs.append([ymin, xmin, ymax, xmax])
-                # plt.Rectangle([a[1] - 1, b[1] + 1], [a[0] + 1, b[0] - 1], c="g", linewidth=1)
-        target["line_boxes"] = torch.tensor(boxs)
-
+        target["boxes"] = line_boxes(target)
         return target
 
     def show(self, idx):
@@ -184,7 +152,7 @@ class WirePointDataset(BaseDataset):
 
             img_path = os.path.join(self.img_path, self.imgs[idx])
             img = PIL.Image.open(img_path).convert('RGB')
-            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["line_boxes"],
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
                                               colors="yellow", width=1)
             plt.imshow(boxed_image.permute(1, 2, 0).numpy())
             plt.show()

+ 10 - 4
models/wirenet/wirepoint_rcnn.py

@@ -613,11 +613,17 @@ if __name__ == '__main__':
     )
     model = wirepointrcnn_resnet50_fpn()
 
-    imgs, targets = next(iter(data_loader))
+    for i in cfg['optim']['max_epoch']:
+        model.train()
+        imgs, targets = next(iter(data_loader))
+        pred = model(imgs, targets)
+
+    # imgs, targets = next(iter(data_loader))
+    #
+    # model.train()
+    # pred = model(imgs, targets)
+    # print(f'pred:{pred}')
 
-    model.train()
-    pred = model(imgs, targets)
-    print(f'pred:{pred}')
     # result, losses = model(imgs, targets)
     # print(f'result:{result}')
     # print(f'pred:{losses}')