Forráskód Böngészése

修复features_align函数参数传递错误导致cuda报错bug

lstrlq 5 hónapja
szülő
commit
601abc124b

+ 11 - 9
models/line_detect/line_detect.py

@@ -102,7 +102,7 @@ class LineDetect(BaseDetectionNet):
             arc_head=None,
             arc_predictor=None,
             num_points=3,
-            detect_point=True,
+            detect_point=False,
             detect_line=True,
             detect_arc=False,
             **kwargs,
@@ -172,19 +172,19 @@ class LineDetect(BaseDetectionNet):
 
 
 
-        if line_head is None:
+        if line_head is None and detect_line:
             layers = tuple(num_points for _ in range(8))
             line_head = LineHeads(8, layers)
 
-        if line_predictor is None:
+        if line_predictor is None and detect_line:
         #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
-            line_predictor = LinePredictor(in_channels=256)
+            line_predictor = LinePredictor(in_channels=128)
 
-        if point_head is None:
+        if point_head is None and detect_point:
             layers = tuple(num_points for _ in range(8))
             point_head = PointHeads(8, layers)
 
-        if point_predictor is None:
+        if point_predictor is None and detect_point:
         #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
             point_predictor = PointPredictor(in_channels=128)
 
@@ -462,7 +462,9 @@ def linedetect_maxvitfpn(
         max_size=size,
         num_classes=3,  # COCO 数据集有 91 类
         rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input),
-        box_roi_pool=roi_pooler
+        box_roi_pool=roi_pooler,
+        detect_line=False,
+        detect_point=True,
     )
     return model
 
@@ -529,8 +531,8 @@ def linedetect_swin_transformer_fpn(
         num_classes=3,  # COCO 数据集有 91 类
         rpn_anchor_generator=anchor_generator,
         box_roi_pool=roi_pooler,
-        detect_line=False,
-        detect_point=True,
+        detect_line=True,
+        detect_point=False,
     )
     return model
 

+ 22 - 22
models/line_detect/loi_heads.py

@@ -586,7 +586,7 @@ class RoIHeads(nn.Module):
         self.detect_arc =detect_arc
 
         self.channel_compress = nn.Sequential(
-            nn.Conv2d(256, 8, kernel_size=1),
+            nn.Conv2d(128, 8, kernel_size=1),
             nn.BatchNorm2d(8),
             nn.ReLU(inplace=True)
         )
@@ -923,7 +923,7 @@ class RoIHeads(nn.Module):
                 else:
                     pos_matched_idxs = None
 
-            feature_logits = self.line_forward3(features, image_shapes, line_proposals)
+            feature_logits = self.line_forward1(features, image_shapes, line_proposals)
 
             loss_line = None
             loss_line_iou =None
@@ -1109,7 +1109,7 @@ class RoIHeads(nn.Module):
                         print(f'start to compute point_loss')
 
                         loss_point = compute_point_loss(feature_logits, point_proposals, gt_points,
-                                                        point_pos_matched_idxs, img_size)
+                                                        point_pos_matched_idxs)
 
                     if loss_point is None:
                         print(f'loss_point is None111')
@@ -1262,7 +1262,7 @@ class RoIHeads(nn.Module):
         feature_logits = self.line_head(cs_features)
         print(f'feature_logits from line_head:{feature_logits.shape}')
 
-        roi_features = features_align(cs_features, line_proposals, image_shapes)
+        roi_features = features_align(feature_logits, line_proposals, image_shapes)
         if roi_features is not None:
             print(f'roi_features from align:{roi_features.shape}')
         return roi_features
@@ -1273,19 +1273,19 @@ class RoIHeads(nn.Module):
         print(f'features-0:{features['0'].shape}')
         # cs_features = self.channel_compress(features['0'])
         cs_features=features['0']
-        filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
-
-        if len(filtered_proposals) > 0:
-            filtered_proposals_tensor = torch.cat(filtered_proposals)
-            print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
-            line_proposals=filtered_proposals
-        line_proposals_tensor = torch.cat(line_proposals)
-        print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
+        # filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
+        #
+        # if len(filtered_proposals) > 0:
+        #     filtered_proposals_tensor = torch.cat(filtered_proposals)
+        #     print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
+        #     line_proposals=filtered_proposals
+        # line_proposals_tensor = torch.cat(line_proposals)
+        # print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
 
         feature_logits = self.line_predictor(cs_features)
         print(f'feature_logits from line_head:{feature_logits.shape}')
 
-        roi_features = features_align(cs_features, line_proposals, image_shapes)
+        roi_features = features_align(feature_logits, line_proposals, image_shapes)
         if roi_features is not None:
             print(f'roi_features from align:{roi_features.shape}')
         return roi_features
@@ -1296,19 +1296,19 @@ class RoIHeads(nn.Module):
         print(f'features-0:{features['0'].shape}')
         # cs_features = self.channel_compress(features['0'])
         cs_features=features['0']
-        filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
-
-        if len(filtered_proposals) > 0:
-            filtered_proposals_tensor = torch.cat(filtered_proposals)
-            print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
-            proposals=filtered_proposals
-        point_proposals_tensor = torch.cat(proposals)
-        print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
+        # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
+        #
+        # if len(filtered_proposals) > 0:
+        #     filtered_proposals_tensor = torch.cat(filtered_proposals)
+        #     print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
+        #     proposals=filtered_proposals
+        # point_proposals_tensor = torch.cat(proposals)
+        # print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
 
         feature_logits = self.point_predictor(cs_features)
         print(f'feature_logits from line_head:{feature_logits.shape}')
 
-        roi_features = features_align(cs_features, proposals, image_shapes)
+        roi_features = features_align(feature_logits, proposals, image_shapes)
         if roi_features is not None:
             print(f'roi_features from align:{roi_features.shape}')
         return roi_features

+ 5 - 4
models/line_detect/train.yaml

@@ -1,8 +1,9 @@
 io:
   logdir: train_results
-#  datadir: /data/share/zjh/Dataset_correct_xanylabel
+  datadir: /data/share/rlq/datasets/250718caisegangban
+#  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
-  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
+#  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
@@ -15,8 +16,8 @@ train_params:
   num_workers: 8
   batch_size: 2
   max_epoch: 8000000
-  augmentation: True
-#  augmentation: False
+#  augmentation: True
+  augmentation: False
   optim:
     name: Adam
     lr: 4.0e-4