backbone_factory.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. from collections import OrderedDict
  2. import torchvision
  3. from torchvision.models import maxvit_t
  4. from torchvision.models.detection.backbone_utils import BackboneWithFPN
  5. from libs.vision_libs import models
  6. from libs.vision_libs.models import mobilenet_v3_large, EfficientNet_V2_S_Weights, efficientnet_v2_s, \
  7. EfficientNet_V2_M_Weights, efficientnet_v2_m, EfficientNet_V2_L_Weights, efficientnet_v2_l, ConvNeXt_Base_Weights
  8. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  9. from libs.vision_libs.models.detection import FasterRCNN
  10. from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
  11. from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
  12. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, resnet18
  13. from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  14. from libs.vision_libs.ops import misc as misc_nn_ops, MultiScaleRoIAlign
  15. from torch import nn
  16. import torch
  17. # from libs.vision_libs.models.detection.backbone_utils import BackboneWithFPN
  18. def get_resnet50_fpn():
  19. is_trained = False
  20. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
  21. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  22. backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
  23. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  24. return backbone
  25. def get_resnet18_fpn():
  26. is_trained = False
  27. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
  28. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  29. backbone = resnet18(weights=None, progress=True, norm_layer=norm_layer)
  30. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  31. return backbone
  32. def get_mobilenet_v3_large_fpn():
  33. is_trained = False
  34. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 6, 3)
  35. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  36. backbone = mobilenet_v3_large(weights=None, progress=True, norm_layer=norm_layer)
  37. backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
  38. return backbone
  39. def get_convnext_fpn():
  40. convnext = models.convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT)
  41. # convnext = models.convnext_small(pretrained=True)
  42. # convnext = models.convnext_large(pretrained=True)
  43. in_channels_list = [128, 256, 512, 1024]
  44. backbone_with_fpn = BackboneWithFPN(
  45. convnext.features,
  46. return_layers={'1': '0', '3': '1', '5': '2', '7': '3'}, # 确保这些键对应到实际的层
  47. in_channels_list=in_channels_list,
  48. out_channels=256
  49. )
  50. return backbone_with_fpn
  51. def get_maxvit_fpn(input_size=(224*7,224*7)):
  52. maxvit = MaxVitBackbone(input_size=input_size)
  53. # print(maxvit.named_children())
  54. # for i,layer in enumerate(maxvit.named_children()):
  55. # print(f'layer:{i}:{layer}')
  56. test_input = torch.randn(1, 3, 224 * 7, 224 * 7)
  57. in_channels_list = [64, 64, 128, 256, 512]
  58. featmap_names = ['0', '1', '2', '3', '4', 'pool']
  59. # print(f'featmap_names:{featmap_names}')
  60. roi_pooler = MultiScaleRoIAlign(
  61. featmap_names=featmap_names,
  62. output_size=7,
  63. sampling_ratio=2
  64. )
  65. backbone_with_fpn = BackboneWithFPN(
  66. maxvit,
  67. return_layers={'stem': '0', 'block0': '1', 'block1': '2', 'block2': '3', 'block3': '4'}, # 确保这些键对应到实际的层
  68. in_channels_list=in_channels_list,
  69. out_channels=256
  70. )
  71. rpn_anchor_generator = get_anchor_generator(backbone_with_fpn, test_input=test_input),
  72. return backbone_with_fpn,rpn_anchor_generator,roi_pooler
  73. def get_efficientnetv2_fpn(name='efficientnet_v2_m', pretrained=True):
  74. # 加载EfficientNetV2模型
  75. if name == 'efficientnet_v2_s':
  76. weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 if pretrained else None
  77. backbone = efficientnet_v2_s(weights=weights).features
  78. if name == 'efficientnet_v2_m':
  79. weights = EfficientNet_V2_M_Weights.IMAGENET1K_V1 if pretrained else None
  80. backbone = efficientnet_v2_m(weights=weights).features
  81. if name == 'efficientnet_v2_l':
  82. weights = EfficientNet_V2_L_Weights.IMAGENET1K_V1 if pretrained else None
  83. backbone = efficientnet_v2_l(weights=weights).features
  84. # 定义返回的层索引和名称
  85. return_layers = {"2": "0", "3": "1", "4": "2", "5": "3"}
  86. # 获取每个层输出通道数
  87. in_channels_list = []
  88. for layer_idx in [2, 3, 4, 5]:
  89. module = backbone[layer_idx]
  90. if hasattr(module, 'out_channels'):
  91. in_channels_list.append(module.out_channels)
  92. elif hasattr(module[-1], 'out_channels'):
  93. # 如果module本身没有out_channels,检查最后一个子模块
  94. in_channels_list.append(module[-1].out_channels)
  95. else:
  96. raise ValueError(f"Cannot determine out_channels for layer {layer_idx}")
  97. # 使用BackboneWithFPN包装backbone
  98. backbone_with_fpn = BackboneWithFPN(
  99. backbone=backbone,
  100. return_layers=return_layers,
  101. in_channels_list=in_channels_list,
  102. out_channels=256
  103. )
  104. return backbone_with_fpn
  105. # 加载 ConvNeXt 模型
  106. # convnext = models.convnext_base(pretrained=True)
  107. # convnext = models.convnext_tiny(pretrained=True)
  108. # convnext = models.convnext_small(pretrained=True)
  109. # print(convnext)
  110. # # 打印模型的所有命名层
  111. # for name, _ in convnext.features[5].named_children():
  112. # print(name)
  113. # 修改 ConvNeXt 以适应 Faster R-CNN
  114. # 修改 ConvNeXt 以适应 Faster R-CNN
  115. def get_anchor_generator(backbone, test_input):
  116. features = backbone(test_input) # 获取 backbone 输出的所有特征图
  117. featmap_names = list(features.keys())
  118. print(f'featmap_names:{featmap_names}')
  119. num_features = len(features) # 特征图数量
  120. print(f'num_features:{num_features}')
  121. # num_features=num_features-1
  122. # # 定义每层的 anchor 尺寸和比例
  123. # base_sizes = [32, 64, 128] # 支持最多 4 层
  124. # sizes = tuple((size,) for size in base_sizes[:num_features])
  125. anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features)) # 自动生成不同大小
  126. print(f'anchor_sizes:{anchor_sizes }')
  127. aspect_ratios = ((0.5, 1.0, 2.0),) * num_features
  128. print(f'aspect_ratios:{aspect_ratios}')
  129. return AnchorGenerator(sizes=anchor_sizes , aspect_ratios=aspect_ratios)
  130. class MaxVitBackbone(torch.nn.Module):
  131. def __init__(self,input_size=(224*7,224*7)):
  132. super(MaxVitBackbone, self).__init__()
  133. # 提取MaxVit的部分层作为特征提取器
  134. maxvit_model =maxvit_t(pretrained=False,input_size=input_size)
  135. self.stem = maxvit_model.stem # Stem层
  136. self.block0= maxvit_model.blocks[0]
  137. self.block1 = maxvit_model.blocks[1]
  138. self.block2 = maxvit_model.blocks[2]
  139. self.block3 = maxvit_model.blocks[3]
  140. def forward(self, x):
  141. print("Input size:", x.shape)
  142. x = self.stem(x)
  143. print("After stem size:", x.shape)
  144. x = self.block0(x)
  145. print("After block0 size:", x.shape)
  146. x = self.block1(x)
  147. print("After block1 size:", x.shape)
  148. x = self.block2(x)
  149. print("After block2 size:", x.shape)
  150. x = self.block3(x)
  151. print("After block3 size:", x.shape)
  152. return x
  153. from torchvision.models.feature_extraction import create_feature_extractor
  154. if __name__ == '__main__':
  155. class Trans(nn.Module):
  156. def __init__(self):
  157. super().__init__()
  158. def forward(self,x):
  159. x=x.permute(0, 3, 2, 1).contiguous()
  160. return x
  161. class SwinTransformer(nn.Module):
  162. def __init__(self):
  163. super().__init__()
  164. # 加载 Swin Transformer v2 Tiny
  165. swin = torchvision.models.swin_v2_t(weights=None)
  166. # 保存需要提取的层
  167. self.patch_embed = swin.features[0] # 第0层 patch embedding
  168. self.layer1 =nn.Sequential(swin.features[1],Trans()) # 第1层 stage1
  169. self.layer2 =nn.Sequential(Trans(),swin.features[2]) # 第2层 downsample
  170. self.layer3 =nn.Sequential(swin.features[3], Trans()) # 第3层 stage2
  171. self.layer4 =nn.Sequential( Trans(),swin.features[4]) # 第4层 downsample
  172. self.layer5 =nn.Sequential(swin.features[5], Trans()) # 第5层 stage3
  173. self.layer6 =nn.Sequential(Trans(),swin.features[6]) # 第6层 downsample
  174. self.layer7 =nn.Sequential(swin.features[7], Trans()) # 第7层 stage4
  175. def forward(self, x):
  176. outputs = {}
  177. # Patch Embedding
  178. x = self.patch_embed(x) # [B, C, H, W] -> [B, H_, W_, C]
  179. # Layer 1: stage1
  180. x = self.layer1(x)
  181. # if 'feat1' not in outputs:
  182. # feat = x.permute(0, 3, 1, 2).contiguous() # NHWC -> NCHW
  183. # outputs['feat1'] = feat
  184. print(f'x1:{x.shape}')
  185. # Downsample 1
  186. x = self.layer2(x)
  187. # Layer 2: stage2
  188. x = self.layer3(x)
  189. # if 'feat2' not in outputs:
  190. # feat = x.permute(0, 3, 1, 2).contiguous()
  191. # outputs['feat2'] = feat
  192. print(f'x2:{x.shape}')
  193. # Downsample 2
  194. x = self.layer4(x)
  195. # Layer 3: stage3
  196. x = self.layer5(x)
  197. print(f'x3:{x.shape}')
  198. # if 'feat3' not in outputs:
  199. # feat = x.permute(0, 3, 1, 2).contiguous()
  200. # outputs['feat3'] = feat
  201. # Downsample 3
  202. x = self.layer6(x)
  203. # Layer 4: stage4
  204. x = self.layer7(x)
  205. # x = x.permute(0, 3, 2, 1).contiguous()
  206. # if 'feat4' not in outputs:
  207. # feat = x.permute(0, 3, 1, 2).contiguous()
  208. # outputs['feat4'] = feat
  209. print(f'x4:{x.shape}')
  210. return x
  211. backbone = SwinTransformer()
  212. input=torch.randn(1,3,512,512)
  213. out=backbone(input)
  214. # print(f'out:{out.keys()}')
  215. # for i,layer in enumerate(swin.features.named_children()):
  216. # print(f'layer:{i}:{layer}')
  217. # out=swin(input)
  218. # print(f'out shape:{out.shape}')
  219. #
  220. backbone_with_fpn = BackboneWithFPN(
  221. # swin.features,
  222. backbone,
  223. return_layers={'layer1': '0', 'layer3': '1', 'layer5': '2', 'layer7': '3'},
  224. in_channels_list=[96, 192, 384, 768],
  225. out_channels=256
  226. )
  227. out=backbone_with_fpn(input)
  228. print(f'out:{out}')
  229. # # maxvit = models.maxvit_t(pretrained=True)
  230. # maxvit=MaxVitBackbone()
  231. # # print(maxvit.named_children())
  232. #
  233. # for i,layer in enumerate(maxvit.named_children()):
  234. # print(f'layer:{i}:{layer}')
  235. #
  236. # in_channels_list = [64,64,128, 256, 512]
  237. # backbone_with_fpn = BackboneWithFPN(
  238. # maxvit,
  239. # return_layers={'stem': '0','block0':'1','block1':'2','block2':'3','block3':'4'}, # 确保这些键对应到实际的层
  240. # in_channels_list=in_channels_list,
  241. # out_channels=256
  242. # )
  243. # model = FasterRCNN(
  244. # backbone=backbone_with_fpn,
  245. # num_classes=91, # COCO 数据集有 91 类
  246. # # rpn_anchor_generator=anchor_generator,
  247. # # box_roi_pool=roi_pooler
  248. # )
  249. #
  250. # test_input = torch.randn(1, 3, 896, 896)
  251. #
  252. # with torch.no_grad():
  253. # output = backbone_with_fpn(test_input)
  254. #
  255. # print("Output feature maps:")
  256. # for k, v in output.items():
  257. # print(f"{k}: {v.shape}")
  258. # model.eval()
  259. # output=model(test_input)