high_reso_resnet.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch import Tensor
  5. from typing import Any, Callable, List, Optional, Type, Union
  6. from torchvision.models.detection.backbone_utils import BackboneWithFPN
  7. # ----------------------------
  8. # 工具函数
  9. # ----------------------------
  10. def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
  11. """3x3 convolution with padding"""
  12. return nn.Conv2d(
  13. in_planes,
  14. out_planes,
  15. kernel_size=3,
  16. stride=stride,
  17. padding=dilation,
  18. groups=groups,
  19. bias=False,
  20. dilation=dilation,
  21. )
  22. def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  23. """1x1 convolution"""
  24. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  25. # ----------------------------
  26. # Bottleneck Block(你提供的)
  27. # ----------------------------
  28. class Bottleneck(nn.Module):
  29. expansion: int = 4
  30. def __init__(
  31. self,
  32. inplanes: int,
  33. planes: int,
  34. stride: int = 1,
  35. downsample: Optional[nn.Module] = None,
  36. groups: int = 1,
  37. base_width: int = 64,
  38. dilation: int = 1,
  39. norm_layer: Optional[Callable[..., nn.Module]] = None,
  40. ) -> None:
  41. super().__init__()
  42. if norm_layer is None:
  43. norm_layer = nn.BatchNorm2d
  44. width = int(planes * (base_width / 64.0)) * groups
  45. # Both self.conv2 and self.downsample layers downsample the input when stride != 1
  46. self.conv1 = conv1x1(inplanes, width)
  47. self.bn1 = norm_layer(width)
  48. self.conv2 = conv3x3(width, width, stride, groups, dilation)
  49. self.bn2 = norm_layer(width)
  50. self.conv3 = conv1x1(width, planes * self.expansion)
  51. self.bn3 = norm_layer(planes * self.expansion)
  52. self.relu = nn.ReLU(inplace=True)
  53. self.downsample = downsample
  54. self.stride = stride
  55. def forward(self, x: Tensor) -> Tensor:
  56. identity = x
  57. out = self.conv1(x)
  58. out = self.bn1(out)
  59. out = self.relu(out)
  60. out = self.conv2(out)
  61. out = self.bn2(out)
  62. out = self.relu(out)
  63. out = self.conv3(out)
  64. out = self.bn3(out)
  65. if self.downsample is not None:
  66. identity = self.downsample(x)
  67. out += identity
  68. out = self.relu(out)
  69. return out
  70. # ----------------------------
  71. # ResNet 主类
  72. # ----------------------------
  73. def resnet18fpn(out_channels=256):
  74. backbone = ResNet(Bottleneck,[2,2,2,2])
  75. return_layers = {
  76. 'encoder0': '0',
  77. 'encoder1': '1',
  78. 'encoder2': '2',
  79. 'encoder3': '3',
  80. 'encoder4': '4'
  81. }
  82. # in_channels_list = [self.inplanes, 64, 128, 256, 512]
  83. in_channels_list = [64, 256, 512, 1024, 2048]
  84. # in_channels_list = [64, 256, 512, 1024]
  85. return BackboneWithFPN(
  86. backbone,
  87. return_layers=return_layers,
  88. in_channels_list=in_channels_list,
  89. out_channels=out_channels,
  90. )
  91. def resnet50fpn(out_channels=256):
  92. backbone = ResNet(Bottleneck,[3,4,6,3])
  93. return_layers = {
  94. 'encoder0': '0',
  95. 'encoder1': '1',
  96. 'encoder2': '2',
  97. 'encoder3': '3',
  98. 'encoder4': '4'
  99. }
  100. # in_channels_list = [self.inplanes, 64, 128, 256, 512]
  101. in_channels_list = [64, 256, 512, 1024, 2048]
  102. # in_channels_list = [64, 256, 512, 1024]
  103. return BackboneWithFPN(
  104. backbone,
  105. return_layers=return_layers,
  106. in_channels_list=in_channels_list,
  107. out_channels=out_channels,
  108. )
  109. def resnet101fpn(out_channels=256):
  110. backbone = ResNet(Bottleneck,[3, 4, 23, 3])
  111. return_layers = {
  112. 'encoder0': '0',
  113. 'encoder1': '1',
  114. 'encoder2': '2',
  115. 'encoder3': '3',
  116. 'encoder4': '4'
  117. }
  118. # in_channels_list = [self.inplanes, 64, 128, 256, 512]
  119. in_channels_list = [64, 256, 512, 1024, 2048]
  120. # in_channels_list = [64, 256, 512, 1024]
  121. return BackboneWithFPN(
  122. backbone,
  123. return_layers=return_layers,
  124. in_channels_list=in_channels_list,
  125. out_channels=out_channels,
  126. )
  127. def resnet152fpn(out_channels=256):
  128. backbone = ResNet(Bottleneck,[3, 8, 36, 3])
  129. return_layers = {
  130. 'encoder0': '0',
  131. 'encoder1': '1',
  132. 'encoder2': '2',
  133. 'encoder3': '3',
  134. 'encoder4': '4'
  135. }
  136. # in_channels_list = [self.inplanes, 64, 128, 256, 512]
  137. in_channels_list = [64, 256, 512, 1024, 2048]
  138. # in_channels_list = [64, 256, 512, 1024]
  139. return BackboneWithFPN(
  140. backbone,
  141. return_layers=return_layers,
  142. in_channels_list=in_channels_list,
  143. out_channels=out_channels,
  144. )
  145. class ResNet(nn.Module):
  146. def __init__(self, block: Type[Union[Bottleneck]], layers: List[int],):
  147. super(ResNet, self).__init__()
  148. self._norm_layer = nn.BatchNorm2d
  149. self.inplanes = 64
  150. self.dilation = 1
  151. self.groups = 1
  152. self.base_width = 64
  153. self.encoder0 = nn.Sequential(
  154. nn.Conv2d(3, self.inplanes, kernel_size=3,padding=1,stride=1, bias=False),
  155. self._norm_layer(self.inplanes),
  156. nn.ReLU(inplace=True),
  157. nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  158. )
  159. # self.encoder0 = nn.Sequential(
  160. # nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, bias=False),
  161. # self._norm_layer(self.inplanes),
  162. # nn.ReLU(inplace=True),
  163. # nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, padding=1, bias=False),
  164. # self._norm_layer(self.inplanes),
  165. # nn.ReLU(inplace=True),
  166. # nn.Conv2d(self.inplanes, self.inplanes, kernel_size=3, padding=1, bias=False),
  167. # self._norm_layer(self.inplanes),
  168. # nn.ReLU(inplace=True),
  169. # nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
  170. # )
  171. self.encoder1 = self._make_layer(block, 64, layers[0],stride=2)
  172. self.encoder2 = self._make_layer(block, 128, layers[1], stride=2)
  173. self.encoder3 = self._make_layer(block, 256, layers[2], stride=2)
  174. self.encoder4 = self._make_layer(block, 512, layers[3], stride=2)
  175. def _make_layer(self, block: Type[Union[Bottleneck]], planes: int, blocks: int,
  176. stride: int = 1, dilate: bool = False) -> nn.Sequential:
  177. norm_layer = self._norm_layer
  178. downsample = None
  179. previous_dilation = self.dilation
  180. if dilate:
  181. self.dilation *= stride
  182. stride = 1
  183. if stride != 1 or self.inplanes != planes * block.expansion:
  184. downsample = nn.Sequential(
  185. conv1x1(self.inplanes, planes * block.expansion, stride),
  186. norm_layer(planes * block.expansion),
  187. )
  188. layers = []
  189. layers.append(
  190. block(
  191. self.inplanes, planes, stride, downsample, self.groups, self.base_width,
  192. previous_dilation, norm_layer
  193. )
  194. )
  195. self.inplanes = planes * block.expansion
  196. for _ in range(1, blocks):
  197. layers.append(
  198. block(
  199. self.inplanes,
  200. planes,
  201. groups=self.groups,
  202. base_width=self.base_width,
  203. dilation=self.dilation,
  204. norm_layer=norm_layer,
  205. )
  206. )
  207. return nn.Sequential(*layers)
  208. def _make_decoder_layer(self, block: Type[Union[Bottleneck]], in_channels: int,
  209. out_channels: int, blocks: int = 1) -> nn.Sequential:
  210. """
  211. 构建解码器部分的残差块
  212. """
  213. assert in_channels == out_channels, "in_channels must equal out_channels"
  214. layers = []
  215. for _ in range(blocks):
  216. layers.append(
  217. block(
  218. in_channels,
  219. in_channels // block.expansion,
  220. groups=self.groups,
  221. base_width=self.base_width,
  222. dilation=self.dilation,
  223. norm_layer=self._norm_layer,
  224. )
  225. )
  226. return nn.Sequential(*layers)
  227. def _make_upsample_layer(self, in_channels: int, out_channels: int) -> nn.Module:
  228. """
  229. 使用转置卷积进行上采样
  230. """
  231. return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
  232. def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
  233. # out = self.fpn(x)
  234. # print("ssssssss")
  235. x0=self.encoder0(x)
  236. print(f'x0:{x0.shape}')
  237. x1=self.encoder1(x0)
  238. print(f'x1:{x1.shape}')
  239. x2= self.encoder2(x1)
  240. print(f'x2:{x2.shape}')
  241. x3= self.encoder3(x2)
  242. print(f'x3:{x3.shape}')
  243. x4= self.encoder4(x3)
  244. print(f'x4:{x4.shape}')
  245. out={
  246. 'encoder0':x0,
  247. 'encoder1': x1,
  248. 'encoder2': x2,
  249. 'encoder3': x3,
  250. 'encoder4': x4,
  251. }
  252. return out
  253. def forward(self, x: torch.Tensor) -> torch.Tensor:
  254. return self._forward_impl(x)
  255. # ----------------------------
  256. # 测试代码
  257. # ----------------------------
  258. if __name__ == "__main__":
  259. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  260. # model = ResNet(Bottleneck, n_classes=5).to(device)
  261. # print(model)
  262. # model=resnet50fpn().to(device)
  263. model = resnet18fpn().to(device)
  264. input_tensor = torch.randn(1, 3, 512, 512).to(device)
  265. output_tensor = model(input_tensor)
  266. # backbone = ResNet(Bottleneck,[3,4,6,3]).to(device)
  267. # features = backbone(input_tensor)
  268. # print("Raw backbone output:", list(features.keys()))
  269. print(f"Input shape: {input_tensor.shape}")
  270. print(f'feat_names:{list(output_tensor.keys())}')
  271. print(f"Output shape0: {output_tensor['0'].shape}")
  272. print(f"Output shape1: {output_tensor['1'].shape}")
  273. print(f"Output shape2: {output_tensor['2'].shape}")
  274. print(f"Output shape3: {output_tensor['3'].shape}")
  275. print(f"Output shape4: {output_tensor['4'].shape}")
  276. print(f"Output shape5: {output_tensor['pool'].shape}")