123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- from __future__ import annotations
- import enum
- from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
- import PIL.Image
- import torch
- from torch import nn
- from torch.utils._pytree import tree_flatten, tree_unflatten
- from torchvision import tv_tensors
- from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
- from torchvision.utils import _log_api_usage_once
- from .functional._utils import _get_kernel
- class Transform(nn.Module):
-
-
- _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
- def __init__(self) -> None:
- super().__init__()
- _log_api_usage_once(self)
- def _check_inputs(self, flat_inputs: List[Any]) -> None:
- pass
- def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
- return dict()
- def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
- kernel = _get_kernel(functional, type(inpt), allow_passthrough=True)
- return kernel(inpt, *args, **kwargs)
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
- raise NotImplementedError
- def forward(self, *inputs: Any) -> Any:
- flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
- self._check_inputs(flat_inputs)
- needs_transform_list = self._needs_transform_list(flat_inputs)
- params = self._get_params(
- [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
- )
- flat_outputs = [
- self._transform(inpt, params) if needs_transform else inpt
- for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
- ]
- return tree_unflatten(flat_outputs, spec)
- def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- needs_transform_list = []
- transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)
- for inpt in flat_inputs:
- needs_transform = True
- if not check_type(inpt, self._transformed_types):
- needs_transform = False
- elif is_pure_tensor(inpt):
- if transform_pure_tensor:
- transform_pure_tensor = False
- else:
- needs_transform = False
- needs_transform_list.append(needs_transform)
- return needs_transform_list
- def extra_repr(self) -> str:
- extra = []
- for name, value in self.__dict__.items():
- if name.startswith("_") or name == "training":
- continue
- if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)):
- continue
- extra.append(f"{name}={value}")
- return ", ".join(extra)
-
-
-
-
-
- _v1_transform_cls: Optional[Type[nn.Module]] = None
- def __init_subclass__(cls) -> None:
-
-
- if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
- cls.get_params = staticmethod(cls._v1_transform_cls.get_params)
- def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
-
-
-
-
-
- common_attrs = nn.Module().__dict__.keys()
- return {
- attr: value
- for attr, value in self.__dict__.items()
- if not attr.startswith("_") and attr not in common_attrs
- }
- def __prepare_scriptable__(self) -> nn.Module:
-
-
-
-
-
- if self._v1_transform_cls is None:
- raise RuntimeError(
- f"Transform {type(self).__name__} cannot be JIT scripted. "
- "torchscript is only supported for backward compatibility with transforms "
- "which are already in torchvision.transforms. "
- "For torchscript support (on tensors only), you can use the functional API instead."
- )
- return self._v1_transform_cls(**self._extract_params_for_v1_transform())
- class _RandomApplyTransform(Transform):
- def __init__(self, p: float = 0.5) -> None:
- if not (0.0 <= p <= 1.0):
- raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
- super().__init__()
- self.p = p
- def forward(self, *inputs: Any) -> Any:
-
-
-
- inputs = inputs if len(inputs) > 1 else inputs[0]
- flat_inputs, spec = tree_flatten(inputs)
- self._check_inputs(flat_inputs)
- if torch.rand(1) >= self.p:
- return inputs
- needs_transform_list = self._needs_transform_list(flat_inputs)
- params = self._get_params(
- [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
- )
- flat_outputs = [
- self._transform(inpt, params) if needs_transform else inpt
- for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
- ]
- return tree_unflatten(flat_outputs, spec)
|