high_reso_resnet.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch import Tensor
  5. from typing import Any, Callable, List, Optional, Type, Union
  6. from torchvision.models.detection.backbone_utils import BackboneWithFPN
  7. # ----------------------------
  8. # 工具函数
  9. # ----------------------------
  10. def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
  11. """3x3 convolution with padding"""
  12. return nn.Conv2d(
  13. in_planes,
  14. out_planes,
  15. kernel_size=3,
  16. stride=stride,
  17. padding=dilation,
  18. groups=groups,
  19. bias=False,
  20. dilation=dilation,
  21. )
  22. def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  23. """1x1 convolution"""
  24. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  25. # ----------------------------
  26. # Bottleneck Block(你提供的)
  27. # ----------------------------
  28. class Bottleneck(nn.Module):
  29. expansion: int = 4
  30. def __init__(
  31. self,
  32. inplanes: int,
  33. planes: int,
  34. stride: int = 1,
  35. downsample: Optional[nn.Module] = None,
  36. groups: int = 1,
  37. base_width: int = 64,
  38. dilation: int = 1,
  39. norm_layer: Optional[Callable[..., nn.Module]] = None,
  40. ) -> None:
  41. super().__init__()
  42. if norm_layer is None:
  43. norm_layer = nn.BatchNorm2d
  44. width = int(planes * (base_width / 64.0)) * groups
  45. # Both self.conv2 and self.downsample layers downsample the input when stride != 1
  46. self.conv1 = conv1x1(inplanes, width)
  47. self.bn1 = norm_layer(width)
  48. self.conv2 = conv3x3(width, width, stride, groups, dilation)
  49. self.bn2 = norm_layer(width)
  50. self.conv3 = conv1x1(width, planes * self.expansion)
  51. self.bn3 = norm_layer(planes * self.expansion)
  52. self.relu = nn.ReLU(inplace=True)
  53. self.downsample = downsample
  54. self.stride = stride
  55. def forward(self, x: Tensor) -> Tensor:
  56. identity = x
  57. out = self.conv1(x)
  58. out = self.bn1(out)
  59. out = self.relu(out)
  60. out = self.conv2(out)
  61. out = self.bn2(out)
  62. out = self.relu(out)
  63. out = self.conv3(out)
  64. out = self.bn3(out)
  65. if self.downsample is not None:
  66. identity = self.downsample(x)
  67. out += identity
  68. out = self.relu(out)
  69. return out
  70. # ----------------------------
  71. # ResNet 主类
  72. # ----------------------------
  73. def resnet18fpn(out_channels=256):
  74. backbone = ResNet(Bottleneck,[2,2,2])
  75. return_layers = {
  76. 'encoder0': '0',
  77. 'encoder1': '1',
  78. 'encoder2': '2',
  79. 'encoder3': '3',
  80. # 'encoder4': '5'
  81. }
  82. # in_channels_list = [self.inplanes, 64, 128, 256, 512]
  83. # in_channels_list = [64, 256, 512, 1024, 2048]
  84. in_channels_list = [64, 256, 512, 1024]
  85. return BackboneWithFPN(
  86. backbone,
  87. return_layers=return_layers,
  88. in_channels_list=in_channels_list,
  89. out_channels=out_channels,
  90. )
  91. def resnet50fpn(out_channels=256):
  92. backbone = ResNet(Bottleneck,[3,4,6])
  93. return_layers = {
  94. 'encoder0': '0',
  95. 'encoder1': '1',
  96. 'encoder2': '2',
  97. 'encoder3': '3',
  98. # 'encoder4': '5'
  99. }
  100. # in_channels_list = [self.inplanes, 64, 128, 256, 512]
  101. # in_channels_list = [64, 256, 512, 1024, 2048]
  102. in_channels_list = [64, 256, 512, 1024]
  103. return BackboneWithFPN(
  104. backbone,
  105. return_layers=return_layers,
  106. in_channels_list=in_channels_list,
  107. out_channels=out_channels,
  108. )
  109. class ResNet(nn.Module):
  110. def __init__(self, block: Type[Union[Bottleneck]], layers: List[int],):
  111. super(ResNet, self).__init__()
  112. self._norm_layer = nn.BatchNorm2d
  113. self.inplanes = 64
  114. self.dilation = 1
  115. self.groups = 1
  116. self.base_width = 64
  117. self.encoder0 = nn.Sequential(
  118. nn.Conv2d(3, self.inplanes, kernel_size=3,padding=1, bias=False),
  119. self._norm_layer(self.inplanes),
  120. nn.ReLU(inplace=True),
  121. nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
  122. )
  123. self.encoder1 = self._make_layer(block, 64, layers[0],stride=2)
  124. self.encoder2 = self._make_layer(block, 128, layers[1], stride=2)
  125. self.encoder3 = self._make_layer(block, 256, layers[2], stride=2)
  126. # self.encoder4 = self._make_layer(block, 512, 3, stride=2)
  127. # self.encoder5 = self._make_layer(block, 512, 3, stride=2)
  128. # self.body = nn.ModuleDict({
  129. # 'encoder0': self.encoder0,
  130. # 'encoder1': self.encoder1,
  131. # 'encoder2': self.encoder2,
  132. # 'encoder3': self.encoder3,
  133. # 'encoder4': self.encoder4
  134. # })
  135. # self.fpn = self.get_convnext_fpn(
  136. # backbone=self.body,
  137. # trainable_layers=5,
  138. # returned_layers=[0, 1, 2, 3, 4],
  139. # extra_blocks=None,
  140. # norm_layer=None
  141. # )
  142. def _make_layer(self, block: Type[Union[Bottleneck]], planes: int, blocks: int,
  143. stride: int = 1, dilate: bool = False) -> nn.Sequential:
  144. norm_layer = self._norm_layer
  145. downsample = None
  146. previous_dilation = self.dilation
  147. if dilate:
  148. self.dilation *= stride
  149. stride = 1
  150. if stride != 1 or self.inplanes != planes * block.expansion:
  151. downsample = nn.Sequential(
  152. conv1x1(self.inplanes, planes * block.expansion, stride),
  153. norm_layer(planes * block.expansion),
  154. )
  155. layers = []
  156. layers.append(
  157. block(
  158. self.inplanes, planes, stride, downsample, self.groups, self.base_width,
  159. previous_dilation, norm_layer
  160. )
  161. )
  162. self.inplanes = planes * block.expansion
  163. for _ in range(1, blocks):
  164. layers.append(
  165. block(
  166. self.inplanes,
  167. planes,
  168. groups=self.groups,
  169. base_width=self.base_width,
  170. dilation=self.dilation,
  171. norm_layer=norm_layer,
  172. )
  173. )
  174. return nn.Sequential(*layers)
  175. def _make_decoder_layer(self, block: Type[Union[Bottleneck]], in_channels: int,
  176. out_channels: int, blocks: int = 1) -> nn.Sequential:
  177. """
  178. 构建解码器部分的残差块
  179. """
  180. assert in_channels == out_channels, "in_channels must equal out_channels"
  181. layers = []
  182. for _ in range(blocks):
  183. layers.append(
  184. block(
  185. in_channels,
  186. in_channels // block.expansion,
  187. groups=self.groups,
  188. base_width=self.base_width,
  189. dilation=self.dilation,
  190. norm_layer=self._norm_layer,
  191. )
  192. )
  193. return nn.Sequential(*layers)
  194. def _make_upsample_layer(self, in_channels: int, out_channels: int) -> nn.Module:
  195. """
  196. 使用转置卷积进行上采样
  197. """
  198. return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
  199. def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
  200. # out = self.fpn(x)
  201. # print("ssssssss")
  202. x0=self.encoder0(x)
  203. print(f'x0:{x0.shape}')
  204. x1=self.encoder1(x0)
  205. print(f'x1:{x1.shape}')
  206. x2= self.encoder2(x1)
  207. print(f'x2:{x2.shape}')
  208. x3= self.encoder3(x2)
  209. print(f'x3:{x3.shape}')
  210. # x4= self.encoder4(x3)
  211. # print(f'x4:{x4.shape}')
  212. out={
  213. 'encoder0':x0,
  214. 'encoder1': x1,
  215. 'encoder2': x2,
  216. 'encoder3': x3,
  217. # 'encoder4': x4,
  218. }
  219. return out
  220. def forward(self, x: torch.Tensor) -> torch.Tensor:
  221. return self._forward_impl(x)
  222. # ----------------------------
  223. # 测试代码
  224. # ----------------------------
  225. if __name__ == "__main__":
  226. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  227. # model = ResNet(Bottleneck, n_classes=5).to(device)
  228. # print(model)
  229. model=resnet50fpn().to(device)
  230. input_tensor = torch.randn(1, 3, 512, 512).to(device)
  231. output_tensor = model(input_tensor)
  232. backbone = ResNet(Bottleneck,[3,4,6]).to(device)
  233. features = backbone(input_tensor)
  234. print("Raw backbone output:", list(features.keys()))
  235. print(f"Input shape: {input_tensor.shape}")
  236. print(f'feat_names:{list(output_tensor.keys())}')
  237. print(f"Output shape0: {output_tensor['0'].shape}")
  238. print(f"Output shape1: {output_tensor['1'].shape}")
  239. print(f"Output shape2: {output_tensor['2'].shape}")
  240. print(f"Output shape3: {output_tensor['3'].shape}")
  241. # print(f"Output shape4: {output_tensor['5'].shape}")
  242. print(f"Output shape5: {output_tensor['pool'].shape}")