import torch from torch import nn class LineRCNNHeads(nn.Sequential): def __init__(self, input_channels, num_class): super(LineRCNNHeads, self).__init__() # print("输入的维度是:", input_channels) m = int(input_channels / 4) heads = [] self.head_size = [[2], [1], [2]] for output_channels in sum(self.head_size, []): heads.append( nn.Sequential( nn.Conv2d(input_channels, m, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(m, output_channels, kernel_size=1), ) ) self.heads = nn.ModuleList(heads) assert num_class == sum(sum(self.head_size, [])) def forward(self, x): return torch.cat([head(x) for head in self.heads], dim=1)