Browse Source

save local changes

zhaoyinghan 1 month ago
parent
commit
bf92764c38
1 changed files with 41 additions and 23 deletions
  1. 41 23
      models/line_detect/line_dataset.py

+ 41 - 23
models/line_detect/line_dataset.py

@@ -71,7 +71,7 @@ class LineDataset(BaseDataset):
             img = PIL.Image.open(img_path).convert('RGB')
             w, h = img.size
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w),image=img)
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
 
         self.transforms = get_transforms(augmention=self.augmentation)
 
@@ -82,7 +82,7 @@ class LineDataset(BaseDataset):
     def __len__(self):
         return len(self.imgs)
 
-    def read_target(self, item, lbl_path, shape,extra=None,image=None):
+    def read_target(self, item, lbl_path, shape, extra=None):
         # print(f'shape:{shape}')
         # print(f'lbl_path:{lbl_path}')
         with open(lbl_path, 'r') as file:
@@ -122,30 +122,26 @@ class LineDataset(BaseDataset):
             # print_params(arc_angles)
             arc_masks = []
 
+
+
             for i in range(len(arc_params)):
                 arc_param_i = arc_params[i].view(-1)  # shape (5,)
                 arc_angle_i = arc_angles[i].view(-1)  # shape (2,)
                 arc7 = torch.cat([arc_param_i, arc_angle_i], dim=0)  # shape (7,)
 
 
+                # print_params(arc7)
                 mask = arc_to_mask(arc7, shape, line_width=1)
 
                 arc_masks.append(mask)
+                # arc7=arc_params[i] + arc_angles[i].tolist()
+                # arc_masks.append(arc_to_mask(arc7, shape, line_width=1))
 
-            print_params(arc_masks,image)
+            # print(f'circle_masks:{torch.stack(arc_masks, dim=0).shape}')
 
             target['circle_masks'] = torch.stack(arc_masks, dim=0)
-
-            # for i, m in enumerate(target['circle_masks']):
-            #     save_full_mask(
-            #         m,
-            #         name=f"arc_mask_{i}",
-            #         out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset",
-            #         save_png=True,
-            #         save_npy=True,
-            #         image=image,
-            #         show_on_image=True
-            #     )
+            save_full_mask(target['circle_masks'], "arc_masks",
+                           "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset")
 
 
 
@@ -177,9 +173,36 @@ class LineDataset(BaseDataset):
         sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
         sm.set_array([])
 
-        # img_path = os.path.join(self.img_path, self.imgs[idx])
-        # print(f'boxes:{target["boxes"]}')
         img = image
+        # print(f'img:{img.shape}')
+
+        if show_type == 'arc_yuan_point_ellipse':
+            arc_ends = target['mask_ends']
+            arc_params = target['mask_params']
+
+            fig, ax = plt.subplots()
+            ax.imshow(img.permute(1, 2, 0))
+
+            for params in arc_params:
+                if torch.all(params == 0):
+                    continue
+                x, y, a, b, q = params
+                theta = np.radians(q)
+                phi = np.linspace(0, 2 * np.pi, 500)
+                x_ellipse = x + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
+                y_ellipse = y + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
+
+                plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2)
+
+            for point2 in arc_ends:
+                if torch.all(point2 == 0):
+                    continue
+                ends_np = point2.cpu().numpy()
+                ax.plot(ends_np[:, 0], ends_np[:, 1], 'ro', markersize=6, label='Arc Endpoints')
+
+            ax.legend()
+            plt.axis('image')  # 保持比例一致
+            plt.show()
 
         if show_type == 'circle_masks':
             boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
@@ -281,7 +304,6 @@ def arc_to_mask(arc7, shape, line_width=1):
     # 确保 phi1 -> phi2 是正向(可处理跨 2π 的情况)
     if torch.all(arc7 == 0):
         return torch.zeros(shape, dtype=torch.uint8)
-    print_params(arc7)
 
     xc, yc, a, b, theta, phi1, phi2 = arc7
     H, W = shape
@@ -559,14 +581,10 @@ def get_boxes_lines(objs, shape):
         mask_params = None
 
 
-
-
-
-
     return boxes, line_point_pairs, points, labels, mask_ends, mask_params
 
 
 if __name__ == '__main__':
-    path = r'/data/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
+    path = r'\\192.168.50.222/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
     dataset = LineDataset(dataset_path=path, dataset_type='train', augmentation=False, data_type='jpg')
-    dataset.show(9, show_type='circle_masks')
+    dataset.show(19, show_type='arc_yuan_point_ellipse')