high_reso_resnet.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch import Tensor
  5. from typing import Any, Callable, List, Optional, Type, Union
  6. from torchvision.models.detection.backbone_utils import BackboneWithFPN
  7. # ----------------------------
  8. # 工具函数
  9. # ----------------------------
  10. def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
  11. """3x3 convolution with padding"""
  12. return nn.Conv2d(
  13. in_planes,
  14. out_planes,
  15. kernel_size=3,
  16. stride=stride,
  17. padding=dilation,
  18. groups=groups,
  19. bias=False,
  20. dilation=dilation,
  21. )
  22. def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  23. """1x1 convolution"""
  24. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
  25. # ----------------------------
  26. # Bottleneck Block(你提供的)
  27. # ----------------------------
  28. class Bottleneck(nn.Module):
  29. expansion: int = 4
  30. def __init__(
  31. self,
  32. inplanes: int,
  33. planes: int,
  34. stride: int = 1,
  35. downsample: Optional[nn.Module] = None,
  36. groups: int = 1,
  37. base_width: int = 64,
  38. dilation: int = 1,
  39. norm_layer: Optional[Callable[..., nn.Module]] = None,
  40. ) -> None:
  41. super().__init__()
  42. if norm_layer is None:
  43. norm_layer = nn.BatchNorm2d
  44. width = int(planes * (base_width / 64.0)) * groups
  45. # Both self.conv2 and self.downsample layers downsample the input when stride != 1
  46. self.conv1 = conv1x1(inplanes, width)
  47. self.bn1 = norm_layer(width)
  48. self.conv2 = conv3x3(width, width, stride, groups, dilation)
  49. self.bn2 = norm_layer(width)
  50. self.conv3 = conv1x1(width, planes * self.expansion)
  51. self.bn3 = norm_layer(planes * self.expansion)
  52. self.relu = nn.ReLU(inplace=True)
  53. self.downsample = downsample
  54. self.stride = stride
  55. def forward(self, x: Tensor) -> Tensor:
  56. identity = x
  57. out = self.conv1(x)
  58. out = self.bn1(out)
  59. out = self.relu(out)
  60. out = self.conv2(out)
  61. out = self.bn2(out)
  62. out = self.relu(out)
  63. out = self.conv3(out)
  64. out = self.bn3(out)
  65. if self.downsample is not None:
  66. identity = self.downsample(x)
  67. out += identity
  68. out = self.relu(out)
  69. return out
  70. # ----------------------------
  71. # ResNet 主类
  72. # ----------------------------
  73. def resnet18fpn(out_channels=256):
  74. backbone = ResNet(Bottleneck,[2,2,2])
  75. return_layers = {
  76. 'encoder0': '0',
  77. 'encoder1': '1',
  78. 'encoder2': '2',
  79. 'encoder3': '3',
  80. # 'encoder4': '5'
  81. }
  82. # in_channels_list = [self.inplanes, 64, 128, 256, 512]
  83. # in_channels_list = [64, 256, 512, 1024, 2048]
  84. in_channels_list = [64, 256, 512, 1024]
  85. return BackboneWithFPN(
  86. backbone,
  87. return_layers=return_layers,
  88. in_channels_list=in_channels_list,
  89. out_channels=out_channels,
  90. )
  91. def resnet50fpn(out_channels=256):
  92. backbone = ResNet(Bottleneck,[3,4,6])
  93. return_layers = {
  94. 'encoder0': '0',
  95. 'encoder1': '1',
  96. 'encoder2': '2',
  97. 'encoder3': '3',
  98. # 'encoder4': '5'
  99. }
  100. # in_channels_list = [self.inplanes, 64, 128, 256, 512]
  101. # in_channels_list = [64, 256, 512, 1024, 2048]
  102. in_channels_list = [64, 256, 512, 1024]
  103. return BackboneWithFPN(
  104. backbone,
  105. return_layers=return_layers,
  106. in_channels_list=in_channels_list,
  107. out_channels=out_channels,
  108. )
  109. class ResNet(nn.Module):
  110. def __init__(self, block: Type[Union[Bottleneck]], layers: List[int],):
  111. super(ResNet, self).__init__()
  112. self._norm_layer = nn.BatchNorm2d
  113. self.inplanes = 64
  114. self.dilation = 1
  115. self.groups = 1
  116. self.base_width = 64
  117. self.encoder0 = nn.Sequential(
  118. nn.Conv2d(3, self.inplanes, kernel_size=3,padding=1, bias=False),
  119. self._norm_layer(self.inplanes),
  120. nn.ReLU(inplace=True),
  121. nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
  122. )
  123. self.encoder1 = self._make_layer(block, 64, layers[0],stride=2)
  124. self.encoder2 = self._make_layer(block, 128, layers[1], stride=2)
  125. self.encoder3 = self._make_layer(block, 256, layers[2], stride=2)
  126. def _make_layer(self, block: Type[Union[Bottleneck]], planes: int, blocks: int,
  127. stride: int = 1, dilate: bool = False) -> nn.Sequential:
  128. norm_layer = self._norm_layer
  129. downsample = None
  130. previous_dilation = self.dilation
  131. if dilate:
  132. self.dilation *= stride
  133. stride = 1
  134. if stride != 1 or self.inplanes != planes * block.expansion:
  135. downsample = nn.Sequential(
  136. conv1x1(self.inplanes, planes * block.expansion, stride),
  137. norm_layer(planes * block.expansion),
  138. )
  139. layers = []
  140. layers.append(
  141. block(
  142. self.inplanes, planes, stride, downsample, self.groups, self.base_width,
  143. previous_dilation, norm_layer
  144. )
  145. )
  146. self.inplanes = planes * block.expansion
  147. for _ in range(1, blocks):
  148. layers.append(
  149. block(
  150. self.inplanes,
  151. planes,
  152. groups=self.groups,
  153. base_width=self.base_width,
  154. dilation=self.dilation,
  155. norm_layer=norm_layer,
  156. )
  157. )
  158. return nn.Sequential(*layers)
  159. def _make_decoder_layer(self, block: Type[Union[Bottleneck]], in_channels: int,
  160. out_channels: int, blocks: int = 1) -> nn.Sequential:
  161. """
  162. 构建解码器部分的残差块
  163. """
  164. assert in_channels == out_channels, "in_channels must equal out_channels"
  165. layers = []
  166. for _ in range(blocks):
  167. layers.append(
  168. block(
  169. in_channels,
  170. in_channels // block.expansion,
  171. groups=self.groups,
  172. base_width=self.base_width,
  173. dilation=self.dilation,
  174. norm_layer=self._norm_layer,
  175. )
  176. )
  177. return nn.Sequential(*layers)
  178. def _make_upsample_layer(self, in_channels: int, out_channels: int) -> nn.Module:
  179. """
  180. 使用转置卷积进行上采样
  181. """
  182. return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
  183. def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
  184. # out = self.fpn(x)
  185. # print("ssssssss")
  186. x0=self.encoder0(x)
  187. print(f'x0:{x0.shape}')
  188. x1=self.encoder1(x0)
  189. print(f'x1:{x1.shape}')
  190. x2= self.encoder2(x1)
  191. print(f'x2:{x2.shape}')
  192. x3= self.encoder3(x2)
  193. print(f'x3:{x3.shape}')
  194. # x4= self.encoder4(x3)
  195. # print(f'x4:{x4.shape}')
  196. out={
  197. 'encoder0':x0,
  198. 'encoder1': x1,
  199. 'encoder2': x2,
  200. 'encoder3': x3,
  201. # 'encoder4': x4,
  202. }
  203. return out
  204. def forward(self, x: torch.Tensor) -> torch.Tensor:
  205. return self._forward_impl(x)
  206. # ----------------------------
  207. # 测试代码
  208. # ----------------------------
  209. if __name__ == "__main__":
  210. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  211. # model = ResNet(Bottleneck, n_classes=5).to(device)
  212. # print(model)
  213. model=resnet50fpn().to(device)
  214. input_tensor = torch.randn(1, 3, 512, 512).to(device)
  215. output_tensor = model(input_tensor)
  216. backbone = ResNet(Bottleneck,[3,4,6]).to(device)
  217. features = backbone(input_tensor)
  218. print("Raw backbone output:", list(features.keys()))
  219. print(f"Input shape: {input_tensor.shape}")
  220. print(f'feat_names:{list(output_tensor.keys())}')
  221. print(f"Output shape0: {output_tensor['0'].shape}")
  222. print(f"Output shape1: {output_tensor['1'].shape}")
  223. print(f"Output shape2: {output_tensor['2'].shape}")
  224. print(f"Output shape3: {output_tensor['3'].shape}")
  225. # print(f"Output shape4: {output_tensor['5'].shape}")
  226. print(f"Output shape5: {output_tensor['pool'].shape}")