__init__.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435
  1. import torch
  2. from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat
  3. from ._image import Image
  4. from ._mask import Mask
  5. from ._torch_function_helpers import set_return_type
  6. from ._tv_tensor import TVTensor
  7. from ._video import Video
  8. # TODO: Fix this. We skip this method as it leads to
  9. # RecursionError: maximum recursion depth exceeded while calling a Python object
  10. # Until `disable` is removed, there will be graph breaks after all calls to functional transforms
  11. @torch.compiler.disable
  12. def wrap(wrappee, *, like, **kwargs):
  13. """Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``.
  14. If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of
  15. ``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``.
  16. Args:
  17. wrappee (Tensor): The tensor to convert.
  18. like (:class:`~torchvision.tv_tensors.TVTensor`): The reference.
  19. ``wrappee`` will be converted into the same subclass as ``like``.
  20. kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.tv_tensor.BoundingBoxes`.
  21. Ignored otherwise.
  22. """
  23. if isinstance(like, BoundingBoxes):
  24. return BoundingBoxes._wrap(
  25. wrappee,
  26. format=kwargs.get("format", like.format),
  27. canvas_size=kwargs.get("canvas_size", like.canvas_size),
  28. )
  29. else:
  30. return wrappee.as_subclass(type(like))