arc_heads.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import torch
  2. from torch import nn
  3. class ArcHeads(nn.Sequential):
  4. def __init__(self, in_channels, layers):
  5. d = []
  6. next_feature = in_channels
  7. for out_channels in layers:
  8. d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
  9. d.append(nn.ReLU(inplace=True))
  10. next_feature = out_channels
  11. super().__init__(*d)
  12. for m in self.children():
  13. if isinstance(m, nn.Conv2d):
  14. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  15. nn.init.constant_(m.bias, 0)
  16. class ArcPredictor(nn.Module):
  17. def __init__(self, in_channels, out_channels=1 ):
  18. super().__init__()
  19. input_features = in_channels
  20. deconv_kernel = 4
  21. self.kps_score_lowres = nn.ConvTranspose2d(
  22. input_features,
  23. out_channels,
  24. deconv_kernel,
  25. stride=2,
  26. padding=deconv_kernel // 2 - 1,
  27. )
  28. nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
  29. nn.init.constant_(self.kps_score_lowres.bias, 0)
  30. self.up_scale = 2
  31. self.out_channels = out_channels
  32. def forward(self, x):
  33. print(f'before kps_score_lowres x:{x.shape}')
  34. x = self.kps_score_lowres(x)
  35. print(f'kps_score_lowres x:{x.shape}')
  36. return x
  37. # return torch.nn.functional.interpolate(
  38. # x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
  39. # )
  40. import torch
  41. import torch.nn as nn
  42. import torch.nn.functional as F
  43. class ArcEquationHead(nn.Module):
  44. """
  45. Input:
  46. feature_logits : [N, 1, H, W]
  47. Output:
  48. arc_params : [N, 7]
  49. 0: center x (cx)
  50. 1: center y (cy)
  51. 2: long axis length (a)
  52. 3: short axis length (b)
  53. 4: ellipse angle (theta)
  54. 5: auxiliary x coordinate
  55. 6: auxiliary y coordinate
  56. """
  57. def __init__(self, num_outputs=9, hidden=512):
  58. super().__init__()
  59. # Use GAP to remove spatial dependency
  60. self.gap = nn.AdaptiveAvgPool2d((1, 1))
  61. # Final MLP that maps pooled feature ¡ú arc parameters
  62. self.mlp = nn.Sequential(
  63. nn.Linear(1, hidden),
  64. nn.ReLU(inplace=True),
  65. nn.Linear(hidden, num_outputs)
  66. )
  67. def forward(self, feature_logits):
  68. """
  69. feature_logits: [N, 1, H, W]
  70. """
  71. N, _, H, W = feature_logits.shape
  72. print(f'N:{N}, H:{H}, W:{W}')
  73. # --------------------------------------------
  74. # Global average pooling
  75. # Input : [N, 1, H, W]
  76. # Output : [N, 1]
  77. # --------------------------------------------
  78. x = self.gap(feature_logits)
  79. x = x.view(N, -1)
  80. # Predict raw parameters
  81. arc_params = self.mlp(x) # [N, 7]
  82. # --------------------------------------------
  83. # Parameter constraints
  84. # --------------------------------------------
  85. # H=1500
  86. # W = 2000
  87. # Ellipse center
  88. arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W # cx in image width range
  89. arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H # cy in image height range
  90. # Axes lengths must be positive
  91. arc_params[..., 2] = torch.sigmoid(arc_params[..., 2]) * W # cx in image width range
  92. arc_params[..., 3] = torch.sigmoid(arc_params[..., 3]) * W # cy in image height range
  93. # arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6 # a > 0
  94. # arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6 # b > 0
  95. # Angle between 0~2¦Ð
  96. arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * (2 * 3.1415926535)
  97. # ------------------------------------------------
  98. # Last two values are auxiliary points
  99. # Now mapped to the same spatial range as image
  100. # ------------------------------------------------
  101. arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * W # x auxiliary
  102. arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * H # y auxiliary
  103. arc_params[..., 7] = torch.sigmoid(arc_params[..., 7]) * W # x auxiliary
  104. arc_params[..., 8] = torch.sigmoid(arc_params[..., 8]) * H # y auxiliary
  105. print(f'arc_params in head:{arc_params}')
  106. return arc_params
  107. # class ArcEquationHead(nn.Module):
  108. # """
  109. # Input:
  110. # feature_logits : [N, 1, H, W] # N:Ô²»¡Êý£¬H/W:ÌØÕ÷ͼ³ß´ç£¨¶ÔӦԭʼͼÏñ¿Õ¼äλÖã©
  111. # Output:
  112. # arc_params : [N, 9] # [cx, cy, a, b, theta, x1, y1, x2, y2]
  113. # """
  114. #
  115. # def __init__(self, num_outputs=9, hidden=1024, feat_size=(672, 672)):
  116. # super().__init__()
  117. # self.feat_H, self.feat_W = feat_size # ÌØÕ÷ͼµÄ¹Ì¶¨³ß´ç£¨ÐèÓëfeature_logitsÒ»Ö£©
  118. #
  119. # self.flatten = nn.Flatten()
  120. # self.input_dim = self.feat_H * self.feat_W # ÊäÈëά¶È£ºH*W£¨¶ø·Ç1£©
  121. #
  122. # self.mlp = nn.Sequential(
  123. # nn.Linear(self.input_dim, hidden),
  124. # nn.ReLU(inplace=True),
  125. # nn.Dropout(0.2), # ·ÀÖ¹¹ýÄâºÏ
  126. # nn.Linear(hidden, hidden // 2),
  127. # nn.ReLU(inplace=True),
  128. # nn.Dropout(0.1),
  129. # nn.Linear(hidden // 2, num_outputs)
  130. # )
  131. #
  132. # self._init_weights()
  133. #
  134. # def _init_weights(self):
  135. # for m in self.mlp.modules():
  136. # if isinstance(m, nn.Linear):
  137. # nn.init.xavier_uniform_(m.weight) # ¾ùÔȳõʼ»¯£¬±ÜÃâÊä³ö¼¯ÖÐ
  138. # if m.bias is not None:
  139. # nn.init.zeros_(m.bias)
  140. #
  141. # def forward(self, feature_logits):
  142. # N, C, H, W = feature_logits.shape
  143. # assert H == self.feat_H and W == self.feat_W, "ÌØÕ÷ͼ³ß´çÐèÓë³õʼ»¯Ê±µÄfeat_sizeÒ»ÖÂ"
  144. #
  145. # # 1. Flatten¿Õ¼äÌØÕ÷£º[N,1,H,W] ¡ú [N, H*W]£¨±£Áôÿ¸öÏñËØµÄ¿Õ¼äÐÅÏ¢£©
  146. # x = self.flatten(feature_logits) # [N, H*W]
  147. #
  148. # # 2. MLPÔ¤²âԭʼ²ÎÊý
  149. # arc_params = self.mlp(x) # [N,9]
  150. #
  151. # # 3. ÓÅ»¯²ÎÊýÔ¼Êø£¨±ÜÃâÖÐÐÄ/Ö᳤Òì³££©
  152. # # ÍÖÔ²ÖÐÐÄ£ºÓ³Éäµ½ÌØÕ÷ͼ³ß´ç£¨ÈôÌØÕ÷ͼÊÇԭʼͼÏñϲÉÑù£¬Ðè³ËÒÔËõ·ÅÒò×Ó£©
  153. # arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W # cx
  154. # arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H # cy
  155. #
  156. # arc_params[..., 2] = torch.sigmoid(arc_params[..., 2]) * W # cx in image width range
  157. # arc_params[..., 3] = torch.sigmoid(arc_params[..., 3]) * H # cy in image height range
  158. # # arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6 # a > 0
  159. # # arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6 # b > 0
  160. #
  161. # # Angle between 0~2¦Ð
  162. # arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * (2 * 3.1415926535)
  163. #
  164. # # ------------------------------------------------
  165. # # Last two values are auxiliary points
  166. # # Now mapped to the same spatial range as image
  167. # # ------------------------------------------------
  168. # arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * W # x auxiliary
  169. # arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * H # y auxiliary
  170. # arc_params[..., 7] = torch.sigmoid(arc_params[..., 7]) * W # x auxiliary
  171. # arc_params[..., 8] = torch.sigmoid(arc_params[..., 8]) * H # y auxiliary
  172. # print(f'arc_params in head:{arc_params}')
  173. # return arc_params