line_head.py 843 B

123456789101112131415161718192021222324
  1. import torch
  2. from torch import nn
  3. class LineRCNNHeads(nn.Sequential):
  4. def __init__(self, input_channels, num_class):
  5. super(LineRCNNHeads, self).__init__()
  6. # print("输入的维度是:", input_channels)
  7. m = int(input_channels / 4)
  8. heads = []
  9. self.head_size = [[2], [1], [2]]
  10. for output_channels in sum(self.head_size, []):
  11. heads.append(
  12. nn.Sequential(
  13. nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
  14. nn.ReLU(inplace=True),
  15. nn.Conv2d(m, output_channels, kernel_size=1),
  16. )
  17. )
  18. self.heads = nn.ModuleList(heads)
  19. assert num_class == sum(sum(self.head_size, []))
  20. def forward(self, x):
  21. return torch.cat([head(x) for head in self.heads], dim=1)