arc_heads.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import torch
  2. from torch import nn
  3. class ArcHeads(nn.Sequential):
  4. def __init__(self, in_channels, layers):
  5. d = []
  6. next_feature = in_channels
  7. for out_channels in layers:
  8. d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
  9. d.append(nn.ReLU(inplace=True))
  10. next_feature = out_channels
  11. super().__init__(*d)
  12. for m in self.children():
  13. if isinstance(m, nn.Conv2d):
  14. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  15. nn.init.constant_(m.bias, 0)
  16. class ArcPredictor(nn.Module):
  17. def __init__(self, in_channels, out_channels=1 ):
  18. super().__init__()
  19. input_features = in_channels
  20. deconv_kernel = 4
  21. self.kps_score_lowres = nn.ConvTranspose2d(
  22. input_features,
  23. out_channels,
  24. deconv_kernel,
  25. stride=2,
  26. padding=deconv_kernel // 2 - 1,
  27. )
  28. nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
  29. nn.init.constant_(self.kps_score_lowres.bias, 0)
  30. self.up_scale = 2
  31. self.out_channels = out_channels
  32. def forward(self, x):
  33. print(f'before kps_score_lowres x:{x.shape}')
  34. x = self.kps_score_lowres(x)
  35. print(f'kps_score_lowres x:{x.shape}')
  36. return x
  37. # return torch.nn.functional.interpolate(
  38. # x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
  39. # )
  40. import torch
  41. import torch.nn as nn
  42. import torch.nn.functional as F
  43. class ArcEquationHead(nn.Module):
  44. """
  45. Input:
  46. feature_logits : [N, 1, H, W]
  47. Output:
  48. arc_params : [N, 7]
  49. 0: center x (cx)
  50. 1: center y (cy)
  51. 2: long axis length (a)
  52. 3: short axis length (b)
  53. 4: ellipse angle (theta)
  54. 5: auxiliary x coordinate
  55. 6: auxiliary y coordinate
  56. """
  57. def __init__(self, num_outputs=9, hidden=512):
  58. super().__init__()
  59. # Use GAP to remove spatial dependency
  60. self.gap = nn.AdaptiveAvgPool2d((1, 1))
  61. # Final MLP that maps pooled feature ¡ú arc parameters
  62. self.mlp = nn.Sequential(
  63. nn.Linear(1, hidden),
  64. nn.ReLU(inplace=True),
  65. nn.Linear(hidden, num_outputs)
  66. )
  67. def forward(self, feature_logits):
  68. """
  69. feature_logits: [N, 1, H, W]
  70. """
  71. N, _, H, W = feature_logits.shape
  72. # --------------------------------------------
  73. # Global average pooling
  74. # Input : [N, 1, H, W]
  75. # Output : [N, 1]
  76. # --------------------------------------------
  77. x = self.gap(feature_logits)
  78. x = x.view(N, -1)
  79. # Predict raw parameters
  80. arc_params = self.mlp(x) # [N, 7]
  81. # --------------------------------------------
  82. # Parameter constraints
  83. # --------------------------------------------
  84. # Ellipse center
  85. arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W # cx in image width range
  86. arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H # cy in image height range
  87. # Axes lengths must be positive
  88. arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6 # a > 0
  89. arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6 # b > 0
  90. # Angle between 0~2¦Ð
  91. arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * (2 * 3.1415926535)
  92. # ------------------------------------------------
  93. # Last two values are auxiliary points
  94. # Now mapped to the same spatial range as image
  95. # ------------------------------------------------
  96. arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * W # x auxiliary
  97. arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * H # y auxiliary
  98. arc_params[..., 7] = torch.sigmoid(arc_params[..., 7]) * W # x auxiliary
  99. arc_params[..., 8] = torch.sigmoid(arc_params[..., 8]) * H # y auxiliary
  100. return arc_params