WirePredictor.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class WirePredictor(nn.Module):
  5. def __init__(self, in_channels=4, out_channels=1, init_features=32):
  6. super(WirePredictor, self).__init__()
  7. features = init_features
  8. self.encoder1 = self._block(in_channels, features, name="enc1")
  9. self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
  10. self.encoder2 = self._block(features, features * 2, name="enc2")
  11. self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
  12. self.bottleneck = self._block(features * 2, features * 4, name="bottleneck")
  13. self.upconv2 = nn.ConvTranspose2d(
  14. features * 4, features * 2, kernel_size=2, stride=2
  15. )
  16. self.decoder2 = self._block((features * 2) * 2, features * 2, name="dec2")
  17. self.upconv1 = nn.ConvTranspose2d(
  18. features * 2, features, kernel_size=2, stride=2
  19. )
  20. self.decoder1 = self._block(features * 2, features, name="dec1")
  21. # Output for line segment mask
  22. self.conv_mask = nn.Conv2d(
  23. in_channels=features, out_channels=out_channels, kernel_size=1
  24. )
  25. # Output for normal vectors (2 channels for x and y components)
  26. self.conv_normals = nn.Conv2d(
  27. in_channels=features, out_channels=2, kernel_size=1
  28. )
  29. def forward(self, x):
  30. enc1 = self.encoder1(x)
  31. enc2 = self.encoder2(self.pool1(enc1))
  32. bottleneck = self.bottleneck(self.pool2(enc2))
  33. dec2 = self.upconv2(bottleneck)
  34. dec2 = torch.cat((dec2, enc2), dim=1)
  35. dec2 = self.decoder2(dec2)
  36. dec1 = self.upconv1(dec2)
  37. dec1 = torch.cat((dec1, enc1), dim=1)
  38. dec1 = self.decoder1(dec1)
  39. mask = torch.sigmoid(self.conv_mask(dec1))
  40. normals = torch.tanh(self.conv_normals(dec1)) # Normalize to [-1, 1]
  41. return mask, normals
  42. def _block(self, in_channels, features, name):
  43. return nn.Sequential(
  44. nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False),
  45. nn.BatchNorm2d(num_features=features),
  46. nn.ReLU(inplace=True),
  47. nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False),
  48. nn.BatchNorm2d(num_features=features),
  49. nn.ReLU(inplace=True),
  50. )
  51. # 测试模型
  52. if __name__ == "__main__":
  53. model = WirePredictor()
  54. x = torch.randn((1, 4, 128, 128)) # 包含法向量信息的输入大小为 128x128
  55. with torch.no_grad():
  56. output_mask, output_normals = model(x)
  57. print(output_mask.shape, output_normals.shape) # 应输出 (1, 1, 128, 128) 和 (1, 2, 128, 128)