import torch from torch import nn from models.base.high_reso_resnet import Bottleneck, resnet101fpn from typing import Any, Callable, List, Optional, Type, Union class TestModel(nn.Sequential): def __init__(self,block: Type[Union[Bottleneck]]): super().__init__() self.encoder=resnet101fpn(out_channels=256) self.decoder=FPNDecoder(block) def forward(self, x): res=self.encoder(x) for k in res.keys(): print(f'k:{k}') # print(f'res:{res}') out=self.decoder(res) return out class FPNDecoder(nn.Sequential): def __init__(self,block: Type[Union[Bottleneck]], in_channels=256): super().__init__() self._norm_layer = nn.BatchNorm2d self.inplanes = 64 self.dilation = 1 self.groups = 1 self.base_width = 64 self.decoder0=self._make_decoder_layer(block,256,256) self.upconv4 = self._make_upsample_layer(256 , 256) self.upconv3 = self._make_upsample_layer(512, 256) self.upconv2 = self._make_upsample_layer(512, 256) self.upconv1 = self._make_upsample_layer(512, 256) self.upconv0 = self._make_upsample_layer(512, 256) self.final_up = self._make_upsample_layer(512, 128) self.decoder4 = self._make_decoder_layer(block, 512, 256) self.decoder3 = self._make_decoder_layer(block, 512, 256) self.decoder2 = self._make_decoder_layer(block, 512, 256) self.decoder1 = self._make_decoder_layer(block, 512, 256) self.final_conv = nn.Conv2d(128, 1, kernel_size=1) def forward(self, fpn_res): # ------------------ # Encoder # ------------------ e0 = fpn_res['0'] # [B, 64, H/4, W/4] print(f'e0:{e0.shape}') e1 = fpn_res['1'] # [B, 256, H/4, W/4] print(f'e1:{e1.shape}') e2 = fpn_res['2'] # [B, 512, H/8, W/8] print(f'e2:{e2.shape}') e3 = fpn_res['3'] # [B, 1024, H/16, W/16] print(f'e3:{e3.shape}') e4 = fpn_res['4'] # [B, 2048, H/32, W/32] print(f'e4:{e4.shape}') # ------------------ # Decoder # ------------------ d4 = self.upconv4(e4) # [B, 1024, H/16, W/16] print(f'd4 = self.upconv4(e4):{d4.shape}') d4 = torch.cat([d4, e3], dim=1) # [B, 2048, H/16, W/16] print(f' d4 = torch.cat([d4, e3], dim=1):{d4.shape}') d4 = self.decoder4(d4) # [B, 2048, H/16, W/16] print(f'd4 = self.decoder4(d4):{d4.shape}') d3 = self.upconv3(d4) # [B, 512, H/8, W/8] print(f'd3 = self.upconv3(d4):{d3.shape}') d3 = torch.cat([d3, e2], dim=1) # [B, 1024, H/8, W/8] print(f'd3 = torch.cat([d3, e2], dim=1):{d3.shape}') d3 = self.decoder3(d3) # [B, 1024, H/8, W/8] print(f'd3 = self.decoder3(d3):{d3.shape}') d2 = self.upconv2(d3) # [B, 256, H/4, W/4] print(f'd2 = self.upconv2(d3):{d2.shape}') d2 = torch.cat([d2, e1], dim=1) # [B, 512, H/4, W/4] print(f'd2 = torch.cat([d2, e1], dim=1):{d2.shape}') d2 = self.decoder2(d2) # [B, 512, H/4, W/4] print(f'd2 = self.decoder2(d2):{d2.shape}') d1 = self.upconv1(d2) # [B, 64, H/2, W/2] print(f'd1 = self.upconv1(d2):{d1.shape}') d1 = torch.cat([d1, e0], dim=1) # [B, 128, H/2, W/2] print(f'd1 = torch.cat([d1, e0], dim=1):{d1.shape}') d1 = self.decoder1(d1) # [B, 128, H/2, W/2] print(f'd1 =self.decoder1(d1):{d1.shape}') # ------------------ # Output Head # ------------------ d0=self.final_up(d1) out = self.final_conv(d0) # [B, num_classes, H/2, W/2] print(f'out:{out.shape}') return out def _make_upsample_layer(self, in_channels: int, out_channels: int) -> nn.Module: """ 使用转置卷积进行上采样 """ return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) def _make_decoder_layer(self, block: Type[Union[Bottleneck]], in_channels: int, out_channels: int, blocks: int = 1) -> nn.Sequential: """ """ # assert in_channels == out_channels, "in_channels must equal out_channels" layers = [] for _ in range(blocks): layers.append( block( in_channels, in_channels // block.expansion, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=self._norm_layer, ) ) return nn.Sequential(*layers) if __name__ == '__main__': model=TestModel(Bottleneck) x=torch.randn(3,3,512,512) out=model(x)