Procházet zdrojové kódy

添加swin_trans_former

RenLiqiang před 5 měsíci
rodič
revize
ad42480896

+ 3 - 3
libs/vision_libs/models/detection/roi_heads.py

@@ -830,9 +830,9 @@ class RoIHeads(nn.Module):
         # keep none checks in if conditional so torchscript will conditionally
         # compile each branch
         if (
-            self.line_roi_pool is not None
-            and self.line_head is not None
-            and self.line_predictor is not None
+            self.keypoint_roi_pool is not None
+            and self.keypoint_head is not None
+            and self.keypoint_predictor is not None
         ):
             keypoint_proposals = [p["boxes"] for p in result]
             if self.training:

+ 40 - 41
models/base/backbone_factory.py

@@ -18,7 +18,9 @@ from libs.vision_libs.ops import misc as misc_nn_ops, MultiScaleRoIAlign
 from torch import nn
 
 import torch
-# from  libs.vision_libs.models.detection.backbone_utils import BackboneWithFPN
+
+
+
 
 
 def get_resnet50_fpn():
@@ -188,8 +190,7 @@ class MaxVitBackbone(torch.nn.Module):
 
 from torchvision.models.feature_extraction import create_feature_extractor
 
-
-if __name__ == '__main__':
+def get_swin_transformer_fpn(type='t'):
     class Trans(nn.Module):
         def __init__(self):
             super().__init__()
@@ -199,10 +200,16 @@ if __name__ == '__main__':
 
 
     class SwinTransformer(nn.Module):
-        def __init__(self):
+        def __init__(self,type='t'):
             super().__init__()
-            # 加载 Swin Transformer v2 Tiny
             swin = torchvision.models.swin_v2_t(weights=None)
+            if type=='t':
+                # 加载 Swin Transformer v2 Tiny
+                swin = torchvision.models.swin_v2_t(weights=None)
+            if type=='s':
+                swin=torchvision.models.swin_v2_s(weights=None)
+            if type=='b':
+                swin=torchvision.models.swin_v2_b(weights=None)
 
             # 保存需要提取的层
             self.patch_embed = swin.features[0]  # 第0层 patch embedding
@@ -215,54 +222,22 @@ if __name__ == '__main__':
             self.layer7 =nn.Sequential(swin.features[7], Trans())  # 第7层 stage4
 
         def forward(self, x):
-            outputs = {}
 
-            # Patch Embedding
             x = self.patch_embed(x)  # [B, C, H, W] -> [B, H_, W_, C]
-
-            # Layer 1: stage1
             x = self.layer1(x)
-
-            # if 'feat1' not in outputs:
-            #     feat = x.permute(0, 3, 1, 2).contiguous()  # NHWC -> NCHW
-            #     outputs['feat1'] = feat
             print(f'x1:{x.shape}')
-            # Downsample 1
             x = self.layer2(x)
-
-            # Layer 2: stage2
             x = self.layer3(x)
-            # if 'feat2' not in outputs:
-            #     feat = x.permute(0, 3, 1, 2).contiguous()
-            #     outputs['feat2'] = feat
             print(f'x2:{x.shape}')
-
-            # Downsample 2
             x = self.layer4(x)
-
-            # Layer 3: stage3
             x = self.layer5(x)
             print(f'x3:{x.shape}')
-
-            # if 'feat3' not in outputs:
-            #     feat = x.permute(0, 3, 1, 2).contiguous()
-            #     outputs['feat3'] = feat
-
-            # Downsample 3
             x = self.layer6(x)
-
-            # Layer 4: stage4
             x = self.layer7(x)
-            # x = x.permute(0, 3, 2, 1).contiguous()
-
-            # if 'feat4' not in outputs:
-            #     feat = x.permute(0, 3, 1, 2).contiguous()
-            #     outputs['feat4'] = feat
             print(f'x4:{x.shape}')
-
             return x
 
-    backbone = SwinTransformer()
+    backbone = SwinTransformer(type=type)
     input=torch.randn(1,3,512,512)
     out=backbone(input)
     # print(f'out:{out.keys()}')
@@ -271,18 +246,42 @@ if __name__ == '__main__':
     # out=swin(input)
     # print(f'out shape:{out.shape}')
     #
