import torch import torch.nn as nn import torch.nn.functional as F class WirePredictor(nn.Module): def __init__(self, in_channels=4, out_channels=1, init_features=32): super(WirePredictor, self).__init__() features = init_features self.encoder1 = self._block(in_channels, features, name="enc1") self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.encoder2 = self._block(features, features * 2, name="enc2") self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.bottleneck = self._block(features * 2, features * 4, name="bottleneck") self.upconv2 = nn.ConvTranspose2d( features * 4, features * 2, kernel_size=2, stride=2 ) self.decoder2 = self._block((features * 2) * 2, features * 2, name="dec2") self.upconv1 = nn.ConvTranspose2d( features * 2, features, kernel_size=2, stride=2 ) self.decoder1 = self._block(features * 2, features, name="dec1") # Output for line segment mask self.conv_mask = nn.Conv2d( in_channels=features, out_channels=out_channels, kernel_size=1 ) # Output for normal vectors (2 channels for x and y components) self.conv_normals = nn.Conv2d( in_channels=features, out_channels=2, kernel_size=1 ) def forward(self, x): enc1 = self.encoder1(x) enc2 = self.encoder2(self.pool1(enc1)) bottleneck = self.bottleneck(self.pool2(enc2)) dec2 = self.upconv2(bottleneck) dec2 = torch.cat((dec2, enc2), dim=1) dec2 = self.decoder2(dec2) dec1 = self.upconv1(dec2) dec1 = torch.cat((dec1, enc1), dim=1) dec1 = self.decoder1(dec1) mask = torch.sigmoid(self.conv_mask(dec1)) normals = torch.tanh(self.conv_normals(dec1)) # Normalize to [-1, 1] return mask, normals def _block(self, in_channels, features, name): return nn.Sequential( nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(num_features=features), nn.ReLU(inplace=True), nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(num_features=features), nn.ReLU(inplace=True), ) # 测试模型 if __name__ == "__main__": model = WirePredictor() x = torch.randn((1, 4, 128, 128)) # 包含法向量信息的输入大小为 128x128 with torch.no_grad(): output_mask, output_normals = model(x) print(output_mask.shape, output_normals.shape) # 应输出 (1, 1, 128, 128) 和 (1, 2, 128, 128)