import torch import torch.nn as nn import torch.nn.functional as F __all__ = ["UNetWithMultipleStacks", "unet"] class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.double_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x) class UNetWithMultipleStacks(nn.Module): def __init__(self, num_classes, num_stacks=2, base_channels=64): super().__init__() self.num_stacks = num_stacks # 编码器 self.enc1 = DoubleConv(3, base_channels) self.pool1 = nn.MaxPool2d(2) self.enc2 = DoubleConv(base_channels, base_channels * 2) self.pool2 = nn.MaxPool2d(2) self.enc3 = DoubleConv(base_channels * 2, base_channels * 4) self.pool3 = nn.MaxPool2d(2) self.enc4 = DoubleConv(base_channels * 4, base_channels * 8) self.pool4 = nn.MaxPool2d(2) self.enc5 = DoubleConv(base_channels * 8, base_channels * 16) self.pool5 = nn.MaxPool2d(2) # bottleneck self.bottleneck = DoubleConv(base_channels * 16, base_channels * 32) # 解码器 self.upconv5 = nn.ConvTranspose2d(base_channels * 32, base_channels * 16, kernel_size=2, stride=2) self.dec5 = DoubleConv(base_channels * 32, base_channels * 16) self.upconv4 = nn.ConvTranspose2d(base_channels * 16, base_channels * 8, kernel_size=2, stride=2) self.dec4 = DoubleConv(base_channels * 16, base_channels * 8) self.upconv3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2) self.dec3 = DoubleConv(base_channels * 8, base_channels * 4) self.upconv2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2) self.dec2 = DoubleConv(base_channels * 4, base_channels * 2) self.upconv1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2) self.dec1 = DoubleConv(base_channels * 2, base_channels) # 额外的上采样层,从512降到128 self.final_upsample = nn.Sequential( nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1), nn.BatchNorm2d(base_channels), nn.ReLU(inplace=True), nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True) ) # 修改score_layers以匹配256通道 self.score_layers = nn.ModuleList([ nn.Sequential( nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, num_classes, kernel_size=1) ) for _ in range(num_stacks) ]) self.channel_adjust = nn.Conv2d(base_channels, 256, kernel_size=1) def forward(self, x): # 编码过程 enc1 = self.enc1(x) enc2 = self.enc2(self.pool1(enc1)) enc3 = self.enc3(self.pool2(enc2)) enc4 = self.enc4(self.pool3(enc3)) enc5 = self.enc5(self.pool4(enc4)) # 瓶颈层 bottleneck = self.bottleneck(self.pool5(enc5)) # 解码过程 dec5 = self.upconv5(bottleneck) dec5 = torch.cat([dec5, enc5], dim=1) dec5 = self.dec5(dec5) dec4 = self.upconv4(dec5) dec4 = torch.cat([dec4, enc4], dim=1) dec4 = self.dec4(dec4) dec3 = self.upconv3(dec4) dec3 = torch.cat([dec3, enc3], dim=1) dec3 = self.dec3(dec3) dec2 = self.upconv2(dec3) dec2 = torch.cat([dec2, enc2], dim=1) dec2 = self.dec2(dec2) dec1 = self.upconv1(dec2) dec1 = torch.cat([dec1, enc1], dim=1) dec1 = self.dec1(dec1) # 额外的上采样,使输出大小为128 dec1 = self.final_upsample(dec1) # 调整通道数 dec1 = self.channel_adjust(dec1) # 多堆栈输出 outputs = [] for score_layer in self.score_layers: output = score_layer(dec1) outputs.append(output) return outputs[::-1], dec1 def unet(**kwargs): model = UNetWithMultipleStacks( num_classes=kwargs["num_classes"], num_stacks=kwargs.get("num_stacks", 2), base_channels=kwargs.get("base_channels", 64) ) return model