_tv_tensor.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. from __future__ import annotations
  2. from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
  3. import torch
  4. from torch._C import DisableTorchFunctionSubclass
  5. from torch.types import _device, _dtype, _size
  6. from torchvision.tv_tensors._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass
  7. D = TypeVar("D", bound="TVTensor")
  8. class TVTensor(torch.Tensor):
  9. """Base class for all TVTensors.
  10. You probably don't want to use this class unless you're defining your own
  11. custom TVTensors. See
  12. :ref:`sphx_glr_auto_examples_transforms_plot_custom_tv_tensors.py` for details.
  13. """
  14. @staticmethod
  15. def _to_tensor(
  16. data: Any,
  17. dtype: Optional[torch.dtype] = None,
  18. device: Optional[Union[torch.device, str, int]] = None,
  19. requires_grad: Optional[bool] = None,
  20. ) -> torch.Tensor:
  21. if requires_grad is None:
  22. requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
  23. return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
  24. @classmethod
  25. def _wrap_output(
  26. cls,
  27. output: torch.Tensor,
  28. args: Sequence[Any] = (),
  29. kwargs: Optional[Mapping[str, Any]] = None,
  30. ) -> torch.Tensor:
  31. # Same as torch._tensor._convert
  32. if isinstance(output, torch.Tensor) and not isinstance(output, cls):
  33. output = output.as_subclass(cls)
  34. if isinstance(output, (tuple, list)):
  35. # Also handles things like namedtuples
  36. output = type(output)(cls._wrap_output(part, args, kwargs) for part in output)
  37. return output
  38. @classmethod
  39. def __torch_function__(
  40. cls,
  41. func: Callable[..., torch.Tensor],
  42. types: Tuple[Type[torch.Tensor], ...],
  43. args: Sequence[Any] = (),
  44. kwargs: Optional[Mapping[str, Any]] = None,
  45. ) -> torch.Tensor:
  46. """For general information about how the __torch_function__ protocol works,
  47. see https://pytorch.org/docs/stable/notes/extending.html#extending-torch
  48. TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
  49. ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
  50. ``args`` and ``kwargs`` of the original call.
  51. Why do we override this? Because the base implementation in torch.Tensor would preserve the TVTensor type
  52. of the output. In our case, we want to return pure tensors instead (with a few exceptions). Refer to the
  53. "TVTensors FAQ" gallery example for a rationale of this behaviour (TL;DR: perf + no silver bullet).
  54. Our implementation below is very similar to the base implementation in ``torch.Tensor`` - go check it out.
  55. """
  56. if not all(issubclass(cls, t) for t in types):
  57. return NotImplemented
  58. # Like in the base Tensor.__torch_function__ implementation, it's easier to always use
  59. # DisableTorchFunctionSubclass and then manually re-wrap the output if necessary
  60. with DisableTorchFunctionSubclass():
  61. output = func(*args, **kwargs or dict())
  62. must_return_subclass = _must_return_subclass()
  63. if must_return_subclass or (func in _FORCE_TORCHFUNCTION_SUBCLASS and isinstance(args[0], cls)):
  64. # If you're wondering why we need the `isinstance(args[0], cls)` check, remove it and see what fails
  65. # in test_to_tv_tensor_reference().
  66. # The __torch_function__ protocol will invoke the __torch_function__ method on *all* types involved in
  67. # the computation by walking the MRO upwards. For example,
  68. # `out = a_pure_tensor.to(an_image)` will invoke `Image.__torch_function__` with
  69. # `args = (a_pure_tensor, an_image)` first. Without this guard, `out` would
  70. # be wrapped into an `Image`.
  71. return cls._wrap_output(output, args, kwargs)
  72. if not must_return_subclass and isinstance(output, cls):
  73. # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
  74. # so for those, the output is still a TVTensor. Thus, we need to manually unwrap.
  75. return output.as_subclass(torch.Tensor)
  76. return output
  77. def _make_repr(self, **kwargs: Any) -> str:
  78. # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
  79. # If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.
  80. extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
  81. return f"{super().__repr__()[:-1]}, {extra_repr})"
  82. # Add properties for common attributes like shape, dtype, device, ndim etc
  83. # this way we return the result without passing into __torch_function__
  84. @property
  85. def shape(self) -> _size: # type: ignore[override]
  86. with DisableTorchFunctionSubclass():
  87. return super().shape
  88. @property
  89. def ndim(self) -> int: # type: ignore[override]
  90. with DisableTorchFunctionSubclass():
  91. return super().ndim
  92. @property
  93. def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override]
  94. with DisableTorchFunctionSubclass():
  95. return super().device
  96. @property
  97. def dtype(self) -> _dtype: # type: ignore[override]
  98. with DisableTorchFunctionSubclass():
  99. return super().dtype
  100. def __deepcopy__(self: D, memo: Dict[int, Any]) -> D:
  101. # We need to detach first, since a plain `Tensor.clone` will be part of the computation graph, which does
  102. # *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad`
  103. # attribute is cleared, so we need to refill it before we return.
  104. # Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is
  105. # `BoundingBoxes.format` and `BoundingBoxes.canvas_size`, which are immutable and thus implicitly deep-copied by
  106. # `BoundingBoxes.clone()`.
  107. return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value]