backbone_factory.py 8.3 KB


  1. from collections import OrderedDict
  2. from torchvision.models import maxvit_t
  3. from libs.vision_libs import models
  4. from libs.vision_libs.models import mobilenet_v3_large, EfficientNet_V2_S_Weights, efficientnet_v2_s, \
  5. EfficientNet_V2_M_Weights, efficientnet_v2_m, EfficientNet_V2_L_Weights, efficientnet_v2_l, ConvNeXt_Base_Weights
  6. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  7. from libs.vision_libs.models.detection import FasterRCNN
  8. from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
  9. from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
  10. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, resnet18
  11. from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  12. from libs.vision_libs.ops import misc as misc_nn_ops, MultiScaleRoIAlign
  13. from torch import nn
  14. import torch
  15. from libs.vision_libs.models.detection.backbone_utils import BackboneWithFPN
  16. def get_resnet50_fpn():
  17. is_trained = False
  18. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
  19. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  20. backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
  21. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  22. return backbone
  23. def get_resnet18_fpn():
  24. is_trained = False
  25. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
  26. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  27. backbone = resnet18(weights=None, progress=True, norm_layer=norm_layer)
  28. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  29. return backbone
  30. def get_mobilenet_v3_large_fpn():
  31. is_trained = False
  32. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 6, 3)
  33. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  34. backbone = mobilenet_v3_large(weights=None, progress=True, norm_layer=norm_layer)
  35. backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
  36. return backbone
  37. def get_convnext_fpn():
  38. convnext = models.convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT)
  39. # convnext = models.convnext_small(pretrained=True)
  40. # convnext = models.convnext_large(pretrained=True)
  41. in_channels_list = [128, 256, 512, 1024]
  42. backbone_with_fpn = BackboneWithFPN(
  43. convnext.features,
  44. return_layers={'1': '0', '3': '1', '5': '2', '7': '3'}, # 确保这些键对应到实际的层
  45. in_channels_list=in_channels_list,
  46. out_channels=256
  47. )
  48. return backbone_with_fpn
  49. def get_maxvit_fpn(input_size=(224*7,224*7)):
  50. maxvit = MaxVitBackbone(input_size=input_size)
  51. # print(maxvit.named_children())
  52. # for i,layer in enumerate(maxvit.named_children()):
  53. # print(f'layer:{i}:{layer}')
  54. test_input = torch.randn(1, 3, 224 * 7, 224 * 7)
  55. in_channels_list = [64, 64, 128, 256, 512]
  56. featmap_names = ['0', '1', '2', '3', '4', 'pool']
  57. # print(f'featmap_names:{featmap_names}')
  58. roi_pooler = MultiScaleRoIAlign(
  59. featmap_names=featmap_names,
  60. output_size=7,
  61. sampling_ratio=2
  62. )
  63. backbone_with_fpn = BackboneWithFPN(
  64. maxvit,
  65. return_layers={'stem': '0', 'block0': '1', 'block1': '2', 'block2': '3', 'block3': '4'}, # 确保这些键对应到实际的层
  66. in_channels_list=in_channels_list,
  67. out_channels=256
  68. )
  69. rpn_anchor_generator = get_anchor_generator(backbone_with_fpn, test_input=test_input),
  70. return backbone_with_fpn,rpn_anchor_generator,roi_pooler
  71. def get_efficientnetv2_fpn(name='efficientnet_v2_m', pretrained=True):
  72. # 加载EfficientNetV2模型
  73. if name == 'efficientnet_v2_s':
  74. weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 if pretrained else None
  75. backbone = efficientnet_v2_s(weights=weights).features
  76. if name == 'efficientnet_v2_m':
  77. weights = EfficientNet_V2_M_Weights.IMAGENET1K_V1 if pretrained else None
  78. backbone = efficientnet_v2_m(weights=weights).features
  79. if name == 'efficientnet_v2_l':
  80. weights = EfficientNet_V2_L_Weights.IMAGENET1K_V1 if pretrained else None
  81. backbone = efficientnet_v2_l(weights=weights).features
  82. # 定义返回的层索引和名称
  83. return_layers = {"2": "0", "3": "1", "4": "2", "5": "3"}
  84. # 获取每个层输出通道数
  85. in_channels_list = []
  86. for layer_idx in [2, 3, 4, 5]:
  87. module = backbone[layer_idx]
  88. if hasattr(module, 'out_channels'):
  89. in_channels_list.append(module.out_channels)
  90. elif hasattr(module[-1], 'out_channels'):
  91. # 如果module本身没有out_channels,检查最后一个子模块
  92. in_channels_list.append(module[-1].out_channels)
  93. else:
  94. raise ValueError(f"Cannot determine out_channels for layer {layer_idx}")
  95. # 使用BackboneWithFPN包装backbone
  96. backbone_with_fpn = BackboneWithFPN(
  97. backbone=backbone,
  98. return_layers=return_layers,
  99. in_channels_list=in_channels_list,
  100. out_channels=256
  101. )
  102. return backbone_with_fpn
  103. # 加载 ConvNeXt 模型
  104. # convnext = models.convnext_base(pretrained=True)
  105. # convnext = models.convnext_tiny(pretrained=True)
  106. # convnext = models.convnext_small(pretrained=True)
  107. # print(convnext)
  108. # # 打印模型的所有命名层
  109. # for name, _ in convnext.features[5].named_children():
  110. # print(name)
  111. # 修改 ConvNeXt 以适应 Faster R-CNN
  112. # 修改 ConvNeXt 以适应 Faster R-CNN
  113. def get_anchor_generator(backbone, test_input):
  114. features = backbone(test_input) # 获取 backbone 输出的所有特征图
  115. featmap_names = list(features.keys())
  116. print(f'featmap_names:{featmap_names}')
  117. num_features = len(features) # 特征图数量
  118. print(f'num_features:{num_features}')
  119. # num_features=num_features-1
  120. # # 定义每层的 anchor 尺寸和比例
  121. # base_sizes = [32, 64, 128] # 支持最多 4 层
  122. # sizes = tuple((size,) for size in base_sizes[:num_features])
  123. anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features)) # 自动生成不同大小
  124. print(f'anchor_sizes:{anchor_sizes }')
  125. aspect_ratios = ((0.5, 1.0, 2.0),) * num_features
  126. print(f'aspect_ratios:{aspect_ratios}')
  127. return AnchorGenerator(sizes=anchor_sizes , aspect_ratios=aspect_ratios)
  128. class MaxVitBackbone(torch.nn.Module):
  129. def __init__(self,input_size=(224*7,224*7)):
  130. super(MaxVitBackbone, self).__init__()
  131. # 提取MaxVit的部分层作为特征提取器
  132. maxvit_model =maxvit_t(pretrained=False,input_size=input_size)
  133. self.stem = maxvit_model.stem # Stem层
  134. self.block0= maxvit_model.blocks[0]
  135. self.block1 = maxvit_model.blocks[1]
  136. self.block2 = maxvit_model.blocks[2]
  137. self.block3 = maxvit_model.blocks[3]
  138. def forward(self, x):
  139. print("Input size:", x.shape)
  140. x = self.stem(x)
  141. print("After stem size:", x.shape)
  142. x = self.block0(x)
  143. print("After block0 size:", x.shape)
  144. x = self.block1(x)
  145. print("After block1 size:", x.shape)
  146. x = self.block2(x)
  147. print("After block2 size:", x.shape)
  148. x = self.block3(x)
  149. print("After block3 size:", x.shape)
  150. return x
  151. if __name__ == '__main__':
  152. # maxvit = models.maxvit_t(pretrained=True)
  153. maxvit=MaxVitBackbone()
  154. # print(maxvit.named_children())
  155. for i,layer in enumerate(maxvit.named_children()):
  156. print(f'layer:{i}:{layer}')
  157. in_channels_list = [64,64,128, 256, 512]
  158. backbone_with_fpn = BackboneWithFPN(
  159. maxvit,
  160. return_layers={'stem': '0','block0':'1','block1':'2','block2':'3','block3':'4'}, # 确保这些键对应到实际的层
  161. in_channels_list=in_channels_list,
  162. out_channels=256
  163. )
  164. model = FasterRCNN(
  165. backbone=backbone_with_fpn,
  166. num_classes=91, # COCO 数据集有 91 类
  167. # rpn_anchor_generator=anchor_generator,
  168. # box_roi_pool=roi_pooler
  169. )
  170. test_input = torch.randn(1, 3, 896, 896)
  171. with torch.no_grad():
  172. output = backbone_with_fpn(test_input)
  173. print("Output feature maps:")
  174. for k, v in output.items():
  175. print(f"{k}: {v.shape}")
  176. model.eval()
  177. output=model(test_input)