| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- import torch
- from torch import nn
- from models.base.high_reso_resnet import Bottleneck, resnet101fpn
- from typing import Any, Callable, List, Optional, Type, Union
- class TestModel(nn.Sequential):
- def __init__(self,block: Type[Union[Bottleneck]]):
- super().__init__()
- self.encoder=resnet101fpn(out_channels=256)
- self.decoder=FPNDecoder(block)
- def forward(self, x):
- res=self.encoder(x)
- for k in res.keys():
- print(f'k:{k}')
- # print(f'res:{res}')
- out=self.decoder(res)
- return out
- class FPNDecoder(nn.Sequential):
- def __init__(self,block: Type[Union[Bottleneck]], in_channels=256):
- super().__init__()
- self._norm_layer = nn.BatchNorm2d
- self.inplanes = 64
- self.dilation = 1
- self.groups = 1
- self.base_width = 64
- self.decoder0=self._make_decoder_layer(block,256,256)
- self.upconv4 = self._make_upsample_layer(256 , 256)
- self.upconv3 = self._make_upsample_layer(512, 256)
- self.upconv2 = self._make_upsample_layer(512, 256)
- self.upconv1 = self._make_upsample_layer(512, 256)
- self.upconv0 = self._make_upsample_layer(512, 256)
- self.final_up = self._make_upsample_layer(512, 128)
- self.decoder4 = self._make_decoder_layer(block, 512, 256)
- self.decoder3 = self._make_decoder_layer(block, 512, 256)
- self.decoder2 = self._make_decoder_layer(block, 512, 256)
- self.decoder1 = self._make_decoder_layer(block, 512, 256)
- self.final_conv = nn.Conv2d(128, 1, kernel_size=1)
- def forward(self, fpn_res):
- # ------------------
- # Encoder
- # ------------------
- e0 = fpn_res['0'] # [B, 64, H/4, W/4]
- print(f'e0:{e0.shape}')
- e1 = fpn_res['1'] # [B, 256, H/4, W/4]
- print(f'e1:{e1.shape}')
- e2 = fpn_res['2'] # [B, 512, H/8, W/8]
- print(f'e2:{e2.shape}')
- e3 = fpn_res['3'] # [B, 1024, H/16, W/16]
- print(f'e3:{e3.shape}')
- e4 = fpn_res['4'] # [B, 2048, H/32, W/32]
- print(f'e4:{e4.shape}')
- # ------------------
- # Decoder
- # ------------------
- d4 = self.upconv4(e4) # [B, 1024, H/16, W/16]
- print(f'd4 = self.upconv4(e4):{d4.shape}')
- d4 = torch.cat([d4, e3], dim=1) # [B, 2048, H/16, W/16]
- print(f' d4 = torch.cat([d4, e3], dim=1):{d4.shape}')
- d4 = self.decoder4(d4) # [B, 2048, H/16, W/16]
- print(f'd4 = self.decoder4(d4):{d4.shape}')
- d3 = self.upconv3(d4) # [B, 512, H/8, W/8]
- print(f'd3 = self.upconv3(d4):{d3.shape}')
- d3 = torch.cat([d3, e2], dim=1) # [B, 1024, H/8, W/8]
- print(f'd3 = torch.cat([d3, e2], dim=1):{d3.shape}')
- d3 = self.decoder3(d3) # [B, 1024, H/8, W/8]
- print(f'd3 = self.decoder3(d3):{d3.shape}')
- d2 = self.upconv2(d3) # [B, 256, H/4, W/4]
- print(f'd2 = self.upconv2(d3):{d2.shape}')
- d2 = torch.cat([d2, e1], dim=1) # [B, 512, H/4, W/4]
- print(f'd2 = torch.cat([d2, e1], dim=1):{d2.shape}')
- d2 = self.decoder2(d2) # [B, 512, H/4, W/4]
- print(f'd2 = self.decoder2(d2):{d2.shape}')
- d1 = self.upconv1(d2) # [B, 64, H/2, W/2]
- print(f'd1 = self.upconv1(d2):{d1.shape}')
- d1 = torch.cat([d1, e0], dim=1) # [B, 128, H/2, W/2]
- print(f'd1 = torch.cat([d1, e0], dim=1):{d1.shape}')
- d1 = self.decoder1(d1) # [B, 128, H/2, W/2]
- print(f'd1 =self.decoder1(d1):{d1.shape}')
- # ------------------
- # Output Head
- # ------------------
- d0=self.final_up(d1)
- out = self.final_conv(d0) # [B, num_classes, H/2, W/2]
- print(f'out:{out.shape}')
- return out
- def _make_upsample_layer(self, in_channels: int, out_channels: int) -> nn.Module:
- """
- 使用转置卷积进行上采样
- """
- return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
- def _make_decoder_layer(self, block: Type[Union[Bottleneck]], in_channels: int,
- out_channels: int, blocks: int = 1) -> nn.Sequential:
- """
- """
- # assert in_channels == out_channels, "in_channels must equal out_channels"
- layers = []
- for _ in range(blocks):
- layers.append(
- block(
- in_channels,
- in_channels // block.expansion,
- groups=self.groups,
- base_width=self.base_width,
- dilation=self.dilation,
- norm_layer=self._norm_layer,
- )
- )
- return nn.Sequential(*layers)
- if __name__ == '__main__':
- model=TestModel(Bottleneck)
- x=torch.randn(3,3,512,512)
- out=model(x)
|