Parcourir la source

加入maxvitfpn backbone

lstrlq il y a 5 mois
Parent
commit
7972128d83

+ 41 - 8
models/base/backbone_factory.py

@@ -1,5 +1,7 @@
 from collections import OrderedDict
 
+from torchvision.models import maxvit_t
+
 from libs.vision_libs import models
 from libs.vision_libs.models import mobilenet_v3_large, EfficientNet_V2_S_Weights, efficientnet_v2_s, \
     EfficientNet_V2_M_Weights, efficientnet_v2_m, EfficientNet_V2_L_Weights, efficientnet_v2_l, ConvNeXt_Base_Weights
@@ -58,6 +60,33 @@ def get_convnext_fpn():
     )
     return backbone_with_fpn
 
+def get_maxvit_fpn(input_size=(224*7,224*7)):
+    maxvit = MaxVitBackbone(input_size=input_size)
+    # print(maxvit.named_children())
+
+    # for i,layer in enumerate(maxvit.named_children()):
+    #     print(f'layer:{i}:{layer}')
+
+    test_input = torch.randn(1, 3, 224 * 7, 224 * 7)
+
+    in_channels_list = [64, 64, 128, 256, 512]
+    featmap_names = ['0', '1', '2', '3', '4', 'pool']
+    # print(f'featmap_names:{featmap_names}')
+    roi_pooler = MultiScaleRoIAlign(
+        featmap_names=featmap_names,
+        output_size=7,
+        sampling_ratio=2
+    )
+    backbone_with_fpn = BackboneWithFPN(
+        maxvit,
+        return_layers={'stem': '0', 'block0': '1', 'block1': '2', 'block2': '3', 'block3': '4'},  # 确保这些键对应到实际的层
+        in_channels_list=in_channels_list,
+        out_channels=256
+    )
+    rpn_anchor_generator = get_anchor_generator(backbone_with_fpn, test_input=test_input),
+
+    return  backbone_with_fpn,rpn_anchor_generator,roi_pooler
+
 def get_efficientnetv2_fpn(name='efficientnet_v2_m', pretrained=True):
     # 加载EfficientNetV2模型
     if name == 'efficientnet_v2_s':
@@ -130,28 +159,32 @@ def get_anchor_generator(backbone, test_input):
 
 
 class MaxVitBackbone(torch.nn.Module):
-    def __init__(self):
+    def __init__(self,input_size=(224*7,224*7)):
         super(MaxVitBackbone, self).__init__()
-        # 提取MaxVit的部分层作为特征提取器
-        maxvit_model = models.maxvit_t(pretrained=True)
-        self.stem = maxvit_model.stem  # Stem
+        # 提取MaxVit的部分层作为特征提取器
+        maxvit_model =maxvit_t(pretrained=False,input_size=input_size)
+        self.stem = maxvit_model.stem  # Stem层
         self.block0= maxvit_model.blocks[0]
         self.block1 = maxvit_model.blocks[1]
         self.block2 = maxvit_model.blocks[2]
         self.block3 = maxvit_model.blocks[3]
 
-
-
     def forward(self, x):
-        # features = {}
+        print("Input size:", x.shape)
         x = self.stem(x)
-        x=self.block0(x)
+        print("After stem size:", x.shape)
+        x = self.block0(x)
+        print("After block0 size:", x.shape)
         x = self.block1(x)
+        print("After block1 size:", x.shape)
         x = self.block2(x)
+        print("After block2 size:", x.shape)
         x = self.block3(x)
+        print("After block3 size:", x.shape)
         return x
 
 
+
 if __name__ == '__main__':
     # maxvit = models.maxvit_t(pretrained=True)
     maxvit=MaxVitBackbone()

+ 47 - 1
models/line_detect/line_detect.py

@@ -26,7 +26,7 @@ from .loi_heads import RoIHeads
 
 from .trainer import Trainer
 from ..base import backbone_factory
-from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
+from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator, get_maxvit_fpn, MaxVitBackbone
 # from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
 from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
@@ -399,6 +399,52 @@ def linedetect_newresnet101fpn(
 
     return model
 
+def linedetect_maxvitfpn(
+        *,
+        num_classes: Optional[int] = None,
+        num_points:Optional[int] = None,
+        **kwargs: Any,
+) -> LineDetect:
+    # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
+    # weights_backbone = ResNet50_Weights.verify(weights_backbone)
+    if num_classes is None:
+        num_classes = 3
+    if num_points is None:
+        num_points = 3
+
+    maxvit = MaxVitBackbone(input_size=(224*2,224*2))
+    # print(maxvit.named_children())
+
+    # for i,layer in enumerate(maxvit.named_children()):
+    #     print(f'layer:{i}:{layer}')
+
+    in_channels_list = [64, 64, 128, 256, 512]
+    featmap_names = ['0', '1', '2', '3', '4', 'pool']
+    # print(f'featmap_names:{featmap_names}')
+    roi_pooler = MultiScaleRoIAlign(
+        featmap_names=featmap_names,
+        output_size=7,
+        sampling_ratio=2
+    )
+    backbone_with_fpn = BackboneWithFPN(
+        maxvit,
+        return_layers={'stem': '0', 'block0': '1', 'block1': '2', 'block2': '3', 'block3': '4'},
+        # 确保这些键对应到实际的层
+        in_channels_list=in_channels_list,
+        out_channels=256
+    )
+    test_input = torch.randn(1, 3, 224 * 2, 224 * 2)
+
+    model = LineDetect(
+        backbone=backbone_with_fpn,
+        min_size=224 * 2,
+        max_size=224 * 2,
+        num_classes=91,  # COCO 数据集有 91 类
+        rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input),
+        box_roi_pool=roi_pooler
+    )
+    return model
+
 
 
 def linedetect_resnet18_fpn(

+ 1 - 1
models/line_detect/loi_heads.py

@@ -1425,7 +1425,7 @@ class RoIHeads(nn.Module):
                     }
                 )
 
-        if self.has_line():
+        if  self.has_line():
             print(f'roi_heads forward has_line()!!!!')
             # print(f'labels:{labels}')
             line_proposals = [p["boxes"] for p in result]

+ 3 - 3
models/line_detect/train_demo.py

@@ -1,7 +1,7 @@
 import torch
 
 from models.line_detect.line_detect import linedetect_newresnet18fpn, linedetect_resnet50_fpn, linedetect_resnet18_fpn, \
-    linedetect_newresnet50fpn
+    linedetect_newresnet50fpn, linedetect_maxvitfpn
 
 from models.line_net.trainer import Trainer
 
@@ -16,8 +16,8 @@ if __name__ == '__main__':
     # model = lineDetect_resnet18_fpn()
 
     # model=linedetect_resnet18_fpn()
-    model=linedetect_newresnet50fpn(num_points=3)
+    # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet50fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
-
+    model=linedetect_maxvitfpn()
     model.start_train(cfg='train.yaml')