Browse Source

添加3个loihead的forward方法

RenLiqiang 5 months ago
parent
commit
9cf8d177b5

+ 3 - 2
models/base/backbone_factory.py

@@ -279,13 +279,14 @@ def get_swin_transformer_fpn(type='t'):
     # print(f'out:{out}')
     # print(f'out:{out}')
     return  backbone_with_fpn,roi_pooler,anchor_generator
     return  backbone_with_fpn,roi_pooler,anchor_generator
 if __name__ == '__main__':
 if __name__ == '__main__':
-    backbone_with_fpn, roi_pooler, anchor_generator=get_swin_transformer_fpn(type='s')
+    backbone_with_fpn, roi_pooler, anchor_generator=get_swin_transformer_fpn(type='t')
     model=FasterRCNN(backbone=backbone_with_fpn,num_classes=3,box_roi_pool=roi_pooler,rpn_anchor_generator=anchor_generator)
     model=FasterRCNN(backbone=backbone_with_fpn,num_classes=3,box_roi_pool=roi_pooler,rpn_anchor_generator=anchor_generator)
     input=torch.randn(3,3,512,512,device='cuda')
     input=torch.randn(3,3,512,512,device='cuda')
     model.eval()
     model.eval()
     model.to('cuda')
     model.to('cuda')
     out=model(input)
     out=model(input)
-
+    out=backbone_with_fpn(input)
+    print(f'out:{out.shape}')
 
 
     # # maxvit = models.maxvit_t(pretrained=True)
     # # maxvit = models.maxvit_t(pretrained=True)
     # maxvit=MaxVitBackbone()
     # maxvit=MaxVitBackbone()

+ 7 - 1
models/base/high_reso_swin.py

@@ -944,7 +944,7 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T
     weights = Swin_V2_T_Weights.verify(weights)
     weights = Swin_V2_T_Weights.verify(weights)
 
 
     return _swin_transformer(
     return _swin_transformer(
-        patch_size=[1, 1],
+        patch_size=[2, 2],
         embed_dim=96,
         embed_dim=96,
         depths=[2, 2, 6, 2],
         depths=[2, 2, 6, 2],
         num_heads=[3, 6, 12, 24],
         num_heads=[3, 6, 12, 24],
@@ -1033,3 +1033,9 @@ def swin_v2_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = T
         downsample_layer=PatchMergingV2,
         downsample_layer=PatchMergingV2,
         **kwargs,
         **kwargs,
     )
     )
+
+if __name__ == '__main__':
+    input=torch.randn(3,3,512,512)
+    model=swin_v2_t(weights=None)
+    out=model(input)
+    print(f'out:{out.shape}')

+ 5 - 4
models/line_detect/heads/line_heads.py

@@ -17,7 +17,7 @@ class LineHeads(nn.Sequential):
 
 
 
 
 class LinePredictor(nn.Module):
 class LinePredictor(nn.Module):
-    def __init__(self, in_channels, out_channels=1 ):
+    def __init__(self, in_channels, out_channels=3 ):
         super().__init__()
         super().__init__()
         input_features = in_channels
         input_features = in_channels
         deconv_kernel = 4
         deconv_kernel = 4
@@ -37,6 +37,7 @@ class LinePredictor(nn.Module):
         print(f'before kps_score_lowres x:{x.shape}')
         print(f'before kps_score_lowres x:{x.shape}')
         x = self.kps_score_lowres(x)
         x = self.kps_score_lowres(x)
         print(f'kps_score_lowres x:{x.shape}')
         print(f'kps_score_lowres x:{x.shape}')
-        return torch.nn.functional.interpolate(
-            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
-        )
+        return x
+        # return torch.nn.functional.interpolate(
+        #     x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        # )

+ 42 - 0
models/line_detect/heads/point_heads.py

