|
@@ -18,7 +18,9 @@ from libs.vision_libs.ops import misc as misc_nn_ops, MultiScaleRoIAlign
|
|
|
from torch import nn
|
|
from torch import nn
|
|
|
|
|
|
|
|
import torch
|
|
import torch
|
|
|
-# from libs.vision_libs.models.detection.backbone_utils import BackboneWithFPN
|
|
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_resnet50_fpn():
|
|
def get_resnet50_fpn():
|
|
@@ -188,8 +190,7 @@ class MaxVitBackbone(torch.nn.Module):
|
|
|
|
|
|
|
|
from torchvision.models.feature_extraction import create_feature_extractor
|
|
from torchvision.models.feature_extraction import create_feature_extractor
|
|
|
|
|
|
|
|
-
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
|
|
|
|
+def get_swin_transformer_fpn(type='t'):
|
|
|
class Trans(nn.Module):
|
|
class Trans(nn.Module):
|
|
|
def __init__(self):
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
@@ -199,10 +200,16 @@ if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
|
|
|
class SwinTransformer(nn.Module):
|
|
class SwinTransformer(nn.Module):
|
|
|
- def __init__(self):
|
|
|
|
|
|
|
+ def __init__(self,type='t'):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
- # 加载 Swin Transformer v2 Tiny
|
|
|
|
|
swin = torchvision.models.swin_v2_t(weights=None)
|
|
swin = torchvision.models.swin_v2_t(weights=None)
|
|
|
|
|
+ if type=='t':
|
|
|
|
|
+ # 加载 Swin Transformer v2 Tiny
|
|
|
|
|
+ swin = torchvision.models.swin_v2_t(weights=None)
|
|
|
|
|
+ if type=='s':
|
|
|
|
|
+ swin=torchvision.models.swin_v2_s(weights=None)
|
|
|
|
|
+ if type=='b':
|
|
|
|
|
+ swin=torchvision.models.swin_v2_b(weights=None)
|
|
|
|
|
|
|
|
# 保存需要提取的层
|
|
# 保存需要提取的层
|
|
|
self.patch_embed = swin.features[0] # 第0层 patch embedding
|
|
self.patch_embed = swin.features[0] # 第0层 patch embedding
|
|
@@ -215,54 +222,22 @@ if __name__ == '__main__':
|
|
|
self.layer7 =nn.Sequential(swin.features[7], Trans()) # 第7层 stage4
|
|
self.layer7 =nn.Sequential(swin.features[7], Trans()) # 第7层 stage4
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
def forward(self, x):
|
|
|
- outputs = {}
|
|
|
|
|
|
|
|
|
|
- # Patch Embedding
|
|
|
|
|
x = self.patch_embed(x) # [B, C, H, W] -> [B, H_, W_, C]
|
|
x = self.patch_embed(x) # [B, C, H, W] -> [B, H_, W_, C]
|
|
|
-
|
|
|
|
|
- # Layer 1: stage1
|
|
|
|
|
x = self.layer1(x)
|
|
x = self.layer1(x)
|
|
|
-
|
|
|
|
|
- # if 'feat1' not in outputs:
|
|
|
|
|
- # feat = x.permute(0, 3, 1, 2).contiguous() # NHWC -> NCHW
|
|
|
|
|
- # outputs['feat1'] = feat
|
|
|
|
|
print(f'x1:{x.shape}')
|
|
print(f'x1:{x.shape}')
|
|
|
- # Downsample 1
|
|
|
|
|
x = self.layer2(x)
|
|
x = self.layer2(x)
|
|
|
-
|
|
|
|
|
- # Layer 2: stage2
|
|
|
|
|
x = self.layer3(x)
|
|
x = self.layer3(x)
|
|
|
- # if 'feat2' not in outputs:
|
|
|
|
|
- # feat = x.permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
- # outputs['feat2'] = feat
|
|
|
|
|
print(f'x2:{x.shape}')
|
|
print(f'x2:{x.shape}')
|
|
|
-
|
|
|
|
|
- # Downsample 2
|
|
|
|
|
x = self.layer4(x)
|
|
x = self.layer4(x)
|
|
|
-
|
|
|
|
|
- # Layer 3: stage3
|
|
|
|
|
x = self.layer5(x)
|
|
x = self.layer5(x)
|
|
|
print(f'x3:{x.shape}')
|
|
print(f'x3:{x.shape}')
|
|
|
-
|
|
|
|
|
- # if 'feat3' not in outputs:
|
|
|
|
|
- # feat = x.permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
- # outputs['feat3'] = feat
|
|
|
|
|
-
|
|
|
|
|
- # Downsample 3
|
|
|
|
|
x = self.layer6(x)
|
|
x = self.layer6(x)
|
|
|
-
|
|
|
|
|
- # Layer 4: stage4
|
|
|
|
|
x = self.layer7(x)
|
|
x = self.layer7(x)
|
|
|
- # x = x.permute(0, 3, 2, 1).contiguous()
|
|
|
|
|
-
|
|
|
|
|
- # if 'feat4' not in outputs:
|
|
|
|
|
- # feat = x.permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
- # outputs['feat4'] = feat
|
|
|
|
|
print(f'x4:{x.shape}')
|
|
print(f'x4:{x.shape}')
|
|
|
-
|
|
|
|
|
return x
|
|
return x
|
|
|
|
|
|
|
|
- backbone = SwinTransformer()
|
|
|
|
|
|
|
+ backbone = SwinTransformer(type=type)
|
|
|
input=torch.randn(1,3,512,512)
|
|
input=torch.randn(1,3,512,512)
|
|
|
out=backbone(input)
|
|
out=backbone(input)
|
|
|
# print(f'out:{out.keys()}')
|
|
# print(f'out:{out.keys()}')
|
|
@@ -271,18 +246,42 @@ if __name__ == '__main__':
|
|
|
# out=swin(input)
|
|
# out=swin(input)
|
|
|
# print(f'out shape:{out.shape}')
|
|
# print(f'out shape:{out.shape}')
|
|
|
#
|
|
#
|
|
|
|
|
+
|
|
|
|
|
+ channels_list = [96, 192, 384, 768]
|
|
|
|
|
+ if type=='t':
|
|
|
|
|
+ channels_list = [96, 192, 384, 768]
|
|
|
|
|
+ if type=='s':
|
|
|
|
|
+ channels_list = [96, 192, 384, 768]
|
|
|
|
|
+ if type=='b':
|
|
|
|
|
+ channels_list = [128, 256, 512, 1024]
|
|
|
backbone_with_fpn = BackboneWithFPN(
|
|
backbone_with_fpn = BackboneWithFPN(
|
|
|
# swin.features,
|
|
# swin.features,
|
|
|
backbone,
|
|
backbone,
|
|
|
return_layers={'layer1': '0', 'layer3': '1', 'layer5': '2', 'layer7': '3'},
|
|
return_layers={'layer1': '0', 'layer3': '1', 'layer5': '2', 'layer7': '3'},
|
|
|
- in_channels_list=[96, 192, 384, 768],
|
|
|
|
|
|
|
+ in_channels_list=channels_list,
|
|
|
out_channels=256
|
|
out_channels=256
|
|
|
)
|
|
)
|
|
|
|
|
+ featmap_names = ['0', '1', '2', '3', 'pool']
|
|
|
|
|
+ # print(f'featmap_names:{featmap_names}')
|
|
|
|
|
+ roi_pooler = MultiScaleRoIAlign(
|
|
|
|
|
+ featmap_names=featmap_names,
|
|
|
|
|
+ output_size=7,
|
|
|
|
|
+ sampling_ratio=2
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- out=backbone_with_fpn(input)
|
|
|
|
|
|
|
+ # out=backbone_with_fpn(input)
|
|
|
|
|
+ anchor_generator = get_anchor_generator(backbone_with_fpn, test_input=input)
|
|
|
|
|
|
|
|
- print(f'out:{out}')
|
|
|
|
|
|
|
|
|
|
|
|
+ # print(f'out:{out}')
|
|
|
|
|
+ return backbone_with_fpn,roi_pooler,anchor_generator
|
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
|
+ backbone_with_fpn, roi_pooler, anchor_generator=get_swin_transformer_fpn(type='s')
|
|
|
|
|
+ model=FasterRCNN(backbone=backbone_with_fpn,num_classes=3,box_roi_pool=roi_pooler,rpn_anchor_generator=anchor_generator)
|
|
|
|
|
+ input=torch.randn(3,3,512,512,device='cuda')
|
|
|
|
|
+ model.eval()
|
|
|
|
|
+ model.to('cuda')
|
|
|
|
|
+ out=model(input)
|
|
|
|
|
|
|
|
|
|
|
|
|
# # maxvit = models.maxvit_t(pretrained=True)
|
|
# # maxvit = models.maxvit_t(pretrained=True)
|