anchor_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import math
  2. from typing import List, Optional
  3. import torch
  4. from torch import nn, Tensor
  5. from .image_list import ImageList
  6. class AnchorGenerator(nn.Module):
  7. """
  8. Module that generates anchors for a set of feature maps and
  9. image sizes.
  10. The module support computing anchors at multiple sizes and aspect ratios
  11. per feature map. This module assumes aspect ratio = height / width for
  12. each anchor.
  13. sizes and aspect_ratios should have the same number of elements, and it should
  14. correspond to the number of feature maps.
  15. sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
  16. and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
  17. per spatial location for feature map i.
  18. Args:
  19. sizes (Tuple[Tuple[int]]):
  20. aspect_ratios (Tuple[Tuple[float]]):
  21. """
  22. __annotations__ = {
  23. "cell_anchors": List[torch.Tensor],
  24. }
  25. def __init__(
  26. self,
  27. sizes=((128, 256, 512),),
  28. aspect_ratios=((0.5, 1.0, 2.0),),
  29. ):
  30. super().__init__()
  31. if not isinstance(sizes[0], (list, tuple)):
  32. # TODO change this
  33. sizes = tuple((s,) for s in sizes)
  34. if not isinstance(aspect_ratios[0], (list, tuple)):
  35. aspect_ratios = (aspect_ratios,) * len(sizes)
  36. self.sizes = sizes
  37. self.aspect_ratios = aspect_ratios
  38. self.cell_anchors = [
  39. self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios)
  40. ]
  41. # TODO: https://github.com/pytorch/pytorch/issues/26792
  42. # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
  43. # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
  44. # This method assumes aspect ratio = height / width for an anchor.
  45. def generate_anchors(
  46. self,
  47. scales: List[int],
  48. aspect_ratios: List[float],
  49. dtype: torch.dtype = torch.float32,
  50. device: torch.device = torch.device("cpu"),
  51. ) -> Tensor:
  52. scales = torch.as_tensor(scales, dtype=dtype, device=device)
  53. aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
  54. h_ratios = torch.sqrt(aspect_ratios)
  55. w_ratios = 1 / h_ratios
  56. ws = (w_ratios[:, None] * scales[None, :]).view(-1)
  57. hs = (h_ratios[:, None] * scales[None, :]).view(-1)
  58. base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
  59. return base_anchors.round()
  60. def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
  61. self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
  62. def num_anchors_per_location(self) -> List[int]:
  63. return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
  64. # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
  65. # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
  66. def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
  67. anchors = []
  68. cell_anchors = self.cell_anchors
  69. torch._assert(cell_anchors is not None, "cell_anchors should not be None")
  70. print(f'grid_sizes:{len(grid_sizes)},len(strides):{len(strides)},len(cell_anchors):{len(cell_anchors)}')
  71. torch._assert(
  72. len(grid_sizes) == len(strides) == len(cell_anchors),
  73. "Anchors should be Tuple[Tuple[int]] because each feature "
  74. "map could potentially have different sizes and aspect ratios. "
  75. "There needs to be a match between the number of "
  76. "feature maps passed and the number of sizes / aspect ratios specified.",
  77. )
  78. for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
  79. grid_height, grid_width = size
  80. stride_height, stride_width = stride
  81. device = base_anchors.device
  82. # For output anchor, compute [x_center, y_center, x_center, y_center]
  83. shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
  84. shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
  85. shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
  86. shift_x = shift_x.reshape(-1)
  87. shift_y = shift_y.reshape(-1)
  88. shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
  89. # For every (base anchor, output anchor) pair,
  90. # offset each zero-centered base anchor by the center of the output anchor.
  91. anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
  92. return anchors
  93. def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
  94. grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
  95. image_size = image_list.tensors.shape[-2:]
  96. dtype, device = feature_maps[0].dtype, feature_maps[0].device
  97. strides = [
  98. [
  99. torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]),
  100. torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]),
  101. ]
  102. for g in grid_sizes
  103. ]
  104. self.set_cell_anchors(dtype, device)
  105. anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
  106. anchors: List[List[torch.Tensor]] = []
  107. for _ in range(len(image_list.image_sizes)):
  108. anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
  109. anchors.append(anchors_in_image)
  110. anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
  111. return anchors
  112. class DefaultBoxGenerator(nn.Module):
  113. """
  114. This module generates the default boxes of SSD for a set of feature maps and image sizes.
  115. Args:
  116. aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map.
  117. min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation
  118. of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
  119. max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}` of the default boxes used in the estimation
  120. of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
  121. scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
  122. the ``min_ratio`` and ``max_ratio`` parameters.
  123. steps (List[int]], optional): It's a hyper-parameter that affects the tiling of default boxes. If not provided
  124. it will be estimated from the data.
  125. clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
  126. is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
  127. """
  128. def __init__(
  129. self,
  130. aspect_ratios: List[List[int]],
  131. min_ratio: float = 0.15,
  132. max_ratio: float = 0.9,
  133. scales: Optional[List[float]] = None,
  134. steps: Optional[List[int]] = None,
  135. clip: bool = True,
  136. ):
  137. super().__init__()
  138. if steps is not None and len(aspect_ratios) != len(steps):
  139. raise ValueError("aspect_ratios and steps should have the same length")
  140. self.aspect_ratios = aspect_ratios
  141. self.steps = steps
  142. self.clip = clip
  143. num_outputs = len(aspect_ratios)
  144. # Estimation of default boxes scales
  145. if scales is None:
  146. if num_outputs > 1:
  147. range_ratio = max_ratio - min_ratio
  148. self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)]
  149. self.scales.append(1.0)
  150. else:
  151. self.scales = [min_ratio, max_ratio]
  152. else:
  153. self.scales = scales
  154. self._wh_pairs = self._generate_wh_pairs(num_outputs)
  155. def _generate_wh_pairs(
  156. self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu")
  157. ) -> List[Tensor]:
  158. _wh_pairs: List[Tensor] = []
  159. for k in range(num_outputs):
  160. # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
  161. s_k = self.scales[k]
  162. s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
  163. wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]]
  164. # Adding 2 pairs for each aspect ratio of the feature map k
  165. for ar in self.aspect_ratios[k]:
  166. sq_ar = math.sqrt(ar)
  167. w = self.scales[k] * sq_ar
  168. h = self.scales[k] / sq_ar
  169. wh_pairs.extend([[w, h], [h, w]])
  170. _wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
  171. return _wh_pairs
  172. def num_anchors_per_location(self) -> List[int]:
  173. # Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
  174. return [2 + 2 * len(r) for r in self.aspect_ratios]
  175. # Default Boxes calculation based on page 6 of SSD paper
  176. def _grid_default_boxes(
  177. self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32
  178. ) -> Tensor:
  179. default_boxes = []
  180. for k, f_k in enumerate(grid_sizes):
  181. # Now add the default boxes for each width-height pair
  182. if self.steps is not None:
  183. x_f_k = image_size[1] / self.steps[k]
  184. y_f_k = image_size[0] / self.steps[k]
  185. else:
  186. y_f_k, x_f_k = f_k
  187. shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
  188. shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
  189. shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
  190. shift_x = shift_x.reshape(-1)
  191. shift_y = shift_y.reshape(-1)
  192. shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2)
  193. # Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
  194. _wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k]
  195. wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1)
  196. default_box = torch.cat((shifts, wh_pairs), dim=1)
  197. default_boxes.append(default_box)
  198. return torch.cat(default_boxes, dim=0)
  199. def __repr__(self) -> str:
  200. s = (
  201. f"{self.__class__.__name__}("
  202. f"aspect_ratios={self.aspect_ratios}"
  203. f", clip={self.clip}"
  204. f", scales={self.scales}"
  205. f", steps={self.steps}"
  206. ")"
  207. )
  208. return s
  209. def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
  210. grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
  211. image_size = image_list.tensors.shape[-2:]
  212. dtype, device = feature_maps[0].dtype, feature_maps[0].device
  213. default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
  214. default_boxes = default_boxes.to(device)
  215. dboxes = []
  216. x_y_size = torch.tensor([image_size[1], image_size[0]], device=default_boxes.device)
  217. for _ in image_list.image_sizes:
  218. dboxes_in_image = default_boxes
  219. dboxes_in_image = torch.cat(
  220. [
  221. (dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
  222. (dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
  223. ],
  224. -1,
  225. )
  226. dboxes.append(dboxes_in_image)
  227. return dboxes