@@ -0,0 +1,42 @@
+import torch
+from torch import nn
+
+class PointHeads(nn.Sequential):
+    def __init__(self, in_channels, layers):
+        d = []
+        next_feature = in_channels
+        for out_channels in layers:
+            d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
+            d.append(nn.ReLU(inplace=True))
+            next_feature = out_channels
+        super().__init__(*d)
+        for m in self.children():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                nn.init.constant_(m.bias, 0)
+
+
+class PointPredictor(nn.Module):
+    def __init__(self, in_channels, out_channels=1 ):
+        super().__init__()
+        input_features = in_channels
+        deconv_kernel = 4
+        self.kps_score_lowres = nn.ConvTranspose2d(
+            input_features,
+            out_channels,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+        self.out_channels = out_channels
+
+    def forward(self, x):
+        # print(f'before kps_score_lowres x:{x.shape}')
+        x = self.kps_score_lowres(x)
+        # print(f'kps_score_lowres x:{x.shape}')
+        return torch.nn.functional.interpolate(
+            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        )

+ 3 - 3
models/line_detect/line_dataset.py

@@ -100,9 +100,9 @@ class LineDataset(BaseDataset):
 
 
         # target["lines"] = lines.to(torch.float32).view(-1,2,3)
         # target["lines"] = lines.to(torch.float32).view(-1,2,3)
 
 
-        print(f'')
+        # print(f'')
 
 
-        print(f'lines:{target["lines"].shape}')
+        # print(f'lines:{target["lines"].shape}')
         target["img_size"]=shape
         target["img_size"]=shape
 
 
         # validate_keypoints(lines, shape[0], shape[1])
         # validate_keypoints(lines, shape[0], shape[1])
@@ -161,7 +161,7 @@ def get_boxes_lines(objs,shape):
 
 
         # print(f"points:{obj['points']}")
         # print(f"points:{obj['points']}")
         label=obj['label']
         label=obj['label']
-        if label =='line':
+        if label =='line' or label=='dseam1':
             a,b=obj['points'][0],obj['points'][1]
             a,b=obj['points'][0],obj['points'][1]
 
 
             line_point_pairs.append(a)
             line_point_pairs.append(a)

+ 11 - 2
models/line_detect/line_detect.py

@@ -87,10 +87,19 @@ class LineDetect(BaseDetectionNet):
             box_batch_size_per_image=512,
             box_batch_size_per_image=512,
             box_positive_fraction=0.25,
             box_positive_fraction=0.25,
             bbox_reg_weights=None,
             bbox_reg_weights=None,
-            # keypoint parameters
+            # line parameters
             line_roi_pool=None,
             line_roi_pool=None,
             line_head=None,
             line_head=None,
             line_predictor=None,
             line_predictor=None,
+            # point parameters
+            point_roi_pool=None,
+            point_head=None,
+            point_predictor=None,
+
+            # arc parameters
+            arc_roi_pool=None,
+            arc_head=None,
+            arc_predictor=None,
             num_points=3,
             num_points=3,
             **kwargs,
             **kwargs,
     ):
     ):
@@ -161,7 +170,7 @@ class LineDetect(BaseDetectionNet):
 
 
         if line_predictor is None:
         if line_predictor is None:
         #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
         #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
-            line_predictor = LinePredictor(in_channels=128)
+            line_predictor = LinePredictor(in_channels=256)
 
 
 
 
         self.roi_heads.line_roi_pool = line_roi_pool
         self.roi_heads.line_roi_pool = line_roi_pool

+ 70 - 30
models/line_detect/loi_heads.py

@@ -795,6 +795,7 @@ class RoIHeads(nn.Module):
                 labels = None
                 labels = None
                 regression_targets = None
                 regression_targets = None
                 matched_idxs = None
                 matched_idxs = None
+        device=features['0'].device
 
 
         box_features = self.box_roi_pool(features, proposals, image_shapes)
         box_features = self.box_roi_pool(features, proposals, image_shapes)
         box_features = self.box_head(box_features)
         box_features = self.box_head(box_features)
@@ -893,31 +894,7 @@ class RoIHeads(nn.Module):
                 else:
                 else:
                     pos_matched_idxs = None
                     pos_matched_idxs = None
 
 
-            print(f'line_proposals:{len(line_proposals)}')
-
-            # cs_features= features['0']
-            print(f'features-0:{features['0'].shape}')
-            cs_features = self.channel_compress(features['0'])
-
-
-
-            filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
-            if len(filtered_proposals) > 0:
-                filtered_proposals_tensor=torch.cat(filtered_proposals)
-                print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
-
-            line_proposals_tensor=torch.cat(line_proposals)
-
-            print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
-
-            roi_features = features_align(cs_features, line_proposals, image_shapes)
-
-            if roi_features is not None:
-                print(f'line_features from align:{roi_features.shape}')
-
-            feature_logits = self.line_head(roi_features)
-            print(f'feature_logits from line_head:{feature_logits.shape}')
-
+            feature_logits = self.head_forward3(features, image_shapes, line_proposals)
 
 
             loss_line = None
             loss_line = None
             loss_line_iou =None
             loss_line_iou =None
@@ -951,11 +928,11 @@ class RoIHeads(nn.Module):
 
 
                 if  loss_line is None:
                 if  loss_line is None:
                     print(f'loss_line is None111')
                     print(f'loss_line is None111')
-                    loss_line = torch.tensor(0.0, device=cs_features.device)
+                    loss_line = torch.tensor(0.0, device=device)
 
 
                 if loss_line_iou is None:
                 if loss_line_iou is None:
                     print(f'loss_line_iou is None111')
                     print(f'loss_line_iou is None111')
-                    loss_line_iou = torch.tensor(0.0, device=cs_features.device)
+                    loss_line_iou = torch.tensor(0.0, device=device)
 
 
                 loss_line = {"loss_line": loss_line}
                 loss_line = {"loss_line": loss_line}
                 loss_line_iou = {'loss_line_iou': loss_line_iou}
                 loss_line_iou = {'loss_line_iou': loss_line_iou}
@@ -981,14 +958,13 @@ class RoIHeads(nn.Module):
                                                       img_size)
                                                       img_size)
 
 
 
 
