Browse Source

单通道提取4个点,有较好效果

admin 4 months ago
parent
commit
2182f9c449

+ 21 - 3
models/line_detect/line_dataset.py

@@ -176,6 +176,24 @@ class LineDataset(BaseDataset):
     def show_img(self, img_path):
         pass
 
+
+def sort_points_clockwise(points):
+    points = np.array(points)
+
+    top_left_idx = np.lexsort((points[:, 0], points[:, 1]))[0]
+    reference_point = points[top_left_idx]
+
+    def angle_to_reference(point):
+        return np.arctan2(point[1] - reference_point[1], point[0] - reference_point[0])
+
+    angles = np.apply_along_axis(angle_to_reference, 1, points)
+
+    angles[angles < 0] += 2 * np.pi
+
+    sorted_indices = np.argsort(angles)
+    sorted_points = points[sorted_indices]
+
+    return sorted_points.tolist()
 def get_boxes_lines(objs,shape):
     boxes = []
     labels=[]
@@ -185,7 +203,6 @@ def get_boxes_lines(objs,shape):
     line_mask=[]
     circle_4points=[]
 
-
     for obj in objs:
         # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
 
@@ -237,7 +254,8 @@ def get_boxes_lines(objs,shape):
 
         elif label == 'circle' :
             # print(f'len circle_4points: {len(obj['points'])}')
-            circle_4points.append(obj['points'])
+            points=sort_points_clockwise(obj['points'])
+            circle_4points.append(points)
 
             xmin = max(obj['xmin'] - 6, 0)
 
@@ -251,7 +269,7 @@ def get_boxes_lines(objs,shape):
 
             labels.append(torch.tensor(4))
 
-    boxes=torch.tensor(boxes)
+    boxes=torch.tensor(boxes,dtype=torch.float32)
     print(f'boxes:{boxes.shape}')
     labels=torch.tensor(labels)
     if len(points)==0:

+ 9 - 7
models/line_detect/line_detect.py

@@ -24,6 +24,7 @@ from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extract
 from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
 from .heads.arc_heads import ArcHeads, ArcPredictor
 from .heads.circle_heads import CircleHeads, CirclePredictor
+from .heads.decoder import FPNDecoder
 from .heads.line_heads import LinePredictor
 from .heads.point_heads import PointHeads, PointPredictor
 from .loi_heads import RoIHeads
@@ -202,7 +203,7 @@ class LineDetect(BaseDetectionNet):
         if detect_arc and arc_predictor is None:
             layers = tuple(num_points for _ in range(8))
             # arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
-            arc_predictor=ArcUnet(Bottleneck)
+            arc_predictor=FPNDecoder(Bottleneck)
 
         if detect_circle and circle_head is None:
             layers = tuple(num_points for _ in range(8))
@@ -210,7 +211,7 @@ class LineDetect(BaseDetectionNet):
         if detect_circle and circle_predictor is None:
             layers = tuple(num_points for _ in range(8))
             # arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
-            circle_predictor = CirclePredictor(in_channels=256)
+            circle_predictor = CirclePredictor(in_channels=256,out_channels=4)
 
 
 
@@ -400,7 +401,7 @@ def linedetect_newresnet50fpn(
     if num_points is None:
         num_points = 4
 
-    size=768
+    size=512
     backbone =resnet50fpn(out_channels=256)
     featmap_names=['0', '1', '2', '3','4','pool']
     # print(f'featmap_names:{featmap_names}')
@@ -511,12 +512,12 @@ def linedetect_maxvitfpn(
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     if num_classes is None:
-        num_classes = 4
+        num_classes = 5
     if num_points is None:
         num_points = 3
 
 
-    size=224*3
+    size=224*2
 
 
     maxvit = MaxVitBackbone(input_size=(size,size))
@@ -537,7 +538,7 @@ def linedetect_maxvitfpn(
         return_layers={'stem': '0', 'block0': '1', 'block1': '2', 'block2': '3', 'block3': '4'},
         # 确保这些键对应到实际的层
         in_channels_list=in_channels_list,
-        out_channels=128
+        out_channels=256
     )
     test_input = torch.randn(1, 3,size,size)
 
@@ -550,7 +551,8 @@ def linedetect_maxvitfpn(
         box_roi_pool=roi_pooler,
         detect_line=False,
         detect_point=False,
-        detect_arc=True,
+        detect_arc=False,
+        detect_circle=True,
     )
     return model
 

+ 2 - 2
models/line_detect/train.yaml

@@ -22,8 +22,8 @@ train_params:
   num_workers: 8
   batch_size: 2
   max_epoch: 8000000
-#  augmentation: True
-  augmentation: False
+  augmentation: True
+#  augmentation: False
   optim:
     name: Adam
     lr: 4.0e-4