unet.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. __all__ = ["UNetWithMultipleStacks", "unet"]
  5. class DoubleConv(nn.Module):
  6. def __init__(self, in_channels, out_channels):
  7. super().__init__()
  8. self.double_conv = nn.Sequential(
  9. nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
  10. nn.BatchNorm2d(out_channels),
  11. nn.ReLU(inplace=True),
  12. nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
  13. nn.BatchNorm2d(out_channels),
  14. nn.ReLU(inplace=True)
  15. )
  16. def forward(self, x):
  17. return self.double_conv(x)
  18. class UNetWithMultipleStacks(nn.Module):
  19. def __init__(self, num_classes, num_stacks=2, base_channels=64):
  20. super().__init__()
  21. self.num_stacks = num_stacks
  22. # 编码器
  23. self.enc1 = DoubleConv(3, base_channels)
  24. self.pool1 = nn.MaxPool2d(2)
  25. self.enc2 = DoubleConv(base_channels, base_channels * 2)
  26. self.pool2 = nn.MaxPool2d(2)
  27. self.enc3 = DoubleConv(base_channels * 2, base_channels * 4)
  28. self.pool3 = nn.MaxPool2d(2)
  29. self.enc4 = DoubleConv(base_channels * 4, base_channels * 8)
  30. self.pool4 = nn.MaxPool2d(2)
  31. self.enc5 = DoubleConv(base_channels * 8, base_channels * 16)
  32. self.pool5 = nn.MaxPool2d(2)
  33. # bottleneck
  34. self.bottleneck = DoubleConv(base_channels * 16, base_channels * 32)
  35. # 解码器
  36. self.upconv5 = nn.ConvTranspose2d(base_channels * 32, base_channels * 16, kernel_size=2, stride=2)
  37. self.dec5 = DoubleConv(base_channels * 32, base_channels * 16)
  38. self.upconv4 = nn.ConvTranspose2d(base_channels * 16, base_channels * 8, kernel_size=2, stride=2)
  39. self.dec4 = DoubleConv(base_channels * 16, base_channels * 8)
  40. self.upconv3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2)
  41. self.dec3 = DoubleConv(base_channels * 8, base_channels * 4)
  42. self.upconv2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2)
  43. self.dec2 = DoubleConv(base_channels * 4, base_channels * 2)
  44. self.upconv1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2)
  45. self.dec1 = DoubleConv(base_channels * 2, base_channels)
  46. # 额外的上采样层,从512降到128
  47. self.final_upsample = nn.Sequential(
  48. nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
  49. nn.BatchNorm2d(base_channels),
  50. nn.ReLU(inplace=True),
  51. nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True)
  52. )
  53. # 修改score_layers以匹配256通道
  54. self.score_layers = nn.ModuleList([
  55. nn.Sequential(
  56. nn.Conv2d(256, 128, kernel_size=3, padding=1),
  57. nn.ReLU(inplace=True),
  58. nn.Conv2d(128, num_classes, kernel_size=1)
  59. )
  60. for _ in range(num_stacks)
  61. ])
  62. self.channel_adjust = nn.Conv2d(base_channels, 256, kernel_size=1)
  63. def forward(self, x):
  64. # 编码过程
  65. enc1 = self.enc1(x)
  66. enc2 = self.enc2(self.pool1(enc1))
  67. enc3 = self.enc3(self.pool2(enc2))
  68. enc4 = self.enc4(self.pool3(enc3))
  69. enc5 = self.enc5(self.pool4(enc4))
  70. # 瓶颈层
  71. bottleneck = self.bottleneck(self.pool5(enc5))
  72. # 解码过程
  73. dec5 = self.upconv5(bottleneck)
  74. dec5 = torch.cat([dec5, enc5], dim=1)
  75. dec5 = self.dec5(dec5)
  76. dec4 = self.upconv4(dec5)
  77. dec4 = torch.cat([dec4, enc4], dim=1)
  78. dec4 = self.dec4(dec4)
  79. dec3 = self.upconv3(dec4)
  80. dec3 = torch.cat([dec3, enc3], dim=1)
  81. dec3 = self.dec3(dec3)
  82. dec2 = self.upconv2(dec3)
  83. dec2 = torch.cat([dec2, enc2], dim=1)
  84. dec2 = self.dec2(dec2)
  85. dec1 = self.upconv1(dec2)
  86. dec1 = torch.cat([dec1, enc1], dim=1)
  87. dec1 = self.dec1(dec1)
  88. # 额外的上采样,使输出大小为128
  89. dec1 = self.final_upsample(dec1)
  90. # 调整通道数
  91. dec1 = self.channel_adjust(dec1)
  92. # 多堆栈输出
  93. outputs = []
  94. for score_layer in self.score_layers:
  95. output = score_layer(dec1)
  96. outputs.append(output)
  97. return outputs[::-1], dec1
  98. def unet(**kwargs):
  99. model = UNetWithMultipleStacks(
  100. num_classes=kwargs["num_classes"],
  101. num_stacks=kwargs.get("num_stacks", 2),
  102. base_channels=kwargs.get("base_channels", 64)
  103. )
  104. return model