import torch
import torch.nn as nn
import torch.nn.functional as F

class WirePredictor(nn.Module):
    def __init__(self, in_channels=4, out_channels=1, init_features=32):
        super(WirePredictor, self).__init__()

        features = init_features
        self.encoder1 = self._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = self._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = self._block(features * 2, features * 4, name="bottleneck")

        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = self._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = self._block(features * 2, features, name="dec1")

        # Output for line segment mask
        self.conv_mask = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

        # Output for normal vectors (2 channels for x and y components)
        self.conv_normals = nn.Conv2d(
            in_channels=features, out_channels=2, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))

        bottleneck = self.bottleneck(self.pool2(enc2))

        dec2 = self.upconv2(bottleneck)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        mask = torch.sigmoid(self.conv_mask(dec1))
        normals = torch.tanh(self.conv_normals(dec1))  # Normalize to [-1, 1]

        return mask, normals

    def _block(self, in_channels, features, name):
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(num_features=features),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(num_features=features),
            nn.ReLU(inplace=True),
        )

# 测试模型
if __name__ == "__main__":
    model = WirePredictor()
    x = torch.randn((1, 4, 128, 128))  # 包含法向量信息的输入大小为 128x128
    with torch.no_grad():
        output_mask, output_normals = model(x)
        print(output_mask.shape, output_normals.shape)  # 应输出 (1, 1, 128, 128) 和 (1, 2, 128, 128)