|
|
@@ -211,8 +211,10 @@ def get_swin_transformer_fpn(type='t'):
|
|
|
if type=='b':
|
|
|
swin=torchvision.models.swin_v2_b(weights=None)
|
|
|
|
|
|
+ for i,layer in enumerate(swin.named_children()):
|
|
|
+ print(f'layer{i}:,{layer}')
|
|
|
# 保存需要提取的层
|
|
|
- self.patch_embed = swin.features[0] # 第0层 patch embedding
|
|
|
+ self.layer0 = swin.features[0] # 第0层 patch embedding
|
|
|
self.layer1 =nn.Sequential(swin.features[1],Trans()) # 第1层 stage1
|
|
|
self.layer2 =nn.Sequential(Trans(),swin.features[2]) # 第2层 downsample
|
|
|
self.layer3 =nn.Sequential(swin.features[3], Trans()) # 第3层 stage2
|
|
|
@@ -223,7 +225,8 @@ def get_swin_transformer_fpn(type='t'):
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
- x = self.patch_embed(x) # [B, C, H, W] -> [B, H_, W_, C]
|
|
|
+ x = self.layer0(x) # [B, C, H, W] -> [B, H_, W_, C]
|
|
|
+ print(f'x0:{x.shape}')
|
|
|
x = self.layer1(x)
|
|
|
print(f'x1:{x.shape}')
|
|
|
x = self.layer2(x)
|