point_heads.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import torch
  2. from torch import nn
  3. class PointHeads(nn.Sequential):
  4. def __init__(self, in_channels, layers):
  5. d = []
  6. next_feature = in_channels
  7. for out_channels in layers:
  8. d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
  9. d.append(nn.ReLU(inplace=True))
  10. next_feature = out_channels
  11. super().__init__(*d)
  12. for m in self.children():
  13. if isinstance(m, nn.Conv2d):
  14. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  15. nn.init.constant_(m.bias, 0)
  16. class PointPredictor(nn.Module):
  17. def __init__(self, in_channels, out_channels=1 ):
  18. super().__init__()
  19. input_features = in_channels
  20. deconv_kernel = 4
  21. self.kps_score_lowres = nn.ConvTranspose2d(
  22. input_features,
  23. out_channels,
  24. deconv_kernel,
  25. stride=2,
  26. padding=deconv_kernel // 2 - 1,
  27. )
  28. nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
  29. nn.init.constant_(self.kps_score_lowres.bias, 0)
  30. self.up_scale = 2
  31. self.out_channels = out_channels
  32. def forward(self, x):
  33. # print(f'before kps_score_lowres x:{x.shape}')
  34. x = self.kps_score_lowres(x)
  35. # print(f'kps_score_lowres x:{x.shape}')
  36. return torch.nn.functional.interpolate(
  37. x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
  38. )