|
@@ -0,0 +1,139 @@
|
|
|
|
|
+import torch
|
|
|
|
|
+from torch import nn
|
|
|
|
|
+
|
|
|
|
|
+from models.base.high_reso_resnet import Bottleneck, resnet101fpn
|
|
|
|
|
+from typing import Any, Callable, List, Optional, Type, Union
|
|
|
|
|
+
|
|
|
|
|
+class TestModel(nn.Sequential):
|
|
|
|
|
+ def __init__(self,block: Type[Union[Bottleneck]]):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.encoder=resnet101fpn(out_channels=256)
|
|
|
|
|
+ self.decoder=ArcUnet(block)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ res=self.encoder(x)
|
|
|
|
|
+ for k in res.keys():
|
|
|
|
|
+ print(f'k:{k}')
|
|
|
|
|
+ # print(f'res:{res}')
|
|
|
|
|
+
|
|
|
|
|
+ out=self.decoder(res)
|
|
|
|
|
+ return out
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ArcUnet(nn.Sequential):
|
|
|
|
|
+ def __init__(self,block: Type[Union[Bottleneck]], in_channels=256):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self._norm_layer = nn.BatchNorm2d
|
|
|
|
|
+ self.inplanes = 64
|
|
|
|
|
+ self.dilation = 1
|
|
|
|
|
+ self.groups = 1
|
|
|
|
|
+ self.base_width = 64
|
|
|
|
|
+
|
|
|
|
|
+ self.decoder0=self._make_decoder_layer(block,256,256)
|
|
|
|
|
+
|
|
|
|
|
+ self.upconv4 = self._make_upsample_layer(256 , 256)
|
|
|
|
|
+ self.upconv3 = self._make_upsample_layer(512, 256)
|
|
|
|
|
+ self.upconv2 = self._make_upsample_layer(512, 256)
|
|
|
|
|
+ self.upconv1 = self._make_upsample_layer(512, 256)
|
|
|
|
|
+ self.upconv0 = self._make_upsample_layer(512, 256)
|
|
|
|
|
+
|
|
|
|
|
+ self.final_up = self._make_upsample_layer(512, 128)
|
|
|
|
|
+
|
|
|
|
|
+ self.decoder4 = self._make_decoder_layer(block, 512, 256)
|
|
|
|
|
+ self.decoder3 = self._make_decoder_layer(block, 512, 256)
|
|
|
|
|
+ self.decoder2 = self._make_decoder_layer(block, 512, 256)
|
|
|
|
|
+ self.decoder1 = self._make_decoder_layer(block, 512, 256)
|
|
|
|
|
+
|
|
|
|
|
+ self.final_conv = nn.Conv2d(128, 1, kernel_size=1)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, fpn_res):
|
|
|
|
|
+ # ------------------
|
|
|
|
|
+ # Encoder
|
|
|
|
|
+ # ------------------
|
|
|
|
|
+ e0 = fpn_res['0'] # [B, 64, H/4, W/4]
|
|
|
|
|
+ print(f'e0:{e0.shape}')
|
|
|
|
|
+ e1 = fpn_res['1'] # [B, 256, H/4, W/4]
|
|
|
|
|
+ print(f'e1:{e1.shape}')
|
|
|
|
|
+ e2 = fpn_res['2'] # [B, 512, H/8, W/8]
|
|
|
|
|
+ print(f'e2:{e2.shape}')
|
|
|
|
|
+ e3 = fpn_res['3'] # [B, 1024, H/16, W/16]
|
|
|
|
|
+ print(f'e3:{e3.shape}')
|
|
|
|
|
+ e4 = fpn_res['4'] # [B, 2048, H/32, W/32]
|
|
|
|
|
+ print(f'e4:{e4.shape}')
|
|
|
|
|
+
|
|
|
|
|
+ # ------------------
|
|
|
|
|
+ # Decoder
|
|
|
|
|
+ # ------------------
|
|
|
|
|
+
|
|
|
|
|
+ d4 = self.upconv4(e4) # [B, 1024, H/16, W/16]
|
|
|
|
|
+ print(f'd4 = self.upconv4(e4):{d4.shape}')
|
|
|
|
|
+ d4 = torch.cat([d4, e3], dim=1) # [B, 2048, H/16, W/16]
|
|
|
|
|
+ print(f' d4 = torch.cat([d4, e3], dim=1):{d4.shape}')
|
|
|
|
|
+ d4 = self.decoder4(d4) # [B, 2048, H/16, W/16]
|
|
|
|
|
+ print(f'd4 = self.decoder4(d4):{d4.shape}')
|
|
|
|
|
+
|
|
|
|
|
+ d3 = self.upconv3(d4) # [B, 512, H/8, W/8]
|
|
|
|
|
+ print(f'd3 = self.upconv3(d4):{d3.shape}')
|
|
|
|
|
+ d3 = torch.cat([d3, e2], dim=1) # [B, 1024, H/8, W/8]
|
|
|
|
|
+ print(f'd3 = torch.cat([d3, e2], dim=1):{d3.shape}')
|
|
|
|
|
+ d3 = self.decoder3(d3) # [B, 1024, H/8, W/8]
|
|
|
|
|
+ print(f'd3 = self.decoder3(d3):{d3.shape}')
|
|
|
|
|
+
|
|
|
|
|
+ d2 = self.upconv2(d3) # [B, 256, H/4, W/4]
|
|
|
|
|
+ print(f'd2 = self.upconv2(d3):{d2.shape}')
|
|
|
|
|
+ d2 = torch.cat([d2, e1], dim=1) # [B, 512, H/4, W/4]
|
|
|
|
|
+ print(f'd2 = torch.cat([d2, e1], dim=1):{d2.shape}')
|
|
|
|
|
+ d2 = self.decoder2(d2) # [B, 512, H/4, W/4]
|
|
|
|
|
+ print(f'd2 = self.decoder2(d2):{d2.shape}')
|
|
|
|
|
+
|
|
|
|
|
+ d1 = self.upconv1(d2) # [B, 64, H/2, W/2]
|
|
|
|
|
+ print(f'd1 = self.upconv1(d2):{d1.shape}')
|
|
|
|
|
+ d1 = torch.cat([d1, e0], dim=1) # [B, 128, H/2, W/2]
|
|
|
|
|
+ print(f'd1 = torch.cat([d1, e0], dim=1):{d1.shape}')
|
|
|
|
|
+ d1 = self.decoder1(d1) # [B, 128, H/2, W/2]
|
|
|
|
|
+ print(f'd1 =self.decoder1(d1):{d1.shape}')
|
|
|
|
|
+
|
|
|
|
|
+ # ------------------
|
|
|
|
|
+ # Output Head
|
|
|
|
|
+ # ------------------
|
|
|
|
|
+ d0=self.final_up(d1)
|
|
|
|
|
+ out = self.final_conv(d0) # [B, num_classes, H/2, W/2]
|
|
|
|
|
+ print(f'out:{out.shape}')
|
|
|
|
|
+
|
|
|
|
|
+ return out
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ def _make_upsample_layer(self, in_channels: int, out_channels: int) -> nn.Module:
|
|
|
|
|
+ """
|
|
|
|
|
+ 使用转置卷积进行上采样
|
|
|
|
|
+ """
|
|
|
|
|
+ return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
|
|
|
|
|
+
|
|
|
|
|
+ def _make_decoder_layer(self, block: Type[Union[Bottleneck]], in_channels: int,
|
|
|
|
|
+ out_channels: int, blocks: int = 1) -> nn.Sequential:
|
|
|
|
|
+ """
|
|
|
|
|
+ """
|
|
|
|
|
+ # assert in_channels == out_channels, "in_channels must equal out_channels"
|
|
|
|
|
+ layers = []
|
|
|
|
|
+ for _ in range(blocks):
|
|
|
|
|
+ layers.append(
|
|
|
|
|
+ block(
|
|
|
|
|
+ in_channels,
|
|
|
|
|
+ in_channels // block.expansion,
|
|
|
|
|
+ groups=self.groups,
|
|
|
|
|
+ base_width=self.base_width,
|
|
|
|
|
+ dilation=self.dilation,
|
|
|
|
|
+ norm_layer=self._norm_layer,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ return nn.Sequential(*layers)
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
|
+ model=TestModel(Bottleneck)
|
|
|
|
|
+ x=torch.randn(3,3,512,512)
|
|
|
|
|
+ out=model(x)
|