_functional_pil.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import numbers
  2. from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
  3. import numpy as np
  4. import torch
  5. from PIL import Image, ImageEnhance, ImageOps
  6. try:
  7. import accimage
  8. except ImportError:
  9. accimage = None
  10. @torch.jit.unused
  11. def _is_pil_image(img: Any) -> bool:
  12. if accimage is not None:
  13. return isinstance(img, (Image.Image, accimage.Image))
  14. else:
  15. return isinstance(img, Image.Image)
  16. @torch.jit.unused
  17. def get_dimensions(img: Any) -> List[int]:
  18. if _is_pil_image(img):
  19. if hasattr(img, "getbands"):
  20. channels = len(img.getbands())
  21. else:
  22. channels = img.channels
  23. width, height = img.size
  24. return [channels, height, width]
  25. raise TypeError(f"Unexpected type {type(img)}")
  26. @torch.jit.unused
  27. def get_image_size(img: Any) -> List[int]:
  28. if _is_pil_image(img):
  29. return list(img.size)
  30. raise TypeError(f"Unexpected type {type(img)}")
  31. @torch.jit.unused
  32. def get_image_num_channels(img: Any) -> int:
  33. if _is_pil_image(img):
  34. if hasattr(img, "getbands"):
  35. return len(img.getbands())
  36. else:
  37. return img.channels
  38. raise TypeError(f"Unexpected type {type(img)}")
  39. @torch.jit.unused
  40. def hflip(img: Image.Image) -> Image.Image:
  41. if not _is_pil_image(img):
  42. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  43. return img.transpose(Image.FLIP_LEFT_RIGHT)
  44. @torch.jit.unused
  45. def vflip(img: Image.Image) -> Image.Image:
  46. if not _is_pil_image(img):
  47. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  48. return img.transpose(Image.FLIP_TOP_BOTTOM)
  49. @torch.jit.unused
  50. def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
  51. if not _is_pil_image(img):
  52. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  53. enhancer = ImageEnhance.Brightness(img)
  54. img = enhancer.enhance(brightness_factor)
  55. return img
  56. @torch.jit.unused
  57. def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
  58. if not _is_pil_image(img):
  59. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  60. enhancer = ImageEnhance.Contrast(img)
  61. img = enhancer.enhance(contrast_factor)
  62. return img
  63. @torch.jit.unused
  64. def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
  65. if not _is_pil_image(img):
  66. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  67. enhancer = ImageEnhance.Color(img)
  68. img = enhancer.enhance(saturation_factor)
  69. return img
  70. @torch.jit.unused
  71. def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
  72. if not (-0.5 <= hue_factor <= 0.5):
  73. raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
  74. if not _is_pil_image(img):
  75. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  76. input_mode = img.mode
  77. if input_mode in {"L", "1", "I", "F"}:
  78. return img
  79. h, s, v = img.convert("HSV").split()
  80. np_h = np.array(h, dtype=np.uint8)
  81. # # uint8 addition take cares of rotation across boundaries
  82. # with np.errstate(over="ignore"):
  83. # np_h += np.uint8(hue_factor * 255)
  84. # h = Image.fromarray(np_h, "L")
  85. # 使用 int16 防止溢出,然后转换回 uint8
  86. with np.errstate(over="ignore"):
  87. np_h = (np_h.astype(np.int16) + int(hue_factor * 255)) % 256
  88. h = Image.fromarray(np_h.astype(np.uint8), "L")
  89. img = Image.merge("HSV", (h, s, v)).convert(input_mode)
  90. return img
  91. @torch.jit.unused
  92. def adjust_gamma(
  93. img: Image.Image,
  94. gamma: float,
  95. gain: float = 1.0,
  96. ) -> Image.Image:
  97. if not _is_pil_image(img):
  98. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  99. if gamma < 0:
  100. raise ValueError("Gamma should be a non-negative real number")
  101. input_mode = img.mode
  102. img = img.convert("RGB")
  103. gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma)) for ele in range(256)] * 3
  104. img = img.point(gamma_map) # use PIL's point-function to accelerate this part
  105. img = img.convert(input_mode)
  106. return img
  107. @torch.jit.unused
  108. def pad(
  109. img: Image.Image,
  110. padding: Union[int, List[int], Tuple[int, ...]],
  111. fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
  112. padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
  113. ) -> Image.Image:
  114. if not _is_pil_image(img):
  115. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  116. if not isinstance(padding, (numbers.Number, tuple, list)):
  117. raise TypeError("Got inappropriate padding arg")
  118. if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
  119. raise TypeError("Got inappropriate fill arg")
  120. if not isinstance(padding_mode, str):
  121. raise TypeError("Got inappropriate padding_mode arg")
  122. if isinstance(padding, list):
  123. padding = tuple(padding)
  124. if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
  125. raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
  126. if isinstance(padding, tuple) and len(padding) == 1:
  127. # Compatibility with `functional_tensor.pad`
  128. padding = padding[0]
  129. if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
  130. raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
  131. if padding_mode == "constant":
  132. opts = _parse_fill(fill, img, name="fill")
  133. if img.mode == "P":
  134. palette = img.getpalette()
  135. image = ImageOps.expand(img, border=padding, **opts)
  136. image.putpalette(palette)
  137. return image
  138. return ImageOps.expand(img, border=padding, **opts)
  139. else:
  140. if isinstance(padding, int):
  141. pad_left = pad_right = pad_top = pad_bottom = padding
  142. if isinstance(padding, tuple) and len(padding) == 2:
  143. pad_left = pad_right = padding[0]
  144. pad_top = pad_bottom = padding[1]
  145. if isinstance(padding, tuple) and len(padding) == 4:
  146. pad_left = padding[0]
  147. pad_top = padding[1]
  148. pad_right = padding[2]
  149. pad_bottom = padding[3]
  150. p = [pad_left, pad_top, pad_right, pad_bottom]
  151. cropping = -np.minimum(p, 0)
  152. if cropping.any():
  153. crop_left, crop_top, crop_right, crop_bottom = cropping
  154. img = img.crop((crop_left, crop_top, img.width - crop_right, img.height - crop_bottom))
  155. pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
  156. if img.mode == "P":
  157. palette = img.getpalette()
  158. img = np.asarray(img)
  159. img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
  160. img = Image.fromarray(img)
  161. img.putpalette(palette)
  162. return img
  163. img = np.asarray(img)
  164. # RGB image
  165. if len(img.shape) == 3:
  166. img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
  167. # Grayscale image
  168. if len(img.shape) == 2:
  169. img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
  170. return Image.fromarray(img)
  171. @torch.jit.unused
  172. def crop(
  173. img: Image.Image,
  174. top: int,
  175. left: int,
  176. height: int,
  177. width: int,
  178. ) -> Image.Image:
  179. if not _is_pil_image(img):
  180. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  181. return img.crop((left, top, left + width, top + height))
  182. @torch.jit.unused
  183. def resize(
  184. img: Image.Image,
  185. size: Union[List[int], int],
  186. interpolation: int = Image.BILINEAR,
  187. ) -> Image.Image:
  188. if not _is_pil_image(img):
  189. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  190. if not (isinstance(size, list) and len(size) == 2):
  191. raise TypeError(f"Got inappropriate size arg: {size}")
  192. return img.resize(tuple(size[::-1]), interpolation)
  193. @torch.jit.unused
  194. def _parse_fill(
  195. fill: Optional[Union[float, List[float], Tuple[float, ...]]],
  196. img: Image.Image,
  197. name: str = "fillcolor",
  198. ) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
  199. # Process fill color for affine transforms
  200. num_channels = get_image_num_channels(img)
  201. if fill is None:
  202. fill = 0
  203. if isinstance(fill, (int, float)) and num_channels > 1:
  204. fill = tuple([fill] * num_channels)
  205. if isinstance(fill, (list, tuple)):
  206. if len(fill) == 1:
  207. fill = fill * num_channels
  208. elif len(fill) != num_channels:
  209. msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
  210. raise ValueError(msg.format(len(fill), num_channels))
  211. fill = tuple(fill) # type: ignore[arg-type]
  212. if img.mode != "F":
  213. if isinstance(fill, (list, tuple)):
  214. fill = tuple(int(x) for x in fill)
  215. else:
  216. fill = int(fill)
  217. return {name: fill}
  218. @torch.jit.unused
  219. def affine(
  220. img: Image.Image,
  221. matrix: List[float],
  222. interpolation: int = Image.NEAREST,
  223. fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
  224. ) -> Image.Image:
  225. if not _is_pil_image(img):
  226. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  227. output_size = img.size
  228. opts = _parse_fill(fill, img)
  229. return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)
  230. @torch.jit.unused
  231. def rotate(
  232. img: Image.Image,
  233. angle: float,
  234. interpolation: int = Image.NEAREST,
  235. expand: bool = False,
  236. center: Optional[Tuple[int, int]] = None,
  237. fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
  238. ) -> Image.Image:
  239. if not _is_pil_image(img):
  240. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  241. opts = _parse_fill(fill, img)
  242. return img.rotate(angle, interpolation, expand, center, **opts)
  243. @torch.jit.unused
  244. def perspective(
  245. img: Image.Image,
  246. perspective_coeffs: List[float],
  247. interpolation: int = Image.BICUBIC,
  248. fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
  249. ) -> Image.Image:
  250. if not _is_pil_image(img):
  251. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  252. opts = _parse_fill(fill, img)
  253. return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
  254. @torch.jit.unused
  255. def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
  256. if not _is_pil_image(img):
  257. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  258. if num_output_channels == 1:
  259. img = img.convert("L")
  260. elif num_output_channels == 3:
  261. img = img.convert("L")
  262. np_img = np.array(img, dtype=np.uint8)
  263. np_img = np.dstack([np_img, np_img, np_img])
  264. img = Image.fromarray(np_img, "RGB")
  265. else:
  266. raise ValueError("num_output_channels should be either 1 or 3")
  267. return img
  268. @torch.jit.unused
  269. def invert(img: Image.Image) -> Image.Image:
  270. if not _is_pil_image(img):
  271. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  272. return ImageOps.invert(img)
  273. @torch.jit.unused
  274. def posterize(img: Image.Image, bits: int) -> Image.Image:
  275. if not _is_pil_image(img):
  276. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  277. return ImageOps.posterize(img, bits)
  278. @torch.jit.unused
  279. def solarize(img: Image.Image, threshold: int) -> Image.Image:
  280. if not _is_pil_image(img):
  281. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  282. return ImageOps.solarize(img, threshold)
  283. @torch.jit.unused
  284. def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
  285. if not _is_pil_image(img):
  286. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  287. enhancer = ImageEnhance.Sharpness(img)
  288. img = enhancer.enhance(sharpness_factor)
  289. return img
  290. @torch.jit.unused
  291. def autocontrast(img: Image.Image) -> Image.Image:
  292. if not _is_pil_image(img):
  293. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  294. return ImageOps.autocontrast(img)
  295. @torch.jit.unused
  296. def equalize(img: Image.Image) -> Image.Image:
  297. if not _is_pil_image(img):
  298. raise TypeError(f"img should be PIL Image. Got {type(img)}")
  299. return ImageOps.equalize(img)