transform.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import math
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import torch
  4. import torchvision
  5. from torch import nn, Tensor
  6. # from .image_list import ImageList
  7. # from .roi_heads import paste_masks_in_image
  8. from .ROI_heads import paste_masks_in_image
  9. class ImageList:
  10. """
  11. Structure that holds a list of images (of possibly
  12. varying sizes) as a single tensor.
  13. This works by padding the images to the same size,
  14. and storing in a field the original sizes of each image
  15. Args:
  16. tensors (tensor): Tensor containing images.
  17. image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images.
  18. """
  19. def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
  20. self.tensors = tensors
  21. self.image_sizes = image_sizes
  22. def to(self, device: torch.device) -> "ImageList":
  23. cast_tensor = self.tensors.to(device)
  24. return ImageList(cast_tensor, self.image_sizes)
  25. def _get_shape_onnx(image: Tensor) -> Tensor:
  26. from torch.onnx import operators
  27. return operators.shape_as_tensor(image)[-2:]
  28. def _fake_cast_onnx(v: Tensor) -> float:
  29. # ONNX requires a tensor but here we fake its type for JIT.
  30. return v
  31. def _resize_image_and_masks(
  32. image: Tensor,
  33. self_min_size: int,
  34. self_max_size: int,
  35. target: Optional[Dict[str, Tensor]] = None,
  36. fixed_size: Optional[Tuple[int, int]] = None,
  37. ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
  38. if torchvision._is_tracing():
  39. im_shape = _get_shape_onnx(image)
  40. elif torch.jit.is_scripting():
  41. im_shape = torch.tensor(image.shape[-2:])
  42. else:
  43. im_shape = image.shape[-2:]
  44. size: Optional[List[int]] = None
  45. scale_factor: Optional[float] = None
  46. recompute_scale_factor: Optional[bool] = None
  47. if fixed_size is not None:
  48. size = [fixed_size[1], fixed_size[0]]
  49. else:
  50. if torch.jit.is_scripting() or torchvision._is_tracing():
  51. min_size = torch.min(im_shape).to(dtype=torch.float32)
  52. max_size = torch.max(im_shape).to(dtype=torch.float32)
  53. self_min_size_f = float(self_min_size)
  54. self_max_size_f = float(self_max_size)
  55. scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
  56. if torchvision._is_tracing():
  57. scale_factor = _fake_cast_onnx(scale)
  58. else:
  59. scale_factor = scale.item()
  60. else:
  61. # Do it the normal way
  62. min_size = min(im_shape)
  63. max_size = max(im_shape)
  64. scale_factor = min(self_min_size / min_size, self_max_size / max_size)
  65. recompute_scale_factor = True
  66. image = torch.nn.functional.interpolate(
  67. image[None],
  68. size=size,
  69. scale_factor=scale_factor,
  70. mode="bilinear",
  71. recompute_scale_factor=recompute_scale_factor,
  72. align_corners=False,
  73. )[0]
  74. if target is None:
  75. return image, target
  76. if "masks" in target:
  77. mask = target["masks"]
  78. mask = torch.nn.functional.interpolate(
  79. mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
  80. )[:, 0].byte()
  81. target["masks"] = mask
  82. return image, target
  83. def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
  84. ratios = [
  85. torch.tensor(s, dtype=torch.float32, device=boxes.device)
  86. / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
  87. for s, s_orig in zip(new_size, original_size)
  88. ]
  89. ratio_height, ratio_width = ratios
  90. xmin, ymin, xmax, ymax = boxes.unbind(1)
  91. xmin = xmin * ratio_width
  92. xmax = xmax * ratio_width
  93. ymin = ymin * ratio_height
  94. ymax = ymax * ratio_height
  95. return torch.stack((xmin, ymin, xmax, ymax), dim=1)
  96. def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
  97. ratios = [
  98. torch.tensor(s, dtype=torch.float32, device=keypoints.device)
  99. / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
  100. for s, s_orig in zip(new_size, original_size)
  101. ]
  102. ratio_h, ratio_w = ratios
  103. resized_data = keypoints.clone()
  104. if torch._C._get_tracing_state():
  105. resized_data_0 = resized_data[:, :, 0] * ratio_w
  106. resized_data_1 = resized_data[:, :, 1] * ratio_h
  107. resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
  108. else:
  109. resized_data[..., 0] *= ratio_w
  110. resized_data[..., 1] *= ratio_h
  111. return resized_data
  112. class GeneralizedRCNNTransform(nn.Module):
  113. """
  114. Performs input / target transformation before feeding the data to a GeneralizedRCNN
  115. model.
  116. The transformations it performs are:
  117. - input normalization (mean subtraction and std division)
  118. - input / target resizing to match min_size / max_size
  119. It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
  120. """
  121. def __init__(
  122. self,
  123. min_size: int,
  124. max_size: int,
  125. image_mean: List[float],
  126. image_std: List[float],
  127. size_divisible: int = 32,
  128. fixed_size: Optional[Tuple[int, int]] = None,
  129. **kwargs: Any,
  130. ):
  131. super().__init__()
  132. if not isinstance(min_size, (list, tuple)):
  133. min_size = (min_size,)
  134. self.min_size = min_size
  135. self.max_size = max_size
  136. self.image_mean = image_mean
  137. self.image_std = image_std
  138. self.size_divisible = size_divisible
  139. self.fixed_size = fixed_size
  140. self._skip_resize = kwargs.pop("_skip_resize", False)
  141. def forward(
  142. self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
  143. ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
  144. images = [img for img in images]
  145. if targets is not None:
  146. # make a copy of targets to avoid modifying it in-place
  147. # once torchscript supports dict comprehension
  148. # this can be simplified as follows
  149. # targets = [{k: v for k,v in t.items()} for t in targets]
  150. targets_copy: List[Dict[str, Tensor]] = []
  151. for t in targets:
  152. data: Dict[str, Tensor] = {}
  153. for k, v in t.items():
  154. data[k] = v
  155. targets_copy.append(data)
  156. targets = targets_copy
  157. for i in range(len(images)):
  158. image = images[i]
  159. target_index = targets[i] if targets is not None else None
  160. if image.dim() != 3:
  161. raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
  162. image = self.normalize(image)
  163. image, target_index = self.resize(image, target_index)
  164. images[i] = image
  165. if targets is not None and target_index is not None:
  166. targets[i] = target_index
  167. image_sizes = [img.shape[-2:] for img in images]
  168. images = self.batch_images(images, size_divisible=self.size_divisible)
  169. image_sizes_list: List[Tuple[int, int]] = []
  170. for image_size in image_sizes:
  171. torch._assert(
  172. len(image_size) == 2,
  173. f"Input tensors expected to have in the last two elements H and W, instead got {image_size}",
  174. )
  175. image_sizes_list.append((image_size[0], image_size[1]))
  176. image_list = ImageList(images, image_sizes_list)
  177. return image_list, targets
  178. def normalize(self, image: Tensor) -> Tensor:
  179. if not image.is_floating_point():
  180. raise TypeError(
  181. f"Expected input images to be of floating type (in range [0, 1]), "
  182. f"but found type {image.dtype} instead"
  183. )
  184. dtype, device = image.dtype, image.device
  185. mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
  186. std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
  187. return (image - mean[:, None, None]) / std[:, None, None]
  188. def torch_choice(self, k: List[int]) -> int:
  189. """
  190. Implements `random.choice` via torch ops, so it can be compiled with
  191. TorchScript and we use PyTorch's RNG (not native RNG)
  192. """
  193. index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
  194. return k[index]
  195. def resize(
  196. self,
  197. image: Tensor,
  198. target: Optional[Dict[str, Tensor]] = None,
  199. ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
  200. h, w = image.shape[-2:]
  201. if self.training:
  202. if self._skip_resize:
  203. return image, target
  204. size = self.torch_choice(self.min_size)
  205. else:
  206. size = self.min_size[-1]
  207. image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
  208. if target is None:
  209. return image, target
  210. bbox = target["boxes"]
  211. bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
  212. target["boxes"] = bbox
  213. if "keypoints" in target:
  214. keypoints = target["keypoints"]
  215. keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
  216. target["keypoints"] = keypoints
  217. return image, target
  218. # _onnx_batch_images() is an implementation of
  219. # batch_images() that is supported by ONNX tracing.
  220. @torch.jit.unused
  221. def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
  222. max_size = []
  223. for i in range(images[0].dim()):
  224. max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
  225. max_size.append(max_size_i)
  226. stride = size_divisible
  227. max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
  228. max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
  229. max_size = tuple(max_size)
  230. # work around for
  231. # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
  232. # which is not yet supported in onnx
  233. padded_imgs = []
  234. for img in images:
  235. padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
  236. padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
  237. padded_imgs.append(padded_img)
  238. return torch.stack(padded_imgs)
  239. def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
  240. maxes = the_list[0]
  241. for sublist in the_list[1:]:
  242. for index, item in enumerate(sublist):
  243. maxes[index] = max(maxes[index], item)
  244. return maxes
  245. def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
  246. if torchvision._is_tracing():
  247. # batch_images() does not export well to ONNX
  248. # call _onnx_batch_images() instead
  249. return self._onnx_batch_images(images, size_divisible)
  250. max_size = self.max_by_axis([list(img.shape) for img in images])
  251. stride = float(size_divisible)
  252. max_size = list(max_size)
  253. max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
  254. max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
  255. batch_shape = [len(images)] + max_size
  256. batched_imgs = images[0].new_full(batch_shape, 0)
  257. for i in range(batched_imgs.shape[0]):
  258. img = images[i]
  259. batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
  260. return batched_imgs
  261. def postprocess(
  262. self,
  263. result: List[Dict[str, Tensor]],
  264. image_shapes: List[Tuple[int, int]],
  265. original_image_sizes: List[Tuple[int, int]],
  266. ) -> List[Dict[str, Tensor]]:
  267. if self.training:
  268. return result
  269. for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
  270. boxes = pred["boxes"]
  271. boxes = resize_boxes(boxes, im_s, o_im_s)
  272. result[i]["boxes"] = boxes
  273. if "masks" in pred:
  274. masks = pred["masks"]
  275. masks = paste_masks_in_image(masks, boxes, o_im_s)
  276. result[i]["masks"] = masks
  277. if "keypoints" in pred:
  278. keypoints = pred["keypoints"]
  279. keypoints = resize_keypoints(keypoints, im_s, o_im_s)
  280. result[i]["keypoints"] = keypoints
  281. return result
  282. def __repr__(self) -> str:
  283. format_string = f"{self.__class__.__name__}("
  284. _indent = "\n "
  285. format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})"
  286. format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')"
  287. format_string += "\n)"
  288. return format_string