Quellcode durchsuchen

keypoint_dataset

xue50 vor 5 Monaten
Ursprung
Commit
f86305e5ad
2 geänderte Dateien mit 16 neuen und 12 gelöschten Zeilen
  1. 7 1
      models/dataset_tool.py
  2. 9 11
      models/keypoint/keypoint_dataset.py

+ 7 - 1
models/dataset_tool.py

@@ -224,11 +224,17 @@ def line_boxes(target):
     lines = lpre
     sline = np.ones(lpre.shape[0])
 
+    keypoints = []
+
     if len(lines) > 0 and not (lines[0] == 0).all():
         for i, ((a, b), s) in enumerate(zip(lines, sline)):
             if i > 0 and (lines[i] == lines[0]).all():
                 break
             # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+
+            keypoints.append([a[0], b[0]])
+            keypoints.append([a[1], b[1]])
+
             if a[1] > b[1]:
                 ymax = a[1] + 1
                 ymin = b[1] - 1
@@ -243,7 +249,7 @@ def line_boxes(target):
                 xmax = b[0] + 1
             boxs.append([ymin, xmin, ymax, xmax])
 
-    return torch.tensor(boxs)
+    return torch.tensor(boxs), torch.tensor(keypoints)
 
 
 def read_polygon_points_wire(lbl_path, shape):

+ 9 - 11
models/keypoint/keypoint_dataset.py

@@ -24,6 +24,7 @@ from torch.utils.data.dataloader import default_collate
 import matplotlib.pyplot as plt
 from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
 
+
 def validate_keypoints(keypoints, image_width, image_height):
     for kp in keypoints:
         x, y, v = kp
@@ -123,14 +124,14 @@ class KeypointDataset(BaseDataset):
 
         target["labels"] = torch.stack(labels)
         # print(f'labels:{target["labels"]}')
-        target["boxes"] = line_boxes(target)
+        target["boxes"], keypoints = line_boxes(target)
         # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
 
-        keypoints= wire_labels["junc_coords"]
-        keypoints[:,2]=2
-        # keypoints[:,0]=keypoints[:,0]/shape[0]
-        # keypoints[:, 1] = keypoints[:, 1] / shape[1]
-        target["keypoints"]=keypoints
+        # keypoints= wire_labels["junc_coords"]
+        a = torch.full((keypoints.shape[0],), 2).unsqueeze(1)
+        keypoints = torch.cat((keypoints, a), dim=1)
+        target["keypoints"] = keypoints.to(torch.float32)
+        print(f'boxes:{target["boxes"].shape}')
         # 在 __getitem__ 方法中调用此函数
         validate_keypoints(keypoints, shape[0], shape[1])
         print(f'keypoints:{target["keypoints"].shape}')
@@ -166,7 +167,6 @@ class KeypointDataset(BaseDataset):
                         break
                     plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
 
-
             img_path = os.path.join(self.img_path, self.imgs[idx])
             img = PIL.Image.open(img_path).convert('RGB')
             boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
@@ -190,13 +190,11 @@ class KeypointDataset(BaseDataset):
         # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
         draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
 
-
     def show_img(self, img_path):
         pass
 
 
-
 if __name__ == '__main__':
-    path=r"I:\wirenet_dateset"
-    dataset= KeypointDataset(dataset_path=path, dataset_type='train')
+    path = r"I:\wirenet_dateset"
+    dataset = KeypointDataset(dataset_path=path, dataset_type='train')
     dataset.show(0)