12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class WirePredictor(nn.Module):
- def __init__(self, in_channels=4, out_channels=1, init_features=32):
- super(WirePredictor, self).__init__()
- features = init_features
- self.encoder1 = self._block(in_channels, features, name="enc1")
- self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
- self.encoder2 = self._block(features, features * 2, name="enc2")
- self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
- self.bottleneck = self._block(features * 2, features * 4, name="bottleneck")
- self.upconv2 = nn.ConvTranspose2d(
- features * 4, features * 2, kernel_size=2, stride=2
- )
- self.decoder2 = self._block((features * 2) * 2, features * 2, name="dec2")
- self.upconv1 = nn.ConvTranspose2d(
- features * 2, features, kernel_size=2, stride=2
- )
- self.decoder1 = self._block(features * 2, features, name="dec1")
- # Output for line segment mask
- self.conv_mask = nn.Conv2d(
- in_channels=features, out_channels=out_channels, kernel_size=1
- )
- # Output for normal vectors (2 channels for x and y components)
- self.conv_normals = nn.Conv2d(
- in_channels=features, out_channels=2, kernel_size=1
- )
- def forward(self, x):
- enc1 = self.encoder1(x)
- enc2 = self.encoder2(self.pool1(enc1))
- bottleneck = self.bottleneck(self.pool2(enc2))
- dec2 = self.upconv2(bottleneck)
- dec2 = torch.cat((dec2, enc2), dim=1)
- dec2 = self.decoder2(dec2)
- dec1 = self.upconv1(dec2)
- dec1 = torch.cat((dec1, enc1), dim=1)
- dec1 = self.decoder1(dec1)
- mask = torch.sigmoid(self.conv_mask(dec1))
- normals = torch.tanh(self.conv_normals(dec1)) # Normalize to [-1, 1]
- return mask, normals
- def _block(self, in_channels, features, name):
- return nn.Sequential(
- nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False),
- nn.BatchNorm2d(num_features=features),
- nn.ReLU(inplace=True),
- nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False),
- nn.BatchNorm2d(num_features=features),
- nn.ReLU(inplace=True),
- )
- # 测试模型
- if __name__ == "__main__":
- model = WirePredictor()
- x = torch.randn((1, 4, 128, 128)) # 包含法向量信息的输入大小为 128x128
- with torch.no_grad():
- output_mask, output_normals = model(x)
- print(output_mask.shape, output_normals.shape) # 应输出 (1, 1, 128, 128) 和 (1, 2, 128, 128)
|