_temporal.py 1.1 KB

123456789101112131415161718192021222324252627
  1. import torch
  2. from torchvision import tv_tensors
  3. from torchvision.utils import _log_api_usage_once
  4. from ._utils import _get_kernel, _register_kernel_internal
  5. def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Tensor:
  6. """See :class:`~torchvision.transforms.v2.UniformTemporalSubsample` for details."""
  7. if torch.jit.is_scripting():
  8. return uniform_temporal_subsample_video(inpt, num_samples=num_samples)
  9. _log_api_usage_once(uniform_temporal_subsample)
  10. kernel = _get_kernel(uniform_temporal_subsample, type(inpt))
  11. return kernel(inpt, num_samples=num_samples)
  12. @_register_kernel_internal(uniform_temporal_subsample, torch.Tensor)
  13. @_register_kernel_internal(uniform_temporal_subsample, tv_tensors.Video)
  14. def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int) -> torch.Tensor:
  15. # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
  16. t_max = video.shape[-4] - 1
  17. indices = torch.linspace(0, t_max, num_samples, device=video.device).long()
  18. return torch.index_select(video, -4, indices)