resnet50_pose.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import torch
  2. import torch.nn as nn
  3. import torchvision.models as models
  4. class ResNet50Backbone(nn.Module):
  5. def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
  6. super(ResNet50Backbone, self).__init__()
  7. # 加载预训练的ResNet50
  8. if pretrained:
  9. resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
  10. else:
  11. resnet = models.resnet50(weights=None)
  12. # 移除最后的全连接层
  13. self.backbone = nn.Sequential(
  14. resnet.conv1,#特征图分辨率降低为1/2,通道数从3升为64
  15. resnet.bn1,
  16. resnet.relu,
  17. resnet.maxpool,#特征图分辨率降低为1/4,通道数仍然为64
  18. resnet.layer1,#stride为1,不改变分辨率,依然为1/4,通道数从64升为64*4=256
  19. # resnet.layer2,#stride为2,特征图分辨率降低为1/8,通道数从256升为128*4=512
  20. # resnet.layer3,#stride为2,特征图分辨率降低为1/16,通道数从512升为256*4=1024
  21. # resnet.layer4,#stride为2,特征图分辨率降低为1/32,通道数从512升为256*4=2048
  22. )
  23. # 多任务输出层
  24. self.score_layers = nn.ModuleList([
  25. nn.Sequential(
  26. nn.Conv2d(256, 128, kernel_size=3, padding=1),
  27. nn.BatchNorm2d(128),
  28. nn.ReLU(inplace=True),
  29. nn.Conv2d(128, num_classes, kernel_size=1)
  30. )
  31. for _ in range(num_stacks)
  32. ])
  33. # 上采样层,确保输出大小为128x128
  34. self.upsample = nn.Upsample(
  35. scale_factor=0.25,
  36. mode='bilinear',
  37. align_corners=True
  38. )
  39. def forward(self, x):
  40. # 主干网络特征提取
  41. x = self.backbone(x)
  42. # # 调整通道数
  43. # x = self.channel_adjust(x)
  44. #
  45. # # 上采样到128x128
  46. # x = self.upsample(x)
  47. # 多堆栈输出
  48. outputs = []
  49. for score_layer in self.score_layers:
  50. output = score_layer(x)
  51. outputs.append(output)
  52. # 返回第一个输出(如果有多个堆栈)
  53. return outputs, x
  54. def resnet50(**kwargs):
  55. model = ResNet50Backbone(
  56. num_classes=kwargs.get("num_classes", 5),
  57. num_stacks=kwargs.get("num_stacks", 1),
  58. pretrained=kwargs.get("pretrained", True)
  59. )
  60. return model
  61. __all__ = ["ResNet50Backbone", "resnet50"]
  62. # 测试网络输出
  63. model = resnet50(num_classes=5, num_stacks=1)
  64. # 方法1:直接传入图像张量
  65. x = torch.randn(2, 3, 512, 512)
  66. outputs, feature = model(x)
  67. print("Outputs length:", len(outputs))
  68. print("Output[0] shape:", outputs[0].shape)
  69. print("Feature shape:", feature.shape)