_auto_augment.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. import math
  2. from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
  3. import PIL.Image
  4. import torch
  5. from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
  6. from torchvision import transforms as _transforms, tv_tensors
  7. from torchvision.transforms import _functional_tensor as _FT
  8. from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
  9. from torchvision.transforms.v2.functional._geometry import _check_interpolation
  10. from torchvision.transforms.v2.functional._meta import get_size
  11. from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
  12. from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor
  13. ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, tv_tensors.Image, tv_tensors.Video]
  14. class _AutoAugmentBase(Transform):
  15. def __init__(
  16. self,
  17. *,
  18. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  19. fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
  20. ) -> None:
  21. super().__init__()
  22. self.interpolation = _check_interpolation(interpolation)
  23. self.fill = fill
  24. self._fill = _setup_fill_arg(fill)
  25. def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
  26. params = super()._extract_params_for_v1_transform()
  27. if isinstance(params["fill"], dict):
  28. raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.")
  29. return params
  30. def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
  31. keys = tuple(dct.keys())
  32. key = keys[int(torch.randint(len(keys), ()))]
  33. return key, dct[key]
  34. def _flatten_and_extract_image_or_video(
  35. self,
  36. inputs: Any,
  37. unsupported_types: Tuple[Type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask),
  38. ) -> Tuple[Tuple[List[Any], TreeSpec, int], ImageOrVideo]:
  39. flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
  40. needs_transform_list = self._needs_transform_list(flat_inputs)
  41. image_or_videos = []
  42. for idx, (inpt, needs_transform) in enumerate(zip(flat_inputs, needs_transform_list)):
  43. if needs_transform and check_type(
  44. inpt,
  45. (
  46. tv_tensors.Image,
  47. PIL.Image.Image,
  48. is_pure_tensor,
  49. tv_tensors.Video,
  50. ),
  51. ):
  52. image_or_videos.append((idx, inpt))
  53. elif isinstance(inpt, unsupported_types):
  54. raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
  55. if not image_or_videos:
  56. raise TypeError("Found no image in the sample.")
  57. if len(image_or_videos) > 1:
  58. raise TypeError(
  59. f"Auto augment transformations are only properly defined for a single image or video, "
  60. f"but found {len(image_or_videos)}."
  61. )
  62. idx, image_or_video = image_or_videos[0]
  63. return (flat_inputs, spec, idx), image_or_video
  64. def _unflatten_and_insert_image_or_video(
  65. self,
  66. flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
  67. image_or_video: ImageOrVideo,
  68. ) -> Any:
  69. flat_inputs, spec, idx = flat_inputs_with_spec
  70. flat_inputs[idx] = image_or_video
  71. return tree_unflatten(flat_inputs, spec)
  72. def _apply_image_or_video_transform(
  73. self,
  74. image: ImageOrVideo,
  75. transform_id: str,
  76. magnitude: float,
  77. interpolation: Union[InterpolationMode, int],
  78. fill: Dict[Union[Type, str], _FillTypeJIT],
  79. ) -> ImageOrVideo:
  80. fill_ = _get_fill(fill, type(image))
  81. if transform_id == "Identity":
  82. return image
  83. elif transform_id == "ShearX":
  84. # magnitude should be arctan(magnitude)
  85. # official autoaug: (1, level, 0, 0, 1, 0)
  86. # https://github.com/tensorflow/models/blob/dd02069717128186b88afa8d857ce57d17957f03/research/autoaugment/augmentation_transforms.py#L290
  87. # compared to
  88. # torchvision: (1, tan(level), 0, 0, 1, 0)
  89. # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976
  90. return F.affine(
  91. image,
  92. angle=0.0,
  93. translate=[0, 0],
  94. scale=1.0,
  95. shear=[math.degrees(math.atan(magnitude)), 0.0],
  96. interpolation=interpolation,
  97. fill=fill_,
  98. center=[0, 0],
  99. )
  100. elif transform_id == "ShearY":
  101. # magnitude should be arctan(magnitude)
  102. # See above
  103. return F.affine(
  104. image,
  105. angle=0.0,
  106. translate=[0, 0],
  107. scale=1.0,
  108. shear=[0.0, math.degrees(math.atan(magnitude))],
  109. interpolation=interpolation,
  110. fill=fill_,
  111. center=[0, 0],
  112. )
  113. elif transform_id == "TranslateX":
  114. return F.affine(
  115. image,
  116. angle=0.0,
  117. translate=[int(magnitude), 0],
  118. scale=1.0,
  119. interpolation=interpolation,
  120. shear=[0.0, 0.0],
  121. fill=fill_,
  122. )
  123. elif transform_id == "TranslateY":
  124. return F.affine(
  125. image,
  126. angle=0.0,
  127. translate=[0, int(magnitude)],
  128. scale=1.0,
  129. interpolation=interpolation,
  130. shear=[0.0, 0.0],
  131. fill=fill_,
  132. )
  133. elif transform_id == "Rotate":
  134. return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_)
  135. elif transform_id == "Brightness":
  136. return F.adjust_brightness(image, brightness_factor=1.0 + magnitude)
  137. elif transform_id == "Color":
  138. return F.adjust_saturation(image, saturation_factor=1.0 + magnitude)
  139. elif transform_id == "Contrast":
  140. return F.adjust_contrast(image, contrast_factor=1.0 + magnitude)
  141. elif transform_id == "Sharpness":
  142. return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude)
  143. elif transform_id == "Posterize":
  144. return F.posterize(image, bits=int(magnitude))
  145. elif transform_id == "Solarize":
  146. bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
  147. return F.solarize(image, threshold=bound * magnitude)
  148. elif transform_id == "AutoContrast":
  149. return F.autocontrast(image)
  150. elif transform_id == "Equalize":
  151. return F.equalize(image)
  152. elif transform_id == "Invert":
  153. return F.invert(image)
  154. else:
  155. raise ValueError(f"No transform available for {transform_id}")
  156. class AutoAugment(_AutoAugmentBase):
  157. r"""AutoAugment data augmentation method based on
  158. `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
  159. This transformation works on images and videos only.
  160. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
  161. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
  162. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  163. Args:
  164. policy (AutoAugmentPolicy, optional): Desired policy enum defined by
  165. :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
  166. interpolation (InterpolationMode, optional): Desired interpolation enum defined by
  167. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  168. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  169. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  170. image. If given a number, the value is used for all bands respectively.
  171. """
  172. _v1_transform_cls = _transforms.AutoAugment
  173. _AUGMENTATION_SPACE = {
  174. "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
  175. "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
  176. "TranslateX": (
  177. lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
  178. True,
  179. ),
  180. "TranslateY": (
  181. lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
  182. True,
  183. ),
  184. "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
  185. "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  186. "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  187. "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  188. "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  189. "Posterize": (
  190. lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
  191. False,
  192. ),
  193. "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
  194. "AutoContrast": (lambda num_bins, height, width: None, False),
  195. "Equalize": (lambda num_bins, height, width: None, False),
  196. "Invert": (lambda num_bins, height, width: None, False),
  197. }
  198. def __init__(
  199. self,
  200. policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
  201. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  202. fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
  203. ) -> None:
  204. super().__init__(interpolation=interpolation, fill=fill)
  205. self.policy = policy
  206. self._policies = self._get_policies(policy)
  207. def _get_policies(
  208. self, policy: AutoAugmentPolicy
  209. ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
  210. if policy == AutoAugmentPolicy.IMAGENET:
  211. return [
  212. (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
  213. (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
  214. (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
  215. (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
  216. (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
  217. (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
  218. (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
  219. (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
  220. (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
  221. (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
  222. (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
  223. (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
  224. (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
  225. (("Invert", 0.6, None), ("Equalize", 1.0, None)),
  226. (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
  227. (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
  228. (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
  229. (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
  230. (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
  231. (("Color", 0.4, 0), ("Equalize", 0.6, None)),
  232. (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
  233. (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
  234. (("Invert", 0.6, None), ("Equalize", 1.0, None)),
  235. (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
  236. (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
  237. ]
  238. elif policy == AutoAugmentPolicy.CIFAR10:
  239. return [
  240. (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
  241. (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
  242. (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
  243. (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
  244. (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
  245. (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
  246. (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
  247. (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
  248. (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
  249. (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
  250. (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
  251. (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
  252. (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
  253. (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
  254. (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
  255. (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
  256. (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
  257. (("Color", 0.9, 9), ("Equalize", 0.6, None)),
  258. (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
  259. (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
  260. (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
  261. (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
  262. (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
  263. (("Equalize", 0.8, None), ("Invert", 0.1, None)),
  264. (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
  265. ]
  266. elif policy == AutoAugmentPolicy.SVHN:
  267. return [
  268. (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
  269. (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
  270. (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
  271. (("Invert", 0.9, None), ("Equalize", 0.6, None)),
  272. (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
  273. (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
  274. (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
  275. (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
  276. (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
  277. (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
  278. (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
  279. (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
  280. (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
  281. (("Invert", 0.9, None), ("Equalize", 0.6, None)),
  282. (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
  283. (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
  284. (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
  285. (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
  286. (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
  287. (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
  288. (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
  289. (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
  290. (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
  291. (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
  292. (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
  293. ]
  294. else:
  295. raise ValueError(f"The provided policy {policy} is not recognized.")
  296. def forward(self, *inputs: Any) -> Any:
  297. flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
  298. height, width = get_size(image_or_video)
  299. policy = self._policies[int(torch.randint(len(self._policies), ()))]
  300. for transform_id, probability, magnitude_idx in policy:
  301. if not torch.rand(()) <= probability:
  302. continue
  303. magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
  304. magnitudes = magnitudes_fn(10, height, width)
  305. if magnitudes is not None:
  306. magnitude = float(magnitudes[magnitude_idx])
  307. if signed and torch.rand(()) <= 0.5:
  308. magnitude *= -1
  309. else:
  310. magnitude = 0.0
  311. image_or_video = self._apply_image_or_video_transform(
  312. image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
  313. )
  314. return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
  315. class RandAugment(_AutoAugmentBase):
  316. r"""RandAugment data augmentation method based on
  317. `"RandAugment: Practical automated data augmentation with a reduced search space"
  318. <https://arxiv.org/abs/1909.13719>`_.
  319. This transformation works on images and videos only.
  320. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
  321. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
  322. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  323. Args:
  324. num_ops (int, optional): Number of augmentation transformations to apply sequentially.
  325. magnitude (int, optional): Magnitude for all the transformations.
  326. num_magnitude_bins (int, optional): The number of different magnitude values.
  327. interpolation (InterpolationMode, optional): Desired interpolation enum defined by
  328. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  329. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  330. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  331. image. If given a number, the value is used for all bands respectively.
  332. """
  333. _v1_transform_cls = _transforms.RandAugment
  334. _AUGMENTATION_SPACE = {
  335. "Identity": (lambda num_bins, height, width: None, False),
  336. "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
  337. "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
  338. "TranslateX": (
  339. lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * width, num_bins),
  340. True,
  341. ),
  342. "TranslateY": (
  343. lambda num_bins, height, width: torch.linspace(0.0, 150.0 / 331.0 * height, num_bins),
  344. True,
  345. ),
  346. "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
  347. "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  348. "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  349. "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  350. "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  351. "Posterize": (
  352. lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
  353. False,
  354. ),
  355. "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
  356. "AutoContrast": (lambda num_bins, height, width: None, False),
  357. "Equalize": (lambda num_bins, height, width: None, False),
  358. }
  359. def __init__(
  360. self,
  361. num_ops: int = 2,
  362. magnitude: int = 9,
  363. num_magnitude_bins: int = 31,
  364. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  365. fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
  366. ) -> None:
  367. super().__init__(interpolation=interpolation, fill=fill)
  368. self.num_ops = num_ops
  369. self.magnitude = magnitude
  370. self.num_magnitude_bins = num_magnitude_bins
  371. def forward(self, *inputs: Any) -> Any:
  372. flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
  373. height, width = get_size(image_or_video)
  374. for _ in range(self.num_ops):
  375. transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
  376. magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
  377. if magnitudes is not None:
  378. magnitude = float(magnitudes[self.magnitude])
  379. if signed and torch.rand(()) <= 0.5:
  380. magnitude *= -1
  381. else:
  382. magnitude = 0.0
  383. image_or_video = self._apply_image_or_video_transform(
  384. image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
  385. )
  386. return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
  387. class TrivialAugmentWide(_AutoAugmentBase):
  388. r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
  389. `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
  390. This transformation works on images and videos only.
  391. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
  392. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
  393. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  394. Args:
  395. num_magnitude_bins (int, optional): The number of different magnitude values.
  396. interpolation (InterpolationMode, optional): Desired interpolation enum defined by
  397. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  398. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  399. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  400. image. If given a number, the value is used for all bands respectively.
  401. """
  402. _v1_transform_cls = _transforms.TrivialAugmentWide
  403. _AUGMENTATION_SPACE = {
  404. "Identity": (lambda num_bins, height, width: None, False),
  405. "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
  406. "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
  407. "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
  408. "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, 32.0, num_bins), True),
  409. "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 135.0, num_bins), True),
  410. "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
  411. "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
  412. "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
  413. "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
  414. "Posterize": (
  415. lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
  416. False,
  417. ),
  418. "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
  419. "AutoContrast": (lambda num_bins, height, width: None, False),
  420. "Equalize": (lambda num_bins, height, width: None, False),
  421. }
  422. def __init__(
  423. self,
  424. num_magnitude_bins: int = 31,
  425. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  426. fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
  427. ):
  428. super().__init__(interpolation=interpolation, fill=fill)
  429. self.num_magnitude_bins = num_magnitude_bins
  430. def forward(self, *inputs: Any) -> Any:
  431. flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
  432. height, width = get_size(image_or_video)
  433. transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
  434. magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
  435. if magnitudes is not None:
  436. magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
  437. if signed and torch.rand(()) <= 0.5:
  438. magnitude *= -1
  439. else:
  440. magnitude = 0.0
  441. image_or_video = self._apply_image_or_video_transform(
  442. image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
  443. )
  444. return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
  445. class AugMix(_AutoAugmentBase):
  446. r"""AugMix data augmentation method based on
  447. `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
  448. This transformation works on images and videos only.
  449. If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected
  450. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
  451. If img is PIL Image, it is expected to be in mode "L" or "RGB".
  452. Args:
  453. severity (int, optional): The severity of base augmentation operators. Default is ``3``.
  454. mixture_width (int, optional): The number of augmentation chains. Default is ``3``.
  455. chain_depth (int, optional): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
  456. Default is ``-1``.
  457. alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``.
  458. all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
  459. interpolation (InterpolationMode, optional): Desired interpolation enum defined by
  460. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
  461. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
  462. fill (sequence or number, optional): Pixel fill value for the area outside the transformed
  463. image. If given a number, the value is used for all bands respectively.
  464. """
  465. _v1_transform_cls = _transforms.AugMix
  466. _PARTIAL_AUGMENTATION_SPACE = {
  467. "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
  468. "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
  469. "TranslateX": (lambda num_bins, height, width: torch.linspace(0.0, width / 3.0, num_bins), True),
  470. "TranslateY": (lambda num_bins, height, width: torch.linspace(0.0, height / 3.0, num_bins), True),
  471. "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins), True),
  472. "Posterize": (
  473. lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
  474. False,
  475. ),
  476. "Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
  477. "AutoContrast": (lambda num_bins, height, width: None, False),
  478. "Equalize": (lambda num_bins, height, width: None, False),
  479. }
  480. _AUGMENTATION_SPACE: Dict[str, Tuple[Callable[[int, int, int], Optional[torch.Tensor]], bool]] = {
  481. **_PARTIAL_AUGMENTATION_SPACE,
  482. "Brightness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  483. "Color": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  484. "Contrast": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  485. "Sharpness": (lambda num_bins, height, width: torch.linspace(0.0, 0.9, num_bins), True),
  486. }
  487. def __init__(
  488. self,
  489. severity: int = 3,
  490. mixture_width: int = 3,
  491. chain_depth: int = -1,
  492. alpha: float = 1.0,
  493. all_ops: bool = True,
  494. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  495. fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
  496. ) -> None:
  497. super().__init__(interpolation=interpolation, fill=fill)
  498. self._PARAMETER_MAX = 10
  499. if not (1 <= severity <= self._PARAMETER_MAX):
  500. raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
  501. self.severity = severity
  502. self.mixture_width = mixture_width
  503. self.chain_depth = chain_depth
  504. self.alpha = alpha
  505. self.all_ops = all_ops
  506. def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
  507. # Must be on a separate method so that we can overwrite it in tests.
  508. return torch._sample_dirichlet(params)
  509. def forward(self, *inputs: Any) -> Any:
  510. flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
  511. height, width = get_size(orig_image_or_video)
  512. if isinstance(orig_image_or_video, torch.Tensor):
  513. image_or_video = orig_image_or_video
  514. else: # isinstance(inpt, PIL.Image.Image):
  515. image_or_video = F.pil_to_tensor(orig_image_or_video)
  516. augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
  517. orig_dims = list(image_or_video.shape)
  518. expected_ndim = 5 if isinstance(orig_image_or_video, tv_tensors.Video) else 4
  519. batch = image_or_video.reshape([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
  520. batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
  521. # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
  522. # Dirichlet with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of
  523. # augmented image or video.
  524. m = self._sample_dirichlet(
  525. torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
  526. )
  527. # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images or videos.
  528. combined_weights = self._sample_dirichlet(
  529. torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
  530. ) * m[:, 1].reshape([batch_dims[0], -1])
  531. mix = m[:, 0].reshape(batch_dims) * batch
  532. for i in range(self.mixture_width):
  533. aug = batch
  534. depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
  535. for _ in range(depth):
  536. transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
  537. magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
  538. if magnitudes is not None:
  539. magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
  540. if signed and torch.rand(()) <= 0.5:
  541. magnitude *= -1
  542. else:
  543. magnitude = 0.0
  544. aug = self._apply_image_or_video_transform(
  545. aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
  546. )
  547. mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
  548. mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
  549. if isinstance(orig_image_or_video, (tv_tensors.Image, tv_tensors.Video)):
  550. mix = tv_tensors.wrap(mix, like=orig_image_or_video)
  551. elif isinstance(orig_image_or_video, PIL.Image.Image):
  552. mix = F.to_pil_image(mix)
  553. return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)