arc_unet.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import torch
  2. from torch import nn
  3. from models.base.high_reso_resnet import Bottleneck, resnet101fpn
  4. from typing import Any, Callable, List, Optional, Type, Union
  5. class TestModel(nn.Sequential):
  6. def __init__(self,block: Type[Union[Bottleneck]]):
  7. super().__init__()
  8. self.encoder=resnet101fpn(out_channels=256)
  9. self.decoder=ArcUnet(block)
  10. def forward(self, x):
  11. res=self.encoder(x)
  12. for k in res.keys():
  13. print(f'k:{k}')
  14. # print(f'res:{res}')
  15. out=self.decoder(res)
  16. return out
  17. class ArcUnet(nn.Sequential):
  18. def __init__(self,block: Type[Union[Bottleneck]], in_channels=256):
  19. super().__init__()
  20. self._norm_layer = nn.BatchNorm2d
  21. self.inplanes = 64
  22. self.dilation = 1
  23. self.groups = 1
  24. self.base_width = 64
  25. self.decoder0=self._make_decoder_layer(block,256,256)
  26. self.upconv4 = self._make_upsample_layer(256 , 256)
  27. self.upconv3 = self._make_upsample_layer(512, 256)
  28. self.upconv2 = self._make_upsample_layer(512, 256)
  29. self.upconv1 = self._make_upsample_layer(512, 256)
  30. self.upconv0 = self._make_upsample_layer(512, 256)
  31. self.final_up = self._make_upsample_layer(512, 128)
  32. self.decoder4 = self._make_decoder_layer(block, 512, 256)
  33. self.decoder3 = self._make_decoder_layer(block, 512, 256)
  34. self.decoder2 = self._make_decoder_layer(block, 512, 256)
  35. self.decoder1 = self._make_decoder_layer(block, 512, 256)
  36. self.final_conv = nn.Conv2d(128, 1, kernel_size=1)
  37. def forward(self, fpn_res):
  38. # ------------------
  39. # Encoder
  40. # ------------------
  41. e0 = fpn_res['0'] # [B, 64, H/4, W/4]
  42. print(f'e0:{e0.shape}')
  43. e1 = fpn_res['1'] # [B, 256, H/4, W/4]
  44. print(f'e1:{e1.shape}')
  45. e2 = fpn_res['2'] # [B, 512, H/8, W/8]
  46. print(f'e2:{e2.shape}')
  47. e3 = fpn_res['3'] # [B, 1024, H/16, W/16]
  48. print(f'e3:{e3.shape}')
  49. e4 = fpn_res['4'] # [B, 2048, H/32, W/32]
  50. print(f'e4:{e4.shape}')
  51. # ------------------
  52. # Decoder
  53. # ------------------
  54. d4 = self.upconv4(e4) # [B, 1024, H/16, W/16]
  55. print(f'd4 = self.upconv4(e4):{d4.shape}')
  56. d4 = torch.cat([d4, e3], dim=1) # [B, 2048, H/16, W/16]
  57. print(f' d4 = torch.cat([d4, e3], dim=1):{d4.shape}')
  58. d4 = self.decoder4(d4) # [B, 2048, H/16, W/16]
  59. print(f'd4 = self.decoder4(d4):{d4.shape}')
  60. d3 = self.upconv3(d4) # [B, 512, H/8, W/8]
  61. print(f'd3 = self.upconv3(d4):{d3.shape}')
  62. d3 = torch.cat([d3, e2], dim=1) # [B, 1024, H/8, W/8]
  63. print(f'd3 = torch.cat([d3, e2], dim=1):{d3.shape}')
  64. d3 = self.decoder3(d3) # [B, 1024, H/8, W/8]
  65. print(f'd3 = self.decoder3(d3):{d3.shape}')
  66. d2 = self.upconv2(d3) # [B, 256, H/4, W/4]
  67. print(f'd2 = self.upconv2(d3):{d2.shape}')
  68. d2 = torch.cat([d2, e1], dim=1) # [B, 512, H/4, W/4]
  69. print(f'd2 = torch.cat([d2, e1], dim=1):{d2.shape}')
  70. d2 = self.decoder2(d2) # [B, 512, H/4, W/4]
  71. print(f'd2 = self.decoder2(d2):{d2.shape}')
  72. d1 = self.upconv1(d2) # [B, 64, H/2, W/2]
  73. print(f'd1 = self.upconv1(d2):{d1.shape}')
  74. d1 = torch.cat([d1, e0], dim=1) # [B, 128, H/2, W/2]
  75. print(f'd1 = torch.cat([d1, e0], dim=1):{d1.shape}')
  76. d1 = self.decoder1(d1) # [B, 128, H/2, W/2]
  77. print(f'd1 =self.decoder1(d1):{d1.shape}')
  78. # ------------------
  79. # Output Head
  80. # ------------------
  81. d0=self.final_up(d1)
  82. out = self.final_conv(d0) # [B, num_classes, H/2, W/2]
  83. print(f'out:{out.shape}')
  84. return out
  85. def _make_upsample_layer(self, in_channels: int, out_channels: int) -> nn.Module:
  86. """
  87. 使用转置卷积进行上采样
  88. """
  89. return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
  90. def _make_decoder_layer(self, block: Type[Union[Bottleneck]], in_channels: int,
  91. out_channels: int, blocks: int = 1) -> nn.Sequential:
  92. """
  93. """
  94. # assert in_channels == out_channels, "in_channels must equal out_channels"
  95. layers = []
  96. for _ in range(blocks):
  97. layers.append(
  98. block(
  99. in_channels,
  100. in_channels // block.expansion,
  101. groups=self.groups,
  102. base_width=self.base_width,
  103. dilation=self.dilation,
  104. norm_layer=self._norm_layer,
  105. )
  106. )
  107. return nn.Sequential(*layers)
  108. if __name__ == '__main__':
  109. model=TestModel(Bottleneck)
  110. x=torch.randn(3,3,512,512)
  111. out=model(x)