backbone_factory.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. from collections import OrderedDict
  2. from libs.vision_libs import models
  3. from libs.vision_libs.models import mobilenet_v3_large, EfficientNet_V2_S_Weights, efficientnet_v2_s, \
  4. EfficientNet_V2_M_Weights, efficientnet_v2_m, EfficientNet_V2_L_Weights, efficientnet_v2_l
  5. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  6. from libs.vision_libs.models.detection import FasterRCNN
  7. from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
  8. from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
  9. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, resnet18
  10. from libs.vision_libs.models.detection._utils import overwrite_eps
  11. from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  12. from libs.vision_libs.ops import misc as misc_nn_ops, MultiScaleRoIAlign
  13. from torch import nn
  14. import torch
  15. from torchvision.models.detection.backbone_utils import BackboneWithFPN, resnet_fpn_backbone
  16. from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
  17. def get_resnet50_fpn():
  18. is_trained = False
  19. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
  20. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  21. backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
  22. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  23. return backbone
  24. def get_resnet18_fpn():
  25. is_trained = False
  26. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
  27. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  28. backbone = resnet18(weights=None, progress=True, norm_layer=norm_layer)
  29. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  30. return backbone
  31. def get_mobilenet_v3_large_fpn():
  32. is_trained = False
  33. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 6, 3)
  34. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  35. backbone = mobilenet_v3_large(weights=None, progress=True, norm_layer=norm_layer)
  36. backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
  37. return backbone
  38. def get_convnext_fpn():
  39. convnext = models.convnext_base(pretrained=True)
  40. in_channels_list = [128, 256, 512, 1024]
  41. backbone_with_fpn = BackboneWithFPN(
  42. convnext.features,
  43. return_layers={'1': '0', '3': '1', '5': '2', '7': '3'}, # 确保这些键对应到实际的层
  44. in_channels_list=in_channels_list,
  45. out_channels=256
  46. )
  47. return backbone_with_fpn
  48. def get_efficientnetv2_fpn(name='efficientnet_v2_m', pretrained=True):
  49. # 加载EfficientNetV2模型
  50. if name == 'efficientnet_v2_s':
  51. weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 if pretrained else None
  52. backbone = efficientnet_v2_s(weights=weights).features
  53. if name == 'efficientnet_v2_m':
  54. weights = EfficientNet_V2_M_Weights.IMAGENET1K_V1 if pretrained else None
  55. backbone = efficientnet_v2_m(weights=weights).features
  56. if name == 'efficientnet_v2_l':
  57. weights = EfficientNet_V2_L_Weights.IMAGENET1K_V1 if pretrained else None
  58. backbone = efficientnet_v2_l(weights=weights).features
  59. # 定义返回的层索引和名称
  60. return_layers = {"2": "0", "3": "1", "4": "2", "5": "3"}
  61. # 获取每个层输出通道数
  62. in_channels_list = []
  63. for layer_idx in [2, 3, 4, 5]:
  64. module = backbone[layer_idx]
  65. if hasattr(module, 'out_channels'):
  66. in_channels_list.append(module.out_channels)
  67. elif hasattr(module[-1], 'out_channels'):
  68. # 如果module本身没有out_channels,检查最后一个子模块
  69. in_channels_list.append(module[-1].out_channels)
  70. else:
  71. raise ValueError(f"Cannot determine out_channels for layer {layer_idx}")
  72. # 使用BackboneWithFPN包装backbone
  73. backbone_with_fpn = BackboneWithFPN(
  74. backbone=backbone,
  75. return_layers=return_layers,
  76. in_channels_list=in_channels_list,
  77. out_channels=256
  78. )
  79. return backbone_with_fpn
  80. # 加载 ConvNeXt 模型
  81. convnext = models.convnext_base(pretrained=True)
  82. # convnext = models.convnext_tiny(pretrained=True)
  83. # convnext = models.convnext_small(pretrained=True)
  84. # print(convnext)
  85. # # 打印模型的所有命名层
  86. # for name, _ in convnext.features[5].named_children():
  87. # print(name)
  88. # 修改 ConvNeXt 以适应 Faster R-CNN
  89. # 修改 ConvNeXt 以适应 Faster R-CNN
  90. def get_anchor_generator(backbone, test_input):
  91. features = backbone(test_input) # 获取 backbone 输出的所有特征图
  92. featmap_names = list(features.keys())
  93. print(f'featmap_names:{featmap_names}')
  94. num_features = len(features) # 特征图数量
  95. print(f'num_features:{num_features}')
  96. # num_features=num_features-1
  97. # # 定义每层的 anchor 尺寸和比例
  98. # base_sizes = [32, 64, 128] # 支持最多 4 层
  99. # sizes = tuple((size,) for size in base_sizes[:num_features])
  100. anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features)) # 自动生成不同大小
  101. print(f'anchor_sizes:{anchor_sizes }')
  102. aspect_ratios = ((0.5, 1.0, 2.0),) * num_features
  103. print(f'aspect_ratios:{aspect_ratios}')
  104. return AnchorGenerator(sizes=anchor_sizes , aspect_ratios=aspect_ratios)
  105. if __name__ == '__main__':
  106. # 创建 ConvNeXt backbone
  107. convnext = models.convnext_base(pretrained=True)
  108. for i,layer in enumerate(convnext.features):
  109. print(f'layer{i}:{layer}')
  110. # 创建一个小的输入张量用于获取各层输出通道数
  111. dummy_input = torch.randn(1, 3, 224, 224)
  112. # output_channels_list = get_output_channels(convnext.features, dummy_input)
  113. # print(f'output_channels_list:{output_channels_list}')
  114. # 根据之前的经验,选择合适的层索引
  115. selected_layers = [3, 5, 7] # 假设这是我们要用作 FPN 输入的层索引
  116. in_channels_list = [128,256,512,1024]
  117. print(f'in_channels_list:{in_channels_list}')
  118. # 创建 FPN
  119. backbone_with_fpn = BackboneWithFPN(
  120. convnext.features,
  121. return_layers={'1':'0','3': '1', '5': '2', '7': '3'}, # 确保这些键对应到实际的层
  122. in_channels_list=in_channels_list,
  123. out_channels=256
  124. )
  125. # 创建 Faster R-CNN 模型
  126. # anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
  127. # aspect_ratios=((0.5, 1.0, 2.0),))
  128. # anchor_generator = AnchorGenerator(
  129. # sizes=((32,), (64,), (128,)), # ✅ 正确
  130. # aspect_ratios=((0.5, 1.0, 2.0),) * 3 # ✅ 正确
  131. # )
  132. test_input = torch.rand(1, 3, 224, 224)
  133. anchor_generator = get_anchor_generator(backbone_with_fpn, test_input)
  134. print(f'anchor_generator:{anchor_generator}')
  135. featmap_names=['0', '1', '2', '3', 'pool']
  136. roi_pooler = MultiScaleRoIAlign(
  137. featmap_names=featmap_names,
  138. output_size=7,
  139. sampling_ratio=2
  140. )
  141. model = FasterRCNN(
  142. backbone=backbone_with_fpn,
  143. num_classes=91, # COCO 数据集有 91 类
  144. rpn_anchor_generator=anchor_generator,
  145. box_roi_pool=roi_pooler
  146. )
  147. # 测试模型
  148. test_input = torch.randn(1, 3, 800, 800) # 注意输入尺寸应符合 Faster R-CNN 需求
  149. model.eval()
  150. output = model(test_input)
  151. print(f'output: {output}')
  152. # 测试模型
  153. dummy_input = torch.randn(1, 3, 800, 800) # 注意输入尺寸应符合 Faster R-CNN 需求
  154. model.eval()
  155. output = model(dummy_input)
  156. print(f'output:{output}')