_geometry.py 84 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330
  1. import math
  2. import numbers
  3. import warnings
  4. from typing import Any, List, Optional, Sequence, Tuple, Union
  5. import PIL.Image
  6. import torch
  7. from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
  8. from torchvision import tv_tensors
  9. from torchvision.transforms import _functional_pil as _FP
  10. from torchvision.transforms._functional_tensor import _pad_symmetric
  11. from torchvision.transforms.functional import (
  12. _compute_resized_output_size as __compute_resized_output_size,
  13. _get_perspective_coeffs,
  14. _interpolation_modes_from_int,
  15. InterpolationMode,
  16. pil_modes_mapping,
  17. pil_to_tensor,
  18. to_pil_image,
  19. )
  20. from torchvision.utils import _log_api_usage_once
  21. from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format
  22. from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
  23. def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
  24. if isinstance(interpolation, int):
  25. interpolation = _interpolation_modes_from_int(interpolation)
  26. elif not isinstance(interpolation, InterpolationMode):
  27. raise ValueError(
  28. f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
  29. f"but got {interpolation}."
  30. )
  31. return interpolation
  32. def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
  33. """See :class:`~torchvision.transforms.v2.RandomHorizontalFlip` for details."""
  34. if torch.jit.is_scripting():
  35. return horizontal_flip_image(inpt)
  36. _log_api_usage_once(horizontal_flip)
  37. kernel = _get_kernel(horizontal_flip, type(inpt))
  38. return kernel(inpt)
  39. @_register_kernel_internal(horizontal_flip, torch.Tensor)
  40. @_register_kernel_internal(horizontal_flip, tv_tensors.Image)
  41. def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor:
  42. return image.flip(-1)
  43. @_register_kernel_internal(horizontal_flip, PIL.Image.Image)
  44. def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
  45. return _FP.hflip(image)
  46. @_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
  47. def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
  48. return horizontal_flip_image(mask)
  49. def horizontal_flip_bounding_boxes(
  50. bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
  51. ) -> torch.Tensor:
  52. shape = bounding_boxes.shape
  53. bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
  54. if format == tv_tensors.BoundingBoxFormat.XYXY:
  55. bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_()
  56. elif format == tv_tensors.BoundingBoxFormat.XYWH:
  57. bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_()
  58. else: # format == tv_tensors.BoundingBoxFormat.CXCYWH:
  59. bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
  60. return bounding_boxes.reshape(shape)
  61. @_register_kernel_internal(horizontal_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  62. def _horizontal_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
  63. output = horizontal_flip_bounding_boxes(
  64. inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
  65. )
  66. return tv_tensors.wrap(output, like=inpt)
  67. @_register_kernel_internal(horizontal_flip, tv_tensors.Video)
  68. def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
  69. return horizontal_flip_image(video)
  70. def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
  71. """See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details."""
  72. if torch.jit.is_scripting():
  73. return vertical_flip_image(inpt)
  74. _log_api_usage_once(vertical_flip)
  75. kernel = _get_kernel(vertical_flip, type(inpt))
  76. return kernel(inpt)
  77. @_register_kernel_internal(vertical_flip, torch.Tensor)
  78. @_register_kernel_internal(vertical_flip, tv_tensors.Image)
  79. def vertical_flip_image(image: torch.Tensor) -> torch.Tensor:
  80. return image.flip(-2)
  81. @_register_kernel_internal(vertical_flip, PIL.Image.Image)
  82. def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
  83. return _FP.vflip(image)
  84. @_register_kernel_internal(vertical_flip, tv_tensors.Mask)
  85. def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
  86. return vertical_flip_image(mask)
  87. def vertical_flip_bounding_boxes(
  88. bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: Tuple[int, int]
  89. ) -> torch.Tensor:
  90. shape = bounding_boxes.shape
  91. bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
  92. if format == tv_tensors.BoundingBoxFormat.XYXY:
  93. bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_()
  94. elif format == tv_tensors.BoundingBoxFormat.XYWH:
  95. bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_()
  96. else: # format == tv_tensors.BoundingBoxFormat.CXCYWH:
  97. bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
  98. return bounding_boxes.reshape(shape)
  99. @_register_kernel_internal(vertical_flip, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  100. def _vertical_flip_bounding_boxes_dispatch(inpt: tv_tensors.BoundingBoxes) -> tv_tensors.BoundingBoxes:
  101. output = vertical_flip_bounding_boxes(
  102. inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
  103. )
  104. return tv_tensors.wrap(output, like=inpt)
  105. @_register_kernel_internal(vertical_flip, tv_tensors.Video)
  106. def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
  107. return vertical_flip_image(video)
  108. # We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
  109. # prevalent and well understood. Thus, we just alias them without deprecating the old names.
  110. hflip = horizontal_flip
  111. vflip = vertical_flip
  112. def _compute_resized_output_size(
  113. canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
  114. ) -> List[int]:
  115. if isinstance(size, int):
  116. size = [size]
  117. elif max_size is not None and len(size) != 1:
  118. raise ValueError(
  119. "max_size should only be passed if size specifies the length of the smaller edge, "
  120. "i.e. size should be an int or a sequence of length 1 in torchscript mode."
  121. )
  122. return __compute_resized_output_size(canvas_size, size=size, max_size=max_size)
  123. def resize(
  124. inpt: torch.Tensor,
  125. size: List[int],
  126. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  127. max_size: Optional[int] = None,
  128. antialias: Optional[bool] = True,
  129. ) -> torch.Tensor:
  130. """See :class:`~torchvision.transforms.v2.Resize` for details."""
  131. if torch.jit.is_scripting():
  132. return resize_image(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
  133. _log_api_usage_once(resize)
  134. kernel = _get_kernel(resize, type(inpt))
  135. return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
  136. # This is an internal helper method for resize_image. We should put it here instead of keeping it
  137. # inside resize_image due to torchscript.
  138. # uint8 dtype support for bilinear and bicubic is limited to cpu and
  139. # according to our benchmarks on eager, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
  140. def _do_native_uint8_resize_on_cpu(interpolation: InterpolationMode) -> bool:
  141. if interpolation == InterpolationMode.BILINEAR:
  142. if torch._dynamo.is_compiling():
  143. return True
  144. else:
  145. return "AVX2" in torch.backends.cpu.get_cpu_capability()
  146. return interpolation == InterpolationMode.BICUBIC
  147. @_register_kernel_internal(resize, torch.Tensor)
  148. @_register_kernel_internal(resize, tv_tensors.Image)
  149. def resize_image(
  150. image: torch.Tensor,
  151. size: List[int],
  152. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  153. max_size: Optional[int] = None,
  154. antialias: Optional[bool] = True,
  155. ) -> torch.Tensor:
  156. interpolation = _check_interpolation(interpolation)
  157. antialias = False if antialias is None else antialias
  158. align_corners: Optional[bool] = None
  159. if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
  160. align_corners = False
  161. else:
  162. # The default of antialias is True from 0.17, so we don't warn or
  163. # error if other interpolation modes are used. This is documented.
  164. antialias = False
  165. shape = image.shape
  166. numel = image.numel()
  167. num_channels, old_height, old_width = shape[-3:]
  168. new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
  169. if (new_height, new_width) == (old_height, old_width):
  170. return image
  171. elif numel > 0:
  172. dtype = image.dtype
  173. acceptable_dtypes = [torch.float32, torch.float64]
  174. if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
  175. # uint8 dtype can be included for cpu and cuda input if nearest mode
  176. acceptable_dtypes.append(torch.uint8)
  177. elif image.device.type == "cpu":
  178. if _do_native_uint8_resize_on_cpu(interpolation):
  179. acceptable_dtypes.append(torch.uint8)
  180. image = image.reshape(-1, num_channels, old_height, old_width)
  181. strides = image.stride()
  182. if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
  183. # There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
  184. # contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
  185. # In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
  186. # to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
  187. # channels_last, thus preserving the memory format of the input. This is not just for format consistency:
  188. # for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
  189. # TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
  190. # we should be able to remove this hack.
  191. new_strides = list(strides)
  192. new_strides[0] = numel
  193. image = image.as_strided((1, num_channels, old_height, old_width), new_strides)
  194. need_cast = dtype not in acceptable_dtypes
  195. if need_cast:
  196. image = image.to(dtype=torch.float32)
  197. image = interpolate(
  198. image,
  199. size=[new_height, new_width],
  200. mode=interpolation.value,
  201. align_corners=align_corners,
  202. antialias=antialias,
  203. )
  204. if need_cast:
  205. if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
  206. # This path is hit on non-AVX archs, or on GPU.
  207. image = image.clamp_(min=0, max=255)
  208. if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
  209. image = image.round_()
  210. image = image.to(dtype=dtype)
  211. return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
  212. def _resize_image_pil(
  213. image: PIL.Image.Image,
  214. size: Union[Sequence[int], int],
  215. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  216. max_size: Optional[int] = None,
  217. ) -> PIL.Image.Image:
  218. old_height, old_width = image.height, image.width
  219. new_height, new_width = _compute_resized_output_size(
  220. (old_height, old_width),
  221. size=size, # type: ignore[arg-type]
  222. max_size=max_size,
  223. )
  224. interpolation = _check_interpolation(interpolation)
  225. if (new_height, new_width) == (old_height, old_width):
  226. return image
  227. return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])
  228. @_register_kernel_internal(resize, PIL.Image.Image)
  229. def __resize_image_pil_dispatch(
  230. image: PIL.Image.Image,
  231. size: Union[Sequence[int], int],
  232. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  233. max_size: Optional[int] = None,
  234. antialias: Optional[bool] = True,
  235. ) -> PIL.Image.Image:
  236. if antialias is False:
  237. warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
  238. return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size)
  239. def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
  240. if mask.ndim < 3:
  241. mask = mask.unsqueeze(0)
  242. needs_squeeze = True
  243. else:
  244. needs_squeeze = False
  245. output = resize_image(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
  246. if needs_squeeze:
  247. output = output.squeeze(0)
  248. return output
  249. @_register_kernel_internal(resize, tv_tensors.Mask, tv_tensor_wrapper=False)
  250. def _resize_mask_dispatch(
  251. inpt: tv_tensors.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any
  252. ) -> tv_tensors.Mask:
  253. output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size)
  254. return tv_tensors.wrap(output, like=inpt)
  255. def resize_bounding_boxes(
  256. bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
  257. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  258. old_height, old_width = canvas_size
  259. new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size)
  260. if (new_height, new_width) == (old_height, old_width):
  261. return bounding_boxes, canvas_size
  262. w_ratio = new_width / old_width
  263. h_ratio = new_height / old_height
  264. ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_boxes.device)
  265. return (
  266. bounding_boxes.mul(ratios).to(bounding_boxes.dtype),
  267. (new_height, new_width),
  268. )
  269. @_register_kernel_internal(resize, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  270. def _resize_bounding_boxes_dispatch(
  271. inpt: tv_tensors.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any
  272. ) -> tv_tensors.BoundingBoxes:
  273. output, canvas_size = resize_bounding_boxes(
  274. inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
  275. )
  276. return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
  277. @_register_kernel_internal(resize, tv_tensors.Video)
  278. def resize_video(
  279. video: torch.Tensor,
  280. size: List[int],
  281. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  282. max_size: Optional[int] = None,
  283. antialias: Optional[bool] = True,
  284. ) -> torch.Tensor:
  285. return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
  286. def affine(
  287. inpt: torch.Tensor,
  288. angle: Union[int, float],
  289. translate: List[float],
  290. scale: float,
  291. shear: List[float],
  292. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  293. fill: _FillTypeJIT = None,
  294. center: Optional[List[float]] = None,
  295. ) -> torch.Tensor:
  296. """See :class:`~torchvision.transforms.v2.RandomAffine` for details."""
  297. if torch.jit.is_scripting():
  298. return affine_image(
  299. inpt,
  300. angle=angle,
  301. translate=translate,
  302. scale=scale,
  303. shear=shear,
  304. interpolation=interpolation,
  305. fill=fill,
  306. center=center,
  307. )
  308. _log_api_usage_once(affine)
  309. kernel = _get_kernel(affine, type(inpt))
  310. return kernel(
  311. inpt,
  312. angle=angle,
  313. translate=translate,
  314. scale=scale,
  315. shear=shear,
  316. interpolation=interpolation,
  317. fill=fill,
  318. center=center,
  319. )
  320. def _affine_parse_args(
  321. angle: Union[int, float],
  322. translate: List[float],
  323. scale: float,
  324. shear: List[float],
  325. interpolation: InterpolationMode = InterpolationMode.NEAREST,
  326. center: Optional[List[float]] = None,
  327. ) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
  328. if not isinstance(angle, (int, float)):
  329. raise TypeError("Argument angle should be int or float")
  330. if not isinstance(translate, (list, tuple)):
  331. raise TypeError("Argument translate should be a sequence")
  332. if len(translate) != 2:
  333. raise ValueError("Argument translate should be a sequence of length 2")
  334. if scale <= 0.0:
  335. raise ValueError("Argument scale should be positive")
  336. if not isinstance(shear, (numbers.Number, (list, tuple))):
  337. raise TypeError("Shear should be either a single value or a sequence of two values")
  338. if not isinstance(interpolation, InterpolationMode):
  339. raise TypeError("Argument interpolation should be a InterpolationMode")
  340. if isinstance(angle, int):
  341. angle = float(angle)
  342. if isinstance(translate, tuple):
  343. translate = list(translate)
  344. if isinstance(shear, numbers.Number):
  345. shear = [shear, 0.0]
  346. if isinstance(shear, tuple):
  347. shear = list(shear)
  348. if len(shear) == 1:
  349. shear = [shear[0], shear[0]]
  350. if len(shear) != 2:
  351. raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
  352. if center is not None:
  353. if not isinstance(center, (list, tuple)):
  354. raise TypeError("Argument center should be a sequence")
  355. else:
  356. center = [float(c) for c in center]
  357. return angle, translate, shear, center
  358. def _get_inverse_affine_matrix(
  359. center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
  360. ) -> List[float]:
  361. # Helper method to compute inverse matrix for affine transformation
  362. # Pillow requires inverse affine transformation matrix:
  363. # Affine matrix is : M = T * C * RotateScaleShear * C^-1
  364. #
  365. # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
  366. # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
  367. # RotateScaleShear is rotation with scale and shear matrix
  368. #
  369. # RotateScaleShear(a, s, (sx, sy)) =
  370. # = R(a) * S(s) * SHy(sy) * SHx(sx)
  371. # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
  372. # [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
  373. # [ 0 , 0 , 1 ]
  374. # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
  375. # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
  376. # [0, 1 ] [-tan(s), 1]
  377. #
  378. # Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
  379. rot = math.radians(angle)
  380. sx = math.radians(shear[0])
  381. sy = math.radians(shear[1])
  382. cx, cy = center
  383. tx, ty = translate
  384. # Cached results
  385. cos_sy = math.cos(sy)
  386. tan_sx = math.tan(sx)
  387. rot_minus_sy = rot - sy
  388. cx_plus_tx = cx + tx
  389. cy_plus_ty = cy + ty
  390. # Rotate Scale Shear (RSS) without scaling
  391. a = math.cos(rot_minus_sy) / cos_sy
  392. b = -(a * tan_sx + math.sin(rot))
  393. c = math.sin(rot_minus_sy) / cos_sy
  394. d = math.cos(rot) - c * tan_sx
  395. if inverted:
  396. # Inverted rotation matrix with scale and shear
  397. # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
  398. matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
  399. # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
  400. # and then apply center translation: C * RSS^-1 * C^-1 * T^-1
  401. matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
  402. matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
  403. else:
  404. matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
  405. # Apply inverse of center translation: RSS * C^-1
  406. # and then apply translation and center : T * C * RSS * C^-1
  407. matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
  408. matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy
  409. return matrix
  410. def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
  411. # Inspired of PIL implementation:
  412. # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
  413. # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
  414. # Points are shifted due to affine matrix torch convention about
  415. # the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
  416. half_w = 0.5 * w
  417. half_h = 0.5 * h
  418. pts = torch.tensor(
  419. [
  420. [-half_w, -half_h, 1.0],
  421. [-half_w, half_h, 1.0],
  422. [half_w, half_h, 1.0],
  423. [half_w, -half_h, 1.0],
  424. ]
  425. )
  426. theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
  427. new_pts = torch.matmul(pts, theta.T)
  428. min_vals, max_vals = new_pts.aminmax(dim=0)
  429. # shift points to [0, w] and [0, h] interval to match PIL results
  430. halfs = torch.tensor((half_w, half_h))
  431. min_vals.add_(halfs)
  432. max_vals.add_(halfs)
  433. # Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
  434. tol = 1e-4
  435. inv_tol = 1.0 / tol
  436. cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
  437. cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
  438. size = cmax.sub_(cmin)
  439. return int(size[0]), int(size[1]) # w, h
  440. def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor:
  441. input_shape = img.shape
  442. output_height, output_width = grid.shape[1], grid.shape[2]
  443. num_channels, input_height, input_width = input_shape[-3:]
  444. output_shape = input_shape[:-3] + (num_channels, output_height, output_width)
  445. if img.numel() == 0:
  446. return img.reshape(output_shape)
  447. img = img.reshape(-1, num_channels, input_height, input_width)
  448. squashed_batch_size = img.shape[0]
  449. # We are using context knowledge that grid should have float dtype
  450. fp = img.dtype == grid.dtype
  451. float_img = img if fp else img.to(grid.dtype)
  452. if squashed_batch_size > 1:
  453. # Apply same grid to a batch of images
  454. grid = grid.expand(squashed_batch_size, -1, -1, -1)
  455. # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
  456. if fill is not None:
  457. mask = torch.ones(
  458. (squashed_batch_size, 1, input_height, input_width), dtype=float_img.dtype, device=float_img.device
  459. )
  460. float_img = torch.cat((float_img, mask), dim=1)
  461. float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)
  462. # Fill with required color
  463. if fill is not None:
  464. float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
  465. mask = mask.expand_as(float_img)
  466. fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type]
  467. fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
  468. if mode == "nearest":
  469. bool_mask = mask < 0.5
  470. float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
  471. else: # 'bilinear'
  472. # The following is mathematically equivalent to:
  473. # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
  474. float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)
  475. img = float_img.round_().to(img.dtype) if not fp else float_img
  476. return img.reshape(output_shape)
  477. def _assert_grid_transform_inputs(
  478. image: torch.Tensor,
  479. matrix: Optional[List[float]],
  480. interpolation: str,
  481. fill: _FillTypeJIT,
  482. supported_interpolation_modes: List[str],
  483. coeffs: Optional[List[float]] = None,
  484. ) -> None:
  485. if matrix is not None:
  486. if not isinstance(matrix, list):
  487. raise TypeError("Argument matrix should be a list")
  488. elif len(matrix) != 6:
  489. raise ValueError("Argument matrix should have 6 float values")
  490. if coeffs is not None and len(coeffs) != 8:
  491. raise ValueError("Argument coeffs should have 8 float values")
  492. if fill is not None:
  493. if isinstance(fill, (tuple, list)):
  494. length = len(fill)
  495. num_channels = image.shape[-3]
  496. if length > 1 and length != num_channels:
  497. raise ValueError(
  498. "The number of elements in 'fill' cannot broadcast to match the number of "
  499. f"channels of the image ({length} != {num_channels})"
  500. )
  501. elif not isinstance(fill, (int, float)):
  502. raise ValueError("Argument fill should be either int, float, tuple or list")
  503. if interpolation not in supported_interpolation_modes:
  504. raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
  505. def _affine_grid(
  506. theta: torch.Tensor,
  507. w: int,
  508. h: int,
  509. ow: int,
  510. oh: int,
  511. ) -> torch.Tensor:
  512. # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
  513. # AffineGridGenerator.cpp#L18
  514. # Difference with AffineGridGenerator is that:
  515. # 1) we normalize grid values after applying theta
  516. # 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
  517. dtype = theta.dtype
  518. device = theta.device
  519. base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
  520. x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
  521. base_grid[..., 0].copy_(x_grid)
  522. y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
  523. base_grid[..., 1].copy_(y_grid)
  524. base_grid[..., 2].fill_(1)
  525. rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
  526. output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
  527. return output_grid.view(1, oh, ow, 2)
  528. @_register_kernel_internal(affine, torch.Tensor)
  529. @_register_kernel_internal(affine, tv_tensors.Image)
  530. def affine_image(
  531. image: torch.Tensor,
  532. angle: Union[int, float],
  533. translate: List[float],
  534. scale: float,
  535. shear: List[float],
  536. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  537. fill: _FillTypeJIT = None,
  538. center: Optional[List[float]] = None,
  539. ) -> torch.Tensor:
  540. interpolation = _check_interpolation(interpolation)
  541. angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
  542. height, width = image.shape[-2:]
  543. center_f = [0.0, 0.0]
  544. if center is not None:
  545. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
  546. center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
  547. translate_f = [float(t) for t in translate]
  548. matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
  549. _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
  550. dtype = image.dtype if torch.is_floating_point(image) else torch.float32
  551. theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
  552. grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
  553. return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
  554. @_register_kernel_internal(affine, PIL.Image.Image)
  555. def _affine_image_pil(
  556. image: PIL.Image.Image,
  557. angle: Union[int, float],
  558. translate: List[float],
  559. scale: float,
  560. shear: List[float],
  561. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  562. fill: _FillTypeJIT = None,
  563. center: Optional[List[float]] = None,
  564. ) -> PIL.Image.Image:
  565. interpolation = _check_interpolation(interpolation)
  566. angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
  567. # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
  568. # it is visually better to estimate the center without 0.5 offset
  569. # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
  570. if center is None:
  571. height, width = _get_size_image_pil(image)
  572. center = [width * 0.5, height * 0.5]
  573. matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
  574. return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
  575. def _affine_bounding_boxes_with_expand(
  576. bounding_boxes: torch.Tensor,
  577. format: tv_tensors.BoundingBoxFormat,
  578. canvas_size: Tuple[int, int],
  579. angle: Union[int, float],
  580. translate: List[float],
  581. scale: float,
  582. shear: List[float],
  583. center: Optional[List[float]] = None,
  584. expand: bool = False,
  585. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  586. if bounding_boxes.numel() == 0:
  587. return bounding_boxes, canvas_size
  588. original_shape = bounding_boxes.shape
  589. original_dtype = bounding_boxes.dtype
  590. bounding_boxes = bounding_boxes.clone() if bounding_boxes.is_floating_point() else bounding_boxes.float()
  591. dtype = bounding_boxes.dtype
  592. device = bounding_boxes.device
  593. bounding_boxes = (
  594. convert_bounding_box_format(
  595. bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY, inplace=True
  596. )
  597. ).reshape(-1, 4)
  598. angle, translate, shear, center = _affine_parse_args(
  599. angle, translate, scale, shear, InterpolationMode.NEAREST, center
  600. )
  601. if center is None:
  602. height, width = canvas_size
  603. center = [width * 0.5, height * 0.5]
  604. affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
  605. transposed_affine_matrix = (
  606. torch.tensor(
  607. affine_vector,
  608. dtype=dtype,
  609. device=device,
  610. )
  611. .reshape(2, 3)
  612. .T
  613. )
  614. # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
  615. # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
  616. # Single point structure is similar to
  617. # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
  618. points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
  619. points = torch.cat([points, torch.ones(points.shape[0], 1, device=device, dtype=dtype)], dim=-1)
  620. # 2) Now let's transform the points using affine matrix
  621. transformed_points = torch.matmul(points, transposed_affine_matrix)
  622. # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
  623. # and compute bounding box from 4 transformed points:
  624. transformed_points = transformed_points.reshape(-1, 4, 2)
  625. out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
  626. out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
  627. if expand:
  628. # Compute minimum point for transformed image frame:
  629. # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
  630. height, width = canvas_size
  631. points = torch.tensor(
  632. [
  633. [0.0, 0.0, 1.0],
  634. [0.0, float(height), 1.0],
  635. [float(width), float(height), 1.0],
  636. [float(width), 0.0, 1.0],
  637. ],
  638. dtype=dtype,
  639. device=device,
  640. )
  641. new_points = torch.matmul(points, transposed_affine_matrix)
  642. tr = torch.amin(new_points, dim=0, keepdim=True)
  643. # Translate bounding boxes
  644. out_bboxes.sub_(tr.repeat((1, 2)))
  645. # Estimate meta-data for image with inverted=True
  646. affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
  647. new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
  648. canvas_size = (new_height, new_width)
  649. out_bboxes = clamp_bounding_boxes(out_bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size)
  650. out_bboxes = convert_bounding_box_format(
  651. out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
  652. ).reshape(original_shape)
  653. out_bboxes = out_bboxes.to(original_dtype)
  654. return out_bboxes, canvas_size
  655. def affine_bounding_boxes(
  656. bounding_boxes: torch.Tensor,
  657. format: tv_tensors.BoundingBoxFormat,
  658. canvas_size: Tuple[int, int],
  659. angle: Union[int, float],
  660. translate: List[float],
  661. scale: float,
  662. shear: List[float],
  663. center: Optional[List[float]] = None,
  664. ) -> torch.Tensor:
  665. out_box, _ = _affine_bounding_boxes_with_expand(
  666. bounding_boxes,
  667. format=format,
  668. canvas_size=canvas_size,
  669. angle=angle,
  670. translate=translate,
  671. scale=scale,
  672. shear=shear,
  673. center=center,
  674. expand=False,
  675. )
  676. return out_box
  677. @_register_kernel_internal(affine, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  678. def _affine_bounding_boxes_dispatch(
  679. inpt: tv_tensors.BoundingBoxes,
  680. angle: Union[int, float],
  681. translate: List[float],
  682. scale: float,
  683. shear: List[float],
  684. center: Optional[List[float]] = None,
  685. **kwargs,
  686. ) -> tv_tensors.BoundingBoxes:
  687. output = affine_bounding_boxes(
  688. inpt.as_subclass(torch.Tensor),
  689. format=inpt.format,
  690. canvas_size=inpt.canvas_size,
  691. angle=angle,
  692. translate=translate,
  693. scale=scale,
  694. shear=shear,
  695. center=center,
  696. )
  697. return tv_tensors.wrap(output, like=inpt)
  698. def affine_mask(
  699. mask: torch.Tensor,
  700. angle: Union[int, float],
  701. translate: List[float],
  702. scale: float,
  703. shear: List[float],
  704. fill: _FillTypeJIT = None,
  705. center: Optional[List[float]] = None,
  706. ) -> torch.Tensor:
  707. if mask.ndim < 3:
  708. mask = mask.unsqueeze(0)
  709. needs_squeeze = True
  710. else:
  711. needs_squeeze = False
  712. output = affine_image(
  713. mask,
  714. angle=angle,
  715. translate=translate,
  716. scale=scale,
  717. shear=shear,
  718. interpolation=InterpolationMode.NEAREST,
  719. fill=fill,
  720. center=center,
  721. )
  722. if needs_squeeze:
  723. output = output.squeeze(0)
  724. return output
  725. @_register_kernel_internal(affine, tv_tensors.Mask, tv_tensor_wrapper=False)
  726. def _affine_mask_dispatch(
  727. inpt: tv_tensors.Mask,
  728. angle: Union[int, float],
  729. translate: List[float],
  730. scale: float,
  731. shear: List[float],
  732. fill: _FillTypeJIT = None,
  733. center: Optional[List[float]] = None,
  734. **kwargs,
  735. ) -> tv_tensors.Mask:
  736. output = affine_mask(
  737. inpt.as_subclass(torch.Tensor),
  738. angle=angle,
  739. translate=translate,
  740. scale=scale,
  741. shear=shear,
  742. fill=fill,
  743. center=center,
  744. )
  745. return tv_tensors.wrap(output, like=inpt)
  746. @_register_kernel_internal(affine, tv_tensors.Video)
  747. def affine_video(
  748. video: torch.Tensor,
  749. angle: Union[int, float],
  750. translate: List[float],
  751. scale: float,
  752. shear: List[float],
  753. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  754. fill: _FillTypeJIT = None,
  755. center: Optional[List[float]] = None,
  756. ) -> torch.Tensor:
  757. return affine_image(
  758. video,
  759. angle=angle,
  760. translate=translate,
  761. scale=scale,
  762. shear=shear,
  763. interpolation=interpolation,
  764. fill=fill,
  765. center=center,
  766. )
  767. def rotate(
  768. inpt: torch.Tensor,
  769. angle: float,
  770. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  771. expand: bool = False,
  772. center: Optional[List[float]] = None,
  773. fill: _FillTypeJIT = None,
  774. ) -> torch.Tensor:
  775. """See :class:`~torchvision.transforms.v2.RandomRotation` for details."""
  776. if torch.jit.is_scripting():
  777. return rotate_image(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
  778. _log_api_usage_once(rotate)
  779. kernel = _get_kernel(rotate, type(inpt))
  780. return kernel(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
  781. @_register_kernel_internal(rotate, torch.Tensor)
  782. @_register_kernel_internal(rotate, tv_tensors.Image)
  783. def rotate_image(
  784. image: torch.Tensor,
  785. angle: float,
  786. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  787. expand: bool = False,
  788. center: Optional[List[float]] = None,
  789. fill: _FillTypeJIT = None,
  790. ) -> torch.Tensor:
  791. interpolation = _check_interpolation(interpolation)
  792. input_height, input_width = image.shape[-2:]
  793. center_f = [0.0, 0.0]
  794. if center is not None:
  795. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
  796. center_f = [(c - s * 0.5) for c, s in zip(center, [input_width, input_height])]
  797. # due to current incoherence of rotation angle direction between affine and rotate implementations
  798. # we need to set -angle.
  799. matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
  800. _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
  801. output_width, output_height = (
  802. _compute_affine_output_size(matrix, input_width, input_height) if expand else (input_width, input_height)
  803. )
  804. dtype = image.dtype if torch.is_floating_point(image) else torch.float32
  805. theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
  806. grid = _affine_grid(theta, w=input_width, h=input_height, ow=output_width, oh=output_height)
  807. return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
  808. @_register_kernel_internal(rotate, PIL.Image.Image)
  809. def _rotate_image_pil(
  810. image: PIL.Image.Image,
  811. angle: float,
  812. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  813. expand: bool = False,
  814. center: Optional[List[float]] = None,
  815. fill: _FillTypeJIT = None,
  816. ) -> PIL.Image.Image:
  817. interpolation = _check_interpolation(interpolation)
  818. return _FP.rotate(
  819. image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
  820. )
  821. def rotate_bounding_boxes(
  822. bounding_boxes: torch.Tensor,
  823. format: tv_tensors.BoundingBoxFormat,
  824. canvas_size: Tuple[int, int],
  825. angle: float,
  826. expand: bool = False,
  827. center: Optional[List[float]] = None,
  828. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  829. return _affine_bounding_boxes_with_expand(
  830. bounding_boxes,
  831. format=format,
  832. canvas_size=canvas_size,
  833. angle=-angle,
  834. translate=[0.0, 0.0],
  835. scale=1.0,
  836. shear=[0.0, 0.0],
  837. center=center,
  838. expand=expand,
  839. )
  840. @_register_kernel_internal(rotate, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  841. def _rotate_bounding_boxes_dispatch(
  842. inpt: tv_tensors.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs
  843. ) -> tv_tensors.BoundingBoxes:
  844. output, canvas_size = rotate_bounding_boxes(
  845. inpt.as_subclass(torch.Tensor),
  846. format=inpt.format,
  847. canvas_size=inpt.canvas_size,
  848. angle=angle,
  849. expand=expand,
  850. center=center,
  851. )
  852. return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
  853. def rotate_mask(
  854. mask: torch.Tensor,
  855. angle: float,
  856. expand: bool = False,
  857. center: Optional[List[float]] = None,
  858. fill: _FillTypeJIT = None,
  859. ) -> torch.Tensor:
  860. if mask.ndim < 3:
  861. mask = mask.unsqueeze(0)
  862. needs_squeeze = True
  863. else:
  864. needs_squeeze = False
  865. output = rotate_image(
  866. mask,
  867. angle=angle,
  868. expand=expand,
  869. interpolation=InterpolationMode.NEAREST,
  870. fill=fill,
  871. center=center,
  872. )
  873. if needs_squeeze:
  874. output = output.squeeze(0)
  875. return output
  876. @_register_kernel_internal(rotate, tv_tensors.Mask, tv_tensor_wrapper=False)
  877. def _rotate_mask_dispatch(
  878. inpt: tv_tensors.Mask,
  879. angle: float,
  880. expand: bool = False,
  881. center: Optional[List[float]] = None,
  882. fill: _FillTypeJIT = None,
  883. **kwargs,
  884. ) -> tv_tensors.Mask:
  885. output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center)
  886. return tv_tensors.wrap(output, like=inpt)
  887. @_register_kernel_internal(rotate, tv_tensors.Video)
  888. def rotate_video(
  889. video: torch.Tensor,
  890. angle: float,
  891. interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
  892. expand: bool = False,
  893. center: Optional[List[float]] = None,
  894. fill: _FillTypeJIT = None,
  895. ) -> torch.Tensor:
  896. return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
  897. def pad(
  898. inpt: torch.Tensor,
  899. padding: List[int],
  900. fill: Optional[Union[int, float, List[float]]] = None,
  901. padding_mode: str = "constant",
  902. ) -> torch.Tensor:
  903. """See :class:`~torchvision.transforms.v2.Pad` for details."""
  904. if torch.jit.is_scripting():
  905. return pad_image(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
  906. _log_api_usage_once(pad)
  907. kernel = _get_kernel(pad, type(inpt))
  908. return kernel(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
  909. def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
  910. if isinstance(padding, int):
  911. pad_left = pad_right = pad_top = pad_bottom = padding
  912. elif isinstance(padding, (tuple, list)):
  913. if len(padding) == 1:
  914. pad_left = pad_right = pad_top = pad_bottom = padding[0]
  915. elif len(padding) == 2:
  916. pad_left = pad_right = padding[0]
  917. pad_top = pad_bottom = padding[1]
  918. elif len(padding) == 4:
  919. pad_left = padding[0]
  920. pad_top = padding[1]
  921. pad_right = padding[2]
  922. pad_bottom = padding[3]
  923. else:
  924. raise ValueError(
  925. f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
  926. )
  927. else:
  928. raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")
  929. return [pad_left, pad_right, pad_top, pad_bottom]
  930. @_register_kernel_internal(pad, torch.Tensor)
  931. @_register_kernel_internal(pad, tv_tensors.Image)
  932. def pad_image(
  933. image: torch.Tensor,
  934. padding: List[int],
  935. fill: Optional[Union[int, float, List[float]]] = None,
  936. padding_mode: str = "constant",
  937. ) -> torch.Tensor:
  938. # Be aware that while `padding` has order `[left, top, right, bottom]`, `torch_padding` uses
  939. # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
  940. # internally.
  941. torch_padding = _parse_pad_padding(padding)
  942. if padding_mode not in ("constant", "edge", "reflect", "symmetric"):
  943. raise ValueError(
  944. f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
  945. f"but got `'{padding_mode}'`."
  946. )
  947. if fill is None:
  948. fill = 0
  949. if isinstance(fill, (int, float)):
  950. return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
  951. elif len(fill) == 1:
  952. return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
  953. else:
  954. return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
  955. def _pad_with_scalar_fill(
  956. image: torch.Tensor,
  957. torch_padding: List[int],
  958. fill: Union[int, float],
  959. padding_mode: str,
  960. ) -> torch.Tensor:
  961. shape = image.shape
  962. num_channels, height, width = shape[-3:]
  963. batch_size = 1
  964. for s in shape[:-3]:
  965. batch_size *= s
  966. image = image.reshape(batch_size, num_channels, height, width)
  967. if padding_mode == "edge":
  968. # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
  969. # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
  970. # name.
  971. padding_mode = "replicate"
  972. if padding_mode == "constant":
  973. image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
  974. elif padding_mode in ("reflect", "replicate"):
  975. # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
  976. # TODO: See https://github.com/pytorch/pytorch/issues/40763
  977. dtype = image.dtype
  978. if not image.is_floating_point():
  979. needs_cast = True
  980. image = image.to(torch.float32)
  981. else:
  982. needs_cast = False
  983. image = torch_pad(image, torch_padding, mode=padding_mode)
  984. if needs_cast:
  985. image = image.to(dtype)
  986. else: # padding_mode == "symmetric"
  987. image = _pad_symmetric(image, torch_padding)
  988. new_height, new_width = image.shape[-2:]
  989. return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
  990. # TODO: This should be removed once torch_pad supports non-scalar padding values
  991. def _pad_with_vector_fill(
  992. image: torch.Tensor,
  993. torch_padding: List[int],
  994. fill: List[float],
  995. padding_mode: str,
  996. ) -> torch.Tensor:
  997. if padding_mode != "constant":
  998. raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
  999. output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
  1000. left, right, top, bottom = torch_padding
  1001. # We are creating the tensor in the autodetected dtype first and convert to the right one after to avoid an implicit
  1002. # float -> int conversion. That happens for example for the valid input of a uint8 image with floating point fill
  1003. # value.
  1004. fill = torch.tensor(fill, device=image.device).to(dtype=image.dtype).reshape(-1, 1, 1)
  1005. if top > 0:
  1006. output[..., :top, :] = fill
  1007. if left > 0:
  1008. output[..., :, :left] = fill
  1009. if bottom > 0:
  1010. output[..., -bottom:, :] = fill
  1011. if right > 0:
  1012. output[..., :, -right:] = fill
  1013. return output
  1014. _pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad)
  1015. @_register_kernel_internal(pad, tv_tensors.Mask)
  1016. def pad_mask(
  1017. mask: torch.Tensor,
  1018. padding: List[int],
  1019. fill: Optional[Union[int, float, List[float]]] = None,
  1020. padding_mode: str = "constant",
  1021. ) -> torch.Tensor:
  1022. if fill is None:
  1023. fill = 0
  1024. if isinstance(fill, (tuple, list)):
  1025. raise ValueError("Non-scalar fill value is not supported")
  1026. if mask.ndim < 3:
  1027. mask = mask.unsqueeze(0)
  1028. needs_squeeze = True
  1029. else:
  1030. needs_squeeze = False
  1031. output = pad_image(mask, padding=padding, fill=fill, padding_mode=padding_mode)
  1032. if needs_squeeze:
  1033. output = output.squeeze(0)
  1034. return output
  1035. def pad_bounding_boxes(
  1036. bounding_boxes: torch.Tensor,
  1037. format: tv_tensors.BoundingBoxFormat,
  1038. canvas_size: Tuple[int, int],
  1039. padding: List[int],
  1040. padding_mode: str = "constant",
  1041. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  1042. if padding_mode not in ["constant"]:
  1043. # TODO: add support of other padding modes
  1044. raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")
  1045. left, right, top, bottom = _parse_pad_padding(padding)
  1046. if format == tv_tensors.BoundingBoxFormat.XYXY:
  1047. pad = [left, top, left, top]
  1048. else:
  1049. pad = [left, top, 0, 0]
  1050. bounding_boxes = bounding_boxes + torch.tensor(pad, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
  1051. height, width = canvas_size
  1052. height += top + bottom
  1053. width += left + right
  1054. canvas_size = (height, width)
  1055. return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
  1056. @_register_kernel_internal(pad, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  1057. def _pad_bounding_boxes_dispatch(
  1058. inpt: tv_tensors.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs
  1059. ) -> tv_tensors.BoundingBoxes:
  1060. output, canvas_size = pad_bounding_boxes(
  1061. inpt.as_subclass(torch.Tensor),
  1062. format=inpt.format,
  1063. canvas_size=inpt.canvas_size,
  1064. padding=padding,
  1065. padding_mode=padding_mode,
  1066. )
  1067. return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
  1068. @_register_kernel_internal(pad, tv_tensors.Video)
  1069. def pad_video(
  1070. video: torch.Tensor,
  1071. padding: List[int],
  1072. fill: Optional[Union[int, float, List[float]]] = None,
  1073. padding_mode: str = "constant",
  1074. ) -> torch.Tensor:
  1075. return pad_image(video, padding, fill=fill, padding_mode=padding_mode)
  1076. def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
  1077. """See :class:`~torchvision.transforms.v2.RandomCrop` for details."""
  1078. if torch.jit.is_scripting():
  1079. return crop_image(inpt, top=top, left=left, height=height, width=width)
  1080. _log_api_usage_once(crop)
  1081. kernel = _get_kernel(crop, type(inpt))
  1082. return kernel(inpt, top=top, left=left, height=height, width=width)
  1083. @_register_kernel_internal(crop, torch.Tensor)
  1084. @_register_kernel_internal(crop, tv_tensors.Image)
  1085. def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
  1086. h, w = image.shape[-2:]
  1087. right = left + width
  1088. bottom = top + height
  1089. if left < 0 or top < 0 or right > w or bottom > h:
  1090. image = image[..., max(top, 0) : bottom, max(left, 0) : right]
  1091. torch_padding = [
  1092. max(min(right, 0) - left, 0),
  1093. max(right - max(w, left), 0),
  1094. max(min(bottom, 0) - top, 0),
  1095. max(bottom - max(h, top), 0),
  1096. ]
  1097. return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
  1098. return image[..., top:bottom, left:right]
  1099. _crop_image_pil = _FP.crop
  1100. _register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil)
  1101. def crop_bounding_boxes(
  1102. bounding_boxes: torch.Tensor,
  1103. format: tv_tensors.BoundingBoxFormat,
  1104. top: int,
  1105. left: int,
  1106. height: int,
  1107. width: int,
  1108. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  1109. # Crop or implicit pad if left and/or top have negative values:
  1110. if format == tv_tensors.BoundingBoxFormat.XYXY:
  1111. sub = [left, top, left, top]
  1112. else:
  1113. sub = [left, top, 0, 0]
  1114. bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
  1115. canvas_size = (height, width)
  1116. return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
  1117. @_register_kernel_internal(crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  1118. def _crop_bounding_boxes_dispatch(
  1119. inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int
  1120. ) -> tv_tensors.BoundingBoxes:
  1121. output, canvas_size = crop_bounding_boxes(
  1122. inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width
  1123. )
  1124. return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
  1125. @_register_kernel_internal(crop, tv_tensors.Mask)
  1126. def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
  1127. if mask.ndim < 3:
  1128. mask = mask.unsqueeze(0)
  1129. needs_squeeze = True
  1130. else:
  1131. needs_squeeze = False
  1132. output = crop_image(mask, top, left, height, width)
  1133. if needs_squeeze:
  1134. output = output.squeeze(0)
  1135. return output
  1136. @_register_kernel_internal(crop, tv_tensors.Video)
  1137. def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
  1138. return crop_image(video, top, left, height, width)
  1139. def perspective(
  1140. inpt: torch.Tensor,
  1141. startpoints: Optional[List[List[int]]],
  1142. endpoints: Optional[List[List[int]]],
  1143. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1144. fill: _FillTypeJIT = None,
  1145. coefficients: Optional[List[float]] = None,
  1146. ) -> torch.Tensor:
  1147. """See :class:`~torchvision.transforms.v2.RandomPerspective` for details."""
  1148. if torch.jit.is_scripting():
  1149. return perspective_image(
  1150. inpt,
  1151. startpoints=startpoints,
  1152. endpoints=endpoints,
  1153. interpolation=interpolation,
  1154. fill=fill,
  1155. coefficients=coefficients,
  1156. )
  1157. _log_api_usage_once(perspective)
  1158. kernel = _get_kernel(perspective, type(inpt))
  1159. return kernel(
  1160. inpt,
  1161. startpoints=startpoints,
  1162. endpoints=endpoints,
  1163. interpolation=interpolation,
  1164. fill=fill,
  1165. coefficients=coefficients,
  1166. )
  1167. def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
  1168. # https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
  1169. # src/libImaging/Geometry.c#L394
  1170. #
  1171. # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
  1172. # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
  1173. #
  1174. theta1 = torch.tensor(
  1175. [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
  1176. )
  1177. theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
  1178. d = 0.5
  1179. base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
  1180. x_grid = torch.linspace(d, ow + d - 1.0, steps=ow, device=device, dtype=dtype)
  1181. base_grid[..., 0].copy_(x_grid)
  1182. y_grid = torch.linspace(d, oh + d - 1.0, steps=oh, device=device, dtype=dtype).unsqueeze_(-1)
  1183. base_grid[..., 1].copy_(y_grid)
  1184. base_grid[..., 2].fill_(1)
  1185. rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
  1186. shape = (1, oh * ow, 3)
  1187. output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
  1188. output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
  1189. output_grid = output_grid1.div_(output_grid2).sub_(1.0)
  1190. return output_grid.view(1, oh, ow, 2)
  1191. def _perspective_coefficients(
  1192. startpoints: Optional[List[List[int]]],
  1193. endpoints: Optional[List[List[int]]],
  1194. coefficients: Optional[List[float]],
  1195. ) -> List[float]:
  1196. if coefficients is not None:
  1197. if startpoints is not None and endpoints is not None:
  1198. raise ValueError("The startpoints/endpoints and the coefficients shouldn't be defined concurrently.")
  1199. elif len(coefficients) != 8:
  1200. raise ValueError("Argument coefficients should have 8 float values")
  1201. return coefficients
  1202. elif startpoints is not None and endpoints is not None:
  1203. return _get_perspective_coeffs(startpoints, endpoints)
  1204. else:
  1205. raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")
  1206. @_register_kernel_internal(perspective, torch.Tensor)
  1207. @_register_kernel_internal(perspective, tv_tensors.Image)
  1208. def perspective_image(
  1209. image: torch.Tensor,
  1210. startpoints: Optional[List[List[int]]],
  1211. endpoints: Optional[List[List[int]]],
  1212. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1213. fill: _FillTypeJIT = None,
  1214. coefficients: Optional[List[float]] = None,
  1215. ) -> torch.Tensor:
  1216. perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
  1217. interpolation = _check_interpolation(interpolation)
  1218. _assert_grid_transform_inputs(
  1219. image,
  1220. matrix=None,
  1221. interpolation=interpolation.value,
  1222. fill=fill,
  1223. supported_interpolation_modes=["nearest", "bilinear"],
  1224. coeffs=perspective_coeffs,
  1225. )
  1226. oh, ow = image.shape[-2:]
  1227. dtype = image.dtype if torch.is_floating_point(image) else torch.float32
  1228. grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
  1229. return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
  1230. @_register_kernel_internal(perspective, PIL.Image.Image)
  1231. def _perspective_image_pil(
  1232. image: PIL.Image.Image,
  1233. startpoints: Optional[List[List[int]]],
  1234. endpoints: Optional[List[List[int]]],
  1235. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1236. fill: _FillTypeJIT = None,
  1237. coefficients: Optional[List[float]] = None,
  1238. ) -> PIL.Image.Image:
  1239. perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
  1240. interpolation = _check_interpolation(interpolation)
  1241. return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
  1242. def perspective_bounding_boxes(
  1243. bounding_boxes: torch.Tensor,
  1244. format: tv_tensors.BoundingBoxFormat,
  1245. canvas_size: Tuple[int, int],
  1246. startpoints: Optional[List[List[int]]],
  1247. endpoints: Optional[List[List[int]]],
  1248. coefficients: Optional[List[float]] = None,
  1249. ) -> torch.Tensor:
  1250. if bounding_boxes.numel() == 0:
  1251. return bounding_boxes
  1252. perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
  1253. original_shape = bounding_boxes.shape
  1254. # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
  1255. bounding_boxes = (
  1256. convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
  1257. ).reshape(-1, 4)
  1258. dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
  1259. device = bounding_boxes.device
  1260. # perspective_coeffs are computed as endpoint -> start point
  1261. # We have to invert perspective_coeffs for bboxes:
  1262. # (x, y) - end point and (x_out, y_out) - start point
  1263. # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
  1264. # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
  1265. # and we would like to get:
  1266. # x = (inv_coeffs[0] * x_out + inv_coeffs[1] * y_out + inv_coeffs[2])
  1267. # / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
  1268. # y = (inv_coeffs[3] * x_out + inv_coeffs[4] * y_out + inv_coeffs[5])
  1269. # / (inv_coeffs[6] * x_out + inv_coeffs[7] * y_out + 1)
  1270. # and compute inv_coeffs in terms of coeffs
  1271. denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3]
  1272. if denom == 0:
  1273. raise RuntimeError(
  1274. f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform bounding boxes. "
  1275. f"Denominator is zero, denom={denom}"
  1276. )
  1277. inv_coeffs = [
  1278. (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
  1279. (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
  1280. (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom,
  1281. (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom,
  1282. (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom,
  1283. (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom,
  1284. (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom,
  1285. (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
  1286. ]
  1287. theta1 = torch.tensor(
  1288. [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
  1289. dtype=dtype,
  1290. device=device,
  1291. )
  1292. theta2 = torch.tensor(
  1293. [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device
  1294. )
  1295. # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
  1296. # Tensor of points has shape (N * 4, 3), where N is the number of bboxes
  1297. # Single point structure is similar to
  1298. # [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
  1299. points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
  1300. points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
  1301. # 2) Now let's transform the points using perspective matrices
  1302. # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
  1303. # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
  1304. numer_points = torch.matmul(points, theta1.T)
  1305. denom_points = torch.matmul(points, theta2.T)
  1306. transformed_points = numer_points.div_(denom_points)
  1307. # 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
  1308. # and compute bounding box from 4 transformed points:
  1309. transformed_points = transformed_points.reshape(-1, 4, 2)
  1310. out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
  1311. out_bboxes = clamp_bounding_boxes(
  1312. torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
  1313. format=tv_tensors.BoundingBoxFormat.XYXY,
  1314. canvas_size=canvas_size,
  1315. )
  1316. # out_bboxes should be of shape [N boxes, 4]
  1317. return convert_bounding_box_format(
  1318. out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
  1319. ).reshape(original_shape)
  1320. @_register_kernel_internal(perspective, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  1321. def _perspective_bounding_boxes_dispatch(
  1322. inpt: tv_tensors.BoundingBoxes,
  1323. startpoints: Optional[List[List[int]]],
  1324. endpoints: Optional[List[List[int]]],
  1325. coefficients: Optional[List[float]] = None,
  1326. **kwargs,
  1327. ) -> tv_tensors.BoundingBoxes:
  1328. output = perspective_bounding_boxes(
  1329. inpt.as_subclass(torch.Tensor),
  1330. format=inpt.format,
  1331. canvas_size=inpt.canvas_size,
  1332. startpoints=startpoints,
  1333. endpoints=endpoints,
  1334. coefficients=coefficients,
  1335. )
  1336. return tv_tensors.wrap(output, like=inpt)
  1337. def perspective_mask(
  1338. mask: torch.Tensor,
  1339. startpoints: Optional[List[List[int]]],
  1340. endpoints: Optional[List[List[int]]],
  1341. fill: _FillTypeJIT = None,
  1342. coefficients: Optional[List[float]] = None,
  1343. ) -> torch.Tensor:
  1344. if mask.ndim < 3:
  1345. mask = mask.unsqueeze(0)
  1346. needs_squeeze = True
  1347. else:
  1348. needs_squeeze = False
  1349. output = perspective_image(
  1350. mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients
  1351. )
  1352. if needs_squeeze:
  1353. output = output.squeeze(0)
  1354. return output
  1355. @_register_kernel_internal(perspective, tv_tensors.Mask, tv_tensor_wrapper=False)
  1356. def _perspective_mask_dispatch(
  1357. inpt: tv_tensors.Mask,
  1358. startpoints: Optional[List[List[int]]],
  1359. endpoints: Optional[List[List[int]]],
  1360. fill: _FillTypeJIT = None,
  1361. coefficients: Optional[List[float]] = None,
  1362. **kwargs,
  1363. ) -> tv_tensors.Mask:
  1364. output = perspective_mask(
  1365. inpt.as_subclass(torch.Tensor),
  1366. startpoints=startpoints,
  1367. endpoints=endpoints,
  1368. fill=fill,
  1369. coefficients=coefficients,
  1370. )
  1371. return tv_tensors.wrap(output, like=inpt)
  1372. @_register_kernel_internal(perspective, tv_tensors.Video)
  1373. def perspective_video(
  1374. video: torch.Tensor,
  1375. startpoints: Optional[List[List[int]]],
  1376. endpoints: Optional[List[List[int]]],
  1377. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1378. fill: _FillTypeJIT = None,
  1379. coefficients: Optional[List[float]] = None,
  1380. ) -> torch.Tensor:
  1381. return perspective_image(
  1382. video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
  1383. )
  1384. def elastic(
  1385. inpt: torch.Tensor,
  1386. displacement: torch.Tensor,
  1387. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1388. fill: _FillTypeJIT = None,
  1389. ) -> torch.Tensor:
  1390. """See :class:`~torchvision.transforms.v2.ElasticTransform` for details."""
  1391. if torch.jit.is_scripting():
  1392. return elastic_image(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
  1393. _log_api_usage_once(elastic)
  1394. kernel = _get_kernel(elastic, type(inpt))
  1395. return kernel(inpt, displacement=displacement, interpolation=interpolation, fill=fill)
  1396. elastic_transform = elastic
  1397. @_register_kernel_internal(elastic, torch.Tensor)
  1398. @_register_kernel_internal(elastic, tv_tensors.Image)
  1399. def elastic_image(
  1400. image: torch.Tensor,
  1401. displacement: torch.Tensor,
  1402. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1403. fill: _FillTypeJIT = None,
  1404. ) -> torch.Tensor:
  1405. if not isinstance(displacement, torch.Tensor):
  1406. raise TypeError("Argument displacement should be a Tensor")
  1407. interpolation = _check_interpolation(interpolation)
  1408. height, width = image.shape[-2:]
  1409. device = image.device
  1410. dtype = image.dtype if torch.is_floating_point(image) else torch.float32
  1411. # Patch: elastic transform should support (cpu,f16) input
  1412. is_cpu_half = device.type == "cpu" and dtype == torch.float16
  1413. if is_cpu_half:
  1414. image = image.to(torch.float32)
  1415. dtype = torch.float32
  1416. # We are aware that if input image dtype is uint8 and displacement is float64 then
  1417. # displacement will be cast to float32 and all computations will be done with float32
  1418. # We can fix this later if needed
  1419. expected_shape = (1, height, width, 2)
  1420. if expected_shape != displacement.shape:
  1421. raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
  1422. grid = _create_identity_grid((height, width), device=device, dtype=dtype).add_(
  1423. displacement.to(dtype=dtype, device=device)
  1424. )
  1425. output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
  1426. if is_cpu_half:
  1427. output = output.to(torch.float16)
  1428. return output
  1429. @_register_kernel_internal(elastic, PIL.Image.Image)
  1430. def _elastic_image_pil(
  1431. image: PIL.Image.Image,
  1432. displacement: torch.Tensor,
  1433. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1434. fill: _FillTypeJIT = None,
  1435. ) -> PIL.Image.Image:
  1436. t_img = pil_to_tensor(image)
  1437. output = elastic_image(t_img, displacement, interpolation=interpolation, fill=fill)
  1438. return to_pil_image(output, mode=image.mode)
  1439. def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
  1440. sy, sx = size
  1441. base_grid = torch.empty(1, sy, sx, 2, device=device, dtype=dtype)
  1442. x_grid = torch.linspace((-sx + 1) / sx, (sx - 1) / sx, sx, device=device, dtype=dtype)
  1443. base_grid[..., 0].copy_(x_grid)
  1444. y_grid = torch.linspace((-sy + 1) / sy, (sy - 1) / sy, sy, device=device, dtype=dtype).unsqueeze_(-1)
  1445. base_grid[..., 1].copy_(y_grid)
  1446. return base_grid
  1447. def elastic_bounding_boxes(
  1448. bounding_boxes: torch.Tensor,
  1449. format: tv_tensors.BoundingBoxFormat,
  1450. canvas_size: Tuple[int, int],
  1451. displacement: torch.Tensor,
  1452. ) -> torch.Tensor:
  1453. expected_shape = (1, canvas_size[0], canvas_size[1], 2)
  1454. if not isinstance(displacement, torch.Tensor):
  1455. raise TypeError("Argument displacement should be a Tensor")
  1456. elif displacement.shape != expected_shape:
  1457. raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
  1458. if bounding_boxes.numel() == 0:
  1459. return bounding_boxes
  1460. # TODO: add in docstring about approximation we are doing for grid inversion
  1461. device = bounding_boxes.device
  1462. dtype = bounding_boxes.dtype if torch.is_floating_point(bounding_boxes) else torch.float32
  1463. if displacement.dtype != dtype or displacement.device != device:
  1464. displacement = displacement.to(dtype=dtype, device=device)
  1465. original_shape = bounding_boxes.shape
  1466. # TODO: first cast to float if bbox is int64 before convert_bounding_box_format
  1467. bounding_boxes = (
  1468. convert_bounding_box_format(bounding_boxes, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
  1469. ).reshape(-1, 4)
  1470. id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
  1471. # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
  1472. # This is not an exact inverse of the grid
  1473. inv_grid = id_grid.sub_(displacement)
  1474. # Get points from bboxes
  1475. points = bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2)
  1476. if points.is_floating_point():
  1477. points = points.ceil_()
  1478. index_xy = points.to(dtype=torch.long)
  1479. index_x, index_y = index_xy[:, 0], index_xy[:, 1]
  1480. # Transform points:
  1481. t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
  1482. transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
  1483. transformed_points = transformed_points.reshape(-1, 4, 2)
  1484. out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
  1485. out_bboxes = clamp_bounding_boxes(
  1486. torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
  1487. format=tv_tensors.BoundingBoxFormat.XYXY,
  1488. canvas_size=canvas_size,
  1489. )
  1490. return convert_bounding_box_format(
  1491. out_bboxes, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format, inplace=True
  1492. ).reshape(original_shape)
  1493. @_register_kernel_internal(elastic, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  1494. def _elastic_bounding_boxes_dispatch(
  1495. inpt: tv_tensors.BoundingBoxes, displacement: torch.Tensor, **kwargs
  1496. ) -> tv_tensors.BoundingBoxes:
  1497. output = elastic_bounding_boxes(
  1498. inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement
  1499. )
  1500. return tv_tensors.wrap(output, like=inpt)
  1501. def elastic_mask(
  1502. mask: torch.Tensor,
  1503. displacement: torch.Tensor,
  1504. fill: _FillTypeJIT = None,
  1505. ) -> torch.Tensor:
  1506. if mask.ndim < 3:
  1507. mask = mask.unsqueeze(0)
  1508. needs_squeeze = True
  1509. else:
  1510. needs_squeeze = False
  1511. output = elastic_image(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill)
  1512. if needs_squeeze:
  1513. output = output.squeeze(0)
  1514. return output
  1515. @_register_kernel_internal(elastic, tv_tensors.Mask, tv_tensor_wrapper=False)
  1516. def _elastic_mask_dispatch(
  1517. inpt: tv_tensors.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs
  1518. ) -> tv_tensors.Mask:
  1519. output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill)
  1520. return tv_tensors.wrap(output, like=inpt)
  1521. @_register_kernel_internal(elastic, tv_tensors.Video)
  1522. def elastic_video(
  1523. video: torch.Tensor,
  1524. displacement: torch.Tensor,
  1525. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1526. fill: _FillTypeJIT = None,
  1527. ) -> torch.Tensor:
  1528. return elastic_image(video, displacement, interpolation=interpolation, fill=fill)
  1529. def center_crop(inpt: torch.Tensor, output_size: List[int]) -> torch.Tensor:
  1530. """See :class:`~torchvision.transforms.v2.RandomCrop` for details."""
  1531. if torch.jit.is_scripting():
  1532. return center_crop_image(inpt, output_size=output_size)
  1533. _log_api_usage_once(center_crop)
  1534. kernel = _get_kernel(center_crop, type(inpt))
  1535. return kernel(inpt, output_size=output_size)
  1536. def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
  1537. if isinstance(output_size, numbers.Number):
  1538. s = int(output_size)
  1539. return [s, s]
  1540. elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
  1541. return [output_size[0], output_size[0]]
  1542. else:
  1543. return list(output_size)
  1544. def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
  1545. return [
  1546. (crop_width - image_width) // 2 if crop_width > image_width else 0,
  1547. (crop_height - image_height) // 2 if crop_height > image_height else 0,
  1548. (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
  1549. (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
  1550. ]
  1551. def _center_crop_compute_crop_anchor(
  1552. crop_height: int, crop_width: int, image_height: int, image_width: int
  1553. ) -> Tuple[int, int]:
  1554. crop_top = int(round((image_height - crop_height) / 2.0))
  1555. crop_left = int(round((image_width - crop_width) / 2.0))
  1556. return crop_top, crop_left
  1557. @_register_kernel_internal(center_crop, torch.Tensor)
  1558. @_register_kernel_internal(center_crop, tv_tensors.Image)
  1559. def center_crop_image(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
  1560. crop_height, crop_width = _center_crop_parse_output_size(output_size)
  1561. shape = image.shape
  1562. if image.numel() == 0:
  1563. return image.reshape(shape[:-2] + (crop_height, crop_width))
  1564. image_height, image_width = shape[-2:]
  1565. if crop_height > image_height or crop_width > image_width:
  1566. padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
  1567. image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
  1568. image_height, image_width = image.shape[-2:]
  1569. if crop_width == image_width and crop_height == image_height:
  1570. return image
  1571. crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
  1572. return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
  1573. @_register_kernel_internal(center_crop, PIL.Image.Image)
  1574. def _center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
  1575. crop_height, crop_width = _center_crop_parse_output_size(output_size)
  1576. image_height, image_width = _get_size_image_pil(image)
  1577. if crop_height > image_height or crop_width > image_width:
  1578. padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
  1579. image = _pad_image_pil(image, padding_ltrb, fill=0)
  1580. image_height, image_width = _get_size_image_pil(image)
  1581. if crop_width == image_width and crop_height == image_height:
  1582. return image
  1583. crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
  1584. return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
  1585. def center_crop_bounding_boxes(
  1586. bounding_boxes: torch.Tensor,
  1587. format: tv_tensors.BoundingBoxFormat,
  1588. canvas_size: Tuple[int, int],
  1589. output_size: List[int],
  1590. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  1591. crop_height, crop_width = _center_crop_parse_output_size(output_size)
  1592. crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
  1593. return crop_bounding_boxes(
  1594. bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width
  1595. )
  1596. @_register_kernel_internal(center_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  1597. def _center_crop_bounding_boxes_dispatch(
  1598. inpt: tv_tensors.BoundingBoxes, output_size: List[int]
  1599. ) -> tv_tensors.BoundingBoxes:
  1600. output, canvas_size = center_crop_bounding_boxes(
  1601. inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
  1602. )
  1603. return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
  1604. @_register_kernel_internal(center_crop, tv_tensors.Mask)
  1605. def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
  1606. if mask.ndim < 3:
  1607. mask = mask.unsqueeze(0)
  1608. needs_squeeze = True
  1609. else:
  1610. needs_squeeze = False
  1611. output = center_crop_image(image=mask, output_size=output_size)
  1612. if needs_squeeze:
  1613. output = output.squeeze(0)
  1614. return output
  1615. @_register_kernel_internal(center_crop, tv_tensors.Video)
  1616. def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
  1617. return center_crop_image(video, output_size)
  1618. def resized_crop(
  1619. inpt: torch.Tensor,
  1620. top: int,
  1621. left: int,
  1622. height: int,
  1623. width: int,
  1624. size: List[int],
  1625. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1626. antialias: Optional[bool] = True,
  1627. ) -> torch.Tensor:
  1628. """See :class:`~torchvision.transforms.v2.RandomResizedCrop` for details."""
  1629. if torch.jit.is_scripting():
  1630. return resized_crop_image(
  1631. inpt,
  1632. top=top,
  1633. left=left,
  1634. height=height,
  1635. width=width,
  1636. size=size,
  1637. interpolation=interpolation,
  1638. antialias=antialias,
  1639. )
  1640. _log_api_usage_once(resized_crop)
  1641. kernel = _get_kernel(resized_crop, type(inpt))
  1642. return kernel(
  1643. inpt,
  1644. top=top,
  1645. left=left,
  1646. height=height,
  1647. width=width,
  1648. size=size,
  1649. interpolation=interpolation,
  1650. antialias=antialias,
  1651. )
  1652. @_register_kernel_internal(resized_crop, torch.Tensor)
  1653. @_register_kernel_internal(resized_crop, tv_tensors.Image)
  1654. def resized_crop_image(
  1655. image: torch.Tensor,
  1656. top: int,
  1657. left: int,
  1658. height: int,
  1659. width: int,
  1660. size: List[int],
  1661. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1662. antialias: Optional[bool] = True,
  1663. ) -> torch.Tensor:
  1664. image = crop_image(image, top, left, height, width)
  1665. return resize_image(image, size, interpolation=interpolation, antialias=antialias)
  1666. def _resized_crop_image_pil(
  1667. image: PIL.Image.Image,
  1668. top: int,
  1669. left: int,
  1670. height: int,
  1671. width: int,
  1672. size: List[int],
  1673. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1674. ) -> PIL.Image.Image:
  1675. image = _crop_image_pil(image, top, left, height, width)
  1676. return _resize_image_pil(image, size, interpolation=interpolation)
  1677. @_register_kernel_internal(resized_crop, PIL.Image.Image)
  1678. def _resized_crop_image_pil_dispatch(
  1679. image: PIL.Image.Image,
  1680. top: int,
  1681. left: int,
  1682. height: int,
  1683. width: int,
  1684. size: List[int],
  1685. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1686. antialias: Optional[bool] = True,
  1687. ) -> PIL.Image.Image:
  1688. if antialias is False:
  1689. warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
  1690. return _resized_crop_image_pil(
  1691. image,
  1692. top=top,
  1693. left=left,
  1694. height=height,
  1695. width=width,
  1696. size=size,
  1697. interpolation=interpolation,
  1698. )
  1699. def resized_crop_bounding_boxes(
  1700. bounding_boxes: torch.Tensor,
  1701. format: tv_tensors.BoundingBoxFormat,
  1702. top: int,
  1703. left: int,
  1704. height: int,
  1705. width: int,
  1706. size: List[int],
  1707. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  1708. bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width)
  1709. return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size)
  1710. @_register_kernel_internal(resized_crop, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  1711. def _resized_crop_bounding_boxes_dispatch(
  1712. inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs
  1713. ) -> tv_tensors.BoundingBoxes:
  1714. output, canvas_size = resized_crop_bounding_boxes(
  1715. inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
  1716. )
  1717. return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
  1718. def resized_crop_mask(
  1719. mask: torch.Tensor,
  1720. top: int,
  1721. left: int,
  1722. height: int,
  1723. width: int,
  1724. size: List[int],
  1725. ) -> torch.Tensor:
  1726. mask = crop_mask(mask, top, left, height, width)
  1727. return resize_mask(mask, size)
  1728. @_register_kernel_internal(resized_crop, tv_tensors.Mask, tv_tensor_wrapper=False)
  1729. def _resized_crop_mask_dispatch(
  1730. inpt: tv_tensors.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs
  1731. ) -> tv_tensors.Mask:
  1732. output = resized_crop_mask(
  1733. inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size
  1734. )
  1735. return tv_tensors.wrap(output, like=inpt)
  1736. @_register_kernel_internal(resized_crop, tv_tensors.Video)
  1737. def resized_crop_video(
  1738. video: torch.Tensor,
  1739. top: int,
  1740. left: int,
  1741. height: int,
  1742. width: int,
  1743. size: List[int],
  1744. interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
  1745. antialias: Optional[bool] = True,
  1746. ) -> torch.Tensor:
  1747. return resized_crop_image(
  1748. video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
  1749. )
  1750. def five_crop(
  1751. inpt: torch.Tensor, size: List[int]
  1752. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  1753. """See :class:`~torchvision.transforms.v2.FiveCrop` for details."""
  1754. if torch.jit.is_scripting():
  1755. return five_crop_image(inpt, size=size)
  1756. _log_api_usage_once(five_crop)
  1757. kernel = _get_kernel(five_crop, type(inpt))
  1758. return kernel(inpt, size=size)
  1759. def _parse_five_crop_size(size: List[int]) -> List[int]:
  1760. if isinstance(size, numbers.Number):
  1761. s = int(size)
  1762. size = [s, s]
  1763. elif isinstance(size, (tuple, list)) and len(size) == 1:
  1764. s = size[0]
  1765. size = [s, s]
  1766. if len(size) != 2:
  1767. raise ValueError("Please provide only two dimensions (h, w) for size.")
  1768. return size
  1769. @_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor)
  1770. @_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Image)
  1771. def five_crop_image(
  1772. image: torch.Tensor, size: List[int]
  1773. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  1774. crop_height, crop_width = _parse_five_crop_size(size)
  1775. image_height, image_width = image.shape[-2:]
  1776. if crop_width > image_width or crop_height > image_height:
  1777. raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
  1778. tl = crop_image(image, 0, 0, crop_height, crop_width)
  1779. tr = crop_image(image, 0, image_width - crop_width, crop_height, crop_width)
  1780. bl = crop_image(image, image_height - crop_height, 0, crop_height, crop_width)
  1781. br = crop_image(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
  1782. center = center_crop_image(image, [crop_height, crop_width])
  1783. return tl, tr, bl, br, center
  1784. @_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image)
  1785. def _five_crop_image_pil(
  1786. image: PIL.Image.Image, size: List[int]
  1787. ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
  1788. crop_height, crop_width = _parse_five_crop_size(size)
  1789. image_height, image_width = _get_size_image_pil(image)
  1790. if crop_width > image_width or crop_height > image_height:
  1791. raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
  1792. tl = _crop_image_pil(image, 0, 0, crop_height, crop_width)
  1793. tr = _crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
  1794. bl = _crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
  1795. br = _crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
  1796. center = _center_crop_image_pil(image, [crop_height, crop_width])
  1797. return tl, tr, bl, br, center
  1798. @_register_five_ten_crop_kernel_internal(five_crop, tv_tensors.Video)
  1799. def five_crop_video(
  1800. video: torch.Tensor, size: List[int]
  1801. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  1802. return five_crop_image(video, size)
  1803. def ten_crop(
  1804. inpt: torch.Tensor, size: List[int], vertical_flip: bool = False
  1805. ) -> Tuple[
  1806. torch.Tensor,
  1807. torch.Tensor,
  1808. torch.Tensor,
  1809. torch.Tensor,
  1810. torch.Tensor,
  1811. torch.Tensor,
  1812. torch.Tensor,
  1813. torch.Tensor,
  1814. torch.Tensor,
  1815. torch.Tensor,
  1816. ]:
  1817. """See :class:`~torchvision.transforms.v2.TenCrop` for details."""
  1818. if torch.jit.is_scripting():
  1819. return ten_crop_image(inpt, size=size, vertical_flip=vertical_flip)
  1820. _log_api_usage_once(ten_crop)
  1821. kernel = _get_kernel(ten_crop, type(inpt))
  1822. return kernel(inpt, size=size, vertical_flip=vertical_flip)
  1823. @_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor)
  1824. @_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Image)
  1825. def ten_crop_image(
  1826. image: torch.Tensor, size: List[int], vertical_flip: bool = False
  1827. ) -> Tuple[
  1828. torch.Tensor,
  1829. torch.Tensor,
  1830. torch.Tensor,
  1831. torch.Tensor,
  1832. torch.Tensor,
  1833. torch.Tensor,
  1834. torch.Tensor,
  1835. torch.Tensor,
  1836. torch.Tensor,
  1837. torch.Tensor,
  1838. ]:
  1839. non_flipped = five_crop_image(image, size)
  1840. if vertical_flip:
  1841. image = vertical_flip_image(image)
  1842. else:
  1843. image = horizontal_flip_image(image)
  1844. flipped = five_crop_image(image, size)
  1845. return non_flipped + flipped
  1846. @_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image)
  1847. def _ten_crop_image_pil(
  1848. image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
  1849. ) -> Tuple[
  1850. PIL.Image.Image,
  1851. PIL.Image.Image,
  1852. PIL.Image.Image,
  1853. PIL.Image.Image,
  1854. PIL.Image.Image,
  1855. PIL.Image.Image,
  1856. PIL.Image.Image,
  1857. PIL.Image.Image,
  1858. PIL.Image.Image,
  1859. PIL.Image.Image,
  1860. ]:
  1861. non_flipped = _five_crop_image_pil(image, size)
  1862. if vertical_flip:
  1863. image = _vertical_flip_image_pil(image)
  1864. else:
  1865. image = _horizontal_flip_image_pil(image)
  1866. flipped = _five_crop_image_pil(image, size)
  1867. return non_flipped + flipped
  1868. @_register_five_ten_crop_kernel_internal(ten_crop, tv_tensors.Video)
  1869. def ten_crop_video(
  1870. video: torch.Tensor, size: List[int], vertical_flip: bool = False
  1871. ) -> Tuple[
  1872. torch.Tensor,
  1873. torch.Tensor,
  1874. torch.Tensor,
  1875. torch.Tensor,
  1876. torch.Tensor,
  1877. torch.Tensor,
  1878. torch.Tensor,
  1879. torch.Tensor,
  1880. torch.Tensor,
  1881. torch.Tensor,
  1882. ]:
  1883. return ten_crop_image(video, size, vertical_flip=vertical_flip)