123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import torch
- import torch.nn as nn
- import torchvision.models as models
- class ResNet50Backbone(nn.Module):
- def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
- super(ResNet50Backbone, self).__init__()
- # 加载预训练的ResNet50
- if pretrained:
- resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
- else:
- resnet = models.resnet50(weights=None)
- # 移除最后的全连接层
- self.backbone = nn.Sequential(
- resnet.conv1,#特征图分辨率降低为1/2,通道数从3升为64
- resnet.bn1,
- resnet.relu,
- resnet.maxpool,#特征图分辨率降低为1/4,通道数仍然为64
- resnet.layer1,#stride为1,不改变分辨率,依然为1/4,通道数从64升为64*4=256
- # resnet.layer2,#stride为2,特征图分辨率降低为1/8,通道数从256升为128*4=512
- # resnet.layer3,#stride为2,特征图分辨率降低为1/16,通道数从512升为256*4=1024
- # resnet.layer4,#stride为2,特征图分辨率降低为1/32,通道数从512升为256*4=2048
- )
- # 多任务输出层
- self.score_layers = nn.ModuleList([
- nn.Sequential(
- nn.Conv2d(256, 128, kernel_size=3, padding=1),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, num_classes, kernel_size=1)
- )
- for _ in range(num_stacks)
- ])
- # 上采样层,确保输出大小为128x128
- self.upsample = nn.Upsample(
- scale_factor=0.25,
- mode='bilinear',
- align_corners=True
- )
- def forward(self, x):
- # 主干网络特征提取
- x = self.backbone(x)
- # # 调整通道数
- # x = self.channel_adjust(x)
- #
- # # 上采样到128x128
- # x = self.upsample(x)
- # 多堆栈输出
- outputs = []
- for score_layer in self.score_layers:
- output = score_layer(x)
- outputs.append(output)
- # 返回第一个输出(如果有多个堆栈)
- return outputs, x
- def resnet50(**kwargs):
- model = ResNet50Backbone(
- num_classes=kwargs.get("num_classes", 5),
- num_stacks=kwargs.get("num_stacks", 1),
- pretrained=kwargs.get("pretrained", True)
- )
- return model
- __all__ = ["ResNet50Backbone", "resnet50"]
- # 测试网络输出
- model = resnet50(num_classes=5, num_stacks=1)
- # 方法1:直接传入图像张量
- x = torch.randn(2, 3, 512, 512)
- outputs, feature = model(x)
- print("Outputs length:", len(outputs))
- print("Output[0] shape:", outputs[0].shape)
- print("Feature shape:", feature.shape)
|