Browse Source

change circle_predictor to decoder

admin 4 months ago
parent
commit
f3b8063559

+ 3 - 1
models/line_detect/line_detect.py

@@ -194,10 +194,12 @@ class LineDetect(BaseDetectionNet):
         if detect_circle and circle_head is None:
             layers = tuple(num_points for _ in range(8))
             circle_head = CircleHeads(8, layers)
+
         if detect_circle and circle_predictor is None:
             layers = tuple(num_points for _ in range(8))
             # arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
-            circle_predictor = CirclePredictor(in_channels=256,out_channels=4)
+            # circle_predictor = CirclePredictor(in_channels=256,out_channels=4)
+            circle_predictor=FPNDecoder(Bottleneck)
 
 
 

+ 2 - 1
models/line_detect/loi_heads.py

@@ -1621,7 +1621,8 @@ class RoIHeads(nn.Module):
         # cs_features= features['0']
         # print(f'features-0:{features['0'].shape}')
         # cs_features = self.channel_compress(features['0'])
-        cs_features=features['0']
+        # cs_features=features['0']
+        cs_features = features
         # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
         #
         # if len(filtered_proposals) > 0:

+ 2 - 2
models/line_detect/train_demo.py

@@ -18,12 +18,12 @@ if __name__ == '__main__':
 
     # model=linedetect_resnet18_fpn()
     # model=linedetect_newresnet18fpn(num_points=4)
-    # model=linedetect_newresnet50fpn(num_points=4)
+    model=linedetect_newresnet50fpn(num_points=4)
     # model = linedetect_newresnet101fpn(num_points=3)
     # model = linedetect_newresnet152fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
-    model=linedetect_efficientnet(name='efficientnet_v2_l')
+    # model=linedetect_efficientnet(name='efficientnet_v2_l')
     # model=linedetect_high_maxvitfpn()
 
     # model=linedetect_swin_transformer_fpn(type='t')