Explorar el Código

train circle on 4080

admin hace 4 meses
padre
commit
6da2565b47
Se han modificado 2 ficheros con 20 adiciones y 10 borrados
  1. 19 8
      models/line_detect/line_detect.py
  2. 1 2
      models/line_detect/loi_heads.py

+ 19 - 8
models/line_detect/line_detect.py

@@ -440,11 +440,11 @@ def linedetect_newresnet101fpn(
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     if num_classes is None:
-        num_classes = 3
+        num_classes = 5
     if num_points is None:
         num_points = 3
 
-    size=768
+    size=512
     backbone =resnet101fpn(out_channels=256)
     featmap_names=['0', '1', '2', '3','4','pool']
     # print(f'featmap_names:{featmap_names}')
@@ -462,7 +462,12 @@ def linedetect_newresnet101fpn(
 
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
 
-    model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
+    model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler,
+                       detect_point=False,
+                       detect_line=False,
+                       detect_arc=False,
+                       detect_circle=True,
+                       **kwargs)
 
     return model
 
@@ -477,11 +482,11 @@ def linedetect_newresnet152fpn(
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     if num_classes is None:
-        num_classes = 4
+        num_classes = 5
     if num_points is None:
         num_points = 3
 
-    size=800
+    size=512
     backbone =resnet101fpn(out_channels=256)
     featmap_names=['0', '1', '2', '3','4','pool']
     # print(f'featmap_names:{featmap_names}')
@@ -499,7 +504,13 @@ def linedetect_newresnet152fpn(
 
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
 
-    model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
+    model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler,
+
+                       detect_point=False,
+                       detect_line=False,
+                       detect_arc=False,
+                       detect_circle=True,
+                       **kwargs)
 
     return model
 
@@ -565,7 +576,7 @@ def linedetect_high_maxvitfpn(
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     if num_classes is None:
-        num_classes = 3
+        num_classes = 5
     if num_points is None:
         num_points = 3
 
@@ -588,9 +599,9 @@ def linedetect_high_maxvitfpn(
 
     model = LineDetect(
         backbone=maxvitfpn,
+        num_classes=num_classes,
         min_size=size,
         max_size=size,
-        num_classes=3,  # COCO 数据集有 91 类
         rpn_anchor_generator=get_anchor_generator(maxvitfpn, test_input=test_input),
         box_roi_pool=roi_pooler
     )

+ 1 - 2
models/line_detect/loi_heads.py

@@ -1400,8 +1400,7 @@ class RoIHeads(nn.Module):
                         loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles,
                                                         circle_pos_matched_idxs)
 
-                        loss_circle_extra = compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles,
-                                                                         circle_pos_matched_idxs)
+                        loss_circle_extra = compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles,circle_pos_matched_idxs)
 
                     if loss_circle is None:
                         print(f'loss_circle is None111')