123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- __all__ = ["UNetWithMultipleStacks", "unet"]
- class DoubleConv(nn.Module):
- def __init__(self, in_channels, out_channels):
- super().__init__()
- self.double_conv = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True)
- )
- def forward(self, x):
- return self.double_conv(x)
- class UNetWithMultipleStacks(nn.Module):
- def __init__(self, num_classes, num_stacks=2, base_channels=64):
- super().__init__()
- self.num_stacks = num_stacks
- # 编码器
- self.enc1 = DoubleConv(3, base_channels)
- self.pool1 = nn.MaxPool2d(2)
- self.enc2 = DoubleConv(base_channels, base_channels * 2)
- self.pool2 = nn.MaxPool2d(2)
- self.enc3 = DoubleConv(base_channels * 2, base_channels * 4)
- self.pool3 = nn.MaxPool2d(2)
- self.enc4 = DoubleConv(base_channels * 4, base_channels * 8)
- self.pool4 = nn.MaxPool2d(2)
- self.enc5 = DoubleConv(base_channels * 8, base_channels * 16)
- self.pool5 = nn.MaxPool2d(2)
- # bottleneck
- self.bottleneck = DoubleConv(base_channels * 16, base_channels * 32)
- # 解码器
- self.upconv5 = nn.ConvTranspose2d(base_channels * 32, base_channels * 16, kernel_size=2, stride=2)
- self.dec5 = DoubleConv(base_channels * 32, base_channels * 16)
- self.upconv4 = nn.ConvTranspose2d(base_channels * 16, base_channels * 8, kernel_size=2, stride=2)
- self.dec4 = DoubleConv(base_channels * 16, base_channels * 8)
- self.upconv3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2)
- self.dec3 = DoubleConv(base_channels * 8, base_channels * 4)
- self.upconv2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2)
- self.dec2 = DoubleConv(base_channels * 4, base_channels * 2)
- self.upconv1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2)
- self.dec1 = DoubleConv(base_channels * 2, base_channels)
- # 额外的上采样层,从512降到128
- self.final_upsample = nn.Sequential(
- nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(base_channels),
- nn.ReLU(inplace=True),
- nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True)
- )
- # 修改score_layers以匹配256通道
- self.score_layers = nn.ModuleList([
- nn.Sequential(
- nn.Conv2d(256, 128, kernel_size=3, padding=1),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, num_classes, kernel_size=1)
- )
- for _ in range(num_stacks)
- ])
- self.channel_adjust = nn.Conv2d(base_channels, 256, kernel_size=1)
- def forward(self, x):
- # 编码过程
- enc1 = self.enc1(x)
- enc2 = self.enc2(self.pool1(enc1))
- enc3 = self.enc3(self.pool2(enc2))
- enc4 = self.enc4(self.pool3(enc3))
- enc5 = self.enc5(self.pool4(enc4))
- # 瓶颈层
- bottleneck = self.bottleneck(self.pool5(enc5))
- # 解码过程
- dec5 = self.upconv5(bottleneck)
- dec5 = torch.cat([dec5, enc5], dim=1)
- dec5 = self.dec5(dec5)
- dec4 = self.upconv4(dec5)
- dec4 = torch.cat([dec4, enc4], dim=1)
- dec4 = self.dec4(dec4)
- dec3 = self.upconv3(dec4)
- dec3 = torch.cat([dec3, enc3], dim=1)
- dec3 = self.dec3(dec3)
- dec2 = self.upconv2(dec3)
- dec2 = torch.cat([dec2, enc2], dim=1)
- dec2 = self.dec2(dec2)
- dec1 = self.upconv1(dec2)
- dec1 = torch.cat([dec1, enc1], dim=1)
- dec1 = self.dec1(dec1)
- # 额外的上采样,使输出大小为128
- dec1 = self.final_upsample(dec1)
- # 调整通道数
- dec1 = self.channel_adjust(dec1)
- # 多堆栈输出
- outputs = []
- for score_layer in self.score_layers:
- output = score_layer(dec1)
- outputs.append(output)
- return outputs[::-1], dec1
- def unet(**kwargs):
- model = UNetWithMultipleStacks(
- num_classes=kwargs["num_classes"],
- num_stacks=kwargs.get("num_stacks", 2),
- base_channels=kwargs.get("base_channels", 64)
- )
- return model
|