浏览代码

add mavit(unfinished)

RenLiqiang 7 月之前
父节点
当前提交
7b8fb177d4
共有 1 个文件被更改,包括 44 次插入47 次删除
  1. 44 47
      models/base/backbone_factory.py

+ 44 - 47
models/base/backbone_factory.py

@@ -125,62 +125,59 @@ def get_anchor_generator(backbone, test_input):
 
     return AnchorGenerator(sizes=anchor_sizes , aspect_ratios=aspect_ratios)
 
+
+class MaxVitBackbone(torch.nn.Module):
+    def __init__(self):
+        super(MaxVitBackbone, self).__init__()
+        # 提取MaxVit的部分层作为特征提取器
+        maxvit_model = models.maxvit_t(pretrained=True)
+        self.stem = maxvit_model.stem  # Stem层
+        self.block0= maxvit_model.blocks[0]
+        self.block1 = maxvit_model.blocks[1]
+        self.block2 = maxvit_model.blocks[2]
+        self.block3 = maxvit_model.blocks[3]
+
+
+
+    def forward(self, x):
+        # features = {}
+        x = self.stem(x)
+        x=self.block0(x)
+        x = self.block1(x)
+        x = self.block2(x)
+        x = self.block3(x)
+        return x
+
+
 if __name__ == '__main__':
-    # 创建 ConvNeXt backbone
-    convnext = models.convnext_base(pretrained=True)
-    for i,layer in enumerate(convnext.features):
-        print(f'layer{i}:{layer}')
-    # 创建一个小的输入张量用于获取各层输出通道数
-    dummy_input = torch.randn(1, 3, 224, 224)
-    # output_channels_list = get_output_channels(convnext.features, dummy_input)
-    # print(f'output_channels_list:{output_channels_list}')
-
-    # 根据之前的经验,选择合适的层索引
-    selected_layers = [3, 5, 7]  # 假设这是我们要用作 FPN 输入的层索引
-    in_channels_list = [128,256,512,1024]
-    print(f'in_channels_list:{in_channels_list}')
-
-    # 创建 FPN
+    # maxvit = models.maxvit_t(pretrained=True)
+    maxvit=MaxVitBackbone()
+    # print(maxvit.named_children())
+
+    for i,layer in enumerate(maxvit.named_children()):
+        print(f'layer:{i}:{layer}')
+
+    in_channels_list = [64,64,128, 256, 512]
     backbone_with_fpn = BackboneWithFPN(
-        convnext.features,
-        return_layers={'1':'0','3': '1', '5': '2', '7': '3'},  # 确保这些键对应到实际的层
+        maxvit,
+        return_layers={'stem': '0','block0':'1','block1':'2','block2':'3','block3':'4'},  # 确保这些键对应到实际的层
         in_channels_list=in_channels_list,
         out_channels=256
     )
-
-    # 创建 Faster R-CNN 模型
-    # anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
-    #                                    aspect_ratios=((0.5, 1.0, 2.0),))
-    # anchor_generator = AnchorGenerator(
-    #     sizes=((32,), (64,), (128,)),  # ✅ 正确
-    #     aspect_ratios=((0.5, 1.0, 2.0),) * 3  # ✅ 正确
-    # )
-    test_input = torch.rand(1, 3, 224, 224)
-    anchor_generator = get_anchor_generator(backbone_with_fpn, test_input)
-    print(f'anchor_generator:{anchor_generator}')
-
-    featmap_names=['0', '1', '2', '3', 'pool']
-    roi_pooler = MultiScaleRoIAlign(
-        featmap_names=featmap_names,
-        output_size=7,
-        sampling_ratio=2
-    )
-
     model = FasterRCNN(
         backbone=backbone_with_fpn,
         num_classes=91,  # COCO 数据集有 91 类
-        rpn_anchor_generator=anchor_generator,
-        box_roi_pool=roi_pooler
+        # rpn_anchor_generator=anchor_generator,
+        # box_roi_pool=roi_pooler
     )
 
-    # 测试模型
-    test_input = torch.randn(1, 3, 800, 800)  # 注意输入尺寸应符合 Faster R-CNN 需求
-    model.eval()
-    output = model(test_input)
-    print(f'output: {output}')
+    test_input = torch.randn(1, 3, 896, 896)
+
+    with torch.no_grad():
+        output = backbone_with_fpn(test_input)
 
-    # 测试模型
-    dummy_input = torch.randn(1, 3, 800, 800)  # 注意输入尺寸应符合 Faster R-CNN 需求
+    print("Output feature maps:")
+    for k, v in output.items():
+        print(f"{k}: {v.shape}")
     model.eval()
-    output = model(dummy_input)
-    print(f'output:{output}')
+    output=model(test_input)