transform.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. import math
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import torch
  4. from torch import nn, Tensor
  5. from .image_list import ImageList
  6. from .roi_heads import paste_masks_in_image
  7. from ... import _is_tracing
  8. @torch.jit.unused
  9. def _get_shape_onnx(image: Tensor) -> Tensor:
  10. from torch.onnx import operators
  11. return operators.shape_as_tensor(image)[-2:]
  12. @torch.jit.unused
  13. def _fake_cast_onnx(v: Tensor) -> float:
  14. # ONNX requires a tensor but here we fake its type for JIT.
  15. return v
  16. def _resize_image_and_masks(
  17. image: Tensor,
  18. self_min_size: int,
  19. self_max_size: int,
  20. target: Optional[Dict[str, Tensor]] = None,
  21. fixed_size: Optional[Tuple[int, int]] = None,
  22. ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
  23. if _is_tracing():
  24. im_shape = _get_shape_onnx(image)
  25. elif torch.jit.is_scripting():
  26. im_shape = torch.tensor(image.shape[-2:])
  27. else:
  28. im_shape = image.shape[-2:]
  29. size: Optional[List[int]] = None
  30. scale_factor: Optional[float] = None
  31. recompute_scale_factor: Optional[bool] = None
  32. if fixed_size is not None:
  33. size = [fixed_size[1], fixed_size[0]]
  34. else:
  35. if torch.jit.is_scripting() or _is_tracing():
  36. min_size = torch.min(im_shape).to(dtype=torch.float32)
  37. max_size = torch.max(im_shape).to(dtype=torch.float32)
  38. self_min_size_f = float(self_min_size)
  39. self_max_size_f = float(self_max_size)
  40. scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
  41. if _is_tracing():
  42. scale_factor = _fake_cast_onnx(scale)
  43. else:
  44. scale_factor = scale.item()
  45. else:
  46. # Do it the normal way
  47. min_size = min(im_shape)
  48. max_size = max(im_shape)
  49. scale_factor = min(self_min_size / min_size, self_max_size / max_size)
  50. recompute_scale_factor = True
  51. image = torch.nn.functional.interpolate(
  52. image[None],
  53. size=size,
  54. scale_factor=scale_factor,
  55. mode="bilinear",
  56. recompute_scale_factor=recompute_scale_factor,
  57. align_corners=False,
  58. )[0]
  59. if target is None:
  60. return image, target
  61. if "masks" in target:
  62. mask = target["masks"]
  63. mask = torch.nn.functional.interpolate(
  64. mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
  65. )[:, 0].byte()
  66. target["masks"] = mask
  67. return image, target
  68. class GeneralizedRCNNTransform(nn.Module):
  69. """
  70. Performs input / target transformation before feeding the data to a GeneralizedRCNN
  71. model.
  72. The transformations it performs are:
  73. - input normalization (mean subtraction and std division)
  74. - input / target resizing to match min_size / max_size
  75. It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
  76. """
  77. def __init__(
  78. self,
  79. min_size: int,
  80. max_size: int,
  81. image_mean: List[float],
  82. image_std: List[float],
  83. size_divisible: int = 32,
  84. fixed_size: Optional[Tuple[int, int]] = None,
  85. **kwargs: Any,
  86. ):
  87. super().__init__()
  88. if not isinstance(min_size, (list, tuple)):
  89. min_size = (min_size,)
  90. self.min_size = min_size
  91. self.max_size = max_size
  92. self.image_mean = image_mean
  93. self.image_std = image_std
  94. self.size_divisible = size_divisible
  95. self.fixed_size = fixed_size
  96. self._skip_resize = kwargs.pop("_skip_resize", False)
  97. def forward(
  98. self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
  99. ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
  100. images = [img for img in images]
  101. if targets is not None:
  102. # make a copy of targets to avoid modifying it in-place
  103. # once torchscript supports dict comprehension
  104. # this can be simplified as follows
  105. # targets = [{k: v for k,v in t.items()} for t in targets]
  106. targets_copy: List[Dict[str, Tensor]] = []
  107. for t in targets:
  108. data: Dict[str, Tensor] = {}
  109. for k, v in t.items():
  110. data[k] = v
  111. targets_copy.append(data)
  112. targets = targets_copy
  113. for i in range(len(images)):
  114. image = images[i]
  115. target_index = targets[i] if targets is not None else None
  116. if image.dim() != 3:
  117. raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
  118. image = self.normalize(image)
  119. image, target_index = self.resize(image, target_index)
  120. images[i] = image
  121. if targets is not None and target_index is not None:
  122. targets[i] = target_index
  123. image_sizes = [img.shape[-2:] for img in images]
  124. images = self.batch_images(images, size_divisible=self.size_divisible)
  125. image_sizes_list: List[Tuple[int, int]] = []
  126. for image_size in image_sizes:
  127. torch._assert(
  128. len(image_size) == 2,
  129. f"Input tensors expected to have in the last two elements H and W, instead got {image_size}",
  130. )
  131. image_sizes_list.append((image_size[0], image_size[1]))
  132. image_list = ImageList(images, image_sizes_list)
  133. return image_list, targets
  134. def normalize(self, image: Tensor) -> Tensor:
  135. if not image.is_floating_point():
  136. raise TypeError(
  137. f"Expected input images to be of floating type (in range [0, 1]), "
  138. f"but found type {image.dtype} instead"
  139. )
  140. dtype, device = image.dtype, image.device
  141. mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
  142. std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
  143. # print(f'mean:{mean}')
  144. # print(f'std:{std}')
  145. # print(f'image:{image.shape}')
  146. return (image - mean[:, None, None]) / std[:, None, None]
  147. def torch_choice(self, k: List[int]) -> int:
  148. """
  149. Implements `random.choice` via torch ops, so it can be compiled with
  150. TorchScript and we use PyTorch's RNG (not native RNG)
  151. """
  152. index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
  153. return k[index]
  154. def resize(
  155. self,
  156. image: Tensor,
  157. target: Optional[Dict[str, Tensor]] = None,
  158. ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
  159. h, w = image.shape[-2:]
  160. if self.training:
  161. if self._skip_resize:
  162. return image, target
  163. size = self.torch_choice(self.min_size)
  164. else:
  165. size = self.min_size[-1]
  166. image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
  167. if target is None:
  168. return image, target
  169. bbox = target["boxes"]
  170. print(f'bbox:{bbox}')
  171. print(f'image.shape[-2:]:{image.shape},,,{image.shape[-2:]}')
  172. bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
  173. target["boxes"] = bbox
  174. if "keypoints" in target:
  175. keypoints = target["keypoints"]
  176. keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
  177. target["keypoints"] = keypoints
  178. if "lines" in target:
  179. lines = target["lines"]
  180. lines = resize_keypoints(lines, (h, w), image.shape[-2:])
  181. target["lines"] = lines
  182. if "points" in target:
  183. points = target["points"]
  184. points = resize_keypoints(points, (h, w), image.shape[-2:])
  185. target["points"] = points
  186. # if "circle_masks" in target:
  187. # arc_mask = target["circle_masks"]
  188. # arc_mask = resize_keypoints(arc_mask, (h, w), image.shape[-2:])
  189. # target["circle_masks"] = arc_mask
  190. if "circles" in target:
  191. arc_mask = target["circles"]
  192. arc_mask = resize_keypoints(arc_mask, (h, w), image.shape[-2:])
  193. target["circles"] = arc_mask
  194. if "mask_ends" in target:
  195. arc_mask = target["mask_ends"]
  196. # print(f'arc_mask in 1:{arc_mask}')
  197. arc_mask = resize_keypoints(arc_mask, (h, w), image.shape[-2:])
  198. # print(f'arc_mask in 2:{arc_mask}')
  199. target["mask_ends"] = arc_mask
  200. if "mask_params" in target:
  201. arc_mask = target["mask_params"]
  202. # print(f'arc_mask in 3:{arc_mask}')
  203. arc_mask = resize_keypoints(arc_mask, (h, w), image.shape[-2:])
  204. arc_mask[:,2:4] = resize_keypoints(arc_mask[:,2:4], (h, w), image.shape[-2:])
  205. # print(f'arc_mask in 4:{arc_mask}')
  206. target["mask_params"] = arc_mask
  207. return image, target
  208. # _onnx_batch_images() is an implementation of
  209. # batch_images() that is supported by ONNX tracing.
  210. @torch.jit.unused
  211. def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
  212. max_size = []
  213. for i in range(images[0].dim()):
  214. max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
  215. max_size.append(max_size_i)
  216. stride = size_divisible
  217. max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
  218. max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
  219. max_size = tuple(max_size)
  220. # work around for
  221. # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
  222. # which is not yet supported in onnx
  223. padded_imgs = []
  224. for img in images:
  225. padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
  226. padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
  227. padded_imgs.append(padded_img)
  228. return torch.stack(padded_imgs)
  229. def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
  230. maxes = the_list[0]
  231. for sublist in the_list[1:]:
  232. for index, item in enumerate(sublist):
  233. maxes[index] = max(maxes[index], item)
  234. return maxes
  235. def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
  236. if _is_tracing():
  237. # batch_images() does not export well to ONNX
  238. # call _onnx_batch_images() instead
  239. return self._onnx_batch_images(images, size_divisible)
  240. max_size = self.max_by_axis([list(img.shape) for img in images])
  241. stride = float(size_divisible)
  242. max_size = list(max_size)
  243. max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
  244. max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
  245. batch_shape = [len(images)] + max_size
  246. batched_imgs = images[0].new_full(batch_shape, 0)
  247. for i in range(batched_imgs.shape[0]):
  248. img = images[i]
  249. batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
  250. return batched_imgs
  251. def postprocess(
  252. self,
  253. result: List[Dict[str, Tensor]],
  254. image_shapes: List[Tuple[int, int]],
  255. original_image_sizes: List[Tuple[int, int]],
  256. ) -> List[Dict[str, Tensor]]:
  257. if self.training:
  258. return result
  259. for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
  260. boxes = pred["boxes"]
  261. boxes = resize_boxes(boxes, im_s, o_im_s)
  262. result[i]["boxes"] = boxes
  263. if "masks" in pred:
  264. masks = pred["masks"]
  265. masks = paste_masks_in_image(masks, boxes, o_im_s)
  266. result[i]["masks"] = masks
  267. if "keypoints" in pred:
  268. keypoints = pred["keypoints"]
  269. keypoints = resize_keypoints(keypoints, im_s, o_im_s)
  270. result[i]["keypoints"] = keypoints
  271. if "lines" in pred:
  272. keypoints = pred["lines"]
  273. keypoints = resize_keypoints(keypoints, im_s, o_im_s)
  274. result[i]["lines"] = keypoints
  275. if "points" in pred:
  276. keypoints = pred["points"]
  277. keypoints = resize_keypoints(keypoints, im_s, o_im_s)
  278. result[i]["points"] = keypoints
  279. if "circles" in pred:
  280. keypoints = pred["circles"]
  281. keypoints = resize_keypoints(keypoints, im_s, o_im_s)
  282. result[i]["circles"] = keypoints
  283. if "circle_masks" in pred:
  284. masks = pred["circle_masks"]
  285. masks = paste_masks_in_image(masks, boxes, o_im_s)
  286. result[i]["circle_masks"] = masks
  287. if "arcs" in pred:
  288. arc_mask = pred["arcs"]
  289. # print(f'arcs in 1:{arc_mask}')
  290. arc_mask[:,0:2] = resize_keypoints(arc_mask[:,0:2],im_s, o_im_s)
  291. arc_mask[:,2:4] = resize_keypoints(arc_mask[:,2:4], im_s, o_im_s)
  292. arc_mask[:,5:7] = resize_keypoints(arc_mask[:,5:7], im_s, o_im_s)
  293. arc_mask[:,7:9] = resize_keypoints(arc_mask[:,7:9], im_s, o_im_s)
  294. # print(f'arcs in 2:{arc_mask}')
  295. result[i]["arcs"] = arc_mask
  296. return result
  297. def __repr__(self) -> str:
  298. format_string = f"{self.__class__.__name__}("
  299. _indent = "\n "
  300. format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})"
  301. format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')"
  302. format_string += "\n)"
  303. return format_string
  304. def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
  305. ratios = [
  306. torch.tensor(s, dtype=torch.float32, device=keypoints.device)
  307. / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
  308. for s, s_orig in zip(new_size, original_size)
  309. ]
  310. ratio_h, ratio_w = ratios
  311. resized_data = keypoints.clone()
  312. if torch._C._get_tracing_state():
  313. resized_data_0 = resized_data[:, :, 0] * ratio_w
  314. resized_data_1 = resized_data[:, :, 1] * ratio_h
  315. resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
  316. else:
  317. resized_data[..., 0] *= ratio_w
  318. resized_data[..., 1] *= ratio_h
  319. return resized_data
  320. def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
  321. ratios = [
  322. torch.tensor(s, dtype=torch.float32, device=boxes.device)
  323. / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
  324. for s, s_orig in zip(new_size, original_size)
  325. ]
  326. ratio_height, ratio_width = ratios
  327. xmin, ymin, xmax, ymax = boxes.unbind(1)
  328. xmin = xmin * ratio_width
  329. xmax = xmax * ratio_width
  330. ymin = ymin * ratio_height
  331. ymax = ymax * ratio_height
  332. return torch.stack((xmin, ymin, xmax, ymax), dim=1)