123456789101112131415161718192021222324 |
- import torch
- from torch import nn
- class LineRCNNHeads(nn.Sequential):
- def __init__(self, input_channels, num_class):
- super(LineRCNNHeads, self).__init__()
- # print("输入的维度是:", input_channels)
- m = int(input_channels / 4)
- heads = []
- self.head_size = [[2], [1], [2]]
- for output_channels in sum(self.head_size, []):
- heads.append(
- nn.Sequential(
- nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
- nn.ReLU(inplace=True),
- nn.Conv2d(m, output_channels, kernel_size=1),
- )
- )
- self.heads = nn.ModuleList(heads)
- assert num_class == sum(sum(self.head_size, []))
- def forward(self, x):
- return torch.cat([head(x) for head in self.heads], dim=1)
|