import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import Any, Callable, List, Optional, Type, Union from torchvision.models.detection.backbone_utils import BackboneWithFPN # ---------------------------- # 工具函数 # ---------------------------- def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation, ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) # ---------------------------- # Bottleneck Block(你提供的) # ---------------------------- class Bottleneck(nn.Module): expansion: int = 4 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super().__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out # ---------------------------- # ResNet 主类 # ---------------------------- def resnet50fpn(out_channels=256): backbone = ResNet(Bottleneck) return_layers = { 'encoder0': '0', 'encoder1': '1', 'encoder2': '2', 'encoder3': '3', # 'encoder4': '5' } # in_channels_list = [self.inplanes, 64, 128, 256, 512] # in_channels_list = [64, 256, 512, 1024, 2048] in_channels_list = [64, 256, 512, 1024] return BackboneWithFPN( backbone, return_layers=return_layers, in_channels_list=in_channels_list, out_channels=out_channels, ) class ResNet(nn.Module): def __init__(self, block: Type[Union[Bottleneck]],): super(ResNet, self).__init__() self._norm_layer = nn.BatchNorm2d self.inplanes = 64 self.dilation = 1 self.groups = 1 self.base_width = 64 self.encoder0 = nn.Sequential( nn.Conv2d(3, self.inplanes, kernel_size=3,padding=1, bias=False), self._norm_layer(self.inplanes), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=1, padding=1) ) self.encoder1 = self._make_layer(block, 64, 3,stride=2) self.encoder2 = self._make_layer(block, 128, 4, stride=2) self.encoder3 = self._make_layer(block, 256, 6, stride=2) # self.encoder4 = self._make_layer(block, 512, 3, stride=2) # self.encoder5 = self._make_layer(block, 512, 3, stride=2) # self.body = nn.ModuleDict({ # 'encoder0': self.encoder0, # 'encoder1': self.encoder1, # 'encoder2': self.encoder2, # 'encoder3': self.encoder3, # 'encoder4': self.encoder4 # }) # self.fpn = self.get_convnext_fpn( # backbone=self.body, # trainable_layers=5, # returned_layers=[0, 1, 2, 3, 4], # extra_blocks=None, # norm_layer=None # ) def _make_layer(self, block: Type[Union[Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append( block( self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer ) ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block( self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, ) ) return nn.Sequential(*layers) def _make_decoder_layer(self, block: Type[Union[Bottleneck]], in_channels: int, out_channels: int, blocks: int = 1) -> nn.Sequential: """ 构建解码器部分的残差块 """ assert in_channels == out_channels, "in_channels must equal out_channels" layers = [] for _ in range(blocks): layers.append( block( in_channels, in_channels // block.expansion, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=self._norm_layer, ) ) return nn.Sequential(*layers) def _make_upsample_layer(self, in_channels: int, out_channels: int) -> nn.Module: """ 使用转置卷积进行上采样 """ return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: # out = self.fpn(x) # print("ssssssss") x0=self.encoder0(x) print(f'x0:{x0.shape}') x1=self.encoder1(x0) print(f'x1:{x1.shape}') x2= self.encoder2(x1) print(f'x2:{x2.shape}') x3= self.encoder3(x2) print(f'x3:{x3.shape}') # x4= self.encoder4(x3) # print(f'x4:{x4.shape}') out={ 'encoder0':x0, 'encoder1': x1, 'encoder2': x2, 'encoder3': x3, # 'encoder4': x4, } return out def forward(self, x: torch.Tensor) -> torch.Tensor: return self._forward_impl(x) # ---------------------------- # 测试代码 # ---------------------------- if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model = ResNet(Bottleneck, n_classes=5).to(device) # print(model) model=resnet50fpn().to(device) input_tensor = torch.randn(1, 3, 512, 512).to(device) output_tensor = model(input_tensor) backbone = ResNet(Bottleneck).to(device) features = backbone(input_tensor) print("Raw backbone output:", list(features.keys())) print(f"Input shape: {input_tensor.shape}") print(f'feat_names:{list(output_tensor.keys())}') print(f"Output shape0: {output_tensor['0'].shape}") print(f"Output shape1: {output_tensor['1'].shape}") print(f"Output shape2: {output_tensor['2'].shape}") print(f"Output shape3: {output_tensor['3'].shape}") # print(f"Output shape4: {output_tensor['5'].shape}") print(f"Output shape5: {output_tensor['pool'].shape}")