Ver Fonte

debug arc

RenLiqiang há 4 meses atrás
pai
commit
b31028059c

+ 5 - 4
models/line_detect/heads/head_losses.py

@@ -751,16 +751,18 @@ def line_inference(x, line_boxes):
 
     return lines_probs, lines_scores
 
-def arc_inference(x, point_boxes):
+def arc_inference(x, arc_boxes):
     # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
 
     points_probs = []
     points_scores = []
 
-    boxes_per_image = [box.size(0) for box in point_boxes]
+    print(f'arc_boxes:{len(arc_boxes)}')
+
+    boxes_per_image = [box.size(0) for box in arc_boxes]
     x2 = x.split(boxes_per_image, dim=0)
 
-    for xx, bb in zip(x2, point_boxes):
+    for xx, bb in zip(x2, arc_boxes):
         point_prob,point_scores = heatmaps_to_arc(xx, bb)
 
         points_probs.append(point_prob.unsqueeze(1))
@@ -768,7 +770,6 @@ def arc_inference(x, point_boxes):
 
     return points_probs,points_scores
 
-import torch.nn.functional as F
 
 import torch.nn.functional as F
 

+ 2 - 2
models/line_detect/line_dataset.py

@@ -234,8 +234,8 @@ def get_boxes_lines(objs,shape):
     if len(points)==0:
         points=None
     else:
-        points=torch.tensor(points)
-    # print(f'read labels:{labels}')
+        points=torch.tensor(points,dtype=torch.float32)
+    print(f'read labels:{labels}')
     # print(f'read points:{points}')
     if len(line_point_pairs)==0:
         line_point_pairs=None

+ 10 - 3
models/line_detect/line_detect.py

@@ -331,11 +331,11 @@ def linedetect_newresnet18fpn(
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     if num_classes is None:
-        num_classes = 3
+        num_classes = 4
     if num_points is None:
         num_points = 3
 
-    size=1024
+    size=512
     backbone =resnet18fpn()
     featmap_names=['0', '1', '2', '3','4','pool']
     # print(f'featmap_names:{featmap_names}')
@@ -353,7 +353,14 @@ def linedetect_newresnet18fpn(
 
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
 
-    model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
+    model = LineDetect(backbone,
+                       num_classes,min_size=size,max_size=size, num_points=num_points,
+                       rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler,
+                       detect_point=False,
+                       detect_line=False,
+                       detect_arc=True,
+
+                       **kwargs)
 
     return model
 

+ 2 - 2
models/line_detect/train.yaml

@@ -1,13 +1,13 @@
 io:
   logdir: train_results
 
-  datadir: /data/zyh/py_ws/code/a_dataset
+#  datadir: /data/zyh/py_ws/code/a_dataset
 #  datadir: /data/share/zjh/Dataset_correct_xanylabel_tiff
 
 #  datadir: /data/share/rlq/datasets/250718caisegangban
 
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
-#  datadir: \\192.168.50.222/share/zyh/arc/a_dataset
+  datadir: \\192.168.50.222/share/zyh/arc/a_dataset
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
   data_type: rgb

+ 2 - 2
models/line_detect/train_demo.py

@@ -17,10 +17,10 @@ if __name__ == '__main__':
     # model = lineDetect_resnet18_fpn()
 
     # model=linedetect_resnet18_fpn()
-    # model=linedetect_newresnet18fpn(num_points=3)
+    model=linedetect_newresnet18fpn(num_points=3)
     # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet101fpn(num_points=3)
-    model = linedetect_newresnet152fpn(num_points=3)
+    # model = linedetect_newresnet152fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
     # model=linedetect_high_maxvitfpn()