Ver Fonte

add predictor(transpose conv to upsample),debug maxvitfpn output point

lstrlq há 5 meses atrás
pai
commit
05e2ac222b

+ 28 - 4
models/line_detect/line_detect.py

@@ -154,9 +154,9 @@ class LineDetect(BaseDetectionNet):
             keypoint_layers = tuple(num_points for _ in range(8))
             line_head = LineHeads(8, keypoint_layers)
 
-        # if line_predictor is None:
+        if line_predictor is None:
         #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
-        #     line_predictor = LinePredictor(keypoint_dim_reduced)
+            line_predictor = LinePredictor(in_channels=128)
 
 
         self.roi_heads.line_roi_pool = line_roi_pool
@@ -286,7 +286,31 @@ class LineHeads(nn.Sequential):
                 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                 nn.init.constant_(m.bias, 0)
 
+class LinePredictor(nn.Module):
+    def __init__(self, in_channels, out_channels=3 ):
+        super().__init__()
+        input_features = in_channels
+        deconv_kernel = 4
+        self.kps_score_lowres = nn.ConvTranspose2d(
+            input_features,
+            out_channels,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+        self.out_channels = out_channels
 
+    def forward(self, x):
+        print(f'before kps_score_lowres x:{x.shape}')
+        x = self.kps_score_lowres(x)
+        print(f'kps_score_lowres x:{x.shape}')
+        return  x
+        # return torch.nn.functional.interpolate(
+        #     x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        # )
 
 
 def linedetect_newresnet18fpn(
@@ -413,7 +437,7 @@ def linedetect_maxvitfpn(
     if num_points is None:
         num_points = 3
 
-    size=224*4
+    size=224*2
 
     maxvit = MaxVitBackbone(input_size=(size,size))
     # print(maxvit.named_children())
@@ -433,7 +457,7 @@ def linedetect_maxvitfpn(
         return_layers={'stem': '0', 'block0': '1', 'block1': '2', 'block2': '3', 'block3': '4'},
         # 确保这些键对应到实际的层
         in_channels_list=in_channels_list,
-        out_channels=64
+        out_channels=128
     )
     test_input = torch.randn(1, 3,size,size)
 

+ 40 - 22
models/line_detect/loi_heads.py

@@ -1501,8 +1501,9 @@ class RoIHeads(nn.Module):
 
             # print(f'line_features from line_roi_pool:{line_features.shape}')
             #(b,256,512,512)
-            cs_features = self.channel_compress(features['0'])
+            # cs_features = self.channel_compress(features['0'])
             #(b.8,512,512)
+            cs_features= features['0']
 
 
             all_proposals=line_proposals+point_proposals
@@ -1530,25 +1531,42 @@ class RoIHeads(nn.Module):
 
             # line_features = lines_features_align(line_features, filtered_proposals, image_shapes)
 
-            point_features =features_align(cs_features, point_proposals, image_shapes)
+            line_features=None
+            # line_features = features_align(cs_features, line_proposals, image_shapes)
+            # if line_features is not None:
+            #     print(f'line_features:{line_features.shape}')
 
 
-            line_features = features_align(cs_features, line_proposals, image_shapes)
 
+            # if line_features is not None and point_features is not None:
+            #     combine_features = torch.cat((point_features, line_features), dim=0)
+            # elif line_features  is not None:
+            #     combine_features =line_features
+            # elif point_features is not None:
+            #     combine_features =point_features
 
+            # combine_features = point_features
+            # print(f'line_features from features_align:{combine_features.shape}')
 
+            # combine_features = self.line_head(cs_features)
 
 
 
-            print(f'line_features from features_align:{cs_features.shape}')
+            # if point_features is not None:
+            #     print(f'point_features:{point_features.shape}')
 
-            cs_features = self.line_head(cs_features)
             #(N,1,512,512)
-            print(f'line_features from line_head:{cs_features.shape}')
-            # line_logits = self.line_predictor(line_features)
+            # print(f'combine_features from line_head:{combine_features.shape}')
 
-            line_logits = cs_features
-            print(f'line_logits:{line_logits.shape}')
+            combine_features = self.line_predictor(cs_features )
+            print(f'combine_features from line_predictor:{combine_features.shape}')
+
+            point_features = features_align(combine_features, point_proposals, image_shapes)
+            print(f'point_features from  features_align:{point_features.shape}')
+            combine_features=point_features
+
+            # line_logits = combine_features
+            # print(f'line_logits:{line_logits.shape}')
 
             loss_line = {}
             loss_line_iou = {}
@@ -1574,13 +1592,13 @@ class RoIHeads(nn.Module):
                 print(f'gt_points_tensor:{gt_points_tensor.shape}')
                 if gt_lines_tensor.shape[0]>0  and line_features is not None:
                     loss_line = lines_point_pair_loss(
-                        line_features, line_proposals, gt_lines, line_pos_matched_idxs
+                        combine_features, line_proposals, gt_lines, line_pos_matched_idxs
                     )
-                    loss_line_iou = line_iou_loss(line_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
+                    loss_line_iou = line_iou_loss(combine_features, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
 
                 if gt_points_tensor.shape[0]>0 and point_features is not None:
                     loss_point = compute_point_loss(
-                        point_features, point_proposals, gt_points, point_pos_matched_idxs
+                        combine_features, point_proposals, gt_points, point_pos_matched_idxs
                     )
 
                 if not loss_line:
@@ -1607,14 +1625,14 @@ class RoIHeads(nn.Module):
 
                     if gt_lines_tensor.shape[0] > 0 and line_features is not None:
                         loss_line = lines_point_pair_loss(
-                            line_features, line_proposals, gt_lines, line_pos_matched_idxs
+                            combine_features, line_proposals, gt_lines, line_pos_matched_idxs
                         )
-                        loss_line_iou = line_iou_loss(line_features, line_proposals, gt_lines, line_pos_matched_idxs,
+                        loss_line_iou = line_iou_loss(combine_features, line_proposals, gt_lines, line_pos_matched_idxs,
                                                       img_size)
 
                     if gt_points_tensor.shape[0] > 0 and point_features is not None:
                         loss_point = compute_point_loss(
-                            point_features, point_proposals, gt_points, point_pos_matched_idxs
+                            combine_features, point_proposals, gt_points, point_pos_matched_idxs
                         )
 
                     if not loss_line :
@@ -1633,18 +1651,18 @@ class RoIHeads(nn.Module):
 
 
                 else:
-                    if line_logits is None or line_proposals is None:
+                    if combine_features is None or line_proposals is None:
                         raise ValueError(
                             "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
                         )
 
-                    if line_features is not None:
-                        lines_probs, lines_scores = line_inference(line_features,line_proposals)
-                        for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
-                            r["lines"] = keypoint_prob
-                            r["liness_scores"] = kps
+                    # if line_features is not None:
+                    #     lines_probs, lines_scores = line_inference(combine_features,line_proposals)
+                    #     for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
+                    #         r["lines"] = keypoint_prob
+                    #         r["liness_scores"] = kps
                     if point_features is not None:
-                        point_probs, points_scores=point_inference(point_features, point_proposals,)
+                        point_probs, points_scores=point_inference(combine_features, point_proposals,)
                         for  points, ps, r in zip(point_probs,points_scores, result):
                             print(f'points_prob :{points.shape}')
 

+ 3 - 3
models/line_detect/train_demo.py

@@ -9,7 +9,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
 
     # model = LineNet('line_net.yaml')
-    # model=linenet_resnet50_fpn()
+    # model=linedetect_resnet50_fpn()
     # model = linedetect_resnet50_fpn()
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
     # model=linenet_newresnet50fpn()
@@ -19,6 +19,6 @@ if __name__ == '__main__':
     # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet50fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
-    # model=linedetect_maxvitfpn()
-    model=linedetect_high_maxvitfpn()
+    model=linedetect_maxvitfpn()
+    # model=linedetect_high_maxvitfpn()
     model.start_train(cfg='train.yaml')