import torch
from torch import nn


class LineRCNNHeads(nn.Sequential):
    def __init__(self, input_channels, num_class):
        super(LineRCNNHeads, self).__init__()
        # print("输入的维度是:", input_channels)
        m = int(input_channels / 4)
        heads = []
        self.head_size = [[2], [1], [2]]
        for output_channels in sum(self.head_size, []):
            heads.append(
                nn.Sequential(
                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(m, output_channels, kernel_size=1),
                )
            )
        self.heads = nn.ModuleList(heads)
        assert num_class == sum(sum(self.head_size, []))

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=1)