arc_heads.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import torch
  2. from torch import nn
  3. class ArcHeads(nn.Sequential):
  4. def __init__(self, in_channels, layers):
  5. d = []
  6. next_feature = in_channels
  7. for out_channels in layers:
  8. d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
  9. d.append(nn.ReLU(inplace=True))
  10. next_feature = out_channels
  11. super().__init__(*d)
  12. for m in self.children():
  13. if isinstance(m, nn.Conv2d):
  14. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  15. nn.init.constant_(m.bias, 0)
  16. class ArcPredictor(nn.Module):
  17. def __init__(self, in_channels, out_channels=1 ):
  18. super().__init__()
  19. input_features = in_channels
  20. deconv_kernel = 4
  21. self.kps_score_lowres = nn.ConvTranspose2d(
  22. input_features,
  23. out_channels,
  24. deconv_kernel,
  25. stride=2,
  26. padding=deconv_kernel // 2 - 1,
  27. )
  28. nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
  29. nn.init.constant_(self.kps_score_lowres.bias, 0)
  30. self.up_scale = 2
  31. self.out_channels = out_channels
  32. def forward(self, x):
  33. print(f'before kps_score_lowres x:{x.shape}')
  34. x = self.kps_score_lowres(x)
  35. print(f'kps_score_lowres x:{x.shape}')
  36. return x
  37. # return torch.nn.functional.interpolate(
  38. # x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
  39. # )
  40. import torch
  41. import torch.nn as nn
  42. import torch.nn.functional as F
  43. class ArcEquationHead(nn.Module):
  44. def __init__(self, num_outputs=7):
  45. super().__init__()
  46. # --------------------------------------------------
  47. # Convolution layers - no fixed H,W assumptions
  48. # Automatically downsamples using stride=2
  49. # --------------------------------------------------
  50. self.conv = nn.Sequential(
  51. nn.Conv2d(1, 32, 3, stride=2, padding=1),
  52. nn.ReLU(inplace=True),
  53. nn.Conv2d(32, 64, 3, stride=2, padding=1),
  54. nn.ReLU(inplace=True),
  55. nn.Conv2d(64, 128, 3, stride=2, padding=1),
  56. nn.ReLU(inplace=True),
  57. nn.Conv2d(128, 256, 3, stride=2, padding=1),
  58. nn.ReLU(inplace=True),
  59. )
  60. # --------------------------------------------------
  61. # Global pooling ¡ú no H,W dependency
  62. # --------------------------------------------------
  63. self.gap = nn.AdaptiveAvgPool2d((1, 1))
  64. # --------------------------------------------------
  65. # MLP
  66. # --------------------------------------------------
  67. self.mlp = nn.Sequential(
  68. nn.Linear(256, 256),
  69. nn.ReLU(inplace=True),
  70. nn.Linear(256, num_outputs)
  71. )
  72. def forward(self, feature_logits):
  73. """
  74. Args:
  75. feature_logits: Tensor [N, 1, H, W]
  76. """
  77. # CNN
  78. x = self.conv(feature_logits)
  79. # Global pool
  80. x = self.gap(x).view(x.size(0), -1)
  81. # Predict params
  82. arc_params = self.mlp(x) # -> [N, 7]
  83. N, _, H, W = feature_logits.shape
  84. # --------------------------------------------
  85. # Apply constraints
  86. # --------------------------------------------
  87. arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W # cx
  88. arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H # cy
  89. arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6 # long axis
  90. arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6 # short axis
  91. # angles 0~2¦Ð
  92. arc_params[..., 4:7] = torch.sigmoid(arc_params[..., 4:7]) * (2 * 3.1415926535)
  93. return arc_params