فهرست منبع

show_type == 'arc_yuan_point_ellipse'

ljy 1 ماه پیش
والد
کامیت
a7f8c96ce2
1فایلهای تغییر یافته به همراه55 افزوده شده و 32 حذف شده
  1. 55 32
      models/line_detect/line_dataset.py

+ 55 - 32
models/line_detect/line_dataset.py

@@ -118,30 +118,30 @@ class LineDataset(BaseDataset):
             target['mask_params'] = arc_params
 
 
-            arc_angles = compute_arc_angles(arc_ends, arc_params)
-            # print_params(arc_angles)
-            arc_masks = []
-
-
-
-            for i in range(len(arc_params)):
-                arc_param_i = arc_params[i].view(-1)  # shape (5,)
-                arc_angle_i = arc_angles[i].view(-1)  # shape (2,)
-                arc7 = torch.cat([arc_param_i, arc_angle_i], dim=0)  # shape (7,)
-
-
-                # print_params(arc7)
-                mask = arc_to_mask(arc7, shape, line_width=1)
-
-                arc_masks.append(mask)
-                # arc7=arc_params[i] + arc_angles[i].tolist()
-                # arc_masks.append(arc_to_mask(arc7, shape, line_width=1))
-
-            # print(f'circle_masks:{torch.stack(arc_masks, dim=0).shape}')
-
-            target['circle_masks'] = torch.stack(arc_masks, dim=0)
-            save_full_mask(target['circle_masks'], "arc_masks",
-                           "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset")
+            # arc_angles = compute_arc_angles(arc_ends, arc_params)
+            # # print_params(arc_angles)
+            # arc_masks = []
+            #
+            #
+            #
+            # for i in range(len(arc_params)):
+            #     arc_param_i = arc_params[i].view(-1)  # shape (5,)
+            #     arc_angle_i = arc_angles[i].view(-1)  # shape (2,)
+            #     arc7 = torch.cat([arc_param_i, arc_angle_i], dim=0)  # shape (7,)
+            #
+            #
+            #     # print_params(arc7)
+            #     mask = arc_to_mask(arc7, shape, line_width=1)
+            #
+            #     arc_masks.append(mask)
+            #     # arc7=arc_params[i] + arc_angles[i].tolist()
+            #     # arc_masks.append(arc_to_mask(arc7, shape, line_width=1))
+            #
+            # # print(f'circle_masks:{torch.stack(arc_masks, dim=0).shape}')
+            #
+            # target['circle_masks'] = torch.stack(arc_masks, dim=0)
+            # save_full_mask(target['circle_masks'], "arc_masks",
+            #                "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset")
 
 
 
@@ -173,9 +173,36 @@ class LineDataset(BaseDataset):
         sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
         sm.set_array([])
 
-        # img_path = os.path.join(self.img_path, self.imgs[idx])
-        # print(f'boxes:{target["boxes"]}')
         img = image
+        # print(f'img:{img.shape}')
+
+        if show_type == 'arc_yuan_point_ellipse':
+            arc_ends = target['mask_ends']
+            arc_params = target['mask_params']
+
+            fig, ax = plt.subplots()
+            ax.imshow(img.permute(1, 2, 0))
+
+            for params in arc_params:
+                if torch.all(params == 0):
+                    continue
+                x, y, a, b, q = params
+                theta = np.radians(q)
+                phi = np.linspace(0, 2 * np.pi, 500)
+                x_ellipse = x + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
+                y_ellipse = y + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
+
+                plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2)
+
+            for point2 in arc_ends:
+                if torch.all(point2 == 0):
+                    continue
+                ends_np = point2.cpu().numpy()
+                ax.plot(ends_np[:, 0], ends_np[:, 1], 'ro', markersize=6, label='Arc Endpoints')
+
+            ax.legend()
+            plt.axis('image')  # 保持比例一致
+            plt.show()
 
         if show_type == 'circle_masks':
             boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
@@ -554,14 +581,10 @@ def get_boxes_lines(objs, shape):
         mask_params = None
 
 
-
-
-
-
     return boxes, line_point_pairs, points, labels, mask_ends, mask_params
 
 
 if __name__ == '__main__':
-    path = r'/data/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
+    path = r'\\192.168.50.222/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
     dataset = LineDataset(dataset_path=path, dataset_type='train', augmentation=False, data_type='jpg')
-    dataset.show(9, show_type='circle_masks')
+    dataset.show(19, show_type='arc_yuan_point_ellipse')