Browse Source

添加swin_trans_former

RenLiqiang 5 months ago
parent
commit
298422fcbf
2 changed files with 6 additions and 2 deletions
  1. 5 2
      models/base/backbone_factory.py
  2. 1 0
      models/line_detect/loi_heads.py

+ 5 - 2
models/base/backbone_factory.py

@@ -211,8 +211,10 @@ def get_swin_transformer_fpn(type='t'):
             if type=='b':
                 swin=torchvision.models.swin_v2_b(weights=None)
 
+            for i,layer in enumerate(swin.named_children()):
+                print(f'layer{i}:,{layer}')
             # 保存需要提取的层
-            self.patch_embed = swin.features[0]  # 第0层 patch embedding
+            self.layer0 = swin.features[0]  # 第0层 patch embedding
             self.layer1 =nn.Sequential(swin.features[1],Trans())  # 第1层 stage1
             self.layer2 =nn.Sequential(Trans(),swin.features[2]) # 第2层 downsample
             self.layer3 =nn.Sequential(swin.features[3], Trans()) # 第3层 stage2
@@ -223,7 +225,8 @@ def get_swin_transformer_fpn(type='t'):
 
         def forward(self, x):
 
-            x = self.patch_embed(x)  # [B, C, H, W] -> [B, H_, W_, C]
+            x = self.layer0(x)  # [B, C, H, W] -> [B, H_, W_, C]
+            print(f'x0:{x.shape}')
             x = self.layer1(x)
             print(f'x1:{x.shape}')
             x = self.layer2(x)

+ 1 - 0
models/line_detect/loi_heads.py

@@ -896,6 +896,7 @@ class RoIHeads(nn.Module):
             print(f'line_proposals:{len(line_proposals)}')
 
             # cs_features= features['0']
+            print(f'features-0:{features['0'].shape}')
             cs_features = self.channel_compress(features['0'])