Browse Source

添加swintransformer_fpn测试代码

RenLiqiang 5 tháng trước cách đây
mục cha
commit
f0bc02d23d
1 tập tin đã thay đổi với 125 bổ sung26 xóa
  1. 125 26
      models/base/backbone_factory.py

+ 125 - 26
models/base/backbone_factory.py

@@ -1,6 +1,8 @@
 from collections import OrderedDict
 
+import torchvision
 from torchvision.models import maxvit_t
+from torchvision.models.detection.backbone_utils import BackboneWithFPN
 
 from libs.vision_libs import models
 from libs.vision_libs.models import mobilenet_v3_large, EfficientNet_V2_S_Weights, efficientnet_v2_s, \
@@ -16,7 +18,7 @@ from libs.vision_libs.ops import misc as misc_nn_ops, MultiScaleRoIAlign
 from torch import nn
 
 import torch
-from  libs.vision_libs.models.detection.backbone_utils import BackboneWithFPN
+# from  libs.vision_libs.models.detection.backbone_utils import BackboneWithFPN
 
 
 def get_resnet50_fpn():
@@ -184,36 +186,133 @@ class MaxVitBackbone(torch.nn.Module):
         return x
 
 
+from torchvision.models.feature_extraction import create_feature_extractor
 
-if __name__ == '__main__':
-    # maxvit = models.maxvit_t(pretrained=True)
-    maxvit=MaxVitBackbone()
-    # print(maxvit.named_children())
-
-    for i,layer in enumerate(maxvit.named_children()):
-        print(f'layer:{i}:{layer}')
 
-    in_channels_list = [64,64,128, 256, 512]
+if __name__ == '__main__':
+    class Trans(nn.Module):
+        def __init__(self):
+            super().__init__()
+        def forward(self,x):
+            x=x.permute(0, 3, 2, 1).contiguous()
+            return x
+
+
+    class SwinTransformer(nn.Module):
+        def __init__(self):
+            super().__init__()
+            # 加载 Swin Transformer v2 Tiny
+            swin = torchvision.models.swin_v2_t(weights=None)
+
+            # 保存需要提取的层
+            self.patch_embed = swin.features[0]  # 第0层 patch embedding
+            self.layer1 =nn.Sequential(swin.features[1],Trans())  # 第1层 stage1
+            self.layer2 =nn.Sequential(Trans(),swin.features[2]) # 第2层 downsample
+            self.layer3 =nn.Sequential(swin.features[3], Trans()) # 第3层 stage2
+            self.layer4 =nn.Sequential( Trans(),swin.features[4])  # 第4层 downsample
+            self.layer5 =nn.Sequential(swin.features[5], Trans())  # 第5层 stage3
+            self.layer6 =nn.Sequential(Trans(),swin.features[6]) # 第6层 downsample
+            self.layer7 =nn.Sequential(swin.features[7], Trans())  # 第7层 stage4
+
+        def forward(self, x):
+            outputs = {}
+
+            # Patch Embedding
+            x = self.patch_embed(x)  # [B, C, H, W] -> [B, H_, W_, C]
+
+            # Layer 1: stage1
+            x = self.layer1(x)
+
+            # if 'feat1' not in outputs:
+            #     feat = x.permute(0, 3, 1, 2).contiguous()  # NHWC -> NCHW
+            #     outputs['feat1'] = feat
+            print(f'x1:{x.shape}')
+            # Downsample 1
+            x = self.layer2(x)
+
+            # Layer 2: stage2
+            x = self.layer3(x)
+            # if 'feat2' not in outputs:
+            #     feat = x.permute(0, 3, 1, 2).contiguous()
+            #     outputs['feat2'] = feat
+            print(f'x2:{x.shape}')
+
+            # Downsample 2
+            x = self.layer4(x)
+
+            # Layer 3: stage3
+            x = self.layer5(x)
+            print(f'x3:{x.shape}')
+
+            # if 'feat3' not in outputs:
+            #     feat = x.permute(0, 3, 1, 2).contiguous()
+            #     outputs['feat3'] = feat
+
+            # Downsample 3
+            x = self.layer6(x)
+
+            # Layer 4: stage4
+            x = self.layer7(x)
+            # x = x.permute(0, 3, 2, 1).contiguous()
+
+            # if 'feat4' not in outputs:
+            #     feat = x.permute(0, 3, 1, 2).contiguous()
+            #     outputs['feat4'] = feat
+            print(f'x4:{x.shape}')
+
+            return x
+
+    backbone = SwinTransformer()
+    input=torch.randn(1,3,512,512)
+    out=backbone(input)
+    # print(f'out:{out.keys()}')
+    # for i,layer in enumerate(swin.features.named_children()):
+    #     print(f'layer:{i}:{layer}')
+    # out=swin(input)
+    # print(f'out shape:{out.shape}')
+    #
     backbone_with_fpn = BackboneWithFPN(
-        maxvit,
-        return_layers={'stem': '0','block0':'1','block1':'2','block2':'3','block3':'4'},  # 确保这些键对应到实际的层
-        in_channels_list=in_channels_list,
+        # swin.features,
+        backbone,
+        return_layers={'layer1': '0', 'layer3': '1', 'layer5': '2', 'layer7': '3'},
+        in_channels_list=[96, 192, 384, 768],
         out_channels=256
     )
-    model = FasterRCNN(
-        backbone=backbone_with_fpn,
-        num_classes=91,  # COCO 数据集有 91 类
-        # rpn_anchor_generator=anchor_generator,
-        # box_roi_pool=roi_pooler
-    )
 
-    test_input = torch.randn(1, 3, 896, 896)
+    out=backbone_with_fpn(input)
+
+    print(f'out:{out}')
 
-    with torch.no_grad():
-        output = backbone_with_fpn(test_input)
 
-    print("Output feature maps:")
-    for k, v in output.items():
-        print(f"{k}: {v.shape}")
-    model.eval()
-    output=model(test_input)
+
+    # # maxvit = models.maxvit_t(pretrained=True)
+    # maxvit=MaxVitBackbone()
+    # # print(maxvit.named_children())
+    #
+    # for i,layer in enumerate(maxvit.named_children()):
+    #     print(f'layer:{i}:{layer}')
+    #
+    # in_channels_list = [64,64,128, 256, 512]
+    # backbone_with_fpn = BackboneWithFPN(
+    #     maxvit,
+    #     return_layers={'stem': '0','block0':'1','block1':'2','block2':'3','block3':'4'},  # 确保这些键对应到实际的层
+    #     in_channels_list=in_channels_list,
+    #     out_channels=256
+    # )
+    # model = FasterRCNN(
+    #     backbone=backbone_with_fpn,
+    #     num_classes=91,  # COCO 数据集有 91 类
+    #     # rpn_anchor_generator=anchor_generator,
+    #     # box_roi_pool=roi_pooler
+    # )
+    #
+    # test_input = torch.randn(1, 3, 896, 896)
+    #
+    # with torch.no_grad():
+    #     output = backbone_with_fpn(test_input)
+    #
+    # print("Output feature maps:")
+    # for k, v in output.items():
+    #     print(f"{k}: {v.shape}")
+    # model.eval()
+    # output=model(test_input)