-
                     if  loss_line is None:
                     if  loss_line is None:
                         print(f'loss_line is None')
                         print(f'loss_line is None')
-                        loss_line=torch.tensor(0.0,device=cs_features.device)
+                        loss_line=torch.tensor(0.0,device=device)
 
 
                     if  loss_line_iou is None:
                     if  loss_line_iou is None:
                         print(f'loss_line_iou is None')
                         print(f'loss_line_iou is None')
-                        loss_line_iou=torch.tensor(0.0,device=cs_features.device)
+                        loss_line_iou=torch.tensor(0.0,device=device)
 
 
 
 
                     loss_line = {"loss_line": loss_line}
                     loss_line = {"loss_line": loss_line}
@@ -1106,3 +1082,67 @@ class RoIHeads(nn.Module):
             losses.update(loss_keypoint)
             losses.update(loss_keypoint)
 
 
         return result, losses
         return result, losses
+
+    def head_forward1(self, features, image_shapes, line_proposals):
+        print(f'line_proposals:{len(line_proposals)}')
+        # cs_features= features['0']
+        print(f'features-0:{features['0'].shape}')
+        cs_features = self.channel_compress(features['0'])
+        filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
+        if len(filtered_proposals) > 0:
+            filtered_proposals_tensor = torch.cat(filtered_proposals)
+            print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
+        line_proposals_tensor = torch.cat(line_proposals)
+        print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
+        roi_features = features_align(cs_features, line_proposals, image_shapes)
+        if roi_features is not None:
+            print(f'line_features from align:{roi_features.shape}')
+        feature_logits = self.line_head(roi_features)
+        print(f'feature_logits from line_head:{feature_logits.shape}')
+        return feature_logits
+
+    def head_forward2(self, features, image_shapes, line_proposals):
+        print(f'line_proposals:{len(line_proposals)}')
+        # cs_features= features['0']
+        print(f'features-0:{features['0'].shape}')
+        # cs_features = self.channel_compress(features['0'])
+        cs_features=features['0']
+        filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
+
+        if len(filtered_proposals) > 0:
+            filtered_proposals_tensor = torch.cat(filtered_proposals)
+            print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
+            line_proposals=filtered_proposals
+        line_proposals_tensor = torch.cat(line_proposals)
+        print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
+
+        feature_logits = self.line_head(cs_features)
+        print(f'feature_logits from line_head:{feature_logits.shape}')
+
+        roi_features = features_align(cs_features, line_proposals, image_shapes)
+        if roi_features is not None:
+            print(f'roi_features from align:{roi_features.shape}')
+        return roi_features
+
+    def head_forward3(self, features, image_shapes, line_proposals):
+        print(f'line_proposals:{len(line_proposals)}')
+        # cs_features= features['0']
+        print(f'features-0:{features['0'].shape}')
+        # cs_features = self.channel_compress(features['0'])
+        cs_features=features['0']
+        filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
+
+        if len(filtered_proposals) > 0:
+            filtered_proposals_tensor = torch.cat(filtered_proposals)
+            print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
+            line_proposals=filtered_proposals
+        line_proposals_tensor = torch.cat(line_proposals)
+        print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
+
+        feature_logits = self.line_predictor(cs_features)
+        print(f'feature_logits from line_head:{feature_logits.shape}')
+
+        roi_features = features_align(cs_features, line_proposals, image_shapes)
+        if roi_features is not None:
+            print(f'roi_features from align:{roi_features.shape}')
+        return roi_features

+ 2 - 2
models/line_detect/train.yaml

@@ -1,7 +1,7 @@
 io:
 io:
   logdir: train_results
   logdir: train_results
-  datadir: /data/share/zjh/Dataset_correct_xanylabel
-#  datadir: \\192.168.50.222/share/rlq/datasets/Dataset_correct_xanylabel
+#  datadir: /data/share/zjh/Dataset_correct_xanylabel
+  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
   data_type: rgb
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 #  datadir: I:\datasets\wirenet_1000

+ 4 - 3
models/line_detect/train_demo.py

@@ -18,10 +18,11 @@ if __name__ == '__main__':
 
 
     # model=linedetect_resnet18_fpn()
     # model=linedetect_resnet18_fpn()
     # model=linedetect_newresnet18fpn(num_points=3)
     # model=linedetect_newresnet18fpn(num_points=3)
+    # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet101fpn(num_points=3)
     # model = linedetect_newresnet101fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
     # model=linedetect_maxvitfpn()
-    model=linedetect_high_maxvitfpn()
-    model.load_weights(r'/data/share/rlq/weights/250718maxvit_best_val.pth')
-    # model=linedetect_swin_transformer_fpn(type='t')
+    # model=linedetect_high_maxvitfpn()
+    # model.load_weights(r'/data/share/rlq/weights/250718maxvit_best_val.pth')
+    model=linedetect_swin_transformer_fpn(type='t')
     model.start_train(cfg='train.yaml')
     model.start_train(cfg='train.yaml')