+
+    channels_list = [96, 192, 384, 768]
+    if type=='t':
+        channels_list = [96, 192, 384, 768]
+    if type=='s':
+        channels_list = [96, 192, 384, 768]
+    if type=='b':
+        channels_list = [128, 256, 512, 1024]
     backbone_with_fpn = BackboneWithFPN(
         # swin.features,
         backbone,
         return_layers={'layer1': '0', 'layer3': '1', 'layer5': '2', 'layer7': '3'},
-        in_channels_list=[96, 192, 384, 768],
+        in_channels_list=channels_list,
         out_channels=256
     )
+    featmap_names = ['0', '1', '2', '3', 'pool']
+    # print(f'featmap_names:{featmap_names}')
+    roi_pooler = MultiScaleRoIAlign(
+        featmap_names=featmap_names,
+        output_size=7,
+        sampling_ratio=2
+    )
 
-    out=backbone_with_fpn(input)
+    # out=backbone_with_fpn(input)
+    anchor_generator = get_anchor_generator(backbone_with_fpn, test_input=input)
 
-    print(f'out:{out}')
 
+    # print(f'out:{out}')
+    return  backbone_with_fpn,roi_pooler,anchor_generator
+if __name__ == '__main__':
+    backbone_with_fpn, roi_pooler, anchor_generator=get_swin_transformer_fpn(type='s')
+    model=FasterRCNN(backbone=backbone_with_fpn,num_classes=3,box_roi_pool=roi_pooler,rpn_anchor_generator=anchor_generator)
+    input=torch.randn(3,3,512,512,device='cuda')
+    model.eval()
+    model.to('cuda')
+    out=model(input)
 
 
     # # maxvit = models.maxvit_t(pretrained=True)

+ 28 - 1
models/line_detect/line_detect.py

@@ -27,7 +27,8 @@ from .loi_heads import RoIHeads
 
 from .trainer import Trainer
 from ..base import backbone_factory
-from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator, get_maxvit_fpn, MaxVitBackbone
+from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator, get_maxvit_fpn, MaxVitBackbone, \
+    get_swin_transformer_fpn
 # from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
 from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
@@ -475,6 +476,32 @@ def linedetect_high_maxvitfpn(
     )
     return model
 
+def linedetect_swin_transformer_fpn(
+        *,
+        num_classes: Optional[int] = None,
+        num_points:Optional[int] = None,
+        type='t',
+        **kwargs: Any,
+) -> LineDetect:
+    # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
+    # weights_backbone = ResNet50_Weights.verify(weights_backbone)
+    if num_classes is None:
+        num_classes = 3
+    if num_points is None:
+        num_points = 3
+    size=512
+    backbone_with_fpn, roi_pooler, anchor_generator=get_swin_transformer_fpn(type=type)
+    # test_input = torch.randn(1, 3,size,size)
+    model = LineDetect(
+        backbone=backbone_with_fpn,
+        min_size=size,
+        max_size=size,
+        num_classes=3,  # COCO 数据集有 91 类
+        rpn_anchor_generator=anchor_generator,
+        box_roi_pool=roi_pooler
+    )
+    return model
+
 def linedetect_resnet18_fpn(
         *,
         num_classes: Optional[int] = None,

+ 3 - 2
models/line_detect/train_demo.py

@@ -1,7 +1,7 @@
 import torch
 
 from models.line_detect.line_detect import linedetect_newresnet18fpn, linedetect_resnet50_fpn, linedetect_resnet18_fpn, \
-    linedetect_newresnet50fpn, linedetect_maxvitfpn, linedetect_high_maxvitfpn
+    linedetect_newresnet50fpn, linedetect_maxvitfpn, linedetect_high_maxvitfpn, linedetect_swin_transformer_fpn
 
 from models.line_net.trainer import Trainer
 
@@ -20,5 +20,6 @@ if __name__ == '__main__':
     # model = linedetect_newresnet50fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
-    model=linedetect_high_maxvitfpn()
+    # model=linedetect_high_maxvitfpn()
+    model=linedetect_swin_transformer_fpn(type='s')
     model.start_train(cfg='train.yaml')