_container.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from typing import Any, Callable, Dict, List, Optional, Sequence, Union
  2. import torch
  3. from torch import nn
  4. from torchvision import transforms as _transforms
  5. from torchvision.transforms.v2 import Transform
  6. class Compose(Transform):
  7. """Composes several transforms together.
  8. This transform does not support torchscript.
  9. Please, see the note below.
  10. Args:
  11. transforms (list of ``Transform`` objects): list of transforms to compose.
  12. Example:
  13. >>> transforms.Compose([
  14. >>> transforms.CenterCrop(10),
  15. >>> transforms.PILToTensor(),
  16. >>> transforms.ConvertImageDtype(torch.float),
  17. >>> ])
  18. .. note::
  19. In order to script the transformations, please use ``torch.nn.Sequential`` as below.
  20. >>> transforms = torch.nn.Sequential(
  21. >>> transforms.CenterCrop(10),
  22. >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  23. >>> )
  24. >>> scripted_transforms = torch.jit.script(transforms)
  25. Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
  26. `lambda` functions or ``PIL.Image``.
  27. """
  28. def __init__(self, transforms: Sequence[Callable]) -> None:
  29. super().__init__()
  30. if not isinstance(transforms, Sequence):
  31. raise TypeError("Argument transforms should be a sequence of callables")
  32. elif not transforms:
  33. raise ValueError("Pass at least one transform")
  34. self.transforms = transforms
  35. def forward(self, *inputs: Any) -> Any:
  36. needs_unpacking = len(inputs) > 1
  37. for transform in self.transforms:
  38. outputs = transform(*inputs)
  39. inputs = outputs if needs_unpacking else (outputs,)
  40. return outputs
  41. def extra_repr(self) -> str:
  42. format_string = []
  43. for t in self.transforms:
  44. format_string.append(f" {t}")
  45. return "\n".join(format_string)
  46. class RandomApply(Transform):
  47. """Apply randomly a list of transformations with a given probability.
  48. .. note::
  49. In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
  50. transforms as shown below:
  51. >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
  52. >>> transforms.ColorJitter(),
  53. >>> ]), p=0.3)
  54. >>> scripted_transforms = torch.jit.script(transforms)
  55. Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
  56. `lambda` functions or ``PIL.Image``.
  57. Args:
  58. transforms (sequence or torch.nn.Module): list of transformations
  59. p (float): probability of applying the list of transforms
  60. """
  61. _v1_transform_cls = _transforms.RandomApply
  62. def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
  63. super().__init__()
  64. if not isinstance(transforms, (Sequence, nn.ModuleList)):
  65. raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
  66. self.transforms = transforms
  67. if not (0.0 <= p <= 1.0):
  68. raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
  69. self.p = p
  70. def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
  71. return {"transforms": self.transforms, "p": self.p}
  72. def forward(self, *inputs: Any) -> Any:
  73. needs_unpacking = len(inputs) > 1
  74. if torch.rand(1) >= self.p:
  75. return inputs if needs_unpacking else inputs[0]
  76. for transform in self.transforms:
  77. outputs = transform(*inputs)
  78. inputs = outputs if needs_unpacking else (outputs,)
  79. return outputs
  80. def extra_repr(self) -> str:
  81. format_string = []
  82. for t in self.transforms:
  83. format_string.append(f" {t}")
  84. return "\n".join(format_string)
  85. class RandomChoice(Transform):
  86. """Apply single transformation randomly picked from a list.
  87. This transform does not support torchscript.
  88. Args:
  89. transforms (sequence or torch.nn.Module): list of transformations
  90. p (list of floats or None, optional): probability of each transform being picked.
  91. If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
  92. (default), all transforms have the same probability.
  93. """
  94. def __init__(
  95. self,
  96. transforms: Sequence[Callable],
  97. p: Optional[List[float]] = None,
  98. ) -> None:
  99. if not isinstance(transforms, Sequence):
  100. raise TypeError("Argument transforms should be a sequence of callables")
  101. if p is None:
  102. p = [1] * len(transforms)
  103. elif len(p) != len(transforms):
  104. raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}")
  105. super().__init__()
  106. self.transforms = transforms
  107. total = sum(p)
  108. self.p = [prob / total for prob in p]
  109. def forward(self, *inputs: Any) -> Any:
  110. idx = int(torch.multinomial(torch.tensor(self.p), 1))
  111. transform = self.transforms[idx]
  112. return transform(*inputs)
  113. class RandomOrder(Transform):
  114. """Apply a list of transformations in a random order.
  115. This transform does not support torchscript.
  116. Args:
  117. transforms (sequence or torch.nn.Module): list of transformations
  118. """
  119. def __init__(self, transforms: Sequence[Callable]) -> None:
  120. if not isinstance(transforms, Sequence):
  121. raise TypeError("Argument transforms should be a sequence of callables")
  122. super().__init__()
  123. self.transforms = transforms
  124. def forward(self, *inputs: Any) -> Any:
  125. needs_unpacking = len(inputs) > 1
  126. for idx in torch.randperm(len(self.transforms)):
  127. transform = self.transforms[idx]
  128. outputs = transform(*inputs)
  129. inputs = outputs if needs_unpacking else (outputs,)
  130. return outputs