浏览代码

add dependent libs(torchivision)

RenLiqiang 5 月之前
父节点
当前提交
9aed83bd3c
共有 100 个文件被更改,包括 27791 次插入0 次删除
  1. 二进制
      libs/vision_libs/_C.pyd
  2. 103 0
      libs/vision_libs/__init__.py
  3. 50 0
      libs/vision_libs/_internally_replaced_utils.py
  4. 225 0
      libs/vision_libs/_meta_registrations.py
  5. 32 0
      libs/vision_libs/_utils.py
  6. 146 0
      libs/vision_libs/datasets/__init__.py
  7. 491 0
      libs/vision_libs/datasets/_optical_flow.py
  8. 1224 0
      libs/vision_libs/datasets/_stereo_matching.py
  9. 241 0
      libs/vision_libs/datasets/caltech.py
  10. 193 0
      libs/vision_libs/datasets/celeba.py
  11. 167 0
      libs/vision_libs/datasets/cifar.py
  12. 221 0
      libs/vision_libs/datasets/cityscapes.py
  13. 88 0
      libs/vision_libs/datasets/clevr.py
  14. 104 0
      libs/vision_libs/datasets/coco.py
  15. 58 0
      libs/vision_libs/datasets/country211.py
  16. 100 0
      libs/vision_libs/datasets/dtd.py
  17. 58 0
      libs/vision_libs/datasets/eurosat.py
  18. 67 0
      libs/vision_libs/datasets/fakedata.py
  19. 75 0
      libs/vision_libs/datasets/fer2013.py
  20. 114 0
      libs/vision_libs/datasets/fgvc_aircraft.py
  21. 166 0
      libs/vision_libs/datasets/flickr.py
  22. 114 0
      libs/vision_libs/datasets/flowers102.py
  23. 317 0
      libs/vision_libs/datasets/folder.py
  24. 93 0
      libs/vision_libs/datasets/food101.py
  25. 103 0
      libs/vision_libs/datasets/gtsrb.py
  26. 151 0
      libs/vision_libs/datasets/hmdb51.py
  27. 218 0
      libs/vision_libs/datasets/imagenet.py
  28. 104 0
      libs/vision_libs/datasets/imagenette.py
  29. 241 0
      libs/vision_libs/datasets/inaturalist.py
  30. 247 0
      libs/vision_libs/datasets/kinetics.py
  31. 157 0
      libs/vision_libs/datasets/kitti.py
  32. 255 0
      libs/vision_libs/datasets/lfw.py
  33. 167 0
      libs/vision_libs/datasets/lsun.py
  34. 558 0
      libs/vision_libs/datasets/mnist.py
  35. 93 0
      libs/vision_libs/datasets/moving_mnist.py
  36. 102 0
      libs/vision_libs/datasets/omniglot.py
  37. 125 0
      libs/vision_libs/datasets/oxford_iiit_pet.py
  38. 134 0
      libs/vision_libs/datasets/pcam.py
  39. 228 0
      libs/vision_libs/datasets/phototour.py
  40. 170 0
      libs/vision_libs/datasets/places365.py
  41. 86 0
      libs/vision_libs/datasets/rendered_sst2.py
  42. 3 0
      libs/vision_libs/datasets/samplers/__init__.py
  43. 172 0
      libs/vision_libs/datasets/samplers/clip_sampler.py
  44. 123 0
      libs/vision_libs/datasets/sbd.py
  45. 109 0
      libs/vision_libs/datasets/sbu.py
  46. 91 0
      libs/vision_libs/datasets/semeion.py
  47. 121 0
      libs/vision_libs/datasets/stanford_cars.py
  48. 174 0
      libs/vision_libs/datasets/stl10.py
  49. 76 0
      libs/vision_libs/datasets/sun397.py
  50. 129 0
      libs/vision_libs/datasets/svhn.py
  51. 130 0
      libs/vision_libs/datasets/ucf101.py
  52. 95 0
      libs/vision_libs/datasets/usps.py
  53. 459 0
      libs/vision_libs/datasets/utils.py
  54. 419 0
      libs/vision_libs/datasets/video_utils.py
  55. 110 0
      libs/vision_libs/datasets/vision.py
  56. 224 0
      libs/vision_libs/datasets/voc.py
  57. 195 0
      libs/vision_libs/datasets/widerface.py
  58. 92 0
      libs/vision_libs/extension.py
  59. 二进制
      libs/vision_libs/image.pyd
  60. 69 0
      libs/vision_libs/io/__init__.py
  61. 8 0
      libs/vision_libs/io/_load_gpu_decoder.py
  62. 512 0
      libs/vision_libs/io/_video_opt.py
  63. 264 0
      libs/vision_libs/io/image.py
  64. 415 0
      libs/vision_libs/io/video.py
  65. 286 0
      libs/vision_libs/io/video_reader.py
  66. 23 0
      libs/vision_libs/models/__init__.py
  67. 277 0
      libs/vision_libs/models/_api.py
  68. 1554 0
      libs/vision_libs/models/_meta.py
  69. 256 0
      libs/vision_libs/models/_utils.py
  70. 119 0
      libs/vision_libs/models/alexnet.py
  71. 414 0
      libs/vision_libs/models/convnext.py
  72. 448 0
      libs/vision_libs/models/densenet.py
  73. 7 0
      libs/vision_libs/models/detection/__init__.py
  74. 540 0
      libs/vision_libs/models/detection/_utils.py
  75. 268 0
      libs/vision_libs/models/detection/anchor_utils.py
  76. 244 0
      libs/vision_libs/models/detection/backbone_utils.py
  77. 843 0
      libs/vision_libs/models/detection/faster_rcnn.py
  78. 771 0
      libs/vision_libs/models/detection/fcos.py
  79. 118 0
      libs/vision_libs/models/detection/generalized_rcnn.py
  80. 25 0
      libs/vision_libs/models/detection/image_list.py
  81. 473 0
      libs/vision_libs/models/detection/keypoint_rcnn.py
  82. 587 0
      libs/vision_libs/models/detection/mask_rcnn.py
  83. 899 0
      libs/vision_libs/models/detection/retinanet.py
  84. 876 0
      libs/vision_libs/models/detection/roi_heads.py
  85. 387 0
      libs/vision_libs/models/detection/rpn.py
  86. 682 0
      libs/vision_libs/models/detection/ssd.py
  87. 331 0
      libs/vision_libs/models/detection/ssdlite.py
  88. 319 0
      libs/vision_libs/models/detection/transform.py
  89. 1131 0
      libs/vision_libs/models/efficientnet.py
  90. 563 0
      libs/vision_libs/models/feature_extraction.py
  91. 345 0
      libs/vision_libs/models/googlenet.py
  92. 478 0
      libs/vision_libs/models/inception.py
  93. 832 0
      libs/vision_libs/models/maxvit.py
  94. 434 0
      libs/vision_libs/models/mnasnet.py
  95. 6 0
      libs/vision_libs/models/mobilenet.py
  96. 260 0
      libs/vision_libs/models/mobilenetv2.py
  97. 423 0
      libs/vision_libs/models/mobilenetv3.py
  98. 1 0
      libs/vision_libs/models/optical_flow/__init__.py
  99. 48 0
      libs/vision_libs/models/optical_flow/_utils.py
  100. 947 0
      libs/vision_libs/models/optical_flow/raft.py

二进制
libs/vision_libs/_C.pyd


+ 103 - 0
libs/vision_libs/__init__.py

@@ -0,0 +1,103 @@
+import os
+import warnings
+from modulefinder import Module
+
+import torch
+from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
+
+from .extension import _HAS_OPS
+
+try:
+    from .version import __version__  # noqa: F401
+except ImportError:
+    pass
+
+
+# Check if torchvision is being imported within the root folder
+if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
+    os.path.realpath(os.getcwd()), "torchvision"
+):
+    message = (
+        "You are importing torchvision within its own root folder ({}). "
+        "This is not expected to work and may give errors. Please exit the "
+        "torchvision project source and relaunch your python interpreter."
+    )
+    warnings.warn(message.format(os.getcwd()))
+
+_image_backend = "PIL"
+
+_video_backend = "pyav"
+
+
+def set_image_backend(backend):
+    """
+    Specifies the package used to load images.
+
+    Args:
+        backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
+            The :mod:`accimage` package uses the Intel IPP library. It is
+            generally faster than PIL, but does not support as many operations.
+    """
+    global _image_backend
+    if backend not in ["PIL", "accimage"]:
+        raise ValueError(f"Invalid backend '{backend}'. Options are 'PIL' and 'accimage'")
+    _image_backend = backend
+
+
+def get_image_backend():
+    """
+    Gets the name of the package used to load images
+    """
+    return _image_backend
+
+
+def set_video_backend(backend):
+    """
+    Specifies the package used to decode videos.
+
+    Args:
+        backend (string): Name of the video backend. one of {'pyav', 'video_reader'}.
+            The :mod:`pyav` package uses the 3rd party PyAv library. It is a Pythonic
+            binding for the FFmpeg libraries.
+            The :mod:`video_reader` package includes a native C++ implementation on
+            top of FFMPEG libraries, and a python API of TorchScript custom operator.
+            It generally decodes faster than :mod:`pyav`, but is perhaps less robust.
+
+    .. note::
+        Building with FFMPEG is disabled by default in the latest `main`. If you want to use the 'video_reader'
+        backend, please compile torchvision from source.
+    """
+    global _video_backend
+    if backend not in ["pyav", "video_reader", "cuda"]:
+        raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
+    if backend == "video_reader" and not io._HAS_VIDEO_OPT:
+        # TODO: better messages
+        message = "video_reader video backend is not available. Please compile torchvision from source and try again"
+        raise RuntimeError(message)
+    elif backend == "cuda" and not io._HAS_GPU_VIDEO_DECODER:
+        # TODO: better messages
+        message = "cuda video backend is not available."
+        raise RuntimeError(message)
+    else:
+        _video_backend = backend
+
+
+def get_video_backend():
+    """
+    Returns the currently active video backend used to decode videos.
+
+    Returns:
+        str: Name of the video backend. one of {'pyav', 'video_reader'}.
+    """
+
+    return _video_backend
+
+
+def _is_tracing():
+    return torch._C._get_tracing_state()
+
+
+def disable_beta_transforms_warning():
+    # Noop, only exists to avoid breaking existing code.
+    # See https://github.com/pytorch/vision/issues/7896
+    pass

+ 50 - 0
libs/vision_libs/_internally_replaced_utils.py

@@ -0,0 +1,50 @@
+import importlib.machinery
+import os
+
+from torch.hub import _get_torch_home
+
+
+_HOME = os.path.join(_get_torch_home(), "datasets", "vision")
+_USE_SHARDED_DATASETS = False
+
+
+def _download_file_from_remote_location(fpath: str, url: str) -> None:
+    pass
+
+
+def _is_remote_location_available() -> bool:
+    return False
+
+
+try:
+    from torch.hub import load_state_dict_from_url  # noqa: 401
+except ImportError:
+    from torch.utils.model_zoo import load_url as load_state_dict_from_url  # noqa: 401
+
+
+def _get_extension_path(lib_name):
+
+    lib_dir = os.path.dirname(__file__)
+    if os.name == "nt":
+        # Register the main torchvision library location on the default DLL path
+        import ctypes
+
+        kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
+        with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
+        prev_error_mode = kernel32.SetErrorMode(0x0001)
+
+        if with_load_library_flags:
+            kernel32.AddDllDirectory.restype = ctypes.c_void_p
+
+        os.add_dll_directory(lib_dir)
+
+        kernel32.SetErrorMode(prev_error_mode)
+
+    loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
+
+    extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
+    ext_specs = extfinder.find_spec(lib_name)
+    if ext_specs is None:
+        raise ImportError
+
+    return ext_specs.origin

+ 225 - 0
libs/vision_libs/_meta_registrations.py

@@ -0,0 +1,225 @@
+import functools
+
+import torch
+import torch._custom_ops
+import torch.library
+
+# Ensure that torch.ops.torchvision is visible
+import torchvision.extension  # noqa: F401
+
+
+@functools.lru_cache(None)
+def get_meta_lib():
+    return torch.library.Library("torchvision", "IMPL", "Meta")
+
+
+def register_meta(op_name, overload_name="default"):
+    def wrapper(fn):
+        if torchvision.extension._has_ops():
+            get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
+        return fn
+
+    return wrapper
+
+
+@register_meta("roi_align")
+def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
+    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
+    torch._check(
+        input.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for input to have the same type as tensor for rois; "
+            f"but type {input.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    num_rois = rois.size(0)
+    channels = input.size(1)
+    return input.new_empty((num_rois, channels, pooled_height, pooled_width))
+
+
+@register_meta("_roi_align_backward")
+def meta_roi_align_backward(
+    grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
+):
+    torch._check(
+        grad.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for grad to have the same type as tensor for rois; "
+            f"but type {grad.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    return grad.new_empty((batch_size, channels, height, width))
+
+
+@register_meta("ps_roi_align")
+def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
+    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
+    torch._check(
+        input.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for input to have the same type as tensor for rois; "
+            f"but type {input.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    channels = input.size(1)
+    torch._check(
+        channels % (pooled_height * pooled_width) == 0,
+        "input channels must be a multiple of pooling height * pooling width",
+    )
+
+    num_rois = rois.size(0)
+    out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
+    return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta")
+
+
+@register_meta("_ps_roi_align_backward")
+def meta_ps_roi_align_backward(
+    grad,
+    rois,
+    channel_mapping,
+    spatial_scale,
+    pooled_height,
+    pooled_width,
+    sampling_ratio,
+    batch_size,
+    channels,
+    height,
+    width,
+):
+    torch._check(
+        grad.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for grad to have the same type as tensor for rois; "
+            f"but type {grad.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    return grad.new_empty((batch_size, channels, height, width))
+
+
+@register_meta("roi_pool")
+def meta_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
+    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
+    torch._check(
+        input.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for input to have the same type as tensor for rois; "
+            f"but type {input.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    num_rois = rois.size(0)
+    channels = input.size(1)
+    out_size = (num_rois, channels, pooled_height, pooled_width)
+    return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
+
+
+@register_meta("_roi_pool_backward")
+def meta_roi_pool_backward(
+    grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
+):
+    torch._check(
+        grad.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for grad to have the same type as tensor for rois; "
+            f"but type {grad.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    return grad.new_empty((batch_size, channels, height, width))
+
+
+@register_meta("ps_roi_pool")
+def meta_ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
+    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
+    torch._check(
+        input.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for input to have the same type as tensor for rois; "
+            f"but type {input.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    channels = input.size(1)
+    torch._check(
+        channels % (pooled_height * pooled_width) == 0,
+        "input channels must be a multiple of pooling height * pooling width",
+    )
+    num_rois = rois.size(0)
+    out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
+    return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
+
+
+@register_meta("_ps_roi_pool_backward")
+def meta_ps_roi_pool_backward(
+    grad, rois, channel_mapping, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
+):
+    torch._check(
+        grad.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for grad to have the same type as tensor for rois; "
+            f"but type {grad.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    return grad.new_empty((batch_size, channels, height, width))
+
+
+@torch._custom_ops.impl_abstract("torchvision::nms")
+def meta_nms(dets, scores, iou_threshold):
+    torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
+    torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
+    torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
+    torch._check(
+        dets.size(0) == scores.size(0),
+        lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
+    )
+    ctx = torch._custom_ops.get_ctx()
+    num_to_keep = ctx.create_unbacked_symint()
+    return dets.new_empty(num_to_keep, dtype=torch.long)
+
+
+@register_meta("deform_conv2d")
+def meta_deform_conv2d(
+    input,
+    weight,
+    offset,
+    mask,
+    bias,
+    stride_h,
+    stride_w,
+    pad_h,
+    pad_w,
+    dil_h,
+    dil_w,
+    n_weight_grps,
+    n_offset_grps,
+    use_mask,
+):
+
+    out_height, out_width = offset.shape[-2:]
+    out_channels = weight.shape[0]
+    batch_size = input.shape[0]
+    return input.new_empty((batch_size, out_channels, out_height, out_width))
+
+
+@register_meta("_deform_conv2d_backward")
+def meta_deform_conv2d_backward(
+    grad,
+    input,
+    weight,
+    offset,
+    mask,
+    bias,
+    stride_h,
+    stride_w,
+    pad_h,
+    pad_w,
+    dilation_h,
+    dilation_w,
+    groups,
+    offset_groups,
+    use_mask,
+):
+
+    grad_input = input.new_empty(input.shape)
+    grad_weight = weight.new_empty(weight.shape)
+    grad_offset = offset.new_empty(offset.shape)
+    grad_mask = mask.new_empty(mask.shape)
+    grad_bias = bias.new_empty(bias.shape)
+    return grad_input, grad_weight, grad_offset, grad_mask, grad_bias

+ 32 - 0
libs/vision_libs/_utils.py

@@ -0,0 +1,32 @@
+import enum
+from typing import Sequence, Type, TypeVar
+
+T = TypeVar("T", bound=enum.Enum)
+
+
+class StrEnumMeta(enum.EnumMeta):
+    auto = enum.auto
+
+    def from_str(self: Type[T], member: str) -> T:  # type: ignore[misc]
+        try:
+            return self[member]
+        except KeyError:
+            # TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
+            #  soon as it is migrated.
+            raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
+
+
+class StrEnum(enum.Enum, metaclass=StrEnumMeta):
+    pass
+
+
+def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
+    if not seq:
+        return ""
+    if len(seq) == 1:
+        return f"'{seq[0]}'"
+
+    head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
+    tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
+
+    return head + tail

+ 146 - 0
libs/vision_libs/datasets/__init__.py

@@ -0,0 +1,146 @@
+from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
+from ._stereo_matching import (
+    CarlaStereo,
+    CREStereo,
+    ETH3DStereo,
+    FallingThingsStereo,
+    InStereo2k,
+    Kitti2012Stereo,
+    Kitti2015Stereo,
+    Middlebury2014Stereo,
+    SceneFlowStereo,
+    SintelStereo,
+)
+from .caltech import Caltech101, Caltech256
+from .celeba import CelebA
+from .cifar import CIFAR10, CIFAR100
+from .cityscapes import Cityscapes
+from .clevr import CLEVRClassification
+from .coco import CocoCaptions, CocoDetection
+from .country211 import Country211
+from .dtd import DTD
+from .eurosat import EuroSAT
+from .fakedata import FakeData
+from .fer2013 import FER2013
+from .fgvc_aircraft import FGVCAircraft
+from .flickr import Flickr30k, Flickr8k
+from .flowers102 import Flowers102
+from .folder import DatasetFolder, ImageFolder
+from .food101 import Food101
+from .gtsrb import GTSRB
+from .hmdb51 import HMDB51
+from .imagenet import ImageNet
+from .imagenette import Imagenette
+from .inaturalist import INaturalist
+from .kinetics import Kinetics
+from .kitti import Kitti
+from .lfw import LFWPairs, LFWPeople
+from .lsun import LSUN, LSUNClass
+from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
+from .moving_mnist import MovingMNIST
+from .omniglot import Omniglot
+from .oxford_iiit_pet import OxfordIIITPet
+from .pcam import PCAM
+from .phototour import PhotoTour
+from .places365 import Places365
+from .rendered_sst2 import RenderedSST2
+from .sbd import SBDataset
+from .sbu import SBU
+from .semeion import SEMEION
+from .stanford_cars import StanfordCars
+from .stl10 import STL10
+from .sun397 import SUN397
+from .svhn import SVHN
+from .ucf101 import UCF101
+from .usps import USPS
+from .vision import VisionDataset
+from .voc import VOCDetection, VOCSegmentation
+from .widerface import WIDERFace
+
+__all__ = (
+    "LSUN",
+    "LSUNClass",
+    "ImageFolder",
+    "DatasetFolder",
+    "FakeData",
+    "CocoCaptions",
+    "CocoDetection",
+    "CIFAR10",
+    "CIFAR100",
+    "EMNIST",
+    "FashionMNIST",
+    "QMNIST",
+    "MNIST",
+    "KMNIST",
+    "StanfordCars",
+    "STL10",
+    "SUN397",
+    "SVHN",
+    "PhotoTour",
+    "SEMEION",
+    "Omniglot",
+    "SBU",
+    "Flickr8k",
+    "Flickr30k",
+    "Flowers102",
+    "VOCSegmentation",
+    "VOCDetection",
+    "Cityscapes",
+    "ImageNet",
+    "Caltech101",
+    "Caltech256",
+    "CelebA",
+    "WIDERFace",
+    "SBDataset",
+    "VisionDataset",
+    "USPS",
+    "Kinetics",
+    "HMDB51",
+    "UCF101",
+    "Places365",
+    "Kitti",
+    "INaturalist",
+    "LFWPeople",
+    "LFWPairs",
+    "KittiFlow",
+    "Sintel",
+    "FlyingChairs",
+    "FlyingThings3D",
+    "HD1K",
+    "Food101",
+    "DTD",
+    "FER2013",
+    "GTSRB",
+    "CLEVRClassification",
+    "OxfordIIITPet",
+    "PCAM",
+    "Country211",
+    "FGVCAircraft",
+    "EuroSAT",
+    "RenderedSST2",
+    "Kitti2012Stereo",
+    "Kitti2015Stereo",
+    "CarlaStereo",
+    "Middlebury2014Stereo",
+    "CREStereo",
+    "FallingThingsStereo",
+    "SceneFlowStereo",
+    "SintelStereo",
+    "InStereo2k",
+    "ETH3DStereo",
+    "wrap_dataset_for_transforms_v2",
+    "Imagenette",
+)
+
+
+# We override current module's attributes to handle the import:
+# from torchvision.datasets import wrap_dataset_for_transforms_v2
+# without a cyclic error.
+# Ref: https://peps.python.org/pep-0562/
+def __getattr__(name):
+    if name in ("wrap_dataset_for_transforms_v2",):
+        from torchvision.tv_tensors._dataset_wrapper import wrap_dataset_for_transforms_v2
+
+        return wrap_dataset_for_transforms_v2
+
+    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

+ 491 - 0
libs/vision_libs/datasets/_optical_flow.py

@@ -0,0 +1,491 @@
+import itertools
+import os
+from abc import ABC, abstractmethod
+from glob import glob
+from pathlib import Path
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from ..io.image import _read_png_16
+from .utils import _read_pfm, verify_str_arg
+from .vision import VisionDataset
+
+
+T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
+T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
+
+
+__all__ = (
+    "KittiFlow",
+    "Sintel",
+    "FlyingThings3D",
+    "FlyingChairs",
+    "HD1K",
+)
+
+
+class FlowDataset(ABC, VisionDataset):
+    # Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid
+    # For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),
+    # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
+    _has_builtin_flow_mask = False
+
+    def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
+
+        super().__init__(root=root)
+        self.transforms = transforms
+
+        self._flow_list: List[str] = []
+        self._image_list: List[List[str]] = []
+
+    def _read_img(self, file_name: str) -> Image.Image:
+        img = Image.open(file_name)
+        if img.mode != "RGB":
+            img = img.convert("RGB")
+        return img
+
+    @abstractmethod
+    def _read_flow(self, file_name: str):
+        # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
+        pass
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+
+        img1 = self._read_img(self._image_list[index][0])
+        img2 = self._read_img(self._image_list[index][1])
+
+        if self._flow_list:  # it will be empty for some dataset when split="test"
+            flow = self._read_flow(self._flow_list[index])
+            if self._has_builtin_flow_mask:
+                flow, valid_flow_mask = flow
+            else:
+                valid_flow_mask = None
+        else:
+            flow = valid_flow_mask = None
+
+        if self.transforms is not None:
+            img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask)
+
+        if self._has_builtin_flow_mask or valid_flow_mask is not None:
+            # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
+            return img1, img2, flow, valid_flow_mask
+        else:
+            return img1, img2, flow
+
+    def __len__(self) -> int:
+        return len(self._image_list)
+
+    def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset:
+        return torch.utils.data.ConcatDataset([self] * v)
+
+
+class Sintel(FlowDataset):
+    """`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            Sintel
+                testing
+                    clean
+                        scene_1
+                        scene_2
+                        ...
+                    final
+                        scene_1
+                        scene_2
+                        ...
+                training
+                    clean
+                        scene_1
+                        scene_2
+                        ...
+                    final
+                        scene_1
+                        scene_2
+                        ...
+                    flow
+                        scene_1
+                        scene_2
+                        ...
+
+    Args:
+        root (string): Root directory of the Sintel Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "test"
+        pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
+            details on the different passes.
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+            ``valid_flow_mask`` is expected for consistency with other datasets which
+            return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        pass_name: str = "clean",
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root=root, transforms=transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+        verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+        passes = ["clean", "final"] if pass_name == "both" else [pass_name]
+
+        root = Path(root) / "Sintel"
+        flow_root = root / "training" / "flow"
+
+        for pass_name in passes:
+            split_dir = "training" if split == "train" else split
+            image_root = root / split_dir / pass_name
+            for scene in os.listdir(image_root):
+                image_list = sorted(glob(str(image_root / scene / "*.png")))
+                for i in range(len(image_list) - 1):
+                    self._image_list += [[image_list[i], image_list[i + 1]]]
+
+                if split == "train":
+                    self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img1, img2, flow)``.
+            The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+            ``flow`` is None if ``split="test"``.
+            If a valid flow mask is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+        """
+        return super().__getitem__(index)
+
+    def _read_flow(self, file_name: str) -> np.ndarray:
+        return _read_flo(file_name)
+
+
+class KittiFlow(FlowDataset):
+    """`KITTI <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow>`__ dataset for optical flow (2015).
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            KittiFlow
+                testing
+                    image_2
+                training
+                    image_2
+                    flow_occ
+
+    Args:
+        root (string): Root directory of the KittiFlow Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "test"
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+    """
+
+    _has_builtin_flow_mask = True
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root=root, transforms=transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "KittiFlow" / (split + "ing")
+        images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
+        images2 = sorted(glob(str(root / "image_2" / "*_11.png")))
+
+        if not images1 or not images2:
+            raise FileNotFoundError(
+                "Could not find the Kitti flow images. Please make sure the directory structure is correct."
+            )
+
+        for img1, img2 in zip(images1, images2):
+            self._image_list += [[img1, img2]]
+
+        if split == "train":
+            self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
+            where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W)
+            indicating which flow values are valid. The flow is a numpy array of
+            shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
+            ``split="test"``.
+        """
+        return super().__getitem__(index)
+
+    def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
+        return _read_16bits_png_with_flow_and_valid_mask(file_name)
+
+
+class FlyingChairs(FlowDataset):
+    """`FlyingChairs <https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs>`_ Dataset for optical flow.
+
+    You will also need to download the FlyingChairs_train_val.txt file from the dataset page.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            FlyingChairs
+                data
+                    00001_flow.flo
+                    00001_img1.ppm
+                    00001_img2.ppm
+                    ...
+                FlyingChairs_train_val.txt
+
+
+    Args:
+        root (string): Root directory of the FlyingChairs Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "val"
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+            ``valid_flow_mask`` is expected for consistency with other datasets which
+            return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+    """
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root=root, transforms=transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "val"))
+
+        root = Path(root) / "FlyingChairs"
+        images = sorted(glob(str(root / "data" / "*.ppm")))
+        flows = sorted(glob(str(root / "data" / "*.flo")))
+
+        split_file_name = "FlyingChairs_train_val.txt"
+
+        if not os.path.exists(root / split_file_name):
+            raise FileNotFoundError(
+                "The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
+            )
+
+        split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32)
+        for i in range(len(flows)):
+            split_id = split_list[i]
+            if (split == "train" and split_id == 1) or (split == "val" and split_id == 2):
+                self._flow_list += [flows[i]]
+                self._image_list += [[images[2 * i], images[2 * i + 1]]]
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img1, img2, flow)``.
+            The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+            ``flow`` is None if ``split="val"``.
+            If a valid flow mask is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+        """
+        return super().__getitem__(index)
+
+    def _read_flow(self, file_name: str) -> np.ndarray:
+        return _read_flo(file_name)
+
+
+class FlyingThings3D(FlowDataset):
+    """`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            FlyingThings3D
+                frames_cleanpass
+                    TEST
+                    TRAIN
+                frames_finalpass
+                    TEST
+                    TRAIN
+                optical_flow
+                    TEST
+                    TRAIN
+
+    Args:
+        root (string): Root directory of the intel FlyingThings3D Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "test"
+        pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
+            details on the different passes.
+        camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+            ``valid_flow_mask`` is expected for consistency with other datasets which
+            return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        pass_name: str = "clean",
+        camera: str = "left",
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root=root, transforms=transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+        split = split.upper()
+
+        verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+        passes = {
+            "clean": ["frames_cleanpass"],
+            "final": ["frames_finalpass"],
+            "both": ["frames_cleanpass", "frames_finalpass"],
+        }[pass_name]
+
+        verify_str_arg(camera, "camera", valid_values=("left", "right", "both"))
+        cameras = ["left", "right"] if camera == "both" else [camera]
+
+        root = Path(root) / "FlyingThings3D"
+
+        directions = ("into_future", "into_past")
+        for pass_name, camera, direction in itertools.product(passes, cameras, directions):
+            image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
+            image_dirs = sorted(Path(image_dir) / camera for image_dir in image_dirs)
+
+            flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
+            flow_dirs = sorted(Path(flow_dir) / direction / camera for flow_dir in flow_dirs)
+
+            if not image_dirs or not flow_dirs:
+                raise FileNotFoundError(
+                    "Could not find the FlyingThings3D flow images. "
+                    "Please make sure the directory structure is correct."
+                )
+
+            for image_dir, flow_dir in zip(image_dirs, flow_dirs):
+                images = sorted(glob(str(image_dir / "*.png")))
+                flows = sorted(glob(str(flow_dir / "*.pfm")))
+                for i in range(len(flows) - 1):
+                    if direction == "into_future":
+                        self._image_list += [[images[i], images[i + 1]]]
+                        self._flow_list += [flows[i]]
+                    elif direction == "into_past":
+                        self._image_list += [[images[i + 1], images[i]]]
+                        self._flow_list += [flows[i + 1]]
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img1, img2, flow)``.
+            The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+            ``flow`` is None if ``split="test"``.
+            If a valid flow mask is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+        """
+        return super().__getitem__(index)
+
+    def _read_flow(self, file_name: str) -> np.ndarray:
+        return _read_pfm(file_name)
+
+
+class HD1K(FlowDataset):
+    """`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            hd1k
+                hd1k_challenge
+                    image_2
+                hd1k_flow_gt
+                    flow_occ
+                hd1k_input
+                    image_2
+
+    Args:
+        root (string): Root directory of the HD1K Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "test"
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+    """
+
+    _has_builtin_flow_mask = True
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root=root, transforms=transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "hd1k"
+        if split == "train":
+            # There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
+            for seq_idx in range(36):
+                flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
+                images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
+                for i in range(len(flows) - 1):
+                    self._flow_list += [flows[i]]
+                    self._image_list += [[images[i], images[i + 1]]]
+        else:
+            images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
+            images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
+            for image1, image2 in zip(images1, images2):
+                self._image_list += [[image1, image2]]
+
+        if not self._image_list:
+            raise FileNotFoundError(
+                "Could not find the HD1K images. Please make sure the directory structure is correct."
+            )
+
+    def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
+        return _read_16bits_png_with_flow_and_valid_mask(file_name)
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
+            is a numpy boolean mask of shape (H, W)
+            indicating which flow values are valid. The flow is a numpy array of
+            shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
+            ``split="test"``.
+        """
+        return super().__getitem__(index)
+
+
+def _read_flo(file_name: str) -> np.ndarray:
+    """Read .flo file in Middlebury format"""
+    # Code adapted from:
+    # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
+    # Everything needs to be in little Endian according to
+    # https://vision.middlebury.edu/flow/code/flow-code/README.txt
+    with open(file_name, "rb") as f:
+        magic = np.fromfile(f, "c", count=4).tobytes()
+        if magic != b"PIEH":
+            raise ValueError("Magic number incorrect. Invalid .flo file")
+
+        w = int(np.fromfile(f, "<i4", count=1))
+        h = int(np.fromfile(f, "<i4", count=1))
+        data = np.fromfile(f, "<f4", count=2 * w * h)
+        return data.reshape(h, w, 2).transpose(2, 0, 1)
+
+
+def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
+
+    flow_and_valid = _read_png_16(file_name).to(torch.float32)
+    flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
+    flow = (flow - 2**15) / 64  # This conversion is explained somewhere on the kitti archive
+    valid_flow_mask = valid_flow_mask.bool()
+
+    # For consistency with other datasets, we convert to numpy
+    return flow.numpy(), valid_flow_mask.numpy()

+ 1224 - 0
libs/vision_libs/datasets/_stereo_matching.py

@@ -0,0 +1,1224 @@
+import functools
+import json
+import os
+import random
+import shutil
+from abc import ABC, abstractmethod
+from glob import glob
+from pathlib import Path
+from typing import Callable, cast, List, Optional, Tuple, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import _read_pfm, download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray]
+T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
+
+__all__ = ()
+
+_read_pfm_file = functools.partial(_read_pfm, slice_channels=1)
+
+
+class StereoMatchingDataset(ABC, VisionDataset):
+    """Base interface for Stereo matching datasets"""
+
+    _has_built_in_disparity_mask = False
+
+    def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
+        """
+        Args:
+            root(str): Root directory of the dataset.
+            transforms(callable, optional): A function/transform that takes in Tuples of
+                (images, disparities, valid_masks) and returns a transformed version of each of them.
+                images is a Tuple of (``PIL.Image``, ``PIL.Image``)
+                disparities is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (1, H, W)
+                valid_masks is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (H, W)
+                In some cases, when a dataset does not provide disparities, the ``disparities`` and
+                ``valid_masks`` can be Tuples containing None values.
+                For training splits generally the datasets provide a minimal guarantee of
+                images: (``PIL.Image``, ``PIL.Image``)
+                disparities: (``np.ndarray``, ``None``) with shape (1, H, W)
+                Optionally, based on the dataset, it can return a ``mask`` as well:
+                valid_masks: (``np.ndarray | None``, ``None``) with shape (H, W)
+                For some test splits, the datasets provides outputs that look like:
+                imgaes: (``PIL.Image``, ``PIL.Image``)
+                disparities: (``None``, ``None``)
+                Optionally, based on the dataset, it can return a ``mask`` as well:
+                valid_masks: (``None``, ``None``)
+        """
+        super().__init__(root=root)
+        self.transforms = transforms
+
+        self._images = []  # type: ignore
+        self._disparities = []  # type: ignore
+
+    def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
+        img = Image.open(file_path)
+        if img.mode != "RGB":
+            img = img.convert("RGB")
+        return img
+
+    def _scan_pairs(
+        self,
+        paths_left_pattern: str,
+        paths_right_pattern: Optional[str] = None,
+    ) -> List[Tuple[str, Optional[str]]]:
+
+        left_paths = list(sorted(glob(paths_left_pattern)))
+
+        right_paths: List[Union[None, str]]
+        if paths_right_pattern:
+            right_paths = list(sorted(glob(paths_right_pattern)))
+        else:
+            right_paths = list(None for _ in left_paths)
+
+        if not left_paths:
+            raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}")
+
+        if not right_paths:
+            raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}")
+
+        if len(left_paths) != len(right_paths):
+            raise ValueError(
+                f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n "
+                f"left pattern: {paths_left_pattern}\n"
+                f"right pattern: {paths_right_pattern}\n"
+            )
+
+        paths = list((left, right) for left, right in zip(left_paths, right_paths))
+        return paths
+
+    @abstractmethod
+    def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
+        # function that returns a disparity map and an occlusion map
+        pass
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask``
+                can be a numpy boolean mask of shape (H, W) if the dataset provides a file
+                indicating which disparity pixels are valid. The disparity is a numpy array of
+                shape (1, H, W) and the images are PIL images. ``disparity`` is None for
+                datasets on which for ``split="test"`` the authors did not provide annotations.
+        """
+        img_left = self._read_img(self._images[index][0])
+        img_right = self._read_img(self._images[index][1])
+
+        dsp_map_left, valid_mask_left = self._read_disparity(self._disparities[index][0])
+        dsp_map_right, valid_mask_right = self._read_disparity(self._disparities[index][1])
+
+        imgs = (img_left, img_right)
+        dsp_maps = (dsp_map_left, dsp_map_right)
+        valid_masks = (valid_mask_left, valid_mask_right)
+
+        if self.transforms is not None:
+            (
+                imgs,
+                dsp_maps,
+                valid_masks,
+            ) = self.transforms(imgs, dsp_maps, valid_masks)
+
+        if self._has_built_in_disparity_mask or valid_masks[0] is not None:
+            return imgs[0], imgs[1], dsp_maps[0], cast(np.ndarray, valid_masks[0])
+        else:
+            return imgs[0], imgs[1], dsp_maps[0]
+
+    def __len__(self) -> int:
+        return len(self._images)
+
+
+class CarlaStereo(StereoMatchingDataset):
+    """
+    Carla simulator data linked in the `CREStereo github repo <https://github.com/megvii-research/CREStereo>`_.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            carla-highres
+                trainingF
+                    scene1
+                        img0.png
+                        img1.png
+                        disp0GT.pfm
+                        disp1GT.pfm
+                        calib.txt
+                    scene2
+                        img0.png
+                        img1.png
+                        disp0GT.pfm
+                        disp1GT.pfm
+                        calib.txt
+                    ...
+
+    Args:
+        root (string): Root directory where `carla-highres` is located.
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "carla-highres"
+
+        left_image_pattern = str(root / "trainingF" / "*" / "im0.png")
+        right_image_pattern = str(root / "trainingF" / "*" / "im1.png")
+        imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
+        self._images = imgs
+
+        left_disparity_pattern = str(root / "trainingF" / "*" / "disp0GT.pfm")
+        right_disparity_pattern = str(root / "trainingF" / "*" / "disp1GT.pfm")
+        disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+        self._disparities = disparities
+
+    def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
+        disparity_map = _read_pfm_file(file_path)
+        disparity_map = np.abs(disparity_map)  # ensure that the disparity is positive
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            If a ``valid_mask`` is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class Kitti2012Stereo(StereoMatchingDataset):
+    """
+    KITTI dataset from the `2012 stereo evaluation benchmark <http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php>`_.
+    Uses the RGB images for consistency with KITTI 2015.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            Kitti2012
+                testing
+                    colored_0
+                        1_10.png
+                        2_10.png
+                        ...
+                    colored_1
+                        1_10.png
+                        2_10.png
+                        ...
+                training
+                    colored_0
+                        1_10.png
+                        2_10.png
+                        ...
+                    colored_1
+                        1_10.png
+                        2_10.png
+                        ...
+                    disp_noc
+                        1.png
+                        2.png
+                        ...
+                    calib
+
+    Args:
+        root (string): Root directory where `Kitti2012` is located.
+        split (string, optional): The dataset split of scenes, either "train" (default) or "test".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "Kitti2012" / (split + "ing")
+
+        left_img_pattern = str(root / "colored_0" / "*_10.png")
+        right_img_pattern = str(root / "colored_1" / "*_10.png")
+        self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+        if split == "train":
+            disparity_pattern = str(root / "disp_noc" / "*.png")
+            self._disparities = self._scan_pairs(disparity_pattern, None)
+        else:
+            self._disparities = list((None, None) for _ in self._images)
+
+    def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
+        # test split has no disparity maps
+        if file_path is None:
+            return None, None
+
+        disparity_map = np.asarray(Image.open(file_path)) / 256.0
+        # unsqueeze the disparity map into (C, H, W) format
+        disparity_map = disparity_map[None, :, :]
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+            generate a valid mask.
+            Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class Kitti2015Stereo(StereoMatchingDataset):
+    """
+    KITTI dataset from the `2015 stereo evaluation benchmark <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php>`_.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            Kitti2015
+                testing
+                    image_2
+                        img1.png
+                        img2.png
+                        ...
+                    image_3
+                        img1.png
+                        img2.png
+                        ...
+                training
+                    image_2
+                        img1.png
+                        img2.png
+                        ...
+                    image_3
+                        img1.png
+                        img2.png
+                        ...
+                    disp_occ_0
+                        img1.png
+                        img2.png
+                        ...
+                    disp_occ_1
+                        img1.png
+                        img2.png
+                        ...
+                    calib
+
+    Args:
+        root (string): Root directory where `Kitti2015` is located.
+        split (string, optional): The dataset split of scenes, either "train" (default) or "test".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "Kitti2015" / (split + "ing")
+        left_img_pattern = str(root / "image_2" / "*.png")
+        right_img_pattern = str(root / "image_3" / "*.png")
+        self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+        if split == "train":
+            left_disparity_pattern = str(root / "disp_occ_0" / "*.png")
+            right_disparity_pattern = str(root / "disp_occ_1" / "*.png")
+            self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+        else:
+            self._disparities = list((None, None) for _ in self._images)
+
+    def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
+        # test split has no disparity maps
+        if file_path is None:
+            return None, None
+
+        disparity_map = np.asarray(Image.open(file_path)) / 256.0
+        # unsqueeze the disparity map into (C, H, W) format
+        disparity_map = disparity_map[None, :, :]
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+            generate a valid mask.
+            Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class Middlebury2014Stereo(StereoMatchingDataset):
+    """Publicly available scenes from the Middlebury dataset `2014 version <https://vision.middlebury.edu/stereo/data/scenes2014/>`.
+
+    The dataset mostly follows the original format, without containing the ambient subdirectories.  : ::
+
+        root
+            Middlebury2014
+                train
+                    scene1-{perfect,imperfect}
+                        calib.txt
+                        im{0,1}.png
+                        im1E.png
+                        im1L.png
+                        disp{0,1}.pfm
+                        disp{0,1}-n.png
+                        disp{0,1}-sd.pfm
+                        disp{0,1}y.pfm
+                    scene2-{perfect,imperfect}
+                        calib.txt
+                        im{0,1}.png
+                        im1E.png
+                        im1L.png
+                        disp{0,1}.pfm
+                        disp{0,1}-n.png
+                        disp{0,1}-sd.pfm
+                        disp{0,1}y.pfm
+                    ...
+                additional
+                    scene1-{perfect,imperfect}
+                        calib.txt
+                        im{0,1}.png
+                        im1E.png
+                        im1L.png
+                        disp{0,1}.pfm
+                        disp{0,1}-n.png
+                        disp{0,1}-sd.pfm
+                        disp{0,1}y.pfm
+                    ...
+                test
+                    scene1
+                        calib.txt
+                        im{0,1}.png
+                    scene2
+                        calib.txt
+                        im{0,1}.png
+                    ...
+
+    Args:
+        root (string): Root directory of the Middleburry 2014 Dataset.
+        split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
+        use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
+            The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
+        calibration (string, optional): Whether or not to use the calibrated (default) or uncalibrated scenes.
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+        download (boolean, optional): Whether or not to download the dataset in the ``root`` directory.
+    """
+
+    splits = {
+        "train": [
+            "Adirondack",
+            "Jadeplant",
+            "Motorcycle",
+            "Piano",
+            "Pipes",
+            "Playroom",
+            "Playtable",
+            "Recycle",
+            "Shelves",
+            "Vintage",
+        ],
+        "additional": [
+            "Backpack",
+            "Bicycle1",
+            "Cable",
+            "Classroom1",
+            "Couch",
+            "Flowers",
+            "Mask",
+            "Shopvac",
+            "Sticks",
+            "Storage",
+            "Sword1",
+            "Sword2",
+            "Umbrella",
+        ],
+        "test": [
+            "Plants",
+            "Classroom2E",
+            "Classroom2",
+            "Australia",
+            "DjembeL",
+            "CrusadeP",
+            "Crusade",
+            "Hoops",
+            "Bicycle2",
+            "Staircase",
+            "Newkuba",
+            "AustraliaP",
+            "Djembe",
+            "Livingroom",
+            "Computer",
+        ],
+    }
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        calibration: Optional[str] = "perfect",
+        use_ambient_views: bool = False,
+        transforms: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test", "additional"))
+        self.split = split
+
+        if calibration:
+            verify_str_arg(calibration, "calibration", valid_values=("perfect", "imperfect", "both", None))  # type: ignore
+            if split == "test":
+                raise ValueError("Split 'test' has only no calibration settings, please set `calibration=None`.")
+        else:
+            if split != "test":
+                raise ValueError(
+                    f"Split '{split}' has calibration settings, however None was provided as an argument."
+                    f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.",
+                )
+
+        if download:
+            self._download_dataset(root)
+
+        root = Path(root) / "Middlebury2014"
+
+        if not os.path.exists(root / split):
+            raise FileNotFoundError(f"The {split} directory was not found in the provided root directory")
+
+        split_scenes = self.splits[split]
+        # check that the provided root folder contains the scene splits
+        if not any(
+            # using startswith to account for perfect / imperfect calibrartion
+            scene.startswith(s)
+            for scene in os.listdir(root / split)
+            for s in split_scenes
+        ):
+            raise FileNotFoundError(f"Provided root folder does not contain any scenes from the {split} split.")
+
+        calibrartion_suffixes = {
+            None: [""],
+            "perfect": ["-perfect"],
+            "imperfect": ["-imperfect"],
+            "both": ["-perfect", "-imperfect"],
+        }[calibration]
+
+        for calibration_suffix in calibrartion_suffixes:
+            scene_pattern = "*" + calibration_suffix
+            left_img_pattern = str(root / split / scene_pattern / "im0.png")
+            right_img_pattern = str(root / split / scene_pattern / "im1.png")
+            self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+            if split == "test":
+                self._disparities = list((None, None) for _ in self._images)
+            else:
+                left_dispartity_pattern = str(root / split / scene_pattern / "disp0.pfm")
+                right_dispartity_pattern = str(root / split / scene_pattern / "disp1.pfm")
+                self._disparities += self._scan_pairs(left_dispartity_pattern, right_dispartity_pattern)
+
+        self.use_ambient_views = use_ambient_views
+
+    def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
+        """
+        Function that reads either the original right image or an augmented view when ``use_ambient_views`` is True.
+        When ``use_ambient_views`` is True, the dataset will return at random one of ``[im1.png, im1E.png, im1L.png]``
+        as the right image.
+        """
+        ambient_file_paths: List[Union[str, Path]]  # make mypy happy
+
+        if not isinstance(file_path, Path):
+            file_path = Path(file_path)
+
+        if file_path.name == "im1.png" and self.use_ambient_views:
+            base_path = file_path.parent
+            # initialize sampleable container
+            ambient_file_paths = list(base_path / view_name for view_name in ["im1E.png", "im1L.png"])
+            # double check that we're not going to try to read from an invalid file path
+            ambient_file_paths = list(filter(lambda p: os.path.exists(p), ambient_file_paths))
+            # keep the original image as an option as well for uniform sampling between base views
+            ambient_file_paths.append(file_path)
+            file_path = random.choice(ambient_file_paths)  # type: ignore
+        return super()._read_img(file_path)
+
+    def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
+        # test split has not disparity maps
+        if file_path is None:
+            return None, None
+
+        disparity_map = _read_pfm_file(file_path)
+        disparity_map = np.abs(disparity_map)  # ensure that the disparity is positive
+        disparity_map[disparity_map == np.inf] = 0  # remove infinite disparities
+        valid_mask = (disparity_map > 0).squeeze(0)  # mask out invalid disparities
+        return disparity_map, valid_mask
+
+    def _download_dataset(self, root: str) -> None:
+        base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
+        # train and additional splits have 2 different calibration settings
+        root = Path(root) / "Middlebury2014"
+        split_name = self.split
+
+        if split_name != "test":
+            for split_scene in self.splits[split_name]:
+                split_root = root / split_name
+                for calibration in ["perfect", "imperfect"]:
+                    scene_name = f"{split_scene}-{calibration}"
+                    scene_url = f"{base_url}/{scene_name}.zip"
+                    print(f"Downloading {scene_url}")
+                    # download the scene only if it doesn't exist
+                    if not (split_root / scene_name).exists():
+                        download_and_extract_archive(
+                            url=scene_url,
+                            filename=f"{scene_name}.zip",
+                            download_root=str(split_root),
+                            remove_finished=True,
+                        )
+        else:
+            os.makedirs(root / "test")
+            if any(s not in os.listdir(root / "test") for s in self.splits["test"]):
+                # test split is downloaded from a different location
+                test_set_url = "https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip"
+                # the unzip is going to produce a directory MiddEval3 with two subdirectories trainingF and testF
+                # we want to move the contents from testF into the  directory
+                download_and_extract_archive(url=test_set_url, download_root=str(root), remove_finished=True)
+                for scene_dir, scene_names, _ in os.walk(str(root / "MiddEval3/testF")):
+                    for scene in scene_names:
+                        scene_dst_dir = root / "test"
+                        scene_src_dir = Path(scene_dir) / scene
+                        os.makedirs(scene_dst_dir, exist_ok=True)
+                        shutil.move(str(scene_src_dir), str(scene_dst_dir))
+
+                # cleanup MiddEval3 directory
+                shutil.rmtree(str(root / "MiddEval3"))
+
+    def __getitem__(self, index: int) -> T2:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` for `split=test`.
+        """
+        return cast(T2, super().__getitem__(index))
+
+
+class CREStereo(StereoMatchingDataset):
+    """Synthetic dataset used in training the `CREStereo <https://arxiv.org/pdf/2203.11483.pdf>`_ architecture.
+    Dataset details on the official paper `repo <https://github.com/megvii-research/CREStereo>`_.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            CREStereo
+                tree
+                    img1_left.jpg
+                    img1_right.jpg
+                    img1_left.disp.jpg
+                    img1_right.disp.jpg
+                    img2_left.jpg
+                    img2_right.jpg
+                    img2_left.disp.jpg
+                    img2_right.disp.jpg
+                    ...
+                shapenet
+                    img1_left.jpg
+                    img1_right.jpg
+                    img1_left.disp.jpg
+                    img1_right.disp.jpg
+                    ...
+                reflective
+                    img1_left.jpg
+                    img1_right.jpg
+                    img1_left.disp.jpg
+                    img1_right.disp.jpg
+                    ...
+                hole
+                    img1_left.jpg
+                    img1_right.jpg
+                    img1_left.disp.jpg
+                    img1_right.disp.jpg
+                    ...
+
+    Args:
+        root (str): Root directory of the dataset.
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(
+        self,
+        root: str,
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "CREStereo"
+
+        dirs = ["shapenet", "reflective", "tree", "hole"]
+
+        for s in dirs:
+            left_image_pattern = str(root / s / "*_left.jpg")
+            right_image_pattern = str(root / s / "*_right.jpg")
+            imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
+            self._images += imgs
+
+            left_disparity_pattern = str(root / s / "*_left.disp.png")
+            right_disparity_pattern = str(root / s / "*_right.disp.png")
+            disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+            self._disparities += disparities
+
+    def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
+        disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
+        # unsqueeze the disparity map into (C, H, W) format
+        disparity_map = disparity_map[None, :, :] / 32.0
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+            generate a valid mask.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class FallingThingsStereo(StereoMatchingDataset):
+    """`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            FallingThings
+                single
+                    dir1
+                        scene1
+                            _object_settings.json
+                            _camera_settings.json
+                            image1.left.depth.png
+                            image1.right.depth.png
+                            image1.left.jpg
+                            image1.right.jpg
+                            image2.left.depth.png
+                            image2.right.depth.png
+                            image2.left.jpg
+                            image2.right
+                            ...
+                        scene2
+                    ...
+                mixed
+                    scene1
+                        _object_settings.json
+                        _camera_settings.json
+                        image1.left.depth.png
+                        image1.right.depth.png
+                        image1.left.jpg
+                        image1.right.jpg
+                        image2.left.depth.png
+                        image2.right.depth.png
+                        image2.left.jpg
+                        image2.right
+                        ...
+                    scene2
+                    ...
+
+    Args:
+        root (string): Root directory where FallingThings is located.
+        variant (string): Which variant to use. Either "single", "mixed", or "both".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "FallingThings"
+
+        verify_str_arg(variant, "variant", valid_values=("single", "mixed", "both"))
+
+        variants = {
+            "single": ["single"],
+            "mixed": ["mixed"],
+            "both": ["single", "mixed"],
+        }[variant]
+
+        split_prefix = {
+            "single": Path("*") / "*",
+            "mixed": Path("*"),
+        }
+
+        for s in variants:
+            left_img_pattern = str(root / s / split_prefix[s] / "*.left.jpg")
+            right_img_pattern = str(root / s / split_prefix[s] / "*.right.jpg")
+            self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+            left_disparity_pattern = str(root / s / split_prefix[s] / "*.left.depth.png")
+            right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png")
+            self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+
+    def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
+        # (H, W) image
+        depth = np.asarray(Image.open(file_path))
+        # as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
+        # in order to extract disparity from depth maps
+        camera_settings_path = Path(file_path).parent / "_camera_settings.json"
+        with open(camera_settings_path, "r") as f:
+            # inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_constant)
+            intrinsics = json.load(f)
+            focal = intrinsics["camera_settings"][0]["intrinsic_settings"]["fx"]
+            baseline, pixel_constant = 6, 100  # pixel constant is inverted
+            disparity_map = (baseline * focal * pixel_constant) / depth.astype(np.float32)
+            # unsqueeze disparity to (C, H, W)
+            disparity_map = disparity_map[None, :, :]
+            valid_mask = None
+            return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            If a ``valid_mask`` is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class SceneFlowStereo(StereoMatchingDataset):
+    """Dataset interface for `Scene Flow <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ datasets.
+    This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            SceneFlow
+                Monkaa
+                    frames_cleanpass
+                        scene1
+                            left
+                                img1.png
+                                img2.png
+                            right
+                                img1.png
+                                img2.png
+                        scene2
+                            left
+                                img1.png
+                                img2.png
+                            right
+                                img1.png
+                                img2.png
+                    frames_finalpass
+                        scene1
+                            left
+                                img1.png
+                                img2.png
+                            right
+                                img1.png
+                                img2.png
+                        ...
+                        ...
+                    disparity
+                        scene1
+                            left
+                                img1.pfm
+                                img2.pfm
+                            right
+                                img1.pfm
+                                img2.pfm
+                FlyingThings3D
+                    ...
+                    ...
+
+    Args:
+        root (string): Root directory where SceneFlow is located.
+        variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
+        pass_name (string): Which pass to use, "clean" (default), "final" or "both".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+
+    """
+
+    def __init__(
+        self,
+        root: str,
+        variant: str = "FlyingThings3D",
+        pass_name: str = "clean",
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "SceneFlow"
+
+        verify_str_arg(variant, "variant", valid_values=("FlyingThings3D", "Driving", "Monkaa"))
+        verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+
+        passes = {
+            "clean": ["frames_cleanpass"],
+            "final": ["frames_finalpass"],
+            "both": ["frames_cleanpass", "frames_finalpass"],
+        }[pass_name]
+
+        root = root / variant
+
+        prefix_directories = {
+            "Monkaa": Path("*"),
+            "FlyingThings3D": Path("*") / "*" / "*",
+            "Driving": Path("*") / "*" / "*",
+        }
+
+        for p in passes:
+            left_image_pattern = str(root / p / prefix_directories[variant] / "left" / "*.png")
+            right_image_pattern = str(root / p / prefix_directories[variant] / "right" / "*.png")
+            self._images += self._scan_pairs(left_image_pattern, right_image_pattern)
+
+            left_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "left" / "*.pfm")
+            right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm")
+            self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+
+    def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
+        disparity_map = _read_pfm_file(file_path)
+        disparity_map = np.abs(disparity_map)  # ensure that the disparity is positive
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            If a ``valid_mask`` is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class SintelStereo(StereoMatchingDataset):
+    """Sintel `Stereo Dataset <http://sintel.is.tue.mpg.de/stereo>`_.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            Sintel
+                training
+                    final_left
+                        scene1
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+                    final_right
+                        scene2
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+                    disparities
+                        scene1
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+                    occlusions
+                        scene1
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+                    outofframe
+                        scene1
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+
+    Args:
+        root (string): Root directory where Sintel Stereo is located.
+        pass_name (string): The name of the pass to use, either "final", "clean" or "both".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
+
+        root = Path(root) / "Sintel"
+        pass_names = {
+            "final": ["final"],
+            "clean": ["clean"],
+            "both": ["final", "clean"],
+        }[pass_name]
+
+        for p in pass_names:
+            left_img_pattern = str(root / "training" / f"{p}_left" / "*" / "*.png")
+            right_img_pattern = str(root / "training" / f"{p}_right" / "*" / "*.png")
+            self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+            disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png")
+            self._disparities += self._scan_pairs(disparity_pattern, None)
+
+    def _get_occlussion_mask_paths(self, file_path: str) -> Tuple[str, str]:
+        # helper function to get the occlusion mask paths
+        # a path will look like  .../.../.../training/disparities/scene1/img1.png
+        # we want to get something like .../.../.../training/occlusions/scene1/img1.png
+        fpath = Path(file_path)
+        basename = fpath.name
+        scenedir = fpath.parent
+        # the parent of the scenedir is actually the disparity dir
+        sampledir = scenedir.parent.parent
+
+        occlusion_path = str(sampledir / "occlusions" / scenedir.name / basename)
+        outofframe_path = str(sampledir / "outofframe" / scenedir.name / basename)
+
+        if not os.path.exists(occlusion_path):
+            raise FileNotFoundError(f"Occlusion mask {occlusion_path} does not exist")
+
+        if not os.path.exists(outofframe_path):
+            raise FileNotFoundError(f"Out of frame mask {outofframe_path} does not exist")
+
+        return occlusion_path, outofframe_path
+
+    def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
+        if file_path is None:
+            return None, None
+
+        # disparity decoding as per Sintel instructions in the README provided with the dataset
+        disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
+        r, g, b = np.split(disparity_map, 3, axis=-1)
+        disparity_map = r * 4 + g / (2**6) + b / (2**14)
+        # reshape into (C, H, W) format
+        disparity_map = np.transpose(disparity_map, (2, 0, 1))
+        # find the appropriate file paths
+        occlued_mask_path, out_of_frame_mask_path = self._get_occlussion_mask_paths(file_path)
+        # occlusion masks
+        valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0
+        # out of frame masks
+        off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0
+        # combine the masks together
+        valid_mask = np.logical_and(off_mask, valid_mask)
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T2:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst
+            the valid_mask is a numpy array of shape (H, W).
+        """
+        return cast(T2, super().__getitem__(index))
+
+
+class InStereo2k(StereoMatchingDataset):
+    """`InStereo2k <https://github.com/YuhuaXu/StereoDataset>`_ dataset.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            InStereo2k
+                train
+                    scene1
+                        left.png
+                        right.png
+                        left_disp.png
+                        right_disp.png
+                        ...
+                    scene2
+                    ...
+                test
+                    scene1
+                        left.png
+                        right.png
+                        left_disp.png
+                        right_disp.png
+                        ...
+                    scene2
+                    ...
+
+    Args:
+        root (string): Root directory where InStereo2k is located.
+        split (string): Either "train" or "test".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "InStereo2k" / split
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        left_img_pattern = str(root / "*" / "left.png")
+        right_img_pattern = str(root / "*" / "right.png")
+        self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+        left_disparity_pattern = str(root / "*" / "left_disp.png")
+        right_disparity_pattern = str(root / "*" / "right_disp.png")
+        self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+
+    def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
+        disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
+        # unsqueeze disparity to (C, H, W)
+        disparity_map = disparity_map[None, :, :] / 1024.0
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            If a ``valid_mask`` is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class ETH3DStereo(StereoMatchingDataset):
+    """ETH3D `Low-Res Two-View <https://www.eth3d.net/datasets>`_ dataset.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            ETH3D
+                two_view_training
+                    scene1
+                        im1.png
+                        im0.png
+                        images.txt
+                        cameras.txt
+                        calib.txt
+                    scene2
+                        im1.png
+                        im0.png
+                        images.txt
+                        cameras.txt
+                        calib.txt
+                    ...
+                two_view_training_gt
+                    scene1
+                        disp0GT.pfm
+                        mask0nocc.png
+                    scene2
+                        disp0GT.pfm
+                        mask0nocc.png
+                    ...
+                two_view_testing
+                    scene1
+                        im1.png
+                        im0.png
+                        images.txt
+                        cameras.txt
+                        calib.txt
+                    scene2
+                        im1.png
+                        im0.png
+                        images.txt
+                        cameras.txt
+                        calib.txt
+                    ...
+
+    Args:
+        root (string): Root directory of the ETH3D Dataset.
+        split (string, optional): The dataset split of scenes, either "train" (default) or "test".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "ETH3D"
+
+        img_dir = "two_view_training" if split == "train" else "two_view_test"
+        anot_dir = "two_view_training_gt"
+
+        left_img_pattern = str(root / img_dir / "*" / "im0.png")
+        right_img_pattern = str(root / img_dir / "*" / "im1.png")
+        self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+        if split == "test":
+            self._disparities = list((None, None) for _ in self._images)
+        else:
+            disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
+            self._disparities = self._scan_pairs(disparity_pattern, None)
+
+    def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
+        # test split has no disparity maps
+        if file_path is None:
+            return None, None
+
+        disparity_map = _read_pfm_file(file_path)
+        disparity_map = np.abs(disparity_map)  # ensure that the disparity is positive
+        mask_path = Path(file_path).parent / "mask0nocc.png"
+        valid_mask = Image.open(mask_path)
+        valid_mask = np.asarray(valid_mask).astype(bool)
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T2:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+            generate a valid mask.
+            Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
+        """
+        return cast(T2, super().__getitem__(index))

+ 241 - 0
libs/vision_libs/datasets/caltech.py

@@ -0,0 +1,241 @@
+import os
+import os.path
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Caltech101(VisionDataset):
+    """`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset.
+
+    .. warning::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``caltech101`` exists or will be saved to if download is set to True.
+        target_type (string or list, optional): Type of target to use, ``category`` or
+            ``annotation``. Can also be a list to output a tuple with all specified
+            target types.  ``category`` represents the target class, and
+            ``annotation`` is a list of points from a hand-generated outline.
+            Defaults to ``category``.
+        transform (callable, optional): A function/transform that takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+            .. warning::
+
+                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        target_type: Union[List[str], str] = "category",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
+        os.makedirs(self.root, exist_ok=True)
+        if isinstance(target_type, str):
+            target_type = [target_type]
+        self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
+        self.categories.remove("BACKGROUND_Google")  # this is not a real class
+
+        # For some reason, the category names in "101_ObjectCategories" and
+        # "Annotations" do not always match. This is a manual map between the
+        # two. Defaults to using same name, since most names are fine.
+        name_map = {
+            "Faces": "Faces_2",
+            "Faces_easy": "Faces_3",
+            "Motorbikes": "Motorbikes_16",
+            "airplanes": "Airplanes_Side_2",
+        }
+        self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
+
+        self.index: List[int] = []
+        self.y = []
+        for (i, c) in enumerate(self.categories):
+            n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
+            self.index.extend(range(1, n + 1))
+            self.y.extend(n * [i])
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where the type of target specified by target_type.
+        """
+        import scipy.io
+
+        img = Image.open(
+            os.path.join(
+                self.root,
+                "101_ObjectCategories",
+                self.categories[self.y[index]],
+                f"image_{self.index[index]:04d}.jpg",
+            )
+        )
+
+        target: Any = []
+        for t in self.target_type:
+            if t == "category":
+                target.append(self.y[index])
+            elif t == "annotation":
+                data = scipy.io.loadmat(
+                    os.path.join(
+                        self.root,
+                        "Annotations",
+                        self.annotation_categories[self.y[index]],
+                        f"annotation_{self.index[index]:04d}.mat",
+                    )
+                )
+                target.append(data["obj_contour"])
+        target = tuple(target) if len(target) > 1 else target[0]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def _check_integrity(self) -> bool:
+        # can be more robust and check hash of files
+        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
+
+    def __len__(self) -> int:
+        return len(self.index)
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        download_and_extract_archive(
+            "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
+            self.root,
+            filename="101_ObjectCategories.tar.gz",
+            md5="b224c7392d521a49829488ab0f1120d9",
+        )
+        download_and_extract_archive(
+            "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
+            self.root,
+            filename="Annotations.tar",
+            md5="6f83eeb1f24d99cab4eb377263132c91",
+        )
+
+    def extra_repr(self) -> str:
+        return "Target type: {target_type}".format(**self.__dict__)
+
+
+class Caltech256(VisionDataset):
+    """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``caltech256`` exists or will be saved to if download is set to True.
+        transform (callable, optional): A function/transform that takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
+        os.makedirs(self.root, exist_ok=True)
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
+        self.index: List[int] = []
+        self.y = []
+        for (i, c) in enumerate(self.categories):
+            n = len(
+                [
+                    item
+                    for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
+                    if item.endswith(".jpg")
+                ]
+            )
+            self.index.extend(range(1, n + 1))
+            self.y.extend(n * [i])
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img = Image.open(
+            os.path.join(
+                self.root,
+                "256_ObjectCategories",
+                self.categories[self.y[index]],
+                f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
+            )
+        )
+
+        target = self.y[index]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def _check_integrity(self) -> bool:
+        # can be more robust and check hash of files
+        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
+
+    def __len__(self) -> int:
+        return len(self.index)
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        download_and_extract_archive(
+            "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
+            self.root,
+            filename="256_ObjectCategories.tar",
+            md5="67b4f42ca05d46448c6bb8ecd2220f6d",
+        )

+ 193 - 0
libs/vision_libs/datasets/celeba.py

@@ -0,0 +1,193 @@
+import csv
+import os
+from collections import namedtuple
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import PIL
+import torch
+
+from .utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+CSV = namedtuple("CSV", ["header", "index", "data"])
+
+
+class CelebA(VisionDataset):
+    """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        split (string): One of {'train', 'valid', 'test', 'all'}.
+            Accordingly dataset is selected.
+        target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
+            or ``landmarks``. Can also be a list to output a tuple with all specified target types.
+            The targets represent:
+
+                - ``attr`` (Tensor shape=(40,) dtype=int): binary (0, 1) labels for attributes
+                - ``identity`` (int): label for each person (data points with the same identity are the same person)
+                - ``bbox`` (Tensor shape=(4,) dtype=int): bounding box (x, y, width, height)
+                - ``landmarks`` (Tensor shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
+                  righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
+
+            Defaults to ``attr``. If empty, ``None`` will be returned as target.
+
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+            .. warning::
+
+                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
+    """
+
+    base_folder = "celeba"
+    # There currently does not appear to be an easy way to extract 7z in python (without introducing additional
+    # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
+    # right now.
+    file_list = [
+        # File ID                                      MD5 Hash                            Filename
+        ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
+        # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
+        # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
+        ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
+        ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
+        ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
+        ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
+        # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
+        ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
+    ]
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        target_type: Union[List[str], str] = "attr",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.split = split
+        if isinstance(target_type, list):
+            self.target_type = target_type
+        else:
+            self.target_type = [target_type]
+
+        if not self.target_type and self.target_transform is not None:
+            raise RuntimeError("target_transform is specified but target_type is empty")
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        split_map = {
+            "train": 0,
+            "valid": 1,
+            "test": 2,
+            "all": None,
+        }
+        split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
+        splits = self._load_csv("list_eval_partition.txt")
+        identity = self._load_csv("identity_CelebA.txt")
+        bbox = self._load_csv("list_bbox_celeba.txt", header=1)
+        landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
+        attr = self._load_csv("list_attr_celeba.txt", header=1)
+
+        mask = slice(None) if split_ is None else (splits.data == split_).squeeze()
+
+        if mask == slice(None):  # if split == "all"
+            self.filename = splits.index
+        else:
+            self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
+        self.identity = identity.data[mask]
+        self.bbox = bbox.data[mask]
+        self.landmarks_align = landmarks_align.data[mask]
+        self.attr = attr.data[mask]
+        # map from {-1, 1} to {0, 1}
+        self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
+        self.attr_names = attr.header
+
+    def _load_csv(
+        self,
+        filename: str,
+        header: Optional[int] = None,
+    ) -> CSV:
+        with open(os.path.join(self.root, self.base_folder, filename)) as csv_file:
+            data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))
+
+        if header is not None:
+            headers = data[header]
+            data = data[header + 1 :]
+        else:
+            headers = []
+
+        indices = [row[0] for row in data]
+        data = [row[1:] for row in data]
+        data_int = [list(map(int, i)) for i in data]
+
+        return CSV(headers, indices, torch.tensor(data_int))
+
+    def _check_integrity(self) -> bool:
+        for (_, md5, filename) in self.file_list:
+            fpath = os.path.join(self.root, self.base_folder, filename)
+            _, ext = os.path.splitext(filename)
+            # Allow original archive to be deleted (zip and 7z)
+            # Only need the extracted images
+            if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
+                return False
+
+        # Should check a hash of the images
+        return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        for (file_id, md5, filename) in self.file_list:
+            download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
+
+        extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
+
+        target: Any = []
+        for t in self.target_type:
+            if t == "attr":
+                target.append(self.attr[index, :])
+            elif t == "identity":
+                target.append(self.identity[index, 0])
+            elif t == "bbox":
+                target.append(self.bbox[index, :])
+            elif t == "landmarks":
+                target.append(self.landmarks_align[index, :])
+            else:
+                # TODO: refactor with utils.verify_str_arg
+                raise ValueError(f'Target type "{t}" is not recognized.')
+
+        if self.transform is not None:
+            X = self.transform(X)
+
+        if target:
+            target = tuple(target) if len(target) > 1 else target[0]
+
+            if self.target_transform is not None:
+                target = self.target_transform(target)
+        else:
+            target = None
+
+        return X, target
+
+    def __len__(self) -> int:
+        return len(self.attr)
+
+    def extra_repr(self) -> str:
+        lines = ["Target type: {target_type}", "Split: {split}"]
+        return "\n".join(lines).format(**self.__dict__)

+ 167 - 0
libs/vision_libs/datasets/cifar.py

@@ -0,0 +1,167 @@
+import os.path
+import pickle
+from typing import Any, Callable, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive
+from .vision import VisionDataset
+
+
+class CIFAR10(VisionDataset):
+    """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
+        train (bool, optional): If True, creates dataset from training set, otherwise
+            creates from test set.
+        transform (callable, optional): A function/transform that takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    base_folder = "cifar-10-batches-py"
+    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
+    filename = "cifar-10-python.tar.gz"
+    tgz_md5 = "c58f30108f718f92721af3b95e74349a"
+    train_list = [
+        ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
+        ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
+        ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
+        ["data_batch_4", "634d18415352ddfa80567beed471001a"],
+        ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
+    ]
+
+    test_list = [
+        ["test_batch", "40351d587109b95175f43aff81a1287e"],
+    ]
+    meta = {
+        "filename": "batches.meta",
+        "key": "label_names",
+        "md5": "5ff9c542aee3614f3951f8cda6e48888",
+    }
+
+    def __init__(
+        self,
+        root: str,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self.train = train  # training set or test set
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        if self.train:
+            downloaded_list = self.train_list
+        else:
+            downloaded_list = self.test_list
+
+        self.data: Any = []
+        self.targets = []
+
+        # now load the picked numpy arrays
+        for file_name, checksum in downloaded_list:
+            file_path = os.path.join(self.root, self.base_folder, file_name)
+            with open(file_path, "rb") as f:
+                entry = pickle.load(f, encoding="latin1")
+                self.data.append(entry["data"])
+                if "labels" in entry:
+                    self.targets.extend(entry["labels"])
+                else:
+                    self.targets.extend(entry["fine_labels"])
+
+        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
+        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
+
+        self._load_meta()
+
+    def _load_meta(self) -> None:
+        path = os.path.join(self.root, self.base_folder, self.meta["filename"])
+        if not check_integrity(path, self.meta["md5"]):
+            raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
+        with open(path, "rb") as infile:
+            data = pickle.load(infile, encoding="latin1")
+            self.classes = data[self.meta["key"]]
+        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], self.targets[index]
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(img)
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    def _check_integrity(self) -> bool:
+        for filename, md5 in self.train_list + self.test_list:
+            fpath = os.path.join(self.root, self.base_folder, filename)
+            if not check_integrity(fpath, md5):
+                return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
+
+    def extra_repr(self) -> str:
+        split = "Train" if self.train is True else "Test"
+        return f"Split: {split}"
+
+
+class CIFAR100(CIFAR10):
+    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
+
+    This is a subclass of the `CIFAR10` Dataset.
+    """
+
+    base_folder = "cifar-100-python"
+    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
+    filename = "cifar-100-python.tar.gz"
+    tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
+    train_list = [
+        ["train", "16019d7e3df5f24257cddd939b257f8d"],
+    ]
+
+    test_list = [
+        ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
+    ]
+    meta = {
+        "filename": "meta",
+        "key": "fine_label_names",
+        "md5": "7973b15100ade9c7d40fb424638fde48",
+    }

+ 221 - 0
libs/vision_libs/datasets/cityscapes.py

@@ -0,0 +1,221 @@
+import json
+import os
+from collections import namedtuple
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import extract_archive, iterable_to_str, verify_str_arg
+from .vision import VisionDataset
+
+
+class Cityscapes(VisionDataset):
+    """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory ``leftImg8bit``
+            and ``gtFine`` or ``gtCoarse`` are located.
+        split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
+            otherwise ``train``, ``train_extra`` or ``val``
+        mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
+        target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
+            or ``color``. Can also be a list to output a tuple with all specified target types.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+
+    Examples:
+
+        Get semantic segmentation target
+
+        .. code-block:: python
+
+            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
+                                 target_type='semantic')
+
+            img, smnt = dataset[0]
+
+        Get multiple targets
+
+        .. code-block:: python
+
+            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
+                                 target_type=['instance', 'color', 'polygon'])
+
+            img, (inst, col, poly) = dataset[0]
+
+        Validate on the "coarse" set
+
+        .. code-block:: python
+
+            dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
+                                 target_type='semantic')
+
+            img, smnt = dataset[0]
+    """
+
+    # Based on https://github.com/mcordts/cityscapesScripts
+    CityscapesClass = namedtuple(
+        "CityscapesClass",
+        ["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"],
+    )
+
+    classes = [
+        CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)),
+        CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)),
+        CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)),
+        CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)),
+        CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)),
+        CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)),
+        CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)),
+        CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)),
+        CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)),
+        CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)),
+        CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)),
+        CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)),
+        CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)),
+        CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)),
+        CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)),
+        CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)),
+        CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)),
+        CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)),
+        CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)),
+        CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)),
+        CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)),
+        CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)),
+        CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)),
+        CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)),
+        CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)),
+        CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)),
+        CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)),
+        CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)),
+        CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)),
+        CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)),
+    ]
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        mode: str = "fine",
+        target_type: Union[List[str], str] = "instance",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transforms, transform, target_transform)
+        self.mode = "gtFine" if mode == "fine" else "gtCoarse"
+        self.images_dir = os.path.join(self.root, "leftImg8bit", split)
+        self.targets_dir = os.path.join(self.root, self.mode, split)
+        self.target_type = target_type
+        self.split = split
+        self.images = []
+        self.targets = []
+
+        verify_str_arg(mode, "mode", ("fine", "coarse"))
+        if mode == "fine":
+            valid_modes = ("train", "test", "val")
+        else:
+            valid_modes = ("train", "train_extra", "val")
+        msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
+        msg = msg.format(split, mode, iterable_to_str(valid_modes))
+        verify_str_arg(split, "split", valid_modes, msg)
+
+        if not isinstance(target_type, list):
+            self.target_type = [target_type]
+        [
+            verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color"))
+            for value in self.target_type
+        ]
+
+        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
+
+            if split == "train_extra":
+                image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip")
+            else:
+                image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip")
+
+            if self.mode == "gtFine":
+                target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip")
+            elif self.mode == "gtCoarse":
+                target_dir_zip = os.path.join(self.root, f"{self.mode}.zip")
+
+            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
+                extract_archive(from_path=image_dir_zip, to_path=self.root)
+                extract_archive(from_path=target_dir_zip, to_path=self.root)
+            else:
+                raise RuntimeError(
+                    "Dataset not found or incomplete. Please make sure all required folders for the"
+                    ' specified "split" and "mode" are inside the "root" directory'
+                )
+
+        for city in os.listdir(self.images_dir):
+            img_dir = os.path.join(self.images_dir, city)
+            target_dir = os.path.join(self.targets_dir, city)
+            for file_name in os.listdir(img_dir):
+                target_types = []
+                for t in self.target_type:
+                    target_name = "{}_{}".format(
+                        file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t)
+                    )
+                    target_types.append(os.path.join(target_dir, target_name))
+
+                self.images.append(os.path.join(img_dir, file_name))
+                self.targets.append(target_types)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+        Returns:
+            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
+            than one item. Otherwise, target is a json object if target_type="polygon", else the image segmentation.
+        """
+
+        image = Image.open(self.images[index]).convert("RGB")
+
+        targets: Any = []
+        for i, t in enumerate(self.target_type):
+            if t == "polygon":
+                target = self._load_json(self.targets[index][i])
+            else:
+                target = Image.open(self.targets[index][i])
+
+            targets.append(target)
+
+        target = tuple(targets) if len(targets) > 1 else targets[0]
+
+        if self.transforms is not None:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def __len__(self) -> int:
+        return len(self.images)
+
+    def extra_repr(self) -> str:
+        lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
+        return "\n".join(lines).format(**self.__dict__)
+
+    def _load_json(self, path: str) -> Dict[str, Any]:
+        with open(path) as file:
+            data = json.load(file)
+        return data
+
+    def _get_target_suffix(self, mode: str, target_type: str) -> str:
+        if target_type == "instance":
+            return f"{mode}_instanceIds.png"
+        elif target_type == "semantic":
+            return f"{mode}_labelIds.png"
+        elif target_type == "color":
+            return f"{mode}_color.png"
+        else:
+            return f"{mode}_polygons.json"

+ 88 - 0
libs/vision_libs/datasets/clevr.py

@@ -0,0 +1,88 @@
+import json
+import pathlib
+from typing import Any, Callable, List, Optional, Tuple
+from urllib.parse import urlparse
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class CLEVRClassification(VisionDataset):
+    """`CLEVR <https://cs.stanford.edu/people/jcjohns/clevr/>`_  classification dataset.
+
+    The number of objects in a scene are used as label.
+
+    Args:
+        root (string): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
+            set to True.
+        split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
+        transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in them target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
+            dataset is already downloaded, it is not downloaded again.
+    """
+
+    _URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
+    _MD5 = "b11922020e72d0cd9154779b2d3d07d2"
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._base_folder = pathlib.Path(self.root) / "clevr"
+        self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*"))
+
+        self._labels: List[Optional[int]]
+        if self._split != "test":
+            with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file:
+                content = json.load(file)
+            num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]}
+            self._labels = [num_objects[image_file.name] for image_file in self._image_files]
+        else:
+            self._labels = [None] * len(self._image_files)
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file = self._image_files[idx]
+        label = self._labels[idx]
+
+        image = Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def _check_exists(self) -> bool:
+        return self._data_folder.exists() and self._data_folder.is_dir()
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+
+        download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5)
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"

+ 104 - 0
libs/vision_libs/datasets/coco.py

@@ -0,0 +1,104 @@
+import os.path
+from typing import Any, Callable, List, Optional, Tuple
+
+from PIL import Image
+
+from .vision import VisionDataset
+
+
+class CocoDetection(VisionDataset):
+    """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
+
+    It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        annFile: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transforms, transform, target_transform)
+        from pycocotools.coco import COCO
+
+        self.coco = COCO(annFile)
+        self.ids = list(sorted(self.coco.imgs.keys()))
+
+    def _load_image(self, id: int) -> Image.Image:
+        path = self.coco.loadImgs(id)[0]["file_name"]
+        return Image.open(os.path.join(self.root, path)).convert("RGB")
+
+    def _load_target(self, id: int) -> List[Any]:
+        return self.coco.loadAnns(self.coco.getAnnIds(id))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        id = self.ids[index]
+        image = self._load_image(id)
+        target = self._load_target(id)
+
+        if self.transforms is not None:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def __len__(self) -> int:
+        return len(self.ids)
+
+
+class CocoCaptions(CocoDetection):
+    """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
+
+    It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+
+    Example:
+
+        .. code:: python
+
+            import torchvision.datasets as dset
+            import torchvision.transforms as transforms
+            cap = dset.CocoCaptions(root = 'dir where images are',
+                                    annFile = 'json annotation file',
+                                    transform=transforms.PILToTensor())
+
+            print('Number of samples: ', len(cap))
+            img, target = cap[3] # load 4th sample
+
+            print("Image Size: ", img.size())
+            print(target)
+
+        Output: ::
+
+            Number of samples: 82783
+            Image Size: (3L, 427L, 640L)
+            [u'A plane emitting smoke stream flying over a mountain.',
+            u'A plane darts across a bright blue sky behind a mountain covered in snow',
+            u'A plane leaves a contrail above the snowy mountain top.',
+            u'A mountain that has a plane flying overheard in the distance.',
+            u'A mountain view with a plume of smoke in the background']
+
+    """
+
+    def _load_target(self, id: int) -> List[str]:
+        return [ann["caption"] for ann in super()._load_target(id)]

+ 58 - 0
libs/vision_libs/datasets/country211.py

@@ -0,0 +1,58 @@
+from pathlib import Path
+from typing import Callable, Optional
+
+from .folder import ImageFolder
+from .utils import download_and_extract_archive, verify_str_arg
+
+
+class Country211(ImageFolder):
+    """`The Country211 Data Set <https://github.com/openai/CLIP/blob/main/data/country211.md>`_ from OpenAI.
+
+    This dataset was built by filtering the images from the YFCC100m dataset
+    that have GPS coordinate corresponding to a ISO-3166 country code. The
+    dataset is balanced by sampling 150 train images, 50 validation images, and
+    100 test images for each country.
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
+        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and puts it into
+            ``root/country211/``. If dataset is already downloaded, it is not downloaded again.
+    """
+
+    _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
+    _MD5 = "84988d7644798601126c29e9877aab6a"
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
+
+        root = Path(root).expanduser()
+        self.root = str(root)
+        self._base_folder = root / "country211"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform)
+        self.root = str(root)
+
+    def _check_exists(self) -> bool:
+        return self._base_folder.exists() and self._base_folder.is_dir()
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

+ 100 - 0
libs/vision_libs/datasets/dtd.py

@@ -0,0 +1,100 @@
+import os
+import pathlib
+from typing import Any, Callable, Optional, Tuple
+
+import PIL.Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class DTD(VisionDataset):
+    """`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
+        partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
+
+            .. note::
+
+                The partition only changes which split each image belongs to. Thus, regardless of the selected
+                partition, combining all splits will result in all images.
+
+        transform (callable, optional): A function/transform that  takes in a PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again. Default is False.
+    """
+
+    _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
+    _MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1"
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        partition: int = 1,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+        if not isinstance(partition, int) and not (1 <= partition <= 10):
+            raise ValueError(
+                f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
+                f"but got {partition} instead"
+            )
+        self._partition = partition
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
+        self._data_folder = self._base_folder / "dtd"
+        self._meta_folder = self._data_folder / "labels"
+        self._images_folder = self._data_folder / "images"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self._image_files = []
+        classes = []
+        with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
+            for line in file:
+                cls, name = line.strip().split("/")
+                self._image_files.append(self._images_folder.joinpath(cls, name))
+                classes.append(cls)
+
+        self.classes = sorted(set(classes))
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+        self._labels = [self.class_to_idx[cls] for cls in classes]
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = PIL.Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}, partition={self._partition}"
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)

+ 58 - 0
libs/vision_libs/datasets/eurosat.py

@@ -0,0 +1,58 @@
+import os
+from typing import Callable, Optional
+
+from .folder import ImageFolder
+from .utils import download_and_extract_archive
+
+
+class EuroSAT(ImageFolder):
+    """RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where ``root/eurosat`` exists.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again. Default is False.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        self.root = os.path.expanduser(root)
+        self._base_folder = os.path.join(self.root, "eurosat")
+        self._data_folder = os.path.join(self._base_folder, "2750")
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
+        self.root = os.path.expanduser(root)
+
+    def __len__(self) -> int:
+        return len(self.samples)
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(self._data_folder)
+
+    def download(self) -> None:
+
+        if self._check_exists():
+            return
+
+        os.makedirs(self._base_folder, exist_ok=True)
+        download_and_extract_archive(
+            "https://madm.dfki.de/files/sentinel/EuroSAT.zip",
+            download_root=self._base_folder,
+            md5="c8fa014336c82ac7804f0398fcb19387",
+        )

+ 67 - 0
libs/vision_libs/datasets/fakedata.py

@@ -0,0 +1,67 @@
+from typing import Any, Callable, Optional, Tuple
+
+import torch
+
+from .. import transforms
+from .vision import VisionDataset
+
+
+class FakeData(VisionDataset):
+    """A fake dataset that returns randomly generated images and returns them as PIL images
+
+    Args:
+        size (int, optional): Size of the dataset. Default: 1000 images
+        image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
+        num_classes(int, optional): Number of classes in the dataset. Default: 10
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        random_offset (int): Offsets the index-based random seed used to
+            generate each image. Default: 0
+
+    """
+
+    def __init__(
+        self,
+        size: int = 1000,
+        image_size: Tuple[int, int, int] = (3, 224, 224),
+        num_classes: int = 10,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        random_offset: int = 0,
+    ) -> None:
+        super().__init__(transform=transform, target_transform=target_transform)
+        self.size = size
+        self.num_classes = num_classes
+        self.image_size = image_size
+        self.random_offset = random_offset
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is class_index of the target class.
+        """
+        # create random image that is consistent with the index id
+        if index >= len(self):
+            raise IndexError(f"{self.__class__.__name__} index out of range")
+        rng_state = torch.get_rng_state()
+        torch.manual_seed(index + self.random_offset)
+        img = torch.randn(*self.image_size)
+        target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
+        torch.set_rng_state(rng_state)
+
+        # convert to PIL Image
+        img = transforms.ToPILImage()(img)
+        if self.transform is not None:
+            img = self.transform(img)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target.item()
+
+    def __len__(self) -> int:
+        return self.size

+ 75 - 0
libs/vision_libs/datasets/fer2013.py

@@ -0,0 +1,75 @@
+import csv
+import pathlib
+from typing import Any, Callable, Optional, Tuple
+
+import torch
+from PIL import Image
+
+from .utils import check_integrity, verify_str_arg
+from .vision import VisionDataset
+
+
+class FER2013(VisionDataset):
+    """`FER2013
+    <https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``root/fer2013`` exists.
+        split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
+        transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+    """
+
+    _RESOURCES = {
+        "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
+        "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
+    }
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        self._split = verify_str_arg(split, "split", self._RESOURCES.keys())
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        base_folder = pathlib.Path(self.root) / "fer2013"
+        file_name, md5 = self._RESOURCES[self._split]
+        data_file = base_folder / file_name
+        if not check_integrity(str(data_file), md5=md5):
+            raise RuntimeError(
+                f"{file_name} not found in {base_folder} or corrupted. "
+                f"You can download it from "
+                f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
+            )
+
+        with open(data_file, "r", newline="") as file:
+            self._samples = [
+                (
+                    torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48),
+                    int(row["emotion"]) if "emotion" in row else None,
+                )
+                for row in csv.DictReader(file)
+            ]
+
+    def __len__(self) -> int:
+        return len(self._samples)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_tensor, target = self._samples[idx]
+        image = Image.fromarray(image_tensor.numpy())
+
+        if self.transform is not None:
+            image = self.transform(image)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return image, target
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"

+ 114 - 0
libs/vision_libs/datasets/fgvc_aircraft.py

@@ -0,0 +1,114 @@
+from __future__ import annotations
+
+import os
+from typing import Any, Callable, Optional, Tuple
+
+import PIL.Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class FGVCAircraft(VisionDataset):
+    """`FGVC Aircraft <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
+
+    The dataset contains 10,000 images of aircraft, with 100 images for each of 100
+    different aircraft model variants, most of which are airplanes.
+    Aircraft models are organized in a three-levels hierarchy. The three levels, from
+    finer to coarser, are:
+
+    - ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually
+        indistinguishable into one class. The dataset comprises 100 different variants.
+    - ``family``, e.g. Boeing 737. The dataset comprises 70 different families.
+    - ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers.
+
+    Args:
+        root (string): Root directory of the FGVC Aircraft dataset.
+        split (string, optional): The dataset split, supports ``train``, ``val``,
+            ``trainval`` and ``test``.
+        annotation_level (str, optional): The annotation level, supports ``variant``,
+            ``family`` and ``manufacturer``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    _URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "trainval",
+        annotation_level: str = "variant",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
+        self._annotation_level = verify_str_arg(
+            annotation_level, "annotation_level", ("variant", "family", "manufacturer")
+        )
+
+        self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        annotation_file = os.path.join(
+            self._data_path,
+            "data",
+            {
+                "variant": "variants.txt",
+                "family": "families.txt",
+                "manufacturer": "manufacturers.txt",
+            }[self._annotation_level],
+        )
+        with open(annotation_file, "r") as f:
+            self.classes = [line.strip() for line in f]
+
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+
+        image_data_folder = os.path.join(self._data_path, "data", "images")
+        labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt")
+
+        self._image_files = []
+        self._labels = []
+
+        with open(labels_file, "r") as f:
+            for line in f:
+                image_name, label_name = line.strip().split(" ", 1)
+                self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
+                self._labels.append(self.class_to_idx[label_name])
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = PIL.Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def _download(self) -> None:
+        """
+        Download the FGVC Aircraft dataset archive and extract it under root.
+        """
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, self.root)
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(self._data_path) and os.path.isdir(self._data_path)

+ 166 - 0
libs/vision_libs/datasets/flickr.py

@@ -0,0 +1,166 @@
+import glob
+import os
+from collections import defaultdict
+from html.parser import HTMLParser
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+from PIL import Image
+
+from .vision import VisionDataset
+
+
+class Flickr8kParser(HTMLParser):
+    """Parser for extracting captions from the Flickr8k dataset web page."""
+
+    def __init__(self, root: str) -> None:
+        super().__init__()
+
+        self.root = root
+
+        # Data structure to store captions
+        self.annotations: Dict[str, List[str]] = {}
+
+        # State variables
+        self.in_table = False
+        self.current_tag: Optional[str] = None
+        self.current_img: Optional[str] = None
+
+    def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
+        self.current_tag = tag
+
+        if tag == "table":
+            self.in_table = True
+
+    def handle_endtag(self, tag: str) -> None:
+        self.current_tag = None
+
+        if tag == "table":
+            self.in_table = False
+
+    def handle_data(self, data: str) -> None:
+        if self.in_table:
+            if data == "Image Not Found":
+                self.current_img = None
+            elif self.current_tag == "a":
+                img_id = data.split("/")[-2]
+                img_id = os.path.join(self.root, img_id + "_*.jpg")
+                img_id = glob.glob(img_id)[0]
+                self.current_img = img_id
+                self.annotations[img_id] = []
+            elif self.current_tag == "li" and self.current_img:
+                img_id = self.current_img
+                self.annotations[img_id].append(data.strip())
+
+
+class Flickr8k(VisionDataset):
+    """`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        ann_file (string): Path to annotation file.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        ann_file: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.ann_file = os.path.expanduser(ann_file)
+
+        # Read annotations and store in a dict
+        parser = Flickr8kParser(self.root)
+        with open(self.ann_file) as fh:
+            parser.feed(fh.read())
+        self.annotations = parser.annotations
+
+        self.ids = list(sorted(self.annotations.keys()))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is a list of captions for the image.
+        """
+        img_id = self.ids[index]
+
+        # Image
+        img = Image.open(img_id).convert("RGB")
+        if self.transform is not None:
+            img = self.transform(img)
+
+        # Captions
+        target = self.annotations[img_id]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.ids)
+
+
+class Flickr30k(VisionDataset):
+    """`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        ann_file (string): Path to annotation file.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        ann_file: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.ann_file = os.path.expanduser(ann_file)
+
+        # Read annotations and store in a dict
+        self.annotations = defaultdict(list)
+        with open(self.ann_file) as fh:
+            for line in fh:
+                img_id, caption = line.strip().split("\t")
+                self.annotations[img_id[:-2]].append(caption)
+
+        self.ids = list(sorted(self.annotations.keys()))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is a list of captions for the image.
+        """
+        img_id = self.ids[index]
+
+        # Image
+        filename = os.path.join(self.root, img_id)
+        img = Image.open(filename).convert("RGB")
+        if self.transform is not None:
+            img = self.transform(img)
+
+        # Captions
+        target = self.annotations[img_id]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.ids)

+ 114 - 0
libs/vision_libs/datasets/flowers102.py

@@ -0,0 +1,114 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple
+
+import PIL.Image
+
+from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class Flowers102(VisionDataset):
+    """`Oxford 102 Flower <https://www.robots.ox.ac.uk/~vgg/data/flowers/102/>`_ Dataset.
+
+    .. warning::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+    Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The
+    flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of
+    between 40 and 258 images.
+
+    The images have large scale, pose and light variations. In addition, there are categories that
+    have large variations within the category, and several very similar categories.
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
+        transform (callable, optional): A function/transform that takes in an PIL image and returns a
+            transformed version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
+    _file_dict = {  # filename, md5
+        "image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),
+        "label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),
+        "setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),
+    }
+    _splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+        self._base_folder = Path(self.root) / "flowers-102"
+        self._images_folder = self._base_folder / "jpg"
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        from scipy.io import loadmat
+
+        set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)
+        image_ids = set_ids[self._splits_map[self._split]].tolist()
+
+        labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True)
+        image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1))
+
+        self._labels = []
+        self._image_files = []
+        for image_id in image_ids:
+            self._labels.append(image_id_to_label[image_id])
+            self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = PIL.Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"
+
+    def _check_integrity(self):
+        if not (self._images_folder.exists() and self._images_folder.is_dir()):
+            return False
+
+        for id in ["label", "setid"]:
+            filename, md5 = self._file_dict[id]
+            if not check_integrity(str(self._base_folder / filename), md5):
+                return False
+        return True
+
+    def download(self):
+        if self._check_integrity():
+            return
+        download_and_extract_archive(
+            f"{self._download_url_prefix}{self._file_dict['image'][0]}",
+            str(self._base_folder),
+            md5=self._file_dict["image"][1],
+        )
+        for id in ["label", "setid"]:
+            filename, md5 = self._file_dict[id]
+            download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)

+ 317 - 0
libs/vision_libs/datasets/folder.py

@@ -0,0 +1,317 @@
+import os
+import os.path
+from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .vision import VisionDataset
+
+
+def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
+    """Checks if a file is an allowed extension.
+
+    Args:
+        filename (string): path to a file
+        extensions (tuple of strings): extensions to consider (lowercase)
+
+    Returns:
+        bool: True if the filename ends with one of given extensions
+    """
+    return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
+
+
+def is_image_file(filename: str) -> bool:
+    """Checks if a file is an allowed image extension.
+
+    Args:
+        filename (string): path to a file
+
+    Returns:
+        bool: True if the filename ends with a known image extension
+    """
+    return has_file_allowed_extension(filename, IMG_EXTENSIONS)
+
+
+def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
+    """Finds the class folders in a dataset.
+
+    See :class:`DatasetFolder` for details.
+    """
+    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
+    if not classes:
+        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
+
+    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
+    return classes, class_to_idx
+
+
+def make_dataset(
+    directory: str,
+    class_to_idx: Optional[Dict[str, int]] = None,
+    extensions: Optional[Union[str, Tuple[str, ...]]] = None,
+    is_valid_file: Optional[Callable[[str], bool]] = None,
+) -> List[Tuple[str, int]]:
+    """Generates a list of samples of a form (path_to_sample, class).
+
+    See :class:`DatasetFolder` for details.
+
+    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
+    by default.
+    """
+    directory = os.path.expanduser(directory)
+
+    if class_to_idx is None:
+        _, class_to_idx = find_classes(directory)
+    elif not class_to_idx:
+        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
+
+    both_none = extensions is None and is_valid_file is None
+    both_something = extensions is not None and is_valid_file is not None
+    if both_none or both_something:
+        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
+
+    if extensions is not None:
+
+        def is_valid_file(x: str) -> bool:
+            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]
+
+    is_valid_file = cast(Callable[[str], bool], is_valid_file)
+
+    instances = []
+    available_classes = set()
+    for target_class in sorted(class_to_idx.keys()):
+        class_index = class_to_idx[target_class]
+        target_dir = os.path.join(directory, target_class)
+        if not os.path.isdir(target_dir):
+            continue
+        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
+            for fname in sorted(fnames):
+                path = os.path.join(root, fname)
+                if is_valid_file(path):
+                    item = path, class_index
+                    instances.append(item)
+
+                    if target_class not in available_classes:
+                        available_classes.add(target_class)
+
+    empty_classes = set(class_to_idx.keys()) - available_classes
+    if empty_classes:
+        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
+        if extensions is not None:
+            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
+        raise FileNotFoundError(msg)
+
+    return instances
+
+
+class DatasetFolder(VisionDataset):
+    """A generic data loader.
+
+    This default directory structure can be customized by overriding the
+    :meth:`find_classes` method.
+
+    Args:
+        root (string): Root directory path.
+        loader (callable): A function to load a sample given its path.
+        extensions (tuple[string]): A list of allowed extensions.
+            both extensions and is_valid_file should not be passed.
+        transform (callable, optional): A function/transform that takes in
+            a sample and returns a transformed version.
+            E.g, ``transforms.RandomCrop`` for images.
+        target_transform (callable, optional): A function/transform that takes
+            in the target and transforms it.
+        is_valid_file (callable, optional): A function that takes path of a file
+            and check if the file is a valid file (used to check of corrupt files)
+            both extensions and is_valid_file should not be passed.
+
+     Attributes:
+        classes (list): List of the class names sorted alphabetically.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        samples (list): List of (sample path, class_index) tuples
+        targets (list): The class_index value for each image in the dataset
+    """
+
+    def __init__(
+        self,
+        root: str,
+        loader: Callable[[str], Any],
+        extensions: Optional[Tuple[str, ...]] = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        is_valid_file: Optional[Callable[[str], bool]] = None,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        classes, class_to_idx = self.find_classes(self.root)
+        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
+
+        self.loader = loader
+        self.extensions = extensions
+
+        self.classes = classes
+        self.class_to_idx = class_to_idx
+        self.samples = samples
+        self.targets = [s[1] for s in samples]
+
+    @staticmethod
+    def make_dataset(
+        directory: str,
+        class_to_idx: Dict[str, int],
+        extensions: Optional[Tuple[str, ...]] = None,
+        is_valid_file: Optional[Callable[[str], bool]] = None,
+    ) -> List[Tuple[str, int]]:
+        """Generates a list of samples of a form (path_to_sample, class).
+
+        This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
+
+        Args:
+            directory (str): root dataset directory, corresponding to ``self.root``.
+            class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
+            extensions (optional): A list of allowed extensions.
+                Either extensions or is_valid_file should be passed. Defaults to None.
+            is_valid_file (optional): A function that takes path of a file
+                and checks if the file is a valid file
+                (used to check of corrupt files) both extensions and
+                is_valid_file should not be passed. Defaults to None.
+
+        Raises:
+            ValueError: In case ``class_to_idx`` is empty.
+            ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
+            FileNotFoundError: In case no valid file was found for any class.
+
+        Returns:
+            List[Tuple[str, int]]: samples of a form (path_to_sample, class)
+        """
+        if class_to_idx is None:
+            # prevent potential bug since make_dataset() would use the class_to_idx logic of the
+            # find_classes() function, instead of using that of the find_classes() method, which
+            # is potentially overridden and thus could have a different logic.
+            raise ValueError("The class_to_idx parameter cannot be None.")
+        return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
+
+    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
+        """Find the class folders in a dataset structured as follows::
+
+            directory/
+            ├── class_x
+            │   ├── xxx.ext
+            │   ├── xxy.ext
+            │   └── ...
+            │       └── xxz.ext
+            └── class_y
+                ├── 123.ext
+                ├── nsdf3.ext
+                └── ...
+                └── asd932_.ext
+
+        This method can be overridden to only consider
+        a subset of classes, or to adapt to a different dataset directory structure.
+
+        Args:
+            directory(str): Root directory path, corresponding to ``self.root``
+
+        Raises:
+            FileNotFoundError: If ``dir`` has no class folders.
+
+        Returns:
+            (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
+        """
+        return find_classes(directory)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (sample, target) where target is class_index of the target class.
+        """
+        path, target = self.samples[index]
+        sample = self.loader(path)
+        if self.transform is not None:
+            sample = self.transform(sample)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return sample, target
+
+    def __len__(self) -> int:
+        return len(self.samples)
+
+
+IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
+
+
+def pil_loader(path: str) -> Image.Image:
+    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+    with open(path, "rb") as f:
+        img = Image.open(f)
+        return img.convert("RGB")
+
+
+# TODO: specify the return type
+def accimage_loader(path: str) -> Any:
+    import accimage
+
+    try:
+        return accimage.Image(path)
+    except OSError:
+        # Potentially a decoding problem, fall back to PIL.Image
+        return pil_loader(path)
+
+
+def default_loader(path: str) -> Any:
+    from torchvision import get_image_backend
+
+    if get_image_backend() == "accimage":
+        return accimage_loader(path)
+    else:
+        return pil_loader(path)
+
+
+class ImageFolder(DatasetFolder):
+    """A generic data loader where the images are arranged in this way by default: ::
+
+        root/dog/xxx.png
+        root/dog/xxy.png
+        root/dog/[...]/xxz.png
+
+        root/cat/123.png
+        root/cat/nsdf3.png
+        root/cat/[...]/asd932_.png
+
+    This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
+    the same methods can be overridden to customize the dataset.
+
+    Args:
+        root (string): Root directory path.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+        is_valid_file (callable, optional): A function that takes path of an Image file
+            and check if the file is a valid file (used to check of corrupt files)
+
+     Attributes:
+        classes (list): List of the class names sorted alphabetically.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        imgs (list): List of (image path, class_index) tuples
+    """
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+        is_valid_file: Optional[Callable[[str], bool]] = None,
+    ):
+        super().__init__(
+            root,
+            loader,
+            IMG_EXTENSIONS if is_valid_file is None else None,
+            transform=transform,
+            target_transform=target_transform,
+            is_valid_file=is_valid_file,
+        )
+        self.imgs = self.samples

+ 93 - 0
libs/vision_libs/datasets/food101.py

@@ -0,0 +1,93 @@
+import json
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple
+
+import PIL.Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Food101(VisionDataset):
+    """`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_.
+
+    The Food-101 is a challenging data set of 101 food categories with 101,000 images.
+    For each class, 250 manually reviewed test images are provided as well as 750 training images.
+    On purpose, the training images were not cleaned, and thus still contain some amount of noise.
+    This comes mostly in the form of intense colors and sometimes wrong labels. All images were
+    rescaled to have a maximum side length of 512 pixels.
+
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
+        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again. Default is False.
+    """
+
+    _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
+    _MD5 = "85eeb15f3717b99a5da872d97d918f87"
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._split = verify_str_arg(split, "split", ("train", "test"))
+        self._base_folder = Path(self.root) / "food-101"
+        self._meta_folder = self._base_folder / "meta"
+        self._images_folder = self._base_folder / "images"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self._labels = []
+        self._image_files = []
+        with open(self._meta_folder / f"{split}.json") as f:
+            metadata = json.loads(f.read())
+
+        self.classes = sorted(metadata.keys())
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+
+        for class_label, im_rel_paths in metadata.items():
+            self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
+            self._image_files += [
+                self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
+            ]
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = PIL.Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"
+
+    def _check_exists(self) -> bool:
+        return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

+ 103 - 0
libs/vision_libs/datasets/gtsrb.py

@@ -0,0 +1,103 @@
+import csv
+import pathlib
+from typing import Any, Callable, Optional, Tuple
+
+import PIL
+
+from .folder import make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class GTSRB(VisionDataset):
+    """`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
+        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self._split = verify_str_arg(split, "split", ("train", "test"))
+        self._base_folder = pathlib.Path(root) / "gtsrb"
+        self._target_folder = (
+            self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
+        )
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        if self._split == "train":
+            samples = make_dataset(str(self._target_folder), extensions=(".ppm",))
+        else:
+            with open(self._base_folder / "GT-final_test.csv") as csv_file:
+                samples = [
+                    (str(self._target_folder / row["Filename"]), int(row["ClassId"]))
+                    for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
+                ]
+
+        self._samples = samples
+        self.transform = transform
+        self.target_transform = target_transform
+
+    def __len__(self) -> int:
+        return len(self._samples)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+
+        path, target = self._samples[index]
+        sample = PIL.Image.open(path).convert("RGB")
+
+        if self.transform is not None:
+            sample = self.transform(sample)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return sample, target
+
+    def _check_exists(self) -> bool:
+        return self._target_folder.is_dir()
+
+    def download(self) -> None:
+        if self._check_exists():
+            return
+
+        base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
+
+        if self._split == "train":
+            download_and_extract_archive(
+                f"{base_url}GTSRB-Training_fixed.zip",
+                download_root=str(self._base_folder),
+                md5="513f3c79a4c5141765e10e952eaa2478",
+            )
+        else:
+            download_and_extract_archive(
+                f"{base_url}GTSRB_Final_Test_Images.zip",
+                download_root=str(self._base_folder),
+                md5="c7e4e6327067d32654124b0fe9e82185",
+            )
+            download_and_extract_archive(
+                f"{base_url}GTSRB_Final_Test_GT.zip",
+                download_root=str(self._base_folder),
+                md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
+            )

+ 151 - 0
libs/vision_libs/datasets/hmdb51.py

@@ -0,0 +1,151 @@
+import glob
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+from torch import Tensor
+
+from .folder import find_classes, make_dataset
+from .video_utils import VideoClips
+from .vision import VisionDataset
+
+
+class HMDB51(VisionDataset):
+    """
+    `HMDB51 <https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/>`_
+    dataset.
+
+    HMDB51 is an action recognition video dataset.
+    This dataset consider every video as a collection of video clips of fixed size, specified
+    by ``frames_per_clip``, where the step in frames between each clip is given by
+    ``step_between_clips``.
+
+    To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+    and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+    elements will come from video 1, and the next three elements from video 2.
+    Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+    frames in a video might be present.
+
+    Internally, it uses a VideoClips object to handle clip creation.
+
+    Args:
+        root (string): Root directory of the HMDB51 Dataset.
+        annotation_path (str): Path to the folder containing the split files.
+        frames_per_clip (int): Number of frames in a clip.
+        step_between_clips (int): Number of frames between each clip.
+        fold (int, optional): Which fold to use. Should be between 1 and 3.
+        train (bool, optional): If ``True``, creates a dataset from the train split,
+            otherwise from the ``test`` split.
+        transform (callable, optional): A function/transform that takes in a TxHxWxC video
+            and returns a transformed version.
+        output_format (str, optional): The format of the output video tensors (before transforms).
+            Can be either "THWC" (default) or "TCHW".
+
+    Returns:
+        tuple: A 3-tuple with the following entries:
+
+            - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
+            - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+              and `L` is the number of points
+            - label (int): class of the video clip
+    """
+
+    data_url = "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
+    splits = {
+        "url": "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
+        "md5": "15e67781e70dcfbdce2d7dbb9b3344b5",
+    }
+    TRAIN_TAG = 1
+    TEST_TAG = 2
+
+    def __init__(
+        self,
+        root: str,
+        annotation_path: str,
+        frames_per_clip: int,
+        step_between_clips: int = 1,
+        frame_rate: Optional[int] = None,
+        fold: int = 1,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        _precomputed_metadata: Optional[Dict[str, Any]] = None,
+        num_workers: int = 1,
+        _video_width: int = 0,
+        _video_height: int = 0,
+        _video_min_dimension: int = 0,
+        _audio_samples: int = 0,
+        output_format: str = "THWC",
+    ) -> None:
+        super().__init__(root)
+        if fold not in (1, 2, 3):
+            raise ValueError(f"fold should be between 1 and 3, got {fold}")
+
+        extensions = ("avi",)
+        self.classes, class_to_idx = find_classes(self.root)
+        self.samples = make_dataset(
+            self.root,
+            class_to_idx,
+            extensions,
+        )
+
+        video_paths = [path for (path, _) in self.samples]
+        video_clips = VideoClips(
+            video_paths,
+            frames_per_clip,
+            step_between_clips,
+            frame_rate,
+            _precomputed_metadata,
+            num_workers=num_workers,
+            _video_width=_video_width,
+            _video_height=_video_height,
+            _video_min_dimension=_video_min_dimension,
+            _audio_samples=_audio_samples,
+            output_format=output_format,
+        )
+        # we bookkeep the full version of video clips because we want to be able
+        # to return the metadata of full version rather than the subset version of
+        # video clips
+        self.full_video_clips = video_clips
+        self.fold = fold
+        self.train = train
+        self.indices = self._select_fold(video_paths, annotation_path, fold, train)
+        self.video_clips = video_clips.subset(self.indices)
+        self.transform = transform
+
+    @property
+    def metadata(self) -> Dict[str, Any]:
+        return self.full_video_clips.metadata
+
+    def _select_fold(self, video_list: List[str], annotations_dir: str, fold: int, train: bool) -> List[int]:
+        target_tag = self.TRAIN_TAG if train else self.TEST_TAG
+        split_pattern_name = f"*test_split{fold}.txt"
+        split_pattern_path = os.path.join(annotations_dir, split_pattern_name)
+        annotation_paths = glob.glob(split_pattern_path)
+        selected_files = set()
+        for filepath in annotation_paths:
+            with open(filepath) as fid:
+                lines = fid.readlines()
+            for line in lines:
+                video_filename, tag_string = line.split()
+                tag = int(tag_string)
+                if tag == target_tag:
+                    selected_files.add(video_filename)
+
+        indices = []
+        for video_index, video_path in enumerate(video_list):
+            if os.path.basename(video_path) in selected_files:
+                indices.append(video_index)
+
+        return indices
+
+    def __len__(self) -> int:
+        return self.video_clips.num_clips()
+
+    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
+        video, audio, _, video_idx = self.video_clips.get_clip(idx)
+        sample_index = self.indices[video_idx]
+        _, class_index = self.samples[sample_index]
+
+        if self.transform is not None:
+            video = self.transform(video)
+
+        return video, audio, class_index

+ 218 - 0
libs/vision_libs/datasets/imagenet.py

@@ -0,0 +1,218 @@
+import os
+import shutil
+import tempfile
+from contextlib import contextmanager
+from typing import Any, Dict, Iterator, List, Optional, Tuple
+
+import torch
+
+from .folder import ImageFolder
+from .utils import check_integrity, extract_archive, verify_str_arg
+
+ARCHIVE_META = {
+    "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
+    "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
+    "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
+}
+
+META_FILE = "meta.bin"
+
+
+class ImageNet(ImageFolder):
+    """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
+
+    .. note::
+        Before using this class, it is required to download ImageNet 2012 dataset from
+        `here <https://image-net.org/challenges/LSVRC/2012/2012-downloads.php>`_ and
+        place the files ``ILSVRC2012_devkit_t12.tar.gz`` and ``ILSVRC2012_img_train.tar``
+        or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory.
+
+    Args:
+        root (string): Root directory of the ImageNet Dataset.
+        split (string, optional): The dataset split, supports ``train``, or ``val``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+
+     Attributes:
+        classes (list): List of the class name tuples.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        wnids (list): List of the WordNet IDs.
+        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
+        imgs (list): List of (image path, class_index) tuples
+        targets (list): The class_index value for each image in the dataset
+    """
+
+    def __init__(self, root: str, split: str = "train", **kwargs: Any) -> None:
+        root = self.root = os.path.expanduser(root)
+        self.split = verify_str_arg(split, "split", ("train", "val"))
+
+        self.parse_archives()
+        wnid_to_classes = load_meta_file(self.root)[0]
+
+        super().__init__(self.split_folder, **kwargs)
+        self.root = root
+
+        self.wnids = self.classes
+        self.wnid_to_idx = self.class_to_idx
+        self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
+        self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss}
+
+    def parse_archives(self) -> None:
+        if not check_integrity(os.path.join(self.root, META_FILE)):
+            parse_devkit_archive(self.root)
+
+        if not os.path.isdir(self.split_folder):
+            if self.split == "train":
+                parse_train_archive(self.root)
+            elif self.split == "val":
+                parse_val_archive(self.root)
+
+    @property
+    def split_folder(self) -> str:
+        return os.path.join(self.root, self.split)
+
+    def extra_repr(self) -> str:
+        return "Split: {split}".format(**self.__dict__)
+
+
+def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]:
+    if file is None:
+        file = META_FILE
+    file = os.path.join(root, file)
+
+    if check_integrity(file):
+        return torch.load(file)
+    else:
+        msg = (
+            "The meta file {} is not present in the root directory or is corrupted. "
+            "This file is automatically created by the ImageNet dataset."
+        )
+        raise RuntimeError(msg.format(file, root))
+
+
+def _verify_archive(root: str, file: str, md5: str) -> None:
+    if not check_integrity(os.path.join(root, file), md5):
+        msg = (
+            "The archive {} is not present in the root directory or is corrupted. "
+            "You need to download it externally and place it in {}."
+        )
+        raise RuntimeError(msg.format(file, root))
+
+
+def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
+    """Parse the devkit archive of the ImageNet2012 classification dataset and save
+    the meta information in a binary file.
+
+    Args:
+        root (str): Root directory containing the devkit archive
+        file (str, optional): Name of devkit archive. Defaults to
+            'ILSVRC2012_devkit_t12.tar.gz'
+    """
+    import scipy.io as sio
+
+    def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, Tuple[str, ...]]]:
+        metafile = os.path.join(devkit_root, "data", "meta.mat")
+        meta = sio.loadmat(metafile, squeeze_me=True)["synsets"]
+        nums_children = list(zip(*meta))[4]
+        meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
+        idcs, wnids, classes = list(zip(*meta))[:3]
+        classes = [tuple(clss.split(", ")) for clss in classes]
+        idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
+        wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
+        return idx_to_wnid, wnid_to_classes
+
+    def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
+        file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
+        with open(file) as txtfh:
+            val_idcs = txtfh.readlines()
+        return [int(val_idx) for val_idx in val_idcs]
+
+    @contextmanager
+    def get_tmp_dir() -> Iterator[str]:
+        tmp_dir = tempfile.mkdtemp()
+        try:
+            yield tmp_dir
+        finally:
+            shutil.rmtree(tmp_dir)
+
+    archive_meta = ARCHIVE_META["devkit"]
+    if file is None:
+        file = archive_meta[0]
+    md5 = archive_meta[1]
+
+    _verify_archive(root, file, md5)
+
+    with get_tmp_dir() as tmp_dir:
+        extract_archive(os.path.join(root, file), tmp_dir)
+
+        devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
+        idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
+        val_idcs = parse_val_groundtruth_txt(devkit_root)
+        val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
+
+        torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
+
+
+def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "train") -> None:
+    """Parse the train images archive of the ImageNet2012 classification dataset and
+    prepare it for usage with the ImageNet dataset.
+
+    Args:
+        root (str): Root directory containing the train images archive
+        file (str, optional): Name of train images archive. Defaults to
+            'ILSVRC2012_img_train.tar'
+        folder (str, optional): Optional name for train images folder. Defaults to
+            'train'
+    """
+    archive_meta = ARCHIVE_META["train"]
+    if file is None:
+        file = archive_meta[0]
+    md5 = archive_meta[1]
+
+    _verify_archive(root, file, md5)
+
+    train_root = os.path.join(root, folder)
+    extract_archive(os.path.join(root, file), train_root)
+
+    archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
+    for archive in archives:
+        extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
+
+
+def parse_val_archive(
+    root: str, file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val"
+) -> None:
+    """Parse the validation images archive of the ImageNet2012 classification dataset
+    and prepare it for usage with the ImageNet dataset.
+
+    Args:
+        root (str): Root directory containing the validation images archive
+        file (str, optional): Name of validation images archive. Defaults to
+            'ILSVRC2012_img_val.tar'
+        wnids (list, optional): List of WordNet IDs of the validation images. If None
+            is given, the IDs are loaded from the meta file in the root directory
+        folder (str, optional): Optional name for validation images folder. Defaults to
+            'val'
+    """
+    archive_meta = ARCHIVE_META["val"]
+    if file is None:
+        file = archive_meta[0]
+    md5 = archive_meta[1]
+    if wnids is None:
+        wnids = load_meta_file(root)[1]
+
+    _verify_archive(root, file, md5)
+
+    val_root = os.path.join(root, folder)
+    extract_archive(os.path.join(root, file), val_root)
+
+    images = sorted(os.path.join(val_root, image) for image in os.listdir(val_root))
+
+    for wnid in set(wnids):
+        os.mkdir(os.path.join(val_root, wnid))
+
+    for wnid, img_file in zip(wnids, images):
+        shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))

+ 104 - 0
libs/vision_libs/datasets/imagenette.py

@@ -0,0 +1,104 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple
+
+from PIL import Image
+
+from .folder import find_classes, make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Imagenette(VisionDataset):
+    """`Imagenette <https://github.com/fastai/imagenette#imagenette-1>`_ image classification dataset.
+
+    Args:
+        root (string): Root directory of the Imagenette dataset.
+        split (string, optional): The dataset split. Supports ``"train"`` (default), and ``"val"``.
+        size (string, optional): The image size. Supports ``"full"`` (default), ``"320px"``, and ``"160px"``.
+        download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
+            downloaded archives are not downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
+            version, e.g. ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+
+     Attributes:
+        classes (list): List of the class name tuples.
+        class_to_idx (dict): Dict with items (class name, class index).
+        wnids (list): List of the WordNet IDs.
+        wnid_to_idx (dict): Dict with items (WordNet ID, class index).
+    """
+
+    _ARCHIVES = {
+        "full": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz", "fe2fc210e6bb7c5664d602c3cd71e612"),
+        "320px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz", "3df6f0d01a2c9592104656642f5e78a3"),
+        "160px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz", "e793b78cc4c9e9a4ccc0c1155377a412"),
+    }
+    _WNID_TO_CLASS = {
+        "n01440764": ("tench", "Tinca tinca"),
+        "n02102040": ("English springer", "English springer spaniel"),
+        "n02979186": ("cassette player",),
+        "n03000684": ("chain saw", "chainsaw"),
+        "n03028079": ("church", "church building"),
+        "n03394916": ("French horn", "horn"),
+        "n03417042": ("garbage truck", "dustcart"),
+        "n03425413": ("gas pump", "gasoline pump", "petrol pump", "island dispenser"),
+        "n03445777": ("golf ball",),
+        "n03888257": ("parachute", "chute"),
+    }
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        size: str = "full",
+        download=False,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self._split = verify_str_arg(split, "split", ["train", "val"])
+        self._size = verify_str_arg(size, "size", ["full", "320px", "160px"])
+
+        self._url, self._md5 = self._ARCHIVES[self._size]
+        self._size_root = Path(self.root) / Path(self._url).stem
+        self._image_root = str(self._size_root / self._split)
+
+        if download:
+            self._download()
+        elif not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it.")
+
+        self.wnids, self.wnid_to_idx = find_classes(self._image_root)
+        self.classes = [self._WNID_TO_CLASS[wnid] for wnid in self.wnids]
+        self.class_to_idx = {
+            class_name: idx for wnid, idx in self.wnid_to_idx.items() for class_name in self._WNID_TO_CLASS[wnid]
+        }
+        self._samples = make_dataset(self._image_root, self.wnid_to_idx, extensions=".jpeg")
+
+    def _check_exists(self) -> bool:
+        return self._size_root.exists()
+
+    def _download(self):
+        if self._check_exists():
+            raise RuntimeError(
+                f"The directory {self._size_root} already exists. "
+                f"If you want to re-download or re-extract the images, delete the directory."
+            )
+
+        download_and_extract_archive(self._url, self.root, md5=self._md5)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        path, label = self._samples[idx]
+        image = Image.open(path).convert("RGB")
+
+        if self.transform is not None:
+            image = self.transform(image)
+
+        if self.target_transform is not None:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def __len__(self) -> int:
+        return len(self._samples)

+ 241 - 0
libs/vision_libs/datasets/inaturalist.py

@@ -0,0 +1,241 @@
+import os
+import os.path
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
+
+DATASET_URLS = {
+    "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
+    "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
+    "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
+    "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
+    "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
+    "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
+}
+
+DATASET_MD5 = {
+    "2017": "7c784ea5e424efaec655bd392f87301f",
+    "2018": "b1c6952ce38f31868cc50ea72d066cc3",
+    "2019": "c60a6e2962c9b8ccbd458d12c8582644",
+    "2021_train": "e0526d53c7f7b2e3167b2b43bb2690ed",
+    "2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
+    "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
+}
+
+
+class INaturalist(VisionDataset):
+    """`iNaturalist <https://github.com/visipedia/inat_comp>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where the image files are stored.
+            This class does not require/use annotation files.
+        version (string, optional): Which version of the dataset to download/use. One of
+            '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
+            Default: `2021_train`.
+        target_type (string or list, optional): Type of target to use, for 2021 versions, one of:
+
+            - ``full``: the full category (species)
+            - ``kingdom``: e.g. "Animalia"
+            - ``phylum``: e.g. "Arthropoda"
+            - ``class``: e.g. "Insecta"
+            - ``order``: e.g. "Coleoptera"
+            - ``family``: e.g. "Cleridae"
+            - ``genus``: e.g. "Trichodes"
+
+            for 2017-2019 versions, one of:
+
+            - ``full``: the full (numeric) category
+            - ``super``: the super category, e.g. "Amphibians"
+
+            Can also be a list to output a tuple with all specified target types.
+            Defaults to ``full``.
+        transform (callable, optional): A function/transform that takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        version: str = "2021_train",
+        target_type: Union[List[str], str] = "full",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
+
+        super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)
+
+        os.makedirs(root, exist_ok=True)
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.all_categories: List[str] = []
+
+        # map: category type -> name of category -> index
+        self.categories_index: Dict[str, Dict[str, int]] = {}
+
+        # list indexed by category id, containing mapping from category type -> index
+        self.categories_map: List[Dict[str, int]] = []
+
+        if not isinstance(target_type, list):
+            target_type = [target_type]
+        if self.version[:4] == "2021":
+            self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
+            self._init_2021()
+        else:
+            self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
+            self._init_pre2021()
+
+        # index of all files: (full category id, filename)
+        self.index: List[Tuple[int, str]] = []
+
+        for dir_index, dir_name in enumerate(self.all_categories):
+            files = os.listdir(os.path.join(self.root, dir_name))
+            for fname in files:
+                self.index.append((dir_index, fname))
+
+    def _init_2021(self) -> None:
+        """Initialize based on 2021 layout"""
+
+        self.all_categories = sorted(os.listdir(self.root))
+
+        # map: category type -> name of category -> index
+        self.categories_index = {k: {} for k in CATEGORIES_2021}
+
+        for dir_index, dir_name in enumerate(self.all_categories):
+            pieces = dir_name.split("_")
+            if len(pieces) != 8:
+                raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
+            if pieces[0] != f"{dir_index:05d}":
+                raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
+            cat_map = {}
+            for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
+                if name in self.categories_index[cat]:
+                    cat_id = self.categories_index[cat][name]
+                else:
+                    cat_id = len(self.categories_index[cat])
+                    self.categories_index[cat][name] = cat_id
+                cat_map[cat] = cat_id
+            self.categories_map.append(cat_map)
+
+    def _init_pre2021(self) -> None:
+        """Initialize based on 2017-2019 layout"""
+
+        # map: category type -> name of category -> index
+        self.categories_index = {"super": {}}
+
+        cat_index = 0
+        super_categories = sorted(os.listdir(self.root))
+        for sindex, scat in enumerate(super_categories):
+            self.categories_index["super"][scat] = sindex
+            subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
+            for subcat in subcategories:
+                if self.version == "2017":
+                    # this version does not use ids as directory names
+                    subcat_i = cat_index
+                    cat_index += 1
+                else:
+                    try:
+                        subcat_i = int(subcat)
+                    except ValueError:
+                        raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
+                if subcat_i >= len(self.categories_map):
+                    old_len = len(self.categories_map)
+                    self.categories_map.extend([{}] * (subcat_i - old_len + 1))
+                    self.all_categories.extend([""] * (subcat_i - old_len + 1))
+                if self.categories_map[subcat_i]:
+                    raise RuntimeError(f"Duplicate category {subcat}")
+                self.categories_map[subcat_i] = {"super": sindex}
+                self.all_categories[subcat_i] = os.path.join(scat, subcat)
+
+        # validate the dictionary
+        for cindex, c in enumerate(self.categories_map):
+            if not c:
+                raise RuntimeError(f"Missing category {cindex}")
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where the type of target specified by target_type.
+        """
+
+        cat_id, fname = self.index[index]
+        img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
+
+        target: Any = []
+        for t in self.target_type:
+            if t == "full":
+                target.append(cat_id)
+            else:
+                target.append(self.categories_map[cat_id][t])
+        target = tuple(target) if len(target) > 1 else target[0]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.index)
+
+    def category_name(self, category_type: str, category_id: int) -> str:
+        """
+        Args:
+            category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
+            category_id(int): an index (class id) from this category
+
+        Returns:
+            the name of the category
+        """
+        if category_type == "full":
+            return self.all_categories[category_id]
+        else:
+            if category_type not in self.categories_index:
+                raise ValueError(f"Invalid category type '{category_type}'")
+            else:
+                for name, id in self.categories_index[category_type].items():
+                    if id == category_id:
+                        return name
+                raise ValueError(f"Invalid category id {category_id} for {category_type}")
+
+    def _check_integrity(self) -> bool:
+        return os.path.exists(self.root) and len(os.listdir(self.root)) > 0
+
+    def download(self) -> None:
+        if self._check_integrity():
+            raise RuntimeError(
+                f"The directory {self.root} already exists. "
+                f"If you want to re-download or re-extract the images, delete the directory."
+            )
+
+        base_root = os.path.dirname(self.root)
+
+        download_and_extract_archive(
+            DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
+        )
+
+        orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
+        if not os.path.exists(orig_dir_name):
+            raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
+        os.rename(orig_dir_name, self.root)
+        print(f"Dataset version '{self.version}' has been downloaded and prepared for use")

+ 247 - 0
libs/vision_libs/datasets/kinetics.py

@@ -0,0 +1,247 @@
+import csv
+import os
+import time
+import urllib
+from functools import partial
+from multiprocessing import Pool
+from os import path
+from typing import Any, Callable, Dict, Optional, Tuple
+
+from torch import Tensor
+
+from .folder import find_classes, make_dataset
+from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
+from .video_utils import VideoClips
+from .vision import VisionDataset
+
+
+def _dl_wrap(tarpath: str, videopath: str, line: str) -> None:
+    download_and_extract_archive(line, tarpath, videopath)
+
+
+class Kinetics(VisionDataset):
+    """`Generic Kinetics <https://www.deepmind.com/open-source/kinetics>`_
+    dataset.
+
+    Kinetics-400/600/700 are action recognition video datasets.
+    This dataset consider every video as a collection of video clips of fixed size, specified
+    by ``frames_per_clip``, where the step in frames between each clip is given by
+    ``step_between_clips``.
+
+    To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+    and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+    elements will come from video 1, and the next three elements from video 2.
+    Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+    frames in a video might be present.
+
+    Args:
+        root (string): Root directory of the Kinetics Dataset.
+            Directory should be structured as follows:
+            .. code::
+
+                root/
+                ├── split
+                │   ├──  class1
+                │   │   ├──  vid1.mp4
+                │   │   ├──  vid2.mp4
+                │   │   ├──  vid3.mp4
+                │   │   ├──  ...
+                │   ├──  class2
+                │   │   ├──   vidx.mp4
+                │   │    └── ...
+
+            Note: split is appended automatically using the split argument.
+        frames_per_clip (int): number of frames in a clip
+        num_classes (int): select between Kinetics-400 (default), Kinetics-600, and Kinetics-700
+        split (str): split of the dataset to consider; supports ``"train"`` (default) ``"val"`` ``"test"``
+        frame_rate (float): If omitted, interpolate different frame rate for each clip.
+        step_between_clips (int): number of frames between each clip
+        transform (callable, optional): A function/transform that  takes in a TxHxWxC video
+            and returns a transformed version.
+        download (bool): Download the official version of the dataset to root folder.
+        num_workers (int): Use multiple workers for VideoClips creation
+        num_download_workers (int): Use multiprocessing in order to speed up download.
+        output_format (str, optional): The format of the output video tensors (before transforms).
+            Can be either "THWC" or "TCHW" (default).
+            Note that in most other utils and datasets, the default is actually "THWC".
+
+    Returns:
+        tuple: A 3-tuple with the following entries:
+
+            - video (Tensor[T, C, H, W] or Tensor[T, H, W, C]): the `T` video frames in torch.uint8 tensor
+            - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+              and `L` is the number of points in torch.float tensor
+            - label (int): class of the video clip
+
+    Raises:
+        RuntimeError: If ``download is True`` and the video archives are already extracted.
+    """
+
+    _TAR_URLS = {
+        "400": "https://s3.amazonaws.com/kinetics/400/{split}/k400_{split}_path.txt",
+        "600": "https://s3.amazonaws.com/kinetics/600/{split}/k600_{split}_path.txt",
+        "700": "https://s3.amazonaws.com/kinetics/700_2020/{split}/k700_2020_{split}_path.txt",
+    }
+    _ANNOTATION_URLS = {
+        "400": "https://s3.amazonaws.com/kinetics/400/annotations/{split}.csv",
+        "600": "https://s3.amazonaws.com/kinetics/600/annotations/{split}.csv",
+        "700": "https://s3.amazonaws.com/kinetics/700_2020/annotations/{split}.csv",
+    }
+
+    def __init__(
+        self,
+        root: str,
+        frames_per_clip: int,
+        num_classes: str = "400",
+        split: str = "train",
+        frame_rate: Optional[int] = None,
+        step_between_clips: int = 1,
+        transform: Optional[Callable] = None,
+        extensions: Tuple[str, ...] = ("avi", "mp4"),
+        download: bool = False,
+        num_download_workers: int = 1,
+        num_workers: int = 1,
+        _precomputed_metadata: Optional[Dict[str, Any]] = None,
+        _video_width: int = 0,
+        _video_height: int = 0,
+        _video_min_dimension: int = 0,
+        _audio_samples: int = 0,
+        _audio_channels: int = 0,
+        _legacy: bool = False,
+        output_format: str = "TCHW",
+    ) -> None:
+
+        # TODO: support test
+        self.num_classes = verify_str_arg(num_classes, arg="num_classes", valid_values=["400", "600", "700"])
+        self.extensions = extensions
+        self.num_download_workers = num_download_workers
+
+        self.root = root
+        self._legacy = _legacy
+
+        if _legacy:
+            print("Using legacy structure")
+            self.split_folder = root
+            self.split = "unknown"
+            output_format = "THWC"
+            if download:
+                raise ValueError("Cannot download the videos using legacy_structure.")
+        else:
+            self.split_folder = path.join(root, split)
+            self.split = verify_str_arg(split, arg="split", valid_values=["train", "val", "test"])
+
+        if download:
+            self.download_and_process_videos()
+
+        super().__init__(self.root)
+
+        self.classes, class_to_idx = find_classes(self.split_folder)
+        self.samples = make_dataset(self.split_folder, class_to_idx, extensions, is_valid_file=None)
+        video_list = [x[0] for x in self.samples]
+        self.video_clips = VideoClips(
+            video_list,
+            frames_per_clip,
+            step_between_clips,
+            frame_rate,
+            _precomputed_metadata,
+            num_workers=num_workers,
+            _video_width=_video_width,
+            _video_height=_video_height,
+            _video_min_dimension=_video_min_dimension,
+            _audio_samples=_audio_samples,
+            _audio_channels=_audio_channels,
+            output_format=output_format,
+        )
+        self.transform = transform
+
+    def download_and_process_videos(self) -> None:
+        """Downloads all the videos to the _root_ folder in the expected format."""
+        tic = time.time()
+        self._download_videos()
+        toc = time.time()
+        print("Elapsed time for downloading in mins ", (toc - tic) / 60)
+        self._make_ds_structure()
+        toc2 = time.time()
+        print("Elapsed time for processing in mins ", (toc2 - toc) / 60)
+        print("Elapsed time overall in mins ", (toc2 - tic) / 60)
+
+    def _download_videos(self) -> None:
+        """download tarballs containing the video to "tars" folder and extract them into the _split_ folder where
+        split is one of the official dataset splits.
+
+        Raises:
+            RuntimeError: if download folder exists, break to prevent downloading entire dataset again.
+        """
+        if path.exists(self.split_folder):
+            raise RuntimeError(
+                f"The directory {self.split_folder} already exists. "
+                f"If you want to re-download or re-extract the images, delete the directory."
+            )
+        tar_path = path.join(self.root, "tars")
+        file_list_path = path.join(self.root, "files")
+
+        split_url = self._TAR_URLS[self.num_classes].format(split=self.split)
+        split_url_filepath = path.join(file_list_path, path.basename(split_url))
+        if not check_integrity(split_url_filepath):
+            download_url(split_url, file_list_path)
+        with open(split_url_filepath) as file:
+            list_video_urls = [urllib.parse.quote(line, safe="/,:") for line in file.read().splitlines()]
+
+        if self.num_download_workers == 1:
+            for line in list_video_urls:
+                download_and_extract_archive(line, tar_path, self.split_folder)
+        else:
+            part = partial(_dl_wrap, tar_path, self.split_folder)
+            poolproc = Pool(self.num_download_workers)
+            poolproc.map(part, list_video_urls)
+
+    def _make_ds_structure(self) -> None:
+        """move videos from
+        split_folder/
+            ├── clip1.avi
+            ├── clip2.avi
+
+        to the correct format as described below:
+        split_folder/
+            ├── class1
+            │   ├── clip1.avi
+
+        """
+        annotation_path = path.join(self.root, "annotations")
+        if not check_integrity(path.join(annotation_path, f"{self.split}.csv")):
+            download_url(self._ANNOTATION_URLS[self.num_classes].format(split=self.split), annotation_path)
+        annotations = path.join(annotation_path, f"{self.split}.csv")
+
+        file_fmtstr = "{ytid}_{start:06}_{end:06}.mp4"
+        with open(annotations) as csvfile:
+            reader = csv.DictReader(csvfile)
+            for row in reader:
+                f = file_fmtstr.format(
+                    ytid=row["youtube_id"],
+                    start=int(row["time_start"]),
+                    end=int(row["time_end"]),
+                )
+                label = row["label"].replace(" ", "_").replace("'", "").replace("(", "").replace(")", "")
+                os.makedirs(path.join(self.split_folder, label), exist_ok=True)
+                downloaded_file = path.join(self.split_folder, f)
+                if path.isfile(downloaded_file):
+                    os.replace(
+                        downloaded_file,
+                        path.join(self.split_folder, label, f),
+                    )
+
+    @property
+    def metadata(self) -> Dict[str, Any]:
+        return self.video_clips.metadata
+
+    def __len__(self) -> int:
+        return self.video_clips.num_clips()
+
+    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
+        video, audio, info, video_idx = self.video_clips.get_clip(idx)
+        label = self.samples[video_idx][1]
+
+        if self.transform is not None:
+            video = self.transform(video)
+
+        return video, audio, label

+ 157 - 0
libs/vision_libs/datasets/kitti.py

@@ -0,0 +1,157 @@
+import csv
+import os
+from typing import Any, Callable, List, Optional, Tuple
+
+from PIL import Image
+
+from .utils import download_and_extract_archive
+from .vision import VisionDataset
+
+
+class Kitti(VisionDataset):
+    """`KITTI <http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark>`_ Dataset.
+
+    It corresponds to the "left color images of object" dataset, for object detection.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+            Expects the following folder structure if download=False:
+
+            .. code::
+
+                <root>
+                    └── Kitti
+                        └─ raw
+                            ├── training
+                            |   ├── image_2
+                            |   └── label_2
+                            └── testing
+                                └── image_2
+        train (bool, optional): Use ``train`` split if true, else ``test`` split.
+            Defaults to ``train``.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample
+            and its target as entry and returns a transformed version.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/"
+    resources = [
+        "data_object_image_2.zip",
+        "data_object_label_2.zip",
+    ]
+    image_dir_name = "image_2"
+    labels_dir_name = "label_2"
+
+    def __init__(
+        self,
+        root: str,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        transforms: Optional[Callable] = None,
+        download: bool = False,
+    ):
+        super().__init__(
+            root,
+            transform=transform,
+            target_transform=target_transform,
+            transforms=transforms,
+        )
+        self.images = []
+        self.targets = []
+        self.train = train
+        self._location = "training" if self.train else "testing"
+
+        if download:
+            self.download()
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You may use download=True to download it.")
+
+        image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
+        if self.train:
+            labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name)
+        for img_file in os.listdir(image_dir):
+            self.images.append(os.path.join(image_dir, img_file))
+            if self.train:
+                self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt"))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """Get item at a given index.
+
+        Args:
+            index (int): Index
+        Returns:
+            tuple: (image, target), where
+            target is a list of dictionaries with the following keys:
+
+            - type: str
+            - truncated: float
+            - occluded: int
+            - alpha: float
+            - bbox: float[4]
+            - dimensions: float[3]
+            - locations: float[3]
+            - rotation_y: float
+
+        """
+        image = Image.open(self.images[index])
+        target = self._parse_target(index) if self.train else None
+        if self.transforms:
+            image, target = self.transforms(image, target)
+        return image, target
+
+    def _parse_target(self, index: int) -> List:
+        target = []
+        with open(self.targets[index]) as inp:
+            content = csv.reader(inp, delimiter=" ")
+            for line in content:
+                target.append(
+                    {
+                        "type": line[0],
+                        "truncated": float(line[1]),
+                        "occluded": int(line[2]),
+                        "alpha": float(line[3]),
+                        "bbox": [float(x) for x in line[4:8]],
+                        "dimensions": [float(x) for x in line[8:11]],
+                        "location": [float(x) for x in line[11:14]],
+                        "rotation_y": float(line[14]),
+                    }
+                )
+        return target
+
+    def __len__(self) -> int:
+        return len(self.images)
+
+    @property
+    def _raw_folder(self) -> str:
+        return os.path.join(self.root, self.__class__.__name__, "raw")
+
+    def _check_exists(self) -> bool:
+        """Check if the data directory exists."""
+        folders = [self.image_dir_name]
+        if self.train:
+            folders.append(self.labels_dir_name)
+        return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders)
+
+    def download(self) -> None:
+        """Download the KITTI data if it doesn't exist already."""
+
+        if self._check_exists():
+            return
+
+        os.makedirs(self._raw_folder, exist_ok=True)
+
+        # download files
+        for fname in self.resources:
+            download_and_extract_archive(
+                url=f"{self.data_url}{fname}",
+                download_root=self._raw_folder,
+                filename=fname,
+            )

+ 255 - 0
libs/vision_libs/datasets/lfw.py

@@ -0,0 +1,255 @@
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class _LFW(VisionDataset):
+
+    base_folder = "lfw-py"
+    download_url_prefix = "http://vis-www.cs.umass.edu/lfw/"
+
+    file_dict = {
+        "original": ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"),
+        "funneled": ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"),
+        "deepfunneled": ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201"),
+    }
+    checksums = {
+        "pairs.txt": "9f1ba174e4e1c508ff7cdf10ac338a7d",
+        "pairsDevTest.txt": "5132f7440eb68cf58910c8a45a2ac10b",
+        "pairsDevTrain.txt": "4f27cbf15b2da4a85c1907eb4181ad21",
+        "people.txt": "450f0863dd89e85e73936a6d71a3474b",
+        "peopleDevTest.txt": "e4bf5be0a43b5dcd9dc5ccfcb8fb19c5",
+        "peopleDevTrain.txt": "54eaac34beb6d042ed3a7d883e247a21",
+        "lfw-names.txt": "a6d0a479bd074669f656265a6e693f6d",
+    }
+    annot_file = {"10fold": "", "train": "DevTrain", "test": "DevTest"}
+    names = "lfw-names.txt"
+
+    def __init__(
+        self,
+        root: str,
+        split: str,
+        image_set: str,
+        view: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform)
+
+        self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
+        images_dir, self.filename, self.md5 = self.file_dict[self.image_set]
+
+        self.view = verify_str_arg(view.lower(), "view", ["people", "pairs"])
+        self.split = verify_str_arg(split.lower(), "split", ["10fold", "train", "test"])
+        self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt"
+        self.data: List[Any] = []
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.images_dir = os.path.join(self.root, images_dir)
+
+    def _loader(self, path: str) -> Image.Image:
+        with open(path, "rb") as f:
+            img = Image.open(f)
+            return img.convert("RGB")
+
+    def _check_integrity(self) -> bool:
+        st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
+        st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file])
+        if not st1 or not st2:
+            return False
+        if self.view == "people":
+            return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names])
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+        url = f"{self.download_url_prefix}{self.filename}"
+        download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5)
+        download_url(f"{self.download_url_prefix}{self.labels_file}", self.root)
+        if self.view == "people":
+            download_url(f"{self.download_url_prefix}{self.names}", self.root)
+
+    def _get_path(self, identity: str, no: Union[int, str]) -> str:
+        return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg")
+
+    def extra_repr(self) -> str:
+        return f"Alignment: {self.image_set}\nSplit: {self.split}"
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+
+class LFWPeople(_LFW):
+    """`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``lfw-py`` exists or will be saved to if download is set to True.
+        split (string, optional): The image split to use. Can be one of ``train``, ``test``,
+            ``10fold`` (default).
+        image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
+            ``deepfunneled``. Defaults to ``funneled``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomRotation``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "10fold",
+        image_set: str = "funneled",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, split, image_set, "people", transform, target_transform, download)
+
+        self.class_to_idx = self._get_classes()
+        self.data, self.targets = self._get_people()
+
+    def _get_people(self) -> Tuple[List[str], List[int]]:
+        data, targets = [], []
+        with open(os.path.join(self.root, self.labels_file)) as f:
+            lines = f.readlines()
+            n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0)
+
+            for fold in range(n_folds):
+                n_lines = int(lines[s])
+                people = [line.strip().split("\t") for line in lines[s + 1 : s + n_lines + 1]]
+                s += n_lines + 1
+                for i, (identity, num_imgs) in enumerate(people):
+                    for num in range(1, int(num_imgs) + 1):
+                        img = self._get_path(identity, num)
+                        data.append(img)
+                        targets.append(self.class_to_idx[identity])
+
+        return data, targets
+
+    def _get_classes(self) -> Dict[str, int]:
+        with open(os.path.join(self.root, self.names)) as f:
+            lines = f.readlines()
+            names = [line.strip().split()[0] for line in lines]
+        class_to_idx = {name: i for i, name in enumerate(names)}
+        return class_to_idx
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target) where target is the identity of the person.
+        """
+        img = self._loader(self.data[index])
+        target = self.targets[index]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def extra_repr(self) -> str:
+        return super().extra_repr() + f"\nClasses (identities): {len(self.class_to_idx)}"
+
+
+class LFWPairs(_LFW):
+    """`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``lfw-py`` exists or will be saved to if download is set to True.
+        split (string, optional): The image split to use. Can be one of ``train``, ``test``,
+            ``10fold``. Defaults to ``10fold``.
+        image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
+            ``deepfunneled``. Defaults to ``funneled``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomRotation``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "10fold",
+        image_set: str = "funneled",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, split, image_set, "pairs", transform, target_transform, download)
+
+        self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
+
+    def _get_pairs(self, images_dir: str) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[int]]:
+        pair_names, data, targets = [], [], []
+        with open(os.path.join(self.root, self.labels_file)) as f:
+            lines = f.readlines()
+            if self.split == "10fold":
+                n_folds, n_pairs = lines[0].split("\t")
+                n_folds, n_pairs = int(n_folds), int(n_pairs)
+            else:
+                n_folds, n_pairs = 1, int(lines[0])
+            s = 1
+
+            for fold in range(n_folds):
+                matched_pairs = [line.strip().split("\t") for line in lines[s : s + n_pairs]]
+                unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs : s + (2 * n_pairs)]]
+                s += 2 * n_pairs
+                for pair in matched_pairs:
+                    img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1
+                    pair_names.append((pair[0], pair[0]))
+                    data.append((img1, img2))
+                    targets.append(same)
+                for pair in unmatched_pairs:
+                    img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[2], pair[3]), 0
+                    pair_names.append((pair[0], pair[2]))
+                    data.append((img1, img2))
+                    targets.append(same)
+
+        return pair_names, data, targets
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any, int]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image1, image2, target) where target is `0` for different indentities and `1` for same identities.
+        """
+        img1, img2 = self.data[index]
+        img1, img2 = self._loader(img1), self._loader(img2)
+        target = self.targets[index]
+
+        if self.transform is not None:
+            img1, img2 = self.transform(img1), self.transform(img2)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img1, img2, target

+ 167 - 0
libs/vision_libs/datasets/lsun.py

@@ -0,0 +1,167 @@
+import io
+import os.path
+import pickle
+import string
+from collections.abc import Iterable
+from typing import Any, Callable, cast, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import iterable_to_str, verify_str_arg
+from .vision import VisionDataset
+
+
+class LSUNClass(VisionDataset):
+    def __init__(
+        self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
+    ) -> None:
+        import lmdb
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
+        with self.env.begin(write=False) as txn:
+            self.length = txn.stat()["entries"]
+        cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters)
+        if os.path.isfile(cache_file):
+            self.keys = pickle.load(open(cache_file, "rb"))
+        else:
+            with self.env.begin(write=False) as txn:
+                self.keys = [key for key in txn.cursor().iternext(keys=True, values=False)]
+            pickle.dump(self.keys, open(cache_file, "wb"))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        img, target = None, None
+        env = self.env
+        with env.begin(write=False) as txn:
+            imgbuf = txn.get(self.keys[index])
+
+        buf = io.BytesIO()
+        buf.write(imgbuf)
+        buf.seek(0)
+        img = Image.open(buf).convert("RGB")
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return self.length
+
+
+class LSUN(VisionDataset):
+    """`LSUN <https://www.yf.io/p/lsun>`_ dataset.
+
+    You will need to install the ``lmdb`` package to use this dataset: run
+    ``pip install lmdb``
+
+    Args:
+        root (string): Root directory for the database files.
+        classes (string or list): One of {'train', 'val', 'test'} or a list of
+            categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        classes: Union[str, List[str]] = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.classes = self._verify_classes(classes)
+
+        # for each class, create an LSUNClassDataset
+        self.dbs = []
+        for c in self.classes:
+            self.dbs.append(LSUNClass(root=os.path.join(root, f"{c}_lmdb"), transform=transform))
+
+        self.indices = []
+        count = 0
+        for db in self.dbs:
+            count += len(db)
+            self.indices.append(count)
+
+        self.length = count
+
+    def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]:
+        categories = [
+            "bedroom",
+            "bridge",
+            "church_outdoor",
+            "classroom",
+            "conference_room",
+            "dining_room",
+            "kitchen",
+            "living_room",
+            "restaurant",
+            "tower",
+        ]
+        dset_opts = ["train", "val", "test"]
+
+        try:
+            classes = cast(str, classes)
+            verify_str_arg(classes, "classes", dset_opts)
+            if classes == "test":
+                classes = [classes]
+            else:
+                classes = [c + "_" + classes for c in categories]
+        except ValueError:
+            if not isinstance(classes, Iterable):
+                msg = "Expected type str or Iterable for argument classes, but got type {}."
+                raise ValueError(msg.format(type(classes)))
+
+            classes = list(classes)
+            msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}."
+            for c in classes:
+                verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
+                c_short = c.split("_")
+                category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
+
+                msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
+                msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
+                verify_str_arg(category, valid_values=categories, custom_msg=msg)
+
+                msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
+                verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
+
+        return classes
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target) where target is the index of the target category.
+        """
+        target = 0
+        sub = 0
+        for ind in self.indices:
+            if index < ind:
+                break
+            target += 1
+            sub = ind
+
+        db = self.dbs[target]
+        index = index - sub
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        img, _ = db[index]
+        return img, target
+
+    def __len__(self) -> int:
+        return self.length
+
+    def extra_repr(self) -> str:
+        return "Classes: {classes}".format(**self.__dict__)

+ 558 - 0
libs/vision_libs/datasets/mnist.py

@@ -0,0 +1,558 @@
+import codecs
+import os
+import os.path
+import shutil
+import string
+import sys
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple
+from urllib.error import URLError
+
+import numpy as np
+import torch
+from PIL import Image
+
+from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class MNIST(VisionDataset):
+    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
+            and  ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
+        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
+            otherwise from ``t10k-images-idx3-ubyte``.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    mirrors = [
+        "http://yann.lecun.com/exdb/mnist/",
+        "https://ossci-datasets.s3.amazonaws.com/mnist/",
+    ]
+
+    resources = [
+        ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
+        ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
+        ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
+        ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
+    ]
+
+    training_file = "training.pt"
+    test_file = "test.pt"
+    classes = [
+        "0 - zero",
+        "1 - one",
+        "2 - two",
+        "3 - three",
+        "4 - four",
+        "5 - five",
+        "6 - six",
+        "7 - seven",
+        "8 - eight",
+        "9 - nine",
+    ]
+
+    @property
+    def train_labels(self):
+        warnings.warn("train_labels has been renamed targets")
+        return self.targets
+
+    @property
+    def test_labels(self):
+        warnings.warn("test_labels has been renamed targets")
+        return self.targets
+
+    @property
+    def train_data(self):
+        warnings.warn("train_data has been renamed data")
+        return self.data
+
+    @property
+    def test_data(self):
+        warnings.warn("test_data has been renamed data")
+        return self.data
+
+    def __init__(
+        self,
+        root: str,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.train = train  # training set or test set
+
+        if self._check_legacy_exist():
+            self.data, self.targets = self._load_legacy_data()
+            return
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self.data, self.targets = self._load_data()
+
+    def _check_legacy_exist(self):
+        processed_folder_exists = os.path.exists(self.processed_folder)
+        if not processed_folder_exists:
+            return False
+
+        return all(
+            check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
+        )
+
+    def _load_legacy_data(self):
+        # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
+        # directly.
+        data_file = self.training_file if self.train else self.test_file
+        return torch.load(os.path.join(self.processed_folder, data_file))
+
+    def _load_data(self):
+        image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
+        data = read_image_file(os.path.join(self.raw_folder, image_file))
+
+        label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
+        targets = read_label_file(os.path.join(self.raw_folder, label_file))
+
+        return data, targets
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], int(self.targets[index])
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(img.numpy(), mode="L")
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    @property
+    def raw_folder(self) -> str:
+        return os.path.join(self.root, self.__class__.__name__, "raw")
+
+    @property
+    def processed_folder(self) -> str:
+        return os.path.join(self.root, self.__class__.__name__, "processed")
+
+    @property
+    def class_to_idx(self) -> Dict[str, int]:
+        return {_class: i for i, _class in enumerate(self.classes)}
+
+    def _check_exists(self) -> bool:
+        return all(
+            check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
+            for url, _ in self.resources
+        )
+
+    def download(self) -> None:
+        """Download the MNIST data if it doesn't exist already."""
+
+        if self._check_exists():
+            return
+
+        os.makedirs(self.raw_folder, exist_ok=True)
+
+        # download files
+        for filename, md5 in self.resources:
+            for mirror in self.mirrors:
+                url = f"{mirror}{filename}"
+                try:
+                    print(f"Downloading {url}")
+                    download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
+                except URLError as error:
+                    print(f"Failed to download (trying next):\n{error}")
+                    continue
+                finally:
+                    print()
+                break
+            else:
+                raise RuntimeError(f"Error downloading {filename}")
+
+    def extra_repr(self) -> str:
+        split = "Train" if self.train is True else "Test"
+        return f"Split: {split}"
+
+
+class FashionMNIST(MNIST):
+    """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
+            and  ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
+        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
+            otherwise from ``t10k-images-idx3-ubyte``.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
+
+    resources = [
+        ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
+        ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
+        ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
+        ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
+    ]
+    classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
+
+
+class KMNIST(MNIST):
+    """`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
+            and  ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
+        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
+            otherwise from ``t10k-images-idx3-ubyte``.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
+
+    resources = [
+        ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
+        ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
+        ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
+        ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
+    ]
+    classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
+
+
+class EMNIST(MNIST):
+    """`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
+            and  ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
+        split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
+            ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
+            which one to use.
+        train (bool, optional): If True, creates dataset from ``training.pt``,
+            otherwise from ``test.pt``.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"
+    md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
+    splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
+    # Merged Classes assumes Same structure for both uppercase and lowercase version
+    _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
+    _all_classes = set(string.digits + string.ascii_letters)
+    classes_split_dict = {
+        "byclass": sorted(list(_all_classes)),
+        "bymerge": sorted(list(_all_classes - _merged_classes)),
+        "balanced": sorted(list(_all_classes - _merged_classes)),
+        "letters": ["N/A"] + list(string.ascii_lowercase),
+        "digits": list(string.digits),
+        "mnist": list(string.digits),
+    }
+
+    def __init__(self, root: str, split: str, **kwargs: Any) -> None:
+        self.split = verify_str_arg(split, "split", self.splits)
+        self.training_file = self._training_file(split)
+        self.test_file = self._test_file(split)
+        super().__init__(root, **kwargs)
+        self.classes = self.classes_split_dict[self.split]
+
+    @staticmethod
+    def _training_file(split) -> str:
+        return f"training_{split}.pt"
+
+    @staticmethod
+    def _test_file(split) -> str:
+        return f"test_{split}.pt"
+
+    @property
+    def _file_prefix(self) -> str:
+        return f"emnist-{self.split}-{'train' if self.train else 'test'}"
+
+    @property
+    def images_file(self) -> str:
+        return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")
+
+    @property
+    def labels_file(self) -> str:
+        return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")
+
+    def _load_data(self):
+        return read_image_file(self.images_file), read_label_file(self.labels_file)
+
+    def _check_exists(self) -> bool:
+        return all(check_integrity(file) for file in (self.images_file, self.labels_file))
+
+    def download(self) -> None:
+        """Download the EMNIST data if it doesn't exist already."""
+
+        if self._check_exists():
+            return
+
+        os.makedirs(self.raw_folder, exist_ok=True)
+
+        download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
+        gzip_folder = os.path.join(self.raw_folder, "gzip")
+        for gzip_file in os.listdir(gzip_folder):
+            if gzip_file.endswith(".gz"):
+                extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
+        shutil.rmtree(gzip_folder)
+
+
+class QMNIST(MNIST):
+    """`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset whose ``raw``
+            subdir contains binary files of the datasets.
+        what (string,optional): Can be 'train', 'test', 'test10k',
+            'test50k', or 'nist' for respectively the mnist compatible
+            training set, the 60k qmnist testing set, the 10k qmnist
+            examples that match the mnist testing set, the 50k
+            remaining qmnist testing examples, or all the nist
+            digits. The default is to select 'train' or 'test'
+            according to the compatibility argument 'train'.
+        compat (bool,optional): A boolean that says whether the target
+            for each example is class number (for compatibility with
+            the MNIST dataloader) or a torch vector containing the
+            full qmnist information. Default=True.
+        download (bool, optional): If True, downloads the dataset from
+            the internet and puts it in root directory. If dataset is
+            already downloaded, it is not downloaded again.
+        transform (callable, optional): A function/transform that
+            takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform
+            that takes in the target and transforms it.
+        train (bool,optional,compatibility): When argument 'what' is
+            not specified, this boolean decides whether to load the
+            training set or the testing set.  Default: True.
+    """
+
+    subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
+    resources: Dict[str, List[Tuple[str, str]]] = {  # type: ignore[assignment]
+        "train": [
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
+                "ed72d4157d28c017586c42bc6afe6370",
+            ),
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
+                "0058f8dd561b90ffdd0f734c6a30e5e4",
+            ),
+        ],
+        "test": [
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
+                "1394631089c404de565df7b7aeaf9412",
+            ),
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
+                "5b5b05890a5e13444e108efe57b788aa",
+            ),
+        ],
+        "nist": [
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
+                "7f124b3b8ab81486c9d8c2749c17f834",
+            ),
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
+                "5ed0e788978e45d4a8bd4b7caec3d79d",
+            ),
+        ],
+    }
+    classes = [
+        "0 - zero",
+        "1 - one",
+        "2 - two",
+        "3 - three",
+        "4 - four",
+        "5 - five",
+        "6 - six",
+        "7 - seven",
+        "8 - eight",
+        "9 - nine",
+    ]
+
+    def __init__(
+        self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
+    ) -> None:
+        if what is None:
+            what = "train" if train else "test"
+        self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
+        self.compat = compat
+        self.data_file = what + ".pt"
+        self.training_file = self.data_file
+        self.test_file = self.data_file
+        super().__init__(root, train, **kwargs)
+
+    @property
+    def images_file(self) -> str:
+        (url, _), _ = self.resources[self.subsets[self.what]]
+        return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
+
+    @property
+    def labels_file(self) -> str:
+        _, (url, _) = self.resources[self.subsets[self.what]]
+        return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
+
+    def _check_exists(self) -> bool:
+        return all(check_integrity(file) for file in (self.images_file, self.labels_file))
+
+    def _load_data(self):
+        data = read_sn3_pascalvincent_tensor(self.images_file)
+        if data.dtype != torch.uint8:
+            raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
+        if data.ndimension() != 3:
+            raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")
+
+        targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
+        if targets.ndimension() != 2:
+            raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")
+
+        if self.what == "test10k":
+            data = data[0:10000, :, :].clone()
+            targets = targets[0:10000, :].clone()
+        elif self.what == "test50k":
+            data = data[10000:, :, :].clone()
+            targets = targets[10000:, :].clone()
+
+        return data, targets
+
+    def download(self) -> None:
+        """Download the QMNIST data if it doesn't exist already.
+        Note that we only download what has been asked for (argument 'what').
+        """
+        if self._check_exists():
+            return
+
+        os.makedirs(self.raw_folder, exist_ok=True)
+        split = self.resources[self.subsets[self.what]]
+
+        for url, md5 in split:
+            download_and_extract_archive(url, self.raw_folder, md5=md5)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        # redefined to handle the compat flag
+        img, target = self.data[index], self.targets[index]
+        img = Image.fromarray(img.numpy(), mode="L")
+        if self.transform is not None:
+            img = self.transform(img)
+        if self.compat:
+            target = int(target[0])
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+        return img, target
+
+    def extra_repr(self) -> str:
+        return f"Split: {self.what}"
+
+
+def get_int(b: bytes) -> int:
+    return int(codecs.encode(b, "hex"), 16)
+
+
+SN3_PASCALVINCENT_TYPEMAP = {
+    8: torch.uint8,
+    9: torch.int8,
+    11: torch.int16,
+    12: torch.int32,
+    13: torch.float32,
+    14: torch.float64,
+}
+
+
+def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
+    """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
+    Argument may be a filename, compressed filename, or file object.
+    """
+    # read
+    with open(path, "rb") as f:
+        data = f.read()
+
+    # parse
+    if sys.byteorder == "little":
+        magic = get_int(data[0:4])
+        nd = magic % 256
+        ty = magic // 256
+    else:
+        nd = get_int(data[0:1])
+        ty = get_int(data[1:2]) + get_int(data[2:3]) * 256 + get_int(data[3:4]) * 256 * 256
+
+    assert 1 <= nd <= 3
+    assert 8 <= ty <= 14
+    torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
+    s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
+
+    if sys.byteorder == "big":
+        for i in range(len(s)):
+            s[i] = int.from_bytes(s[i].to_bytes(4, byteorder="little"), byteorder="big", signed=False)
+
+    parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
+
+    # The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case
+    # that is little endian and the dtype has more than one byte, we need to flip them.
+    if sys.byteorder == "little" and parsed.element_size() > 1:
+        parsed = _flip_byte_order(parsed)
+
+    assert parsed.shape[0] == np.prod(s) or not strict
+    return parsed.view(*s)
+
+
+def read_label_file(path: str) -> torch.Tensor:
+    x = read_sn3_pascalvincent_tensor(path, strict=False)
+    if x.dtype != torch.uint8:
+        raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
+    if x.ndimension() != 1:
+        raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
+    return x.long()
+
+
+def read_image_file(path: str) -> torch.Tensor:
+    x = read_sn3_pascalvincent_tensor(path, strict=False)
+    if x.dtype != torch.uint8:
+        raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
+    if x.ndimension() != 3:
+        raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
+    return x

+ 93 - 0
libs/vision_libs/datasets/moving_mnist.py

@@ -0,0 +1,93 @@
+import os.path
+from typing import Callable, Optional
+
+import numpy as np
+import torch
+from torchvision.datasets.utils import download_url, verify_str_arg
+from torchvision.datasets.vision import VisionDataset
+
+
+class MovingMNIST(VisionDataset):
+    """`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists.
+        split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``.
+            If ``split=None``, the full data is returned.
+        split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
+            frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
+            is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
+        transform (callable, optional): A function/transform that takes in an torch Tensor
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    _URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"
+
+    def __init__(
+        self,
+        root: str,
+        split: Optional[str] = None,
+        split_ratio: int = 10,
+        download: bool = False,
+        transform: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transform=transform)
+
+        self._base_folder = os.path.join(self.root, self.__class__.__name__)
+        self._filename = self._URL.split("/")[-1]
+
+        if split is not None:
+            verify_str_arg(split, "split", ("train", "test"))
+        self.split = split
+
+        if not isinstance(split_ratio, int):
+            raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}")
+        elif not (1 <= split_ratio <= 19):
+            raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.")
+        self.split_ratio = split_ratio
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it.")
+
+        data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
+        if self.split == "train":
+            data = data[: self.split_ratio]
+        elif self.split == "test":
+            data = data[self.split_ratio :]
+        self.data = data.transpose(0, 1).unsqueeze(2).contiguous()
+
+    def __getitem__(self, idx: int) -> torch.Tensor:
+        """
+        Args:
+            index (int): Index
+        Returns:
+            torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
+        """
+        data = self.data[idx]
+        if self.transform is not None:
+            data = self.transform(data)
+
+        return data
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(os.path.join(self._base_folder, self._filename))
+
+    def download(self) -> None:
+        if self._check_exists():
+            return
+
+        download_url(
+            url=self._URL,
+            root=self._base_folder,
+            filename=self._filename,
+            md5="be083ec986bfe91a449d63653c411eb2",
+        )

+ 102 - 0
libs/vision_libs/datasets/omniglot.py

@@ -0,0 +1,102 @@
+from os.path import join
+from typing import Any, Callable, List, Optional, Tuple
+
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, list_dir, list_files
+from .vision import VisionDataset
+
+
+class Omniglot(VisionDataset):
+    """`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``omniglot-py`` exists.
+        background (bool, optional): If True, creates dataset from the "background" set, otherwise
+            creates from the "evaluation" set. This terminology is defined by the authors.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset zip files from the internet and
+            puts it in root directory. If the zip files are already downloaded, they are not
+            downloaded again.
+    """
+
+    folder = "omniglot-py"
+    download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
+    zips_md5 = {
+        "images_background": "68d2efa1b9178cc56df9314c21c6e718",
+        "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
+    }
+
+    def __init__(
+        self,
+        root: str,
+        background: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
+        self.background = background
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.target_folder = join(self.root, self._get_target_folder())
+        self._alphabets = list_dir(self.target_folder)
+        self._characters: List[str] = sum(
+            ([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
+        )
+        self._character_images = [
+            [(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
+            for idx, character in enumerate(self._characters)
+        ]
+        self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
+
+    def __len__(self) -> int:
+        return len(self._flat_character_images)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target character class.
+        """
+        image_name, character_class = self._flat_character_images[index]
+        image_path = join(self.target_folder, self._characters[character_class], image_name)
+        image = Image.open(image_path, mode="r").convert("L")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            character_class = self.target_transform(character_class)
+
+        return image, character_class
+
+    def _check_integrity(self) -> bool:
+        zip_filename = self._get_target_folder()
+        if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
+            return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        filename = self._get_target_folder()
+        zip_filename = filename + ".zip"
+        url = self.download_url_prefix + "/" + zip_filename
+        download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
+
+    def _get_target_folder(self) -> str:
+        return "images_background" if self.background else "images_evaluation"

+ 125 - 0
libs/vision_libs/datasets/oxford_iiit_pet.py

@@ -0,0 +1,125 @@
+import os
+import os.path
+import pathlib
+from typing import Any, Callable, Optional, Sequence, Tuple, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class OxfordIIITPet(VisionDataset):
+    """`Oxford-IIIT Pet Dataset   <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_.
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"trainval"`` (default) or ``"test"``.
+        target_types (string, sequence of strings, optional): Types of target to use. Can be ``category`` (default) or
+            ``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent:
+
+                - ``category`` (int): Label for one of the 37 pet categories.
+                - ``segmentation`` (PIL image): Segmentation trimap of the image.
+
+            If empty, ``None`` will be returned as target.
+
+        transform (callable, optional): A function/transform that  takes in a PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and puts it into
+            ``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again.
+    """
+
+    _RESOURCES = (
+        ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"),
+        ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"),
+    )
+    _VALID_TARGET_TYPES = ("category", "segmentation")
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "trainval",
+        target_types: Union[Sequence[str], str] = "category",
+        transforms: Optional[Callable] = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ):
+        self._split = verify_str_arg(split, "split", ("trainval", "test"))
+        if isinstance(target_types, str):
+            target_types = [target_types]
+        self._target_types = [
+            verify_str_arg(target_type, "target_types", self._VALID_TARGET_TYPES) for target_type in target_types
+        ]
+
+        super().__init__(root, transforms=transforms, transform=transform, target_transform=target_transform)
+        self._base_folder = pathlib.Path(self.root) / "oxford-iiit-pet"
+        self._images_folder = self._base_folder / "images"
+        self._anns_folder = self._base_folder / "annotations"
+        self._segs_folder = self._anns_folder / "trimaps"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        image_ids = []
+        self._labels = []
+        with open(self._anns_folder / f"{self._split}.txt") as file:
+            for line in file:
+                image_id, label, *_ = line.strip().split()
+                image_ids.append(image_id)
+                self._labels.append(int(label) - 1)
+
+        self.classes = [
+            " ".join(part.title() for part in raw_cls.split("_"))
+            for raw_cls, _ in sorted(
+                {(image_id.rsplit("_", 1)[0], label) for image_id, label in zip(image_ids, self._labels)},
+                key=lambda image_id_and_label: image_id_and_label[1],
+            )
+        ]
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+
+        self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids]
+        self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids]
+
+    def __len__(self) -> int:
+        return len(self._images)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image = Image.open(self._images[idx]).convert("RGB")
+
+        target: Any = []
+        for target_type in self._target_types:
+            if target_type == "category":
+                target.append(self._labels[idx])
+            else:  # target_type == "segmentation"
+                target.append(Image.open(self._segs[idx]))
+
+        if not target:
+            target = None
+        elif len(target) == 1:
+            target = target[0]
+        else:
+            target = tuple(target)
+
+        if self.transforms:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def _check_exists(self) -> bool:
+        for folder in (self._images_folder, self._anns_folder):
+            if not (os.path.exists(folder) and os.path.isdir(folder)):
+                return False
+        else:
+            return True
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+
+        for url, md5 in self._RESOURCES:
+            download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5)

+ 134 - 0
libs/vision_libs/datasets/pcam.py

@@ -0,0 +1,134 @@
+import pathlib
+from typing import Any, Callable, Optional, Tuple
+
+from PIL import Image
+
+from .utils import _decompress, download_file_from_google_drive, verify_str_arg
+from .vision import VisionDataset
+
+
+class PCAM(VisionDataset):
+    """`PCAM Dataset   <https://github.com/basveeling/pcam>`_.
+
+    The PatchCamelyon dataset is a binary classification dataset with 327,680
+    color images (96px x 96px), extracted from histopathologic scans of lymph node
+    sections. Each image is annotated with a binary label indicating presence of
+    metastatic tissue.
+
+    This dataset requires the ``h5py`` package which you can install with ``pip install h5py``.
+
+    Args:
+         root (string): Root directory of the dataset.
+         split (string, optional): The dataset split, supports ``"train"`` (default), ``"test"`` or ``"val"``.
+         transform (callable, optional): A function/transform that  takes in a PIL image and returns a transformed
+             version. E.g, ``transforms.RandomCrop``.
+         target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+         download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
+             dataset is already downloaded, it is not downloaded again.
+
+             .. warning::
+
+                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
+    """
+
+    _FILES = {
+        "train": {
+            "images": (
+                "camelyonpatch_level_2_split_train_x.h5",  # Data file name
+                "1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2",  # Google Drive ID
+                "1571f514728f59376b705fc836ff4b63",  # md5 hash
+            ),
+            "targets": (
+                "camelyonpatch_level_2_split_train_y.h5",
+                "1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
+                "35c2d7259d906cfc8143347bb8e05be7",
+            ),
+        },
+        "test": {
+            "images": (
+                "camelyonpatch_level_2_split_test_x.h5",
+                "1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
+                "d8c2d60d490dbd479f8199bdfa0cf6ec",
+            ),
+            "targets": (
+                "camelyonpatch_level_2_split_test_y.h5",
+                "17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
+                "60a7035772fbdb7f34eb86d4420cf66a",
+            ),
+        },
+        "val": {
+            "images": (
+                "camelyonpatch_level_2_split_valid_x.h5",
+                "1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
+                "d5b63470df7cfa627aeec8b9dc0c066e",
+            ),
+            "targets": (
+                "camelyonpatch_level_2_split_valid_y.h5",
+                "1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
+                "2b85f58b927af9964a4c15b8f7e8f179",
+            ),
+        },
+    }
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ):
+        try:
+            import h5py
+
+            self.h5py = h5py
+        except ImportError:
+            raise RuntimeError(
+                "h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
+            )
+
+        self._split = verify_str_arg(split, "split", ("train", "test", "val"))
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._base_folder = pathlib.Path(self.root) / "pcam"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+    def __len__(self) -> int:
+        images_file = self._FILES[self._split]["images"][0]
+        with self.h5py.File(self._base_folder / images_file) as images_data:
+            return images_data["x"].shape[0]
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        images_file = self._FILES[self._split]["images"][0]
+        with self.h5py.File(self._base_folder / images_file) as images_data:
+            image = Image.fromarray(images_data["x"][idx]).convert("RGB")
+
+        targets_file = self._FILES[self._split]["targets"][0]
+        with self.h5py.File(self._base_folder / targets_file) as targets_data:
+            target = int(targets_data["y"][idx, 0, 0, 0])  # shape is [num_images, 1, 1, 1]
+
+        if self.transform:
+            image = self.transform(image)
+        if self.target_transform:
+            target = self.target_transform(target)
+
+        return image, target
+
+    def _check_exists(self) -> bool:
+        images_file = self._FILES[self._split]["images"][0]
+        targets_file = self._FILES[self._split]["targets"][0]
+        return all(self._base_folder.joinpath(h5_file).exists() for h5_file in (images_file, targets_file))
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+
+        for file_name, file_id, md5 in self._FILES[self._split].values():
+            archive_name = file_name + ".gz"
+            download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5)
+            _decompress(str(self._base_folder / archive_name))

+ 228 - 0
libs/vision_libs/datasets/phototour.py

@@ -0,0 +1,228 @@
+import os
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from .utils import download_url
+from .vision import VisionDataset
+
+
+class PhotoTour(VisionDataset):
+    """`Multi-view Stereo Correspondence <http://matthewalunbrown.com/patchdata/patchdata.html>`_ Dataset.
+
+    .. note::
+
+        We only provide the newer version of the dataset, since the authors state that it
+
+            is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the
+            patches are centred on real interest point detections, rather than being projections of 3D points as is the
+            case in the old dataset.
+
+        The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm.
+
+
+    Args:
+        root (string): Root directory where images are.
+        name (string): Name of the dataset to load.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    urls = {
+        "notredame_harris": [
+            "http://matthewalunbrown.com/patchdata/notredame_harris.zip",
+            "notredame_harris.zip",
+            "69f8c90f78e171349abdf0307afefe4d",
+        ],
+        "yosemite_harris": [
+            "http://matthewalunbrown.com/patchdata/yosemite_harris.zip",
+            "yosemite_harris.zip",
+            "a73253d1c6fbd3ba2613c45065c00d46",
+        ],
+        "liberty_harris": [
+            "http://matthewalunbrown.com/patchdata/liberty_harris.zip",
+            "liberty_harris.zip",
+            "c731fcfb3abb4091110d0ae8c7ba182c",
+        ],
+        "notredame": [
+            "http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip",
+            "notredame.zip",
+            "509eda8535847b8c0a90bbb210c83484",
+        ],
+        "yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"],
+        "liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"],
+    }
+    means = {
+        "notredame": 0.4854,
+        "yosemite": 0.4844,
+        "liberty": 0.4437,
+        "notredame_harris": 0.4854,
+        "yosemite_harris": 0.4844,
+        "liberty_harris": 0.4437,
+    }
+    stds = {
+        "notredame": 0.1864,
+        "yosemite": 0.1818,
+        "liberty": 0.2019,
+        "notredame_harris": 0.1864,
+        "yosemite_harris": 0.1818,
+        "liberty_harris": 0.2019,
+    }
+    lens = {
+        "notredame": 468159,
+        "yosemite": 633587,
+        "liberty": 450092,
+        "liberty_harris": 379587,
+        "yosemite_harris": 450912,
+        "notredame_harris": 325295,
+    }
+    image_ext = "bmp"
+    info_file = "info.txt"
+    matches_files = "m50_100000_100000_0.txt"
+
+    def __init__(
+        self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False
+    ) -> None:
+        super().__init__(root, transform=transform)
+        self.name = name
+        self.data_dir = os.path.join(self.root, name)
+        self.data_down = os.path.join(self.root, f"{name}.zip")
+        self.data_file = os.path.join(self.root, f"{name}.pt")
+
+        self.train = train
+        self.mean = self.means[name]
+        self.std = self.stds[name]
+
+        if download:
+            self.download()
+
+        if not self._check_datafile_exists():
+            self.cache()
+
+        # load the serialized data
+        self.data, self.labels, self.matches = torch.load(self.data_file)
+
+    def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (data1, data2, matches)
+        """
+        if self.train:
+            data = self.data[index]
+            if self.transform is not None:
+                data = self.transform(data)
+            return data
+        m = self.matches[index]
+        data1, data2 = self.data[m[0]], self.data[m[1]]
+        if self.transform is not None:
+            data1 = self.transform(data1)
+            data2 = self.transform(data2)
+        return data1, data2, m[2]
+
+    def __len__(self) -> int:
+        return len(self.data if self.train else self.matches)
+
+    def _check_datafile_exists(self) -> bool:
+        return os.path.exists(self.data_file)
+
+    def _check_downloaded(self) -> bool:
+        return os.path.exists(self.data_dir)
+
+    def download(self) -> None:
+        if self._check_datafile_exists():
+            print(f"# Found cached data {self.data_file}")
+            return
+
+        if not self._check_downloaded():
+            # download files
+            url = self.urls[self.name][0]
+            filename = self.urls[self.name][1]
+            md5 = self.urls[self.name][2]
+            fpath = os.path.join(self.root, filename)
+
+            download_url(url, self.root, filename, md5)
+
+            print(f"# Extracting data {self.data_down}\n")
+
+            import zipfile
+
+            with zipfile.ZipFile(fpath, "r") as z:
+                z.extractall(self.data_dir)
+
+            os.unlink(fpath)
+
+    def cache(self) -> None:
+        # process and save as torch files
+        print(f"# Caching data {self.data_file}")
+
+        dataset = (
+            read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
+            read_info_file(self.data_dir, self.info_file),
+            read_matches_files(self.data_dir, self.matches_files),
+        )
+
+        with open(self.data_file, "wb") as f:
+            torch.save(dataset, f)
+
+    def extra_repr(self) -> str:
+        split = "Train" if self.train is True else "Test"
+        return f"Split: {split}"
+
+
+def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
+    """Return a Tensor containing the patches"""
+
+    def PIL2array(_img: Image.Image) -> np.ndarray:
+        """Convert PIL image type to numpy 2D array"""
+        return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)
+
+    def find_files(_data_dir: str, _image_ext: str) -> List[str]:
+        """Return a list with the file names of the images containing the patches"""
+        files = []
+        # find those files with the specified extension
+        for file_dir in os.listdir(_data_dir):
+            if file_dir.endswith(_image_ext):
+                files.append(os.path.join(_data_dir, file_dir))
+        return sorted(files)  # sort files in ascend order to keep relations
+
+    patches = []
+    list_files = find_files(data_dir, image_ext)
+
+    for fpath in list_files:
+        img = Image.open(fpath)
+        for y in range(0, img.height, 64):
+            for x in range(0, img.width, 64):
+                patch = img.crop((x, y, x + 64, y + 64))
+                patches.append(PIL2array(patch))
+    return torch.ByteTensor(np.array(patches[:n]))
+
+
+def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
+    """Return a Tensor containing the list of labels
+    Read the file and keep only the ID of the 3D point.
+    """
+    with open(os.path.join(data_dir, info_file)) as f:
+        labels = [int(line.split()[0]) for line in f]
+    return torch.LongTensor(labels)
+
+
+def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
+    """Return a Tensor containing the ground truth matches
+    Read the file and keep only 3D point ID.
+    Matches are represented with a 1, non matches with a 0.
+    """
+    matches = []
+    with open(os.path.join(data_dir, matches_file)) as f:
+        for line in f:
+            line_split = line.split()
+            matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
+    return torch.LongTensor(matches)

+ 170 - 0
libs/vision_libs/datasets/places365.py

@@ -0,0 +1,170 @@
+import os
+from os import path
+from typing import Any, Callable, Dict, List, Optional, Tuple
+from urllib.parse import urljoin
+
+from .folder import default_loader
+from .utils import check_integrity, download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Places365(VisionDataset):
+    r"""`Places365 <http://places2.csail.mit.edu/index.html>`_ classification dataset.
+
+    Args:
+        root (string): Root directory of the Places365 dataset.
+        split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
+            ``val``.
+        small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the
+            high resolution ones.
+        download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
+            downloaded archives are not downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+
+     Attributes:
+        classes (list): List of the class names.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        imgs (list): List of (image path, class_index) tuples
+        targets (list): The class_index value for each image in the dataset
+
+    Raises:
+        RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted.
+        RuntimeError: If ``download is True`` and the image archive is already extracted.
+    """
+    _SPLITS = ("train-standard", "train-challenge", "val")
+    _BASE_URL = "http://data.csail.mit.edu/places/places365/"
+    # {variant: (archive, md5)}
+    _DEVKIT_META = {
+        "standard": ("filelist_places365-standard.tar", "35a0585fee1fa656440f3ab298f8479c"),
+        "challenge": ("filelist_places365-challenge.tar", "70a8307e459c3de41690a7c76c931734"),
+    }
+    # (file, md5)
+    _CATEGORIES_META = ("categories_places365.txt", "06c963b85866bd0649f97cb43dd16673")
+    # {split: (file, md5)}
+    _FILE_LIST_META = {
+        "train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
+        "train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
+        "val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
+    }
+    # {(split, small): (file, md5)}
+    _IMAGES_META = {
+        ("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
+        ("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
+        ("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
+        ("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
+        ("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
+        ("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
+    }
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train-standard",
+        small: bool = False,
+        download: bool = False,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self.split = self._verify_split(split)
+        self.small = small
+        self.loader = loader
+
+        self.classes, self.class_to_idx = self.load_categories(download)
+        self.imgs, self.targets = self.load_file_list(download)
+
+        if download:
+            self.download_images()
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        file, target = self.imgs[index]
+        image = self.loader(file)
+
+        if self.transforms is not None:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def __len__(self) -> int:
+        return len(self.imgs)
+
+    @property
+    def variant(self) -> str:
+        return "challenge" if "challenge" in self.split else "standard"
+
+    @property
+    def images_dir(self) -> str:
+        size = "256" if self.small else "large"
+        if self.split.startswith("train"):
+            dir = f"data_{size}_{self.variant}"
+        else:
+            dir = f"{self.split}_{size}"
+        return path.join(self.root, dir)
+
+    def load_categories(self, download: bool = True) -> Tuple[List[str], Dict[str, int]]:
+        def process(line: str) -> Tuple[str, int]:
+            cls, idx = line.split()
+            return cls, int(idx)
+
+        file, md5 = self._CATEGORIES_META
+        file = path.join(self.root, file)
+        if not self._check_integrity(file, md5, download):
+            self.download_devkit()
+
+        with open(file) as fh:
+            class_to_idx = dict(process(line) for line in fh)
+
+        return sorted(class_to_idx.keys()), class_to_idx
+
+    def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]:
+        def process(line: str, sep="/") -> Tuple[str, int]:
+            image, idx = line.split()
+            return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx)
+
+        file, md5 = self._FILE_LIST_META[self.split]
+        file = path.join(self.root, file)
+        if not self._check_integrity(file, md5, download):
+            self.download_devkit()
+
+        with open(file) as fh:
+            images = [process(line) for line in fh]
+
+        _, targets = zip(*images)
+        return images, list(targets)
+
+    def download_devkit(self) -> None:
+        file, md5 = self._DEVKIT_META[self.variant]
+        download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
+
+    def download_images(self) -> None:
+        if path.exists(self.images_dir):
+            raise RuntimeError(
+                f"The directory {self.images_dir} already exists. If you want to re-download or re-extract the images, "
+                f"delete the directory."
+            )
+
+        file, md5 = self._IMAGES_META[(self.split, self.small)]
+        download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
+
+        if self.split.startswith("train"):
+            os.rename(self.images_dir.rsplit("_", 1)[0], self.images_dir)
+
+    def extra_repr(self) -> str:
+        return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__)
+
+    def _verify_split(self, split: str) -> str:
+        return verify_str_arg(split, "split", self._SPLITS)
+
+    def _check_integrity(self, file: str, md5: str, download: bool) -> bool:
+        integrity = check_integrity(file, md5=md5)
+        if not integrity and not download:
+            raise RuntimeError(
+                f"The file {file} does not exist or is corrupted. You can set download=True to download it."
+            )
+        return integrity

+ 86 - 0
libs/vision_libs/datasets/rendered_sst2.py

@@ -0,0 +1,86 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple
+
+import PIL.Image
+
+from .folder import make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class RenderedSST2(VisionDataset):
+    """`The Rendered SST2 Dataset <https://github.com/openai/CLIP/blob/main/data/rendered-sst2.md>`_.
+
+    Rendered SST2 is an image classification dataset used to evaluate the models capability on optical
+    character recognition. This dataset was generated by rendering sentences in the Standford Sentiment
+    Treebank v2 dataset.
+
+    This dataset contains two classes (positive and negative) and is divided in three splits: a  train
+    split containing 6920 images (3610 positive and 3310 negative), a validation split containing 872 images
+    (444 positive and 428 negative), and a test split containing 1821 images (909 positive and 912 negative).
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``.
+        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again. Default is False.
+    """
+
+    _URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz"
+    _MD5 = "2384d08e9dcfa4bd55b324e610496ee5"
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+        self._split_to_folder = {"train": "train", "val": "valid", "test": "test"}
+        self._base_folder = Path(self.root) / "rendered-sst2"
+        self.classes = ["negative", "positive"]
+        self.class_to_idx = {"negative": 0, "positive": 1}
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self._samples = make_dataset(str(self._base_folder / self._split_to_folder[self._split]), extensions=("png",))
+
+    def __len__(self) -> int:
+        return len(self._samples)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file, label = self._samples[idx]
+        image = PIL.Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"
+
+    def _check_exists(self) -> bool:
+        for class_label in set(self.classes):
+            if not (self._base_folder / self._split_to_folder[self._split] / class_label).is_dir():
+                return False
+        return True
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

+ 3 - 0
libs/vision_libs/datasets/samplers/__init__.py

@@ -0,0 +1,3 @@
+from .clip_sampler import DistributedSampler, RandomClipSampler, UniformClipSampler
+
+__all__ = ("DistributedSampler", "UniformClipSampler", "RandomClipSampler")

+ 172 - 0
libs/vision_libs/datasets/samplers/clip_sampler.py

@@ -0,0 +1,172 @@
+import math
+from typing import cast, Iterator, List, Optional, Sized, Union
+
+import torch
+import torch.distributed as dist
+from torch.utils.data import Sampler
+from torchvision.datasets.video_utils import VideoClips
+
+
+class DistributedSampler(Sampler):
+    """
+    Extension of DistributedSampler, as discussed in
+    https://github.com/pytorch/pytorch/issues/23430
+
+    Example:
+        dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
+        num_replicas: 4
+        shuffle: False
+
+    when group_size = 1
+            RANK    |  shard_dataset
+            =========================
+            rank_0  |  [0, 4, 8, 12]
+            rank_1  |  [1, 5, 9, 13]
+            rank_2  |  [2, 6, 10, 0]
+            rank_3  |  [3, 7, 11, 1]
+
+    when group_size = 2
+
+            RANK    |  shard_dataset
+            =========================
+            rank_0  |  [0, 1, 8, 9]
+            rank_1  |  [2, 3, 10, 11]
+            rank_2  |  [4, 5, 12, 13]
+            rank_3  |  [6, 7, 0, 1]
+
+    """
+
+    def __init__(
+        self,
+        dataset: Sized,
+        num_replicas: Optional[int] = None,
+        rank: Optional[int] = None,
+        shuffle: bool = False,
+        group_size: int = 1,
+    ) -> None:
+        if num_replicas is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            num_replicas = dist.get_world_size()
+        if rank is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            rank = dist.get_rank()
+        if len(dataset) % group_size != 0:
+            raise ValueError(
+                f"dataset length must be a multiplier of group size dataset length: {len(dataset)}, group size: {group_size}"
+            )
+        self.dataset = dataset
+        self.group_size = group_size
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.epoch = 0
+        dataset_group_length = len(dataset) // group_size
+        self.num_group_samples = int(math.ceil(dataset_group_length * 1.0 / self.num_replicas))
+        self.num_samples = self.num_group_samples * group_size
+        self.total_size = self.num_samples * self.num_replicas
+        self.shuffle = shuffle
+
+    def __iter__(self) -> Iterator[int]:
+        # deterministically shuffle based on epoch
+        g = torch.Generator()
+        g.manual_seed(self.epoch)
+        indices: Union[torch.Tensor, List[int]]
+        if self.shuffle:
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = list(range(len(self.dataset)))
+
+        # add extra samples to make it evenly divisible
+        indices += indices[: (self.total_size - len(indices))]
+        assert len(indices) == self.total_size
+
+        total_group_size = self.total_size // self.group_size
+        indices = torch.reshape(torch.LongTensor(indices), (total_group_size, self.group_size))
+
+        # subsample
+        indices = indices[self.rank : total_group_size : self.num_replicas, :]
+        indices = torch.reshape(indices, (-1,)).tolist()
+        assert len(indices) == self.num_samples
+
+        if isinstance(self.dataset, Sampler):
+            orig_indices = list(iter(self.dataset))
+            indices = [orig_indices[i] for i in indices]
+
+        return iter(indices)
+
+    def __len__(self) -> int:
+        return self.num_samples
+
+    def set_epoch(self, epoch: int) -> None:
+        self.epoch = epoch
+
+
+class UniformClipSampler(Sampler):
+    """
+    Sample `num_video_clips_per_video` clips for each video, equally spaced.
+    When number of unique clips in the video is fewer than num_video_clips_per_video,
+    repeat the clips until `num_video_clips_per_video` clips are collected
+
+    Args:
+        video_clips (VideoClips): video clips to sample from
+        num_clips_per_video (int): number of clips to be sampled per video
+    """
+
+    def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
+        if not isinstance(video_clips, VideoClips):
+            raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
+        self.video_clips = video_clips
+        self.num_clips_per_video = num_clips_per_video
+
+    def __iter__(self) -> Iterator[int]:
+        idxs = []
+        s = 0
+        # select num_clips_per_video for each video, uniformly spaced
+        for c in self.video_clips.clips:
+            length = len(c)
+            if length == 0:
+                # corner case where video decoding fails
+                continue
+
+            sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64)
+            s += length
+            idxs.append(sampled)
+        return iter(cast(List[int], torch.cat(idxs).tolist()))
+
+    def __len__(self) -> int:
+        return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0)
+
+
+class RandomClipSampler(Sampler):
+    """
+    Samples at most `max_video_clips_per_video` clips for each video randomly
+
+    Args:
+        video_clips (VideoClips): video clips to sample from
+        max_clips_per_video (int): maximum number of clips to be sampled per video
+    """
+
+    def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
+        if not isinstance(video_clips, VideoClips):
+            raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
+        self.video_clips = video_clips
+        self.max_clips_per_video = max_clips_per_video
+
+    def __iter__(self) -> Iterator[int]:
+        idxs = []
+        s = 0
+        # select at most max_clips_per_video for each video, randomly
+        for c in self.video_clips.clips:
+            length = len(c)
+            size = min(length, self.max_clips_per_video)
+            sampled = torch.randperm(length)[:size] + s
+            s += length
+            idxs.append(sampled)
+        idxs_ = torch.cat(idxs)
+        # shuffle all clips randomly
+        perm = torch.randperm(len(idxs_))
+        return iter(idxs_[perm].tolist())
+
+    def __len__(self) -> int:
+        return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)

+ 123 - 0
libs/vision_libs/datasets/sbd.py

@@ -0,0 +1,123 @@
+import os
+import shutil
+from typing import Any, Callable, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+
+from .utils import download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class SBDataset(VisionDataset):
+    """`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_
+
+    The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
+
+    .. note ::
+
+        Please note that the train and val splits included with this dataset are different from
+        the splits in the PASCAL VOC dataset. In particular some "train" images might be part of
+        VOC2012 val.
+        If you are interested in testing on VOC 2012 val, then use `image_set='train_noval'`,
+        which excludes all val images.
+
+    .. warning::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+    Args:
+        root (string): Root directory of the Semantic Boundaries Dataset
+        image_set (string, optional): Select the image_set to use, ``train``, ``val`` or ``train_noval``.
+            Image set ``train_noval`` excludes VOC 2012 val images.
+        mode (string, optional): Select target type. Possible values 'boundaries' or 'segmentation'.
+            In case of 'boundaries', the target is an array of shape `[num_classes, H, W]`,
+            where `num_classes=20`.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version. Input sample is PIL image and target is a numpy array
+            if `mode='boundaries'` or PIL image if `mode='segmentation'`.
+    """
+
+    url = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
+    md5 = "82b4d87ceb2ed10f6038a1cba92111cb"
+    filename = "benchmark.tgz"
+
+    voc_train_url = "http://home.bharathh.info/pubs/codes/SBD/train_noval.txt"
+    voc_split_filename = "train_noval.txt"
+    voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"
+
+    def __init__(
+        self,
+        root: str,
+        image_set: str = "train",
+        mode: str = "boundaries",
+        download: bool = False,
+        transforms: Optional[Callable] = None,
+    ) -> None:
+
+        try:
+            from scipy.io import loadmat
+
+            self._loadmat = loadmat
+        except ImportError:
+            raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
+
+        super().__init__(root, transforms)
+        self.image_set = verify_str_arg(image_set, "image_set", ("train", "val", "train_noval"))
+        self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
+        self.num_classes = 20
+
+        sbd_root = self.root
+        image_dir = os.path.join(sbd_root, "img")
+        mask_dir = os.path.join(sbd_root, "cls")
+
+        if download:
+            download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
+            extracted_ds_root = os.path.join(self.root, "benchmark_RELEASE", "dataset")
+            for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
+                old_path = os.path.join(extracted_ds_root, f)
+                shutil.move(old_path, sbd_root)
+            download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5)
+
+        if not os.path.isdir(sbd_root):
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        split_f = os.path.join(sbd_root, image_set.rstrip("\n") + ".txt")
+
+        with open(os.path.join(split_f)) as fh:
+            file_names = [x.strip() for x in fh.readlines()]
+
+        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+        self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
+
+        self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target
+
+    def _get_segmentation_target(self, filepath: str) -> Image.Image:
+        mat = self._loadmat(filepath)
+        return Image.fromarray(mat["GTcls"][0]["Segmentation"][0])
+
+    def _get_boundaries_target(self, filepath: str) -> np.ndarray:
+        mat = self._loadmat(filepath)
+        return np.concatenate(
+            [np.expand_dims(mat["GTcls"][0]["Boundaries"][0][i][0].toarray(), axis=0) for i in range(self.num_classes)],
+            axis=0,
+        )
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        img = Image.open(self.images[index]).convert("RGB")
+        target = self._get_target(self.masks[index])
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.images)
+
+    def extra_repr(self) -> str:
+        lines = ["Image set: {image_set}", "Mode: {mode}"]
+        return "\n".join(lines).format(**self.__dict__)

+ 109 - 0
libs/vision_libs/datasets/sbu.py

@@ -0,0 +1,109 @@
+import os
+from typing import Any, Callable, Optional, Tuple
+
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, download_url
+from .vision import VisionDataset
+
+
+class SBU(VisionDataset):
+    """`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where tarball
+            ``SBUCaptionedPhotoDataset.tar.gz`` exists.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    url = "https://www.cs.rice.edu/~vo9/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
+    filename = "SBUCaptionedPhotoDataset.tar.gz"
+    md5_checksum = "9aec147b3488753cf758b4d493422285"
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = True,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        # Read the caption for each photo
+        self.photos = []
+        self.captions = []
+
+        file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")
+        file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt")
+
+        for line1, line2 in zip(open(file1), open(file2)):
+            url = line1.rstrip()
+            photo = os.path.basename(url)
+            filename = os.path.join(self.root, "dataset", photo)
+            if os.path.exists(filename):
+                caption = line2.rstrip()
+                self.photos.append(photo)
+                self.captions.append(caption)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is a caption for the photo.
+        """
+        filename = os.path.join(self.root, "dataset", self.photos[index])
+        img = Image.open(filename).convert("RGB")
+        if self.transform is not None:
+            img = self.transform(img)
+
+        target = self.captions[index]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        """The number of photos in the dataset."""
+        return len(self.photos)
+
+    def _check_integrity(self) -> bool:
+        """Check the md5 checksum of the downloaded tarball."""
+        root = self.root
+        fpath = os.path.join(root, self.filename)
+        if not check_integrity(fpath, self.md5_checksum):
+            return False
+        return True
+
+    def download(self) -> None:
+        """Download and extract the tarball, and download each individual photo."""
+
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        download_and_extract_archive(self.url, self.root, self.root, self.filename, self.md5_checksum)
+
+        # Download individual photos
+        with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
+            for line in fh:
+                url = line.rstrip()
+                try:
+                    download_url(url, os.path.join(self.root, "dataset"))
+                except OSError:
+                    # The images point to public images on Flickr.
+                    # Note: Images might be removed by users at anytime.
+                    pass

+ 91 - 0
libs/vision_libs/datasets/semeion.py

@@ -0,0 +1,91 @@
+import os.path
+from typing import Any, Callable, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_url
+from .vision import VisionDataset
+
+
+class SEMEION(VisionDataset):
+    r"""`SEMEION <http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``semeion.py`` exists.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+    url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data"
+    filename = "semeion.data"
+    md5_checksum = "cb545d371d2ce14ec121470795a77432"
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = True,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        fp = os.path.join(self.root, self.filename)
+        data = np.loadtxt(fp)
+        # convert value to 8 bit unsigned integer
+        # color (white #255) the pixels
+        self.data = (data[:, :256] * 255).astype("uint8")
+        self.data = np.reshape(self.data, (-1, 16, 16))
+        self.labels = np.nonzero(data[:, 256:])[1]
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], int(self.labels[index])
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(img, mode="L")
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    def _check_integrity(self) -> bool:
+        root = self.root
+        fpath = os.path.join(root, self.filename)
+        if not check_integrity(fpath, self.md5_checksum):
+            return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        root = self.root
+        download_url(self.url, root, self.filename, self.md5_checksum)

+ 121 - 0
libs/vision_libs/datasets/stanford_cars.py

@@ -0,0 +1,121 @@
+import pathlib
+from typing import Any, Callable, Optional, Tuple
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class StanfordCars(VisionDataset):
+    """`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
+
+    The Cars dataset contains 16,185 images of 196 classes of cars. The data is
+    split into 8,144 training images and 8,041 testing images, where each class
+    has been split roughly in a 50-50 split
+
+    .. note::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+    Args:
+        root (string): Root directory of dataset
+        split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again."""
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+
+        try:
+            import scipy.io as sio
+        except ImportError:
+            raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self._split = verify_str_arg(split, "split", ("train", "test"))
+        self._base_folder = pathlib.Path(root) / "stanford_cars"
+        devkit = self._base_folder / "devkit"
+
+        if self._split == "train":
+            self._annotations_mat_path = devkit / "cars_train_annos.mat"
+            self._images_base_path = self._base_folder / "cars_train"
+        else:
+            self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
+            self._images_base_path = self._base_folder / "cars_test"
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self._samples = [
+            (
+                str(self._images_base_path / annotation["fname"]),
+                annotation["class"] - 1,  # Original target mapping  starts from 1, hence -1
+            )
+            for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
+        ]
+
+        self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
+        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
+
+    def __len__(self) -> int:
+        return len(self._samples)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        """Returns pil_image and class_id for given index"""
+        image_path, target = self._samples[idx]
+        pil_image = Image.open(image_path).convert("RGB")
+
+        if self.transform is not None:
+            pil_image = self.transform(pil_image)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+        return pil_image, target
+
+    def download(self) -> None:
+        if self._check_exists():
+            return
+
+        download_and_extract_archive(
+            url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
+            download_root=str(self._base_folder),
+            md5="c3b158d763b6e2245038c8ad08e45376",
+        )
+        if self._split == "train":
+            download_and_extract_archive(
+                url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
+                download_root=str(self._base_folder),
+                md5="065e5b463ae28d29e77c1b4b166cfe61",
+            )
+        else:
+            download_and_extract_archive(
+                url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
+                download_root=str(self._base_folder),
+                md5="4ce7ebf6a94d07f1952d94dd34c4d501",
+            )
+            download_url(
+                url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
+                root=str(self._base_folder),
+                md5="b0a2b23655a3edd16d84508592a98d10",
+            )
+
+    def _check_exists(self) -> bool:
+        if not (self._base_folder / "devkit").is_dir():
+            return False
+
+        return self._annotations_mat_path.exists() and self._images_base_path.is_dir()

+ 174 - 0
libs/vision_libs/datasets/stl10.py

@@ -0,0 +1,174 @@
+import os.path
+from typing import Any, Callable, cast, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class STL10(VisionDataset):
+    """`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``stl10_binary`` exists.
+        split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
+            Accordingly, dataset is selected.
+        folds (int, optional): One of {0-9} or None.
+            For training, loads one of the 10 pre-defined folds of 1k samples for the
+            standard evaluation procedure. If no value is passed, loads the 5k samples.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    base_folder = "stl10_binary"
+    url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
+    filename = "stl10_binary.tar.gz"
+    tgz_md5 = "91f7769df0f17e558f3565bffb0c7dfb"
+    class_names_file = "class_names.txt"
+    folds_list_file = "fold_indices.txt"
+    train_list = [
+        ["train_X.bin", "918c2871b30a85fa023e0c44e0bee87f"],
+        ["train_y.bin", "5a34089d4802c674881badbb80307741"],
+        ["unlabeled_X.bin", "5242ba1fed5e4be9e1e742405eb56ca4"],
+    ]
+
+    test_list = [["test_X.bin", "7f263ba9f9e0b06b93213547f721ac82"], ["test_y.bin", "36f9794fa4beb8a2c72628de14fa638e"]]
+    splits = ("train", "train+unlabeled", "unlabeled", "test")
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        folds: Optional[int] = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.split = verify_str_arg(split, "split", self.splits)
+        self.folds = self._verify_folds(folds)
+
+        if download:
+            self.download()
+        elif not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        # now load the picked numpy arrays
+        self.labels: Optional[np.ndarray]
+        if self.split == "train":
+            self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
+            self.labels = cast(np.ndarray, self.labels)
+            self.__load_folds(folds)
+
+        elif self.split == "train+unlabeled":
+            self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
+            self.labels = cast(np.ndarray, self.labels)
+            self.__load_folds(folds)
+            unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
+            self.data = np.concatenate((self.data, unlabeled_data))
+            self.labels = np.concatenate((self.labels, np.asarray([-1] * unlabeled_data.shape[0])))
+
+        elif self.split == "unlabeled":
+            self.data, _ = self.__loadfile(self.train_list[2][0])
+            self.labels = np.asarray([-1] * self.data.shape[0])
+        else:  # self.split == 'test':
+            self.data, self.labels = self.__loadfile(self.test_list[0][0], self.test_list[1][0])
+
+        class_file = os.path.join(self.root, self.base_folder, self.class_names_file)
+        if os.path.isfile(class_file):
+            with open(class_file) as f:
+                self.classes = f.read().splitlines()
+
+    def _verify_folds(self, folds: Optional[int]) -> Optional[int]:
+        if folds is None:
+            return folds
+        elif isinstance(folds, int):
+            if folds in range(10):
+                return folds
+            msg = "Value for argument folds should be in the range [0, 10), but got {}."
+            raise ValueError(msg.format(folds))
+        else:
+            msg = "Expected type None or int for argument folds, but got type {}."
+            raise ValueError(msg.format(type(folds)))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        target: Optional[int]
+        if self.labels is not None:
+            img, target = self.data[index], int(self.labels[index])
+        else:
+            img, target = self.data[index], None
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(np.transpose(img, (1, 2, 0)))
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return self.data.shape[0]
+
+    def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
+        labels = None
+        if labels_file:
+            path_to_labels = os.path.join(self.root, self.base_folder, labels_file)
+            with open(path_to_labels, "rb") as f:
+                labels = np.fromfile(f, dtype=np.uint8) - 1  # 0-based
+
+        path_to_data = os.path.join(self.root, self.base_folder, data_file)
+        with open(path_to_data, "rb") as f:
+            # read whole file in uint8 chunks
+            everything = np.fromfile(f, dtype=np.uint8)
+            images = np.reshape(everything, (-1, 3, 96, 96))
+            images = np.transpose(images, (0, 1, 3, 2))
+
+        return images, labels
+
+    def _check_integrity(self) -> bool:
+        for filename, md5 in self.train_list + self.test_list:
+            fpath = os.path.join(self.root, self.base_folder, filename)
+            if not check_integrity(fpath, md5):
+                return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
+        self._check_integrity()
+
+    def extra_repr(self) -> str:
+        return "Split: {split}".format(**self.__dict__)
+
+    def __load_folds(self, folds: Optional[int]) -> None:
+        # loads one of the folds if specified
+        if folds is None:
+            return
+        path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file)
+        with open(path_to_folds) as f:
+            str_idx = f.read().splitlines()[folds]
+            list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ")
+            self.data = self.data[list_idx, :, :, :]
+            if self.labels is not None:
+                self.labels = self.labels[list_idx]

+ 76 - 0
libs/vision_libs/datasets/sun397.py

@@ -0,0 +1,76 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple
+
+import PIL.Image
+
+from .utils import download_and_extract_archive
+from .vision import VisionDataset
+
+
+class SUN397(VisionDataset):
+    """`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.
+
+    The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
+    397 categories with 108'754 images.
+
+    Args:
+        root (string): Root directory of the dataset.
+        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
+    _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._data_dir = Path(self.root) / "SUN397"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        with open(self._data_dir / "ClassName.txt") as f:
+            self.classes = [c[3:].strip() for c in f]
+
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+        self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
+
+        self._labels = [
+            self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
+        ]
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = PIL.Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def _check_exists(self) -> bool:
+        return self._data_dir.is_dir()
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)

+ 129 - 0
libs/vision_libs/datasets/svhn.py

@@ -0,0 +1,129 @@
+import os.path
+from typing import Any, Callable, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class SVHN(VisionDataset):
+    """`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
+    Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
+    we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
+    expect the class labels to be in the range `[0, C-1]`
+
+    .. warning::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load data from `.mat` format.
+
+    Args:
+        root (string): Root directory of the dataset where the data is stored.
+        split (string): One of {'train', 'test', 'extra'}.
+            Accordingly dataset is selected. 'extra' is Extra training set.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    split_list = {
+        "train": [
+            "http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
+            "train_32x32.mat",
+            "e26dedcc434d2e4c54c9b2d4a06d8373",
+        ],
+        "test": [
+            "http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
+            "test_32x32.mat",
+            "eb5a983be6a315427106f1b164d9cef3",
+        ],
+        "extra": [
+            "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
+            "extra_32x32.mat",
+            "a93ce644f1a588dc4d68dda5feec44a7",
+        ],
+    }
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
+        self.url = self.split_list[split][0]
+        self.filename = self.split_list[split][1]
+        self.file_md5 = self.split_list[split][2]
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        # import here rather than at top of file because this is
+        # an optional dependency for torchvision
+        import scipy.io as sio
+
+        # reading(loading) mat file as array
+        loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
+
+        self.data = loaded_mat["X"]
+        # loading from the .mat file gives an np.ndarray of type np.uint8
+        # converting to np.int64, so that we have a LongTensor after
+        # the conversion from the numpy array
+        # the squeeze is needed to obtain a 1D tensor
+        self.labels = loaded_mat["y"].astype(np.int64).squeeze()
+
+        # the svhn dataset assigns the class label "10" to the digit 0
+        # this makes it inconsistent with several loss functions
+        # which expect the class labels to be in the range [0, C-1]
+        np.place(self.labels, self.labels == 10, 0)
+        self.data = np.transpose(self.data, (3, 2, 0, 1))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], int(self.labels[index])
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(np.transpose(img, (1, 2, 0)))
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    def _check_integrity(self) -> bool:
+        root = self.root
+        md5 = self.split_list[self.split][2]
+        fpath = os.path.join(root, self.filename)
+        return check_integrity(fpath, md5)
+
+    def download(self) -> None:
+        md5 = self.split_list[self.split][2]
+        download_url(self.url, self.root, self.filename, md5)
+
+    def extra_repr(self) -> str:
+        return "Split: {split}".format(**self.__dict__)

+ 130 - 0
libs/vision_libs/datasets/ucf101.py

@@ -0,0 +1,130 @@
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+from torch import Tensor
+
+from .folder import find_classes, make_dataset
+from .video_utils import VideoClips
+from .vision import VisionDataset
+
+
+class UCF101(VisionDataset):
+    """
+    `UCF101 <https://www.crcv.ucf.edu/data/UCF101.php>`_ dataset.
+
+    UCF101 is an action recognition video dataset.
+    This dataset consider every video as a collection of video clips of fixed size, specified
+    by ``frames_per_clip``, where the step in frames between each clip is given by
+    ``step_between_clips``. The dataset itself can be downloaded from the dataset website;
+    annotations that ``annotation_path`` should be pointing to can be downloaded from `here
+    <https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip>`_.
+
+    To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+    and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+    elements will come from video 1, and the next three elements from video 2.
+    Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+    frames in a video might be present.
+
+    Internally, it uses a VideoClips object to handle clip creation.
+
+    Args:
+        root (string): Root directory of the UCF101 Dataset.
+        annotation_path (str): path to the folder containing the split files;
+            see docstring above for download instructions of these files
+        frames_per_clip (int): number of frames in a clip.
+        step_between_clips (int, optional): number of frames between each clip.
+        fold (int, optional): which fold to use. Should be between 1 and 3.
+        train (bool, optional): if ``True``, creates a dataset from the train split,
+            otherwise from the ``test`` split.
+        transform (callable, optional): A function/transform that  takes in a TxHxWxC video
+            and returns a transformed version.
+        output_format (str, optional): The format of the output video tensors (before transforms).
+            Can be either "THWC" (default) or "TCHW".
+
+    Returns:
+        tuple: A 3-tuple with the following entries:
+
+            - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
+            -  audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+               and `L` is the number of points
+            - label (int): class of the video clip
+    """
+
+    def __init__(
+        self,
+        root: str,
+        annotation_path: str,
+        frames_per_clip: int,
+        step_between_clips: int = 1,
+        frame_rate: Optional[int] = None,
+        fold: int = 1,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        _precomputed_metadata: Optional[Dict[str, Any]] = None,
+        num_workers: int = 1,
+        _video_width: int = 0,
+        _video_height: int = 0,
+        _video_min_dimension: int = 0,
+        _audio_samples: int = 0,
+        output_format: str = "THWC",
+    ) -> None:
+        super().__init__(root)
+        if not 1 <= fold <= 3:
+            raise ValueError(f"fold should be between 1 and 3, got {fold}")
+
+        extensions = ("avi",)
+        self.fold = fold
+        self.train = train
+
+        self.classes, class_to_idx = find_classes(self.root)
+        self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
+        video_list = [x[0] for x in self.samples]
+        video_clips = VideoClips(
+            video_list,
+            frames_per_clip,
+            step_between_clips,
+            frame_rate,
+            _precomputed_metadata,
+            num_workers=num_workers,
+            _video_width=_video_width,
+            _video_height=_video_height,
+            _video_min_dimension=_video_min_dimension,
+            _audio_samples=_audio_samples,
+            output_format=output_format,
+        )
+        # we bookkeep the full version of video clips because we want to be able
+        # to return the metadata of full version rather than the subset version of
+        # video clips
+        self.full_video_clips = video_clips
+        self.indices = self._select_fold(video_list, annotation_path, fold, train)
+        self.video_clips = video_clips.subset(self.indices)
+        self.transform = transform
+
+    @property
+    def metadata(self) -> Dict[str, Any]:
+        return self.full_video_clips.metadata
+
+    def _select_fold(self, video_list: List[str], annotation_path: str, fold: int, train: bool) -> List[int]:
+        name = "train" if train else "test"
+        name = f"{name}list{fold:02d}.txt"
+        f = os.path.join(annotation_path, name)
+        selected_files = set()
+        with open(f) as fid:
+            data = fid.readlines()
+            data = [x.strip().split(" ")[0] for x in data]
+            data = [os.path.join(self.root, *x.split("/")) for x in data]
+            selected_files.update(data)
+        indices = [i for i in range(len(video_list)) if video_list[i] in selected_files]
+        return indices
+
+    def __len__(self) -> int:
+        return self.video_clips.num_clips()
+
+    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
+        video, audio, info, video_idx = self.video_clips.get_clip(idx)
+        label = self.samples[self.indices[video_idx]][1]
+
+        if self.transform is not None:
+            video = self.transform(video)
+
+        return video, audio, label

+ 95 - 0
libs/vision_libs/datasets/usps.py

@@ -0,0 +1,95 @@
+import os
+from typing import Any, Callable, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+
+from .utils import download_url
+from .vision import VisionDataset
+
+
+class USPS(VisionDataset):
+    """`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset.
+    The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``.
+    The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
+    and make pixel values in ``[0, 255]``.
+
+    Args:
+        root (string): Root directory of dataset to store``USPS`` data files.
+        train (bool, optional): If True, creates dataset from ``usps.bz2``,
+            otherwise from ``usps.t.bz2``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    split_list = {
+        "train": [
+            "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
+            "usps.bz2",
+            "ec16c51db3855ca6c91edd34d0e9b197",
+        ],
+        "test": [
+            "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
+            "usps.t.bz2",
+            "8ea070ee2aca1ac39742fdd1ef5ed118",
+        ],
+    }
+
+    def __init__(
+        self,
+        root: str,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        split = "train" if train else "test"
+        url, filename, checksum = self.split_list[split]
+        full_path = os.path.join(self.root, filename)
+
+        if download and not os.path.exists(full_path):
+            download_url(url, self.root, filename, md5=checksum)
+
+        import bz2
+
+        with bz2.open(full_path) as fp:
+            raw_data = [line.decode().split() for line in fp.readlines()]
+            tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
+            imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
+            imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
+            targets = [int(d[0]) - 1 for d in raw_data]
+
+        self.data = imgs
+        self.targets = targets
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], int(self.targets[index])
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(img, mode="L")
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)

+ 459 - 0
libs/vision_libs/datasets/utils.py

@@ -0,0 +1,459 @@
+import bz2
+import gzip
+import hashlib
+import lzma
+import os
+import os.path
+import pathlib
+import re
+import sys
+import tarfile
+import urllib
+import urllib.error
+import urllib.request
+import zipfile
+from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
+from urllib.parse import urlparse
+
+import numpy as np
+import torch
+from torch.utils.model_zoo import tqdm
+
+from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
+
+USER_AGENT = "pytorch/vision"
+
+
+def _save_response_content(
+    content: Iterator[bytes],
+    destination: str,
+    length: Optional[int] = None,
+) -> None:
+    with open(destination, "wb") as fh, tqdm(total=length) as pbar:
+        for chunk in content:
+            # filter out keep-alive new chunks
+            if not chunk:
+                continue
+
+            fh.write(chunk)
+            pbar.update(len(chunk))
+
+
+def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
+    with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
+        _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
+
+
+def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
+    # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
+    # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
+    # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
+    if sys.version_info >= (3, 9):
+        md5 = hashlib.md5(usedforsecurity=False)
+    else:
+        md5 = hashlib.md5()
+    with open(fpath, "rb") as f:
+        while chunk := f.read(chunk_size):
+            md5.update(chunk)
+    return md5.hexdigest()
+
+
+def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
+    return md5 == calculate_md5(fpath, **kwargs)
+
+
+def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
+    if not os.path.isfile(fpath):
+        return False
+    if md5 is None:
+        return True
+    return check_md5(fpath, md5)
+
+
+def _get_redirect_url(url: str, max_hops: int = 3) -> str:
+    initial_url = url
+    headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
+
+    for _ in range(max_hops + 1):
+        with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
+            if response.url == url or response.url is None:
+                return url
+
+            url = response.url
+    else:
+        raise RecursionError(
+            f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
+        )
+
+
+def _get_google_drive_file_id(url: str) -> Optional[str]:
+    parts = urlparse(url)
+
+    if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
+        return None
+
+    match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
+    if match is None:
+        return None
+
+    return match.group("id")
+
+
+def download_url(
+    url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
+) -> None:
+    """Download a file from a url and place it in root.
+
+    Args:
+        url (str): URL to download file from
+        root (str): Directory to place downloaded file in
+        filename (str, optional): Name to save the file under. If None, use the basename of the URL
+        md5 (str, optional): MD5 checksum of the download. If None, do not check
+        max_redirect_hops (int, optional): Maximum number of redirect hops allowed
+    """
+    root = os.path.expanduser(root)
+    if not filename:
+        filename = os.path.basename(url)
+    fpath = os.path.join(root, filename)
+
+    os.makedirs(root, exist_ok=True)
+
+    # check if file is already present locally
+    if check_integrity(fpath, md5):
+        print("Using downloaded and verified file: " + fpath)
+        return
+
+    if _is_remote_location_available():
+        _download_file_from_remote_location(fpath, url)
+    else:
+        # expand redirect chain if needed
+        url = _get_redirect_url(url, max_hops=max_redirect_hops)
+
+        # check if file is located on Google Drive
+        file_id = _get_google_drive_file_id(url)
+        if file_id is not None:
+            return download_file_from_google_drive(file_id, root, filename, md5)
+
+        # download the file
+        try:
+            print("Downloading " + url + " to " + fpath)
+            _urlretrieve(url, fpath)
+        except (urllib.error.URLError, OSError) as e:  # type: ignore[attr-defined]
+            if url[:5] == "https":
+                url = url.replace("https:", "http:")
+                print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
+                _urlretrieve(url, fpath)
+            else:
+                raise e
+
+    # check integrity of downloaded file
+    if not check_integrity(fpath, md5):
+        raise RuntimeError("File not found or corrupted.")
+
+
+def list_dir(root: str, prefix: bool = False) -> List[str]:
+    """List all directories at a given root
+
+    Args:
+        root (str): Path to directory whose folders need to be listed
+        prefix (bool, optional): If true, prepends the path to each result, otherwise
+            only returns the name of the directories found
+    """
+    root = os.path.expanduser(root)
+    directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
+    if prefix is True:
+        directories = [os.path.join(root, d) for d in directories]
+    return directories
+
+
+def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
+    """List all files ending with a suffix at a given root
+
+    Args:
+        root (str): Path to directory whose folders need to be listed
+        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
+            It uses the Python "str.endswith" method and is passed directly
+        prefix (bool, optional): If true, prepends the path to each result, otherwise
+            only returns the name of the files found
+    """
+    root = os.path.expanduser(root)
+    files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
+    if prefix is True:
+        files = [os.path.join(root, d) for d in files]
+    return files
+
+
+def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
+    """Download a Google Drive file from  and place it in root.
+
+    Args:
+        file_id (str): id of file to be downloaded
+        root (str): Directory to place downloaded file in
+        filename (str, optional): Name to save the file under. If None, use the id of the file.
+        md5 (str, optional): MD5 checksum of the download. If None, do not check
+    """
+    try:
+        import gdown
+    except ModuleNotFoundError:
+        raise RuntimeError(
+            "To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
+        )
+
+    root = os.path.expanduser(root)
+    if not filename:
+        filename = file_id
+    fpath = os.path.join(root, filename)
+
+    os.makedirs(root, exist_ok=True)
+
+    if check_integrity(fpath, md5):
+        print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
+        return
+
+    gdown.download(id=file_id, output=fpath, quiet=False, user_agent=USER_AGENT)
+
+    if not check_integrity(fpath, md5):
+        raise RuntimeError("File not found or corrupted.")
+
+
+def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
+    with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
+        tar.extractall(to_path)
+
+
+_ZIP_COMPRESSION_MAP: Dict[str, int] = {
+    ".bz2": zipfile.ZIP_BZIP2,
+    ".xz": zipfile.ZIP_LZMA,
+}
+
+
+def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
+    with zipfile.ZipFile(
+        from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
+    ) as zip:
+        zip.extractall(to_path)
+
+
+_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
+    ".tar": _extract_tar,
+    ".zip": _extract_zip,
+}
+_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
+    ".bz2": bz2.open,
+    ".gz": gzip.open,
+    ".xz": lzma.open,
+}
+_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
+    ".tbz": (".tar", ".bz2"),
+    ".tbz2": (".tar", ".bz2"),
+    ".tgz": (".tar", ".gz"),
+}
+
+
+def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
+    """Detect the archive type and/or compression of a file.
+
+    Args:
+        file (str): the filename
+
+    Returns:
+        (tuple): tuple of suffix, archive type, and compression
+
+    Raises:
+        RuntimeError: if file has no suffix or suffix is not supported
+    """
+    suffixes = pathlib.Path(file).suffixes
+    if not suffixes:
+        raise RuntimeError(
+            f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
+        )
+    suffix = suffixes[-1]
+
+    # check if the suffix is a known alias
+    if suffix in _FILE_TYPE_ALIASES:
+        return (suffix, *_FILE_TYPE_ALIASES[suffix])
+
+    # check if the suffix is an archive type
+    if suffix in _ARCHIVE_EXTRACTORS:
+        return suffix, suffix, None
+
+    # check if the suffix is a compression
+    if suffix in _COMPRESSED_FILE_OPENERS:
+        # check for suffix hierarchy
+        if len(suffixes) > 1:
+            suffix2 = suffixes[-2]
+
+            # check if the suffix2 is an archive type
+            if suffix2 in _ARCHIVE_EXTRACTORS:
+                return suffix2 + suffix, suffix2, suffix
+
+        return suffix, None, suffix
+
+    valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
+    raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
+
+
+def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
+    r"""Decompress a file.
+
+    The compression is automatically detected from the file name.
+
+    Args:
+        from_path (str): Path to the file to be decompressed.
+        to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
+        remove_finished (bool): If ``True``, remove the file after the extraction.
+
+    Returns:
+        (str): Path to the decompressed file.
+    """
+    suffix, archive_type, compression = _detect_file_type(from_path)
+    if not compression:
+        raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
+
+    if to_path is None:
+        to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
+
+    # We don't need to check for a missing key here, since this was already done in _detect_file_type()
+    compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
+
+    with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
+        wfh.write(rfh.read())
+
+    if remove_finished:
+        os.remove(from_path)
+
+    return to_path
+
+
+def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
+    """Extract an archive.
+
+    The archive type and a possible compression is automatically detected from the file name. If the file is compressed
+    but not an archive the call is dispatched to :func:`decompress`.
+
+    Args:
+        from_path (str): Path to the file to be extracted.
+        to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
+            used.
+        remove_finished (bool): If ``True``, remove the file after the extraction.
+
+    Returns:
+        (str): Path to the directory the file was extracted to.
+    """
+    if to_path is None:
+        to_path = os.path.dirname(from_path)
+
+    suffix, archive_type, compression = _detect_file_type(from_path)
+    if not archive_type:
+        return _decompress(
+            from_path,
+            os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
+            remove_finished=remove_finished,
+        )
+
+    # We don't need to check for a missing key here, since this was already done in _detect_file_type()
+    extractor = _ARCHIVE_EXTRACTORS[archive_type]
+
+    extractor(from_path, to_path, compression)
+    if remove_finished:
+        os.remove(from_path)
+
+    return to_path
+
+
+def download_and_extract_archive(
+    url: str,
+    download_root: str,
+    extract_root: Optional[str] = None,
+    filename: Optional[str] = None,
+    md5: Optional[str] = None,
+    remove_finished: bool = False,
+) -> None:
+    download_root = os.path.expanduser(download_root)
+    if extract_root is None:
+        extract_root = download_root
+    if not filename:
+        filename = os.path.basename(url)
+
+    download_url(url, download_root, filename, md5)
+
+    archive = os.path.join(download_root, filename)
+    print(f"Extracting {archive} to {extract_root}")
+    extract_archive(archive, extract_root, remove_finished)
+
+
+def iterable_to_str(iterable: Iterable) -> str:
+    return "'" + "', '".join([str(item) for item in iterable]) + "'"
+
+
+T = TypeVar("T", str, bytes)
+
+
+def verify_str_arg(
+    value: T,
+    arg: Optional[str] = None,
+    valid_values: Optional[Iterable[T]] = None,
+    custom_msg: Optional[str] = None,
+) -> T:
+    if not isinstance(value, str):
+        if arg is None:
+            msg = "Expected type str, but got type {type}."
+        else:
+            msg = "Expected type str for argument {arg}, but got type {type}."
+        msg = msg.format(type=type(value), arg=arg)
+        raise ValueError(msg)
+
+    if valid_values is None:
+        return value
+
+    if value not in valid_values:
+        if custom_msg is not None:
+            msg = custom_msg
+        else:
+            msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
+            msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
+        raise ValueError(msg)
+
+    return value
+
+
+def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
+    """Read file in .pfm format. Might contain either 1 or 3 channels of data.
+
+    Args:
+        file_name (str): Path to the file.
+        slice_channels (int): Number of channels to slice out of the file.
+            Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
+    """
+
+    with open(file_name, "rb") as f:
+        header = f.readline().rstrip()
+        if header not in [b"PF", b"Pf"]:
+            raise ValueError("Invalid PFM file")
+
+        dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
+        if not dim_match:
+            raise Exception("Malformed PFM header.")
+        w, h = (int(dim) for dim in dim_match.groups())
+
+        scale = float(f.readline().rstrip())
+        if scale < 0:  # little-endian
+            endian = "<"
+            scale = -scale
+        else:
+            endian = ">"  # big-endian
+
+        data = np.fromfile(f, dtype=endian + "f")
+
+    pfm_channels = 3 if header == b"PF" else 1
+
+    data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
+    data = np.flip(data, axis=1)  # flip on h dimension
+    data = data[:slice_channels, :, :]
+    return data.astype(np.float32)
+
+
+def _flip_byte_order(t: torch.Tensor) -> torch.Tensor:
+    return (
+        t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype)
+    )

+ 419 - 0
libs/vision_libs/datasets/video_utils.py

@@ -0,0 +1,419 @@
+import bisect
+import math
+import warnings
+from fractions import Fraction
+from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar, Union
+
+import torch
+from torchvision.io import _probe_video_from_file, _read_video_from_file, read_video, read_video_timestamps
+
+from .utils import tqdm
+
+T = TypeVar("T")
+
+
+def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int:
+    """convert pts between different time bases
+    Args:
+        pts: presentation timestamp, float
+        timebase_from: original timebase. Fraction
+        timebase_to: new timebase. Fraction
+        round_func: rounding function.
+    """
+    new_pts = Fraction(pts, 1) * timebase_from / timebase_to
+    return round_func(new_pts)
+
+
+def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor:
+    """
+    similar to tensor.unfold, but with the dilation
+    and specialized for 1d tensors
+
+    Returns all consecutive windows of `size` elements, with
+    `step` between windows. The distance between each element
+    in a window is given by `dilation`.
+    """
+    if tensor.dim() != 1:
+        raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}")
+    o_stride = tensor.stride(0)
+    numel = tensor.numel()
+    new_stride = (step * o_stride, dilation * o_stride)
+    new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
+    if new_size[0] < 1:
+        new_size = (0, size)
+    return torch.as_strided(tensor, new_size, new_stride)
+
+
+class _VideoTimestampsDataset:
+    """
+    Dataset used to parallelize the reading of the timestamps
+    of a list of videos, given their paths in the filesystem.
+
+    Used in VideoClips and defined at top level, so it can be
+    pickled when forking.
+    """
+
+    def __init__(self, video_paths: List[str]) -> None:
+        self.video_paths = video_paths
+
+    def __len__(self) -> int:
+        return len(self.video_paths)
+
+    def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]:
+        return read_video_timestamps(self.video_paths[idx])
+
+
+def _collate_fn(x: T) -> T:
+    """
+    Dummy collate function to be used with _VideoTimestampsDataset
+    """
+    return x
+
+
+class VideoClips:
+    """
+    Given a list of video files, computes all consecutive subvideos of size
+    `clip_length_in_frames`, where the distance between each subvideo in the
+    same video is defined by `frames_between_clips`.
+    If `frame_rate` is specified, it will also resample all the videos to have
+    the same frame rate, and the clips will refer to this frame rate.
+
+    Creating this instance the first time is time-consuming, as it needs to
+    decode all the videos in `video_paths`. It is recommended that you
+    cache the results after instantiation of the class.
+
+    Recreating the clips for different clip lengths is fast, and can be done
+    with the `compute_clips` method.
+
+    Args:
+        video_paths (List[str]): paths to the video files
+        clip_length_in_frames (int): size of a clip in number of frames
+        frames_between_clips (int): step (in frames) between each clip
+        frame_rate (int, optional): if specified, it will resample the video
+            so that it has `frame_rate`, and then the clips will be defined
+            on the resampled video
+        num_workers (int): how many subprocesses to use for data loading.
+            0 means that the data will be loaded in the main process. (default: 0)
+        output_format (str): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
+    """
+
+    def __init__(
+        self,
+        video_paths: List[str],
+        clip_length_in_frames: int = 16,
+        frames_between_clips: int = 1,
+        frame_rate: Optional[int] = None,
+        _precomputed_metadata: Optional[Dict[str, Any]] = None,
+        num_workers: int = 0,
+        _video_width: int = 0,
+        _video_height: int = 0,
+        _video_min_dimension: int = 0,
+        _video_max_dimension: int = 0,
+        _audio_samples: int = 0,
+        _audio_channels: int = 0,
+        output_format: str = "THWC",
+    ) -> None:
+
+        self.video_paths = video_paths
+        self.num_workers = num_workers
+
+        # these options are not valid for pyav backend
+        self._video_width = _video_width
+        self._video_height = _video_height
+        self._video_min_dimension = _video_min_dimension
+        self._video_max_dimension = _video_max_dimension
+        self._audio_samples = _audio_samples
+        self._audio_channels = _audio_channels
+        self.output_format = output_format.upper()
+        if self.output_format not in ("THWC", "TCHW"):
+            raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
+
+        if _precomputed_metadata is None:
+            self._compute_frame_pts()
+        else:
+            self._init_from_metadata(_precomputed_metadata)
+        self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
+
+    def _compute_frame_pts(self) -> None:
+        self.video_pts = []  # len = num_videos. Each entry is a tensor of shape (num_frames_in_video,)
+        self.video_fps: List[int] = []  # len = num_videos
+
+        # strategy: use a DataLoader to parallelize read_video_timestamps
+        # so need to create a dummy dataset first
+        import torch.utils.data
+
+        dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
+            _VideoTimestampsDataset(self.video_paths),  # type: ignore[arg-type]
+            batch_size=16,
+            num_workers=self.num_workers,
+            collate_fn=_collate_fn,
+        )
+
+        with tqdm(total=len(dl)) as pbar:
+            for batch in dl:
+                pbar.update(1)
+                batch_pts, batch_fps = list(zip(*batch))
+                # we need to specify dtype=torch.long because for empty list,
+                # torch.as_tensor will use torch.float as default dtype. This
+                # happens when decoding fails and no pts is returned in the list.
+                batch_pts = [torch.as_tensor(pts, dtype=torch.long) for pts in batch_pts]
+                self.video_pts.extend(batch_pts)
+                self.video_fps.extend(batch_fps)
+
+    def _init_from_metadata(self, metadata: Dict[str, Any]) -> None:
+        self.video_paths = metadata["video_paths"]
+        assert len(self.video_paths) == len(metadata["video_pts"])
+        self.video_pts = metadata["video_pts"]
+        assert len(self.video_paths) == len(metadata["video_fps"])
+        self.video_fps = metadata["video_fps"]
+
+    @property
+    def metadata(self) -> Dict[str, Any]:
+        _metadata = {
+            "video_paths": self.video_paths,
+            "video_pts": self.video_pts,
+            "video_fps": self.video_fps,
+        }
+        return _metadata
+
+    def subset(self, indices: List[int]) -> "VideoClips":
+        video_paths = [self.video_paths[i] for i in indices]
+        video_pts = [self.video_pts[i] for i in indices]
+        video_fps = [self.video_fps[i] for i in indices]
+        metadata = {
+            "video_paths": video_paths,
+            "video_pts": video_pts,
+            "video_fps": video_fps,
+        }
+        return type(self)(
+            video_paths,
+            clip_length_in_frames=self.num_frames,
+            frames_between_clips=self.step,
+            frame_rate=self.frame_rate,
+            _precomputed_metadata=metadata,
+            num_workers=self.num_workers,
+            _video_width=self._video_width,
+            _video_height=self._video_height,
+            _video_min_dimension=self._video_min_dimension,
+            _video_max_dimension=self._video_max_dimension,
+            _audio_samples=self._audio_samples,
+            _audio_channels=self._audio_channels,
+            output_format=self.output_format,
+        )
+
+    @staticmethod
+    def compute_clips_for_video(
+        video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None
+    ) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]:
+        if fps is None:
+            # if for some reason the video doesn't have fps (because doesn't have a video stream)
+            # set the fps to 1. The value doesn't matter, because video_pts is empty anyway
+            fps = 1
+        if frame_rate is None:
+            frame_rate = fps
+        total_frames = len(video_pts) * (float(frame_rate) / fps)
+        _idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
+        video_pts = video_pts[_idxs]
+        clips = unfold(video_pts, num_frames, step)
+        if not clips.numel():
+            warnings.warn(
+                "There aren't enough frames in the current video to get a clip for the given clip length and "
+                "frames between clips. The video (and potentially others) will be skipped."
+            )
+        idxs: Union[List[slice], torch.Tensor]
+        if isinstance(_idxs, slice):
+            idxs = [_idxs] * len(clips)
+        else:
+            idxs = unfold(_idxs, num_frames, step)
+        return clips, idxs
+
+    def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None:
+        """
+        Compute all consecutive sequences of clips from video_pts.
+        Always returns clips of size `num_frames`, meaning that the
+        last few frames in a video can potentially be dropped.
+
+        Args:
+            num_frames (int): number of frames for the clip
+            step (int): distance between two clips
+            frame_rate (int, optional): The frame rate
+        """
+        self.num_frames = num_frames
+        self.step = step
+        self.frame_rate = frame_rate
+        self.clips = []
+        self.resampling_idxs = []
+        for video_pts, fps in zip(self.video_pts, self.video_fps):
+            clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
+            self.clips.append(clips)
+            self.resampling_idxs.append(idxs)
+        clip_lengths = torch.as_tensor([len(v) for v in self.clips])
+        self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
+
+    def __len__(self) -> int:
+        return self.num_clips()
+
+    def num_videos(self) -> int:
+        return len(self.video_paths)
+
+    def num_clips(self) -> int:
+        """
+        Number of subclips that are available in the video list.
+        """
+        return self.cumulative_sizes[-1]
+
+    def get_clip_location(self, idx: int) -> Tuple[int, int]:
+        """
+        Converts a flattened representation of the indices into a video_idx, clip_idx
+        representation.
+        """
+        video_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+        if video_idx == 0:
+            clip_idx = idx
+        else:
+            clip_idx = idx - self.cumulative_sizes[video_idx - 1]
+        return video_idx, clip_idx
+
+    @staticmethod
+    def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]:
+        step = float(original_fps) / new_fps
+        if step.is_integer():
+            # optimization: if step is integer, don't need to perform
+            # advanced indexing
+            step = int(step)
+            return slice(None, None, step)
+        idxs = torch.arange(num_frames, dtype=torch.float32) * step
+        idxs = idxs.floor().to(torch.int64)
+        return idxs
+
+    def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]:
+        """
+        Gets a subclip from a list of videos.
+
+        Args:
+            idx (int): index of the subclip. Must be between 0 and num_clips().
+
+        Returns:
+            video (Tensor)
+            audio (Tensor)
+            info (Dict)
+            video_idx (int): index of the video in `video_paths`
+        """
+        if idx >= self.num_clips():
+            raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)")
+        video_idx, clip_idx = self.get_clip_location(idx)
+        video_path = self.video_paths[video_idx]
+        clip_pts = self.clips[video_idx][clip_idx]
+
+        from torchvision import get_video_backend
+
+        backend = get_video_backend()
+
+        if backend == "pyav":
+            # check for invalid options
+            if self._video_width != 0:
+                raise ValueError("pyav backend doesn't support _video_width != 0")
+            if self._video_height != 0:
+                raise ValueError("pyav backend doesn't support _video_height != 0")
+            if self._video_min_dimension != 0:
+                raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
+            if self._video_max_dimension != 0:
+                raise ValueError("pyav backend doesn't support _video_max_dimension != 0")
+            if self._audio_samples != 0:
+                raise ValueError("pyav backend doesn't support _audio_samples != 0")
+
+        if backend == "pyav":
+            start_pts = clip_pts[0].item()
+            end_pts = clip_pts[-1].item()
+            video, audio, info = read_video(video_path, start_pts, end_pts)
+        else:
+            _info = _probe_video_from_file(video_path)
+            video_fps = _info.video_fps
+            audio_fps = None
+
+            video_start_pts = cast(int, clip_pts[0].item())
+            video_end_pts = cast(int, clip_pts[-1].item())
+
+            audio_start_pts, audio_end_pts = 0, -1
+            audio_timebase = Fraction(0, 1)
+            video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator)
+            if _info.has_audio:
+                audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator)
+                audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
+                audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
+                audio_fps = _info.audio_sample_rate
+            video, audio, _ = _read_video_from_file(
+                video_path,
+                video_width=self._video_width,
+                video_height=self._video_height,
+                video_min_dimension=self._video_min_dimension,
+                video_max_dimension=self._video_max_dimension,
+                video_pts_range=(video_start_pts, video_end_pts),
+                video_timebase=video_timebase,
+                audio_samples=self._audio_samples,
+                audio_channels=self._audio_channels,
+                audio_pts_range=(audio_start_pts, audio_end_pts),
+                audio_timebase=audio_timebase,
+            )
+
+            info = {"video_fps": video_fps}
+            if audio_fps is not None:
+                info["audio_fps"] = audio_fps
+
+        if self.frame_rate is not None:
+            resampling_idx = self.resampling_idxs[video_idx][clip_idx]
+            if isinstance(resampling_idx, torch.Tensor):
+                resampling_idx = resampling_idx - resampling_idx[0]
+            video = video[resampling_idx]
+            info["video_fps"] = self.frame_rate
+        assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
+
+        if self.output_format == "TCHW":
+            # [T,H,W,C] --> [T,C,H,W]
+            video = video.permute(0, 3, 1, 2)
+
+        return video, audio, info, video_idx
+
+    def __getstate__(self) -> Dict[str, Any]:
+        video_pts_sizes = [len(v) for v in self.video_pts]
+        # To be back-compatible, we convert data to dtype torch.long as needed
+        # because for empty list, in legacy implementation, torch.as_tensor will
+        # use torch.float as default dtype. This happens when decoding fails and
+        # no pts is returned in the list.
+        video_pts = [x.to(torch.int64) for x in self.video_pts]
+        # video_pts can be an empty list if no frames have been decoded
+        if video_pts:
+            video_pts = torch.cat(video_pts)  # type: ignore[assignment]
+            # avoid bug in https://github.com/pytorch/pytorch/issues/32351
+            # TODO: Revert it once the bug is fixed.
+            video_pts = video_pts.numpy()  # type: ignore[attr-defined]
+
+        # make a copy of the fields of self
+        d = self.__dict__.copy()
+        d["video_pts_sizes"] = video_pts_sizes
+        d["video_pts"] = video_pts
+        # delete the following attributes to reduce the size of dictionary. They
+        # will be re-computed in "__setstate__()"
+        del d["clips"]
+        del d["resampling_idxs"]
+        del d["cumulative_sizes"]
+
+        # for backwards-compatibility
+        d["_version"] = 2
+        return d
+
+    def __setstate__(self, d: Dict[str, Any]) -> None:
+        # for backwards-compatibility
+        if "_version" not in d:
+            self.__dict__ = d
+            return
+
+        video_pts = torch.as_tensor(d["video_pts"], dtype=torch.int64)
+        video_pts = torch.split(video_pts, d["video_pts_sizes"], dim=0)
+        # don't need this info anymore
+        del d["video_pts_sizes"]
+
+        d["video_pts"] = video_pts
+        self.__dict__ = d
+        # recompute attributes "clips", "resampling_idxs" and other derivative ones
+        self.compute_clips(self.num_frames, self.step, self.frame_rate)

+ 110 - 0
libs/vision_libs/datasets/vision.py

@@ -0,0 +1,110 @@
+import os
+from typing import Any, Callable, List, Optional, Tuple
+
+import torch.utils.data as data
+
+from ..utils import _log_api_usage_once
+
+
+class VisionDataset(data.Dataset):
+    """
+    Base Class For making datasets which are compatible with torchvision.
+    It is necessary to override the ``__getitem__`` and ``__len__`` method.
+
+    Args:
+        root (string, optional): Root directory of dataset. Only used for `__repr__`.
+        transforms (callable, optional): A function/transforms that takes in
+            an image and a label and returns the transformed versions of both.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+
+    .. note::
+
+        :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
+    """
+
+    _repr_indent = 4
+
+    def __init__(
+        self,
+        root: str = None,  # type: ignore[assignment]
+        transforms: Optional[Callable] = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        _log_api_usage_once(self)
+        if isinstance(root, str):
+            root = os.path.expanduser(root)
+        self.root = root
+
+        has_transforms = transforms is not None
+        has_separate_transform = transform is not None or target_transform is not None
+        if has_transforms and has_separate_transform:
+            raise ValueError("Only transforms or transform/target_transform can be passed as argument")
+
+        # for backwards-compatibility
+        self.transform = transform
+        self.target_transform = target_transform
+
+        if has_separate_transform:
+            transforms = StandardTransform(transform, target_transform)
+        self.transforms = transforms
+
+    def __getitem__(self, index: int) -> Any:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            (Any): Sample and meta data, optionally transformed by the respective transforms.
+        """
+        raise NotImplementedError
+
+    def __len__(self) -> int:
+        raise NotImplementedError
+
+    def __repr__(self) -> str:
+        head = "Dataset " + self.__class__.__name__
+        body = [f"Number of datapoints: {self.__len__()}"]
+        if self.root is not None:
+            body.append(f"Root location: {self.root}")
+        body += self.extra_repr().splitlines()
+        if hasattr(self, "transforms") and self.transforms is not None:
+            body += [repr(self.transforms)]
+        lines = [head] + [" " * self._repr_indent + line for line in body]
+        return "\n".join(lines)
+
+    def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
+        lines = transform.__repr__().splitlines()
+        return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
+
+    def extra_repr(self) -> str:
+        return ""
+
+
+class StandardTransform:
+    def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
+        self.transform = transform
+        self.target_transform = target_transform
+
+    def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
+        if self.transform is not None:
+            input = self.transform(input)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+        return input, target
+
+    def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
+        lines = transform.__repr__().splitlines()
+        return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
+
+    def __repr__(self) -> str:
+        body = [self.__class__.__name__]
+        if self.transform is not None:
+            body += self._format_transform_repr(self.transform, "Transform: ")
+        if self.target_transform is not None:
+            body += self._format_transform_repr(self.target_transform, "Target transform: ")
+
+        return "\n".join(body)

+ 224 - 0
libs/vision_libs/datasets/voc.py

@@ -0,0 +1,224 @@
+import collections
+import os
+from xml.etree.ElementTree import Element as ET_Element
+
+from .vision import VisionDataset
+
+try:
+    from defusedxml.ElementTree import parse as ET_parse
+except ImportError:
+    from xml.etree.ElementTree import parse as ET_parse
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+
+DATASET_YEAR_DICT = {
+    "2012": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
+        "filename": "VOCtrainval_11-May-2012.tar",
+        "md5": "6cd6e144f989b92b3379bac3b3de84fd",
+        "base_dir": os.path.join("VOCdevkit", "VOC2012"),
+    },
+    "2011": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
+        "filename": "VOCtrainval_25-May-2011.tar",
+        "md5": "6c3384ef61512963050cb5d687e5bf1e",
+        "base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"),
+    },
+    "2010": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
+        "filename": "VOCtrainval_03-May-2010.tar",
+        "md5": "da459979d0c395079b5c75ee67908abb",
+        "base_dir": os.path.join("VOCdevkit", "VOC2010"),
+    },
+    "2009": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
+        "filename": "VOCtrainval_11-May-2009.tar",
+        "md5": "a3e00b113cfcfebf17e343f59da3caa1",
+        "base_dir": os.path.join("VOCdevkit", "VOC2009"),
+    },
+    "2008": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
+        "filename": "VOCtrainval_11-May-2012.tar",
+        "md5": "2629fa636546599198acfcfbfcf1904a",
+        "base_dir": os.path.join("VOCdevkit", "VOC2008"),
+    },
+    "2007": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
+        "filename": "VOCtrainval_06-Nov-2007.tar",
+        "md5": "c52e279531787c972589f7e41ab4ae64",
+        "base_dir": os.path.join("VOCdevkit", "VOC2007"),
+    },
+    "2007-test": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
+        "filename": "VOCtest_06-Nov-2007.tar",
+        "md5": "b6e924de25625d8de591ea690078ad9f",
+        "base_dir": os.path.join("VOCdevkit", "VOC2007"),
+    },
+}
+
+
+class _VOCBase(VisionDataset):
+    _SPLITS_DIR: str
+    _TARGET_DIR: str
+    _TARGET_FILE_EXT: str
+
+    def __init__(
+        self,
+        root: str,
+        year: str = "2012",
+        image_set: str = "train",
+        download: bool = False,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        transforms: Optional[Callable] = None,
+    ):
+        super().__init__(root, transforms, transform, target_transform)
+
+        self.year = verify_str_arg(year, "year", valid_values=[str(yr) for yr in range(2007, 2013)])
+
+        valid_image_sets = ["train", "trainval", "val"]
+        if year == "2007":
+            valid_image_sets.append("test")
+        self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
+
+        key = "2007-test" if year == "2007" and image_set == "test" else year
+        dataset_year_dict = DATASET_YEAR_DICT[key]
+
+        self.url = dataset_year_dict["url"]
+        self.filename = dataset_year_dict["filename"]
+        self.md5 = dataset_year_dict["md5"]
+
+        base_dir = dataset_year_dict["base_dir"]
+        voc_root = os.path.join(self.root, base_dir)
+
+        if download:
+            download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
+
+        if not os.path.isdir(voc_root):
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
+        split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
+        with open(os.path.join(split_f)) as f:
+            file_names = [x.strip() for x in f.readlines()]
+
+        image_dir = os.path.join(voc_root, "JPEGImages")
+        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+
+        target_dir = os.path.join(voc_root, self._TARGET_DIR)
+        self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
+
+        assert len(self.images) == len(self.targets)
+
+    def __len__(self) -> int:
+        return len(self.images)
+
+
+class VOCSegmentation(_VOCBase):
+    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
+
+    Args:
+        root (string): Root directory of the VOC Dataset.
+        year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
+        image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
+            ``year=="2007"``, can also be ``"test"``.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+    """
+
+    _SPLITS_DIR = "Segmentation"
+    _TARGET_DIR = "SegmentationClass"
+    _TARGET_FILE_EXT = ".png"
+
+    @property
+    def masks(self) -> List[str]:
+        return self.targets
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is the image segmentation.
+        """
+        img = Image.open(self.images[index]).convert("RGB")
+        target = Image.open(self.masks[index])
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+
+class VOCDetection(_VOCBase):
+    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
+
+    Args:
+        root (string): Root directory of the VOC Dataset.
+        year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
+        image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
+            ``year=="2007"``, can also be ``"test"``.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+            (default: alphabetic indexing of VOC's 20 classes).
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, required): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+    """
+
+    _SPLITS_DIR = "Main"
+    _TARGET_DIR = "Annotations"
+    _TARGET_FILE_EXT = ".xml"
+
+    @property
+    def annotations(self) -> List[str]:
+        return self.targets
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is a dictionary of the XML tree.
+        """
+        img = Image.open(self.images[index]).convert("RGB")
+        target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+    @staticmethod
+    def parse_voc_xml(node: ET_Element) -> Dict[str, Any]:
+        voc_dict: Dict[str, Any] = {}
+        children = list(node)
+        if children:
+            def_dic: Dict[str, Any] = collections.defaultdict(list)
+            for dc in map(VOCDetection.parse_voc_xml, children):
+                for ind, v in dc.items():
+                    def_dic[ind].append(v)
+            if node.tag == "annotation":
+                def_dic["object"] = [def_dic["object"]]
+            voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
+        if node.text:
+            text = node.text.strip()
+            if not children:
+                voc_dict[node.tag] = text
+        return voc_dict

+ 195 - 0
libs/vision_libs/datasets/widerface.py

@@ -0,0 +1,195 @@
+import os
+from os.path import abspath, expanduser
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+
+from .utils import download_and_extract_archive, download_file_from_google_drive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class WIDERFace(VisionDataset):
+    """`WIDERFace <http://shuoyang1213.me/WIDERFACE/>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images and annotations are downloaded to.
+            Expects the following folder structure if download=False:
+
+            .. code::
+
+                <root>
+                    └── widerface
+                        ├── wider_face_split ('wider_face_split.zip' if compressed)
+                        ├── WIDER_train ('WIDER_train.zip' if compressed)
+                        ├── WIDER_val ('WIDER_val.zip' if compressed)
+                        └── WIDER_test ('WIDER_test.zip' if compressed)
+        split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
+            Defaults to ``train``.
+        transform (callable, optional): A function/transform that  takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+            .. warning::
+
+                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
+
+    """
+
+    BASE_FOLDER = "widerface"
+    FILE_LIST = [
+        # File ID                             MD5 Hash                            Filename
+        ("15hGDLhsx8bLgLcIRD5DhYt5iBxnjNF1M", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
+        ("1GUCogbp16PMGa39thoMMeWxp7Rp5oM8Q", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
+        ("1HIfDbVEWKmsYKJZm4lchTBDLW5N7dY5T", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"),
+    ]
+    ANNOTATIONS_FILE = (
+        "http://shuoyang1213.me/WIDERFACE/support/bbx_annotation/wider_face_split.zip",
+        "0e3767bcf0e326556d407bf5bff5d27c",
+        "wider_face_split.zip",
+    )
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(
+            root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
+        )
+        # check arguments
+        self.split = verify_str_arg(split, "split", ("train", "val", "test"))
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download and prepare it")
+
+        self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
+        if self.split in ("train", "val"):
+            self.parse_train_val_annotations_file()
+        else:
+            self.parse_test_annotations_file()
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is a dict of annotations for all faces in the image.
+            target=None for the test split.
+        """
+
+        # stay consistent with other datasets and return a PIL Image
+        img = Image.open(self.img_info[index]["img_path"])
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        target = None if self.split == "test" else self.img_info[index]["annotations"]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.img_info)
+
+    def extra_repr(self) -> str:
+        lines = ["Split: {split}"]
+        return "\n".join(lines).format(**self.__dict__)
+
+    def parse_train_val_annotations_file(self) -> None:
+        filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
+        filepath = os.path.join(self.root, "wider_face_split", filename)
+
+        with open(filepath) as f:
+            lines = f.readlines()
+            file_name_line, num_boxes_line, box_annotation_line = True, False, False
+            num_boxes, box_counter = 0, 0
+            labels = []
+            for line in lines:
+                line = line.rstrip()
+                if file_name_line:
+                    img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
+                    img_path = abspath(expanduser(img_path))
+                    file_name_line = False
+                    num_boxes_line = True
+                elif num_boxes_line:
+                    num_boxes = int(line)
+                    num_boxes_line = False
+                    box_annotation_line = True
+                elif box_annotation_line:
+                    box_counter += 1
+                    line_split = line.split(" ")
+                    line_values = [int(x) for x in line_split]
+                    labels.append(line_values)
+                    if box_counter >= num_boxes:
+                        box_annotation_line = False
+                        file_name_line = True
+                        labels_tensor = torch.tensor(labels)
+                        self.img_info.append(
+                            {
+                                "img_path": img_path,
+                                "annotations": {
+                                    "bbox": labels_tensor[:, 0:4].clone(),  # x, y, width, height
+                                    "blur": labels_tensor[:, 4].clone(),
+                                    "expression": labels_tensor[:, 5].clone(),
+                                    "illumination": labels_tensor[:, 6].clone(),
+                                    "occlusion": labels_tensor[:, 7].clone(),
+                                    "pose": labels_tensor[:, 8].clone(),
+                                    "invalid": labels_tensor[:, 9].clone(),
+                                },
+                            }
+                        )
+                        box_counter = 0
+                        labels.clear()
+                else:
+                    raise RuntimeError(f"Error parsing annotation file {filepath}")
+
+    def parse_test_annotations_file(self) -> None:
+        filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
+        filepath = abspath(expanduser(filepath))
+        with open(filepath) as f:
+            lines = f.readlines()
+            for line in lines:
+                line = line.rstrip()
+                img_path = os.path.join(self.root, "WIDER_test", "images", line)
+                img_path = abspath(expanduser(img_path))
+                self.img_info.append({"img_path": img_path})
+
+    def _check_integrity(self) -> bool:
+        # Allow original archive to be deleted (zip). Only need the extracted images
+        all_files = self.FILE_LIST.copy()
+        all_files.append(self.ANNOTATIONS_FILE)
+        for (_, md5, filename) in all_files:
+            file, ext = os.path.splitext(filename)
+            extracted_dir = os.path.join(self.root, file)
+            if not os.path.exists(extracted_dir):
+                return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        # download and extract image data
+        for (file_id, md5, filename) in self.FILE_LIST:
+            download_file_from_google_drive(file_id, self.root, filename, md5)
+            filepath = os.path.join(self.root, filename)
+            extract_archive(filepath)
+
+        # download and extract annotation files
+        download_and_extract_archive(
+            url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1]
+        )

+ 92 - 0
libs/vision_libs/extension.py

@@ -0,0 +1,92 @@
+import os
+import sys
+
+import torch
+
+from ._internally_replaced_utils import _get_extension_path
+
+
+_HAS_OPS = False
+
+
+def _has_ops():
+    return False
+
+
+try:
+    # On Windows Python-3.8.x has `os.add_dll_directory` call,
+    # which is called to configure dll search path.
+    # To find cuda related dlls we need to make sure the
+    # conda environment/bin path is configured Please take a look:
+    # https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python
+    # Please note: if some path can't be added using add_dll_directory we simply ignore this path
+    if os.name == "nt" and sys.version_info < (3, 9):
+        env_path = os.environ["PATH"]
+        path_arr = env_path.split(";")
+        for path in path_arr:
+            if os.path.exists(path):
+                try:
+                    os.add_dll_directory(path)  # type: ignore[attr-defined]
+                except Exception:
+                    pass
+
+    lib_path = _get_extension_path("_C")
+    torch.ops.load_library(lib_path)
+    _HAS_OPS = True
+
+    def _has_ops():  # noqa: F811
+        return True
+
+except (ImportError, OSError):
+    pass
+
+
+def _assert_has_ops():
+    if not _has_ops():
+        raise RuntimeError(
+            "Couldn't load custom C++ ops. This can happen if your PyTorch and "
+            "torchvision versions are incompatible, or if you had errors while compiling "
+            "torchvision from source. For further information on the compatible versions, check "
+            "https://github.com/pytorch/vision#installation for the compatibility matrix. "
+            "Please check your PyTorch version with torch.__version__ and your torchvision "
+            "version with torchvision.__version__ and verify if they are compatible, and if not "
+            "please reinstall torchvision so that it matches your PyTorch install."
+        )
+
+
+def _check_cuda_version():
+    """
+    Make sure that CUDA versions match between the pytorch install and torchvision install
+    """
+    if not _HAS_OPS:
+        return -1
+    from torch.version import cuda as torch_version_cuda
+
+    _version = torch.ops.torchvision._cuda_version()
+    if _version != -1 and torch_version_cuda is not None:
+        tv_version = str(_version)
+        if int(tv_version) < 10000:
+            tv_major = int(tv_version[0])
+            tv_minor = int(tv_version[2])
+        else:
+            tv_major = int(tv_version[0:2])
+            tv_minor = int(tv_version[3])
+        t_version = torch_version_cuda.split(".")
+        t_major = int(t_version[0])
+        t_minor = int(t_version[1])
+        if t_major != tv_major:
+            raise RuntimeError(
+                "Detected that PyTorch and torchvision were compiled with different CUDA major versions. "
+                f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
+                f"CUDA Version={tv_major}.{tv_minor}. "
+                "Please reinstall the torchvision that matches your PyTorch install."
+            )
+    return _version
+
+
+def _load_library(lib_name):
+    lib_path = _get_extension_path(lib_name)
+    torch.ops.load_library(lib_path)
+
+
+_check_cuda_version()

二进制
libs/vision_libs/image.pyd


+ 69 - 0
libs/vision_libs/io/__init__.py

@@ -0,0 +1,69 @@
+from typing import Any, Dict, Iterator
+
+import torch
+
+from ..utils import _log_api_usage_once
+
+try:
+    from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
+except ModuleNotFoundError:
+    _HAS_GPU_VIDEO_DECODER = False
+
+from ._video_opt import (
+    _HAS_VIDEO_OPT,
+    _probe_video_from_file,
+    _probe_video_from_memory,
+    _read_video_from_file,
+    _read_video_from_memory,
+    _read_video_timestamps_from_file,
+    _read_video_timestamps_from_memory,
+    Timebase,
+    VideoMetaData,
+)
+from .image import (
+    decode_image,
+    decode_jpeg,
+    decode_png,
+    encode_jpeg,
+    encode_png,
+    ImageReadMode,
+    read_file,
+    read_image,
+    write_file,
+    write_jpeg,
+    write_png,
+)
+from .video import read_video, read_video_timestamps, write_video
+from .video_reader import VideoReader
+
+
+__all__ = [
+    "write_video",
+    "read_video",
+    "read_video_timestamps",
+    "_read_video_from_file",
+    "_read_video_timestamps_from_file",
+    "_probe_video_from_file",
+    "_read_video_from_memory",
+    "_read_video_timestamps_from_memory",
+    "_probe_video_from_memory",
+    "_HAS_VIDEO_OPT",
+    "_HAS_GPU_VIDEO_DECODER",
+    "_read_video_clip_from_memory",
+    "_read_video_meta_data",
+    "VideoMetaData",
+    "Timebase",
+    "ImageReadMode",
+    "decode_image",
+    "decode_jpeg",
+    "decode_png",
+    "encode_jpeg",
+    "encode_png",
+    "read_file",
+    "read_image",
+    "write_file",
+    "write_jpeg",
+    "write_png",
+    "Video",
+    "VideoReader",
+]

+ 8 - 0
libs/vision_libs/io/_load_gpu_decoder.py

@@ -0,0 +1,8 @@
+from ..extension import _load_library
+
+
+try:
+    _load_library("Decoder")
+    _HAS_GPU_VIDEO_DECODER = True
+except (ImportError, OSError):
+    _HAS_GPU_VIDEO_DECODER = False

+ 512 - 0
libs/vision_libs/io/_video_opt.py

@@ -0,0 +1,512 @@
+import math
+import warnings
+from fractions import Fraction
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..extension import _load_library
+
+
+try:
+    _load_library("video_reader")
+    _HAS_VIDEO_OPT = True
+except (ImportError, OSError):
+    _HAS_VIDEO_OPT = False
+
+default_timebase = Fraction(0, 1)
+
+
+# simple class for torch scripting
+# the complex Fraction class from fractions module is not scriptable
+class Timebase:
+    __annotations__ = {"numerator": int, "denominator": int}
+    __slots__ = ["numerator", "denominator"]
+
+    def __init__(
+        self,
+        numerator: int,
+        denominator: int,
+    ) -> None:
+        self.numerator = numerator
+        self.denominator = denominator
+
+
+class VideoMetaData:
+    __annotations__ = {
+        "has_video": bool,
+        "video_timebase": Timebase,
+        "video_duration": float,
+        "video_fps": float,
+        "has_audio": bool,
+        "audio_timebase": Timebase,
+        "audio_duration": float,
+        "audio_sample_rate": float,
+    }
+    __slots__ = [
+        "has_video",
+        "video_timebase",
+        "video_duration",
+        "video_fps",
+        "has_audio",
+        "audio_timebase",
+        "audio_duration",
+        "audio_sample_rate",
+    ]
+
+    def __init__(self) -> None:
+        self.has_video = False
+        self.video_timebase = Timebase(0, 1)
+        self.video_duration = 0.0
+        self.video_fps = 0.0
+        self.has_audio = False
+        self.audio_timebase = Timebase(0, 1)
+        self.audio_duration = 0.0
+        self.audio_sample_rate = 0.0
+
+
+def _validate_pts(pts_range: Tuple[int, int]) -> None:
+
+    if pts_range[0] > pts_range[1] > 0:
+        raise ValueError(
+            f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}"
+        )
+
+
+def _fill_info(
+    vtimebase: torch.Tensor,
+    vfps: torch.Tensor,
+    vduration: torch.Tensor,
+    atimebase: torch.Tensor,
+    asample_rate: torch.Tensor,
+    aduration: torch.Tensor,
+) -> VideoMetaData:
+    """
+    Build update VideoMetaData struct with info about the video
+    """
+    meta = VideoMetaData()
+    if vtimebase.numel() > 0:
+        meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item()))
+        timebase = vtimebase[0].item() / float(vtimebase[1].item())
+        if vduration.numel() > 0:
+            meta.has_video = True
+            meta.video_duration = float(vduration.item()) * timebase
+    if vfps.numel() > 0:
+        meta.video_fps = float(vfps.item())
+    if atimebase.numel() > 0:
+        meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item()))
+        timebase = atimebase[0].item() / float(atimebase[1].item())
+        if aduration.numel() > 0:
+            meta.has_audio = True
+            meta.audio_duration = float(aduration.item()) * timebase
+    if asample_rate.numel() > 0:
+        meta.audio_sample_rate = float(asample_rate.item())
+
+    return meta
+
+
+def _align_audio_frames(
+    aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int]
+) -> torch.Tensor:
+    start, end = aframe_pts[0], aframe_pts[-1]
+    num_samples = aframes.size(0)
+    step_per_aframe = float(end - start + 1) / float(num_samples)
+    s_idx = 0
+    e_idx = num_samples
+    if start < audio_pts_range[0]:
+        s_idx = int((audio_pts_range[0] - start) / step_per_aframe)
+    if audio_pts_range[1] != -1 and end > audio_pts_range[1]:
+        e_idx = int((audio_pts_range[1] - end) / step_per_aframe)
+    return aframes[s_idx:e_idx, :]
+
+
+def _read_video_from_file(
+    filename: str,
+    seek_frame_margin: float = 0.25,
+    read_video_stream: bool = True,
+    video_width: int = 0,
+    video_height: int = 0,
+    video_min_dimension: int = 0,
+    video_max_dimension: int = 0,
+    video_pts_range: Tuple[int, int] = (0, -1),
+    video_timebase: Fraction = default_timebase,
+    read_audio_stream: bool = True,
+    audio_samples: int = 0,
+    audio_channels: int = 0,
+    audio_pts_range: Tuple[int, int] = (0, -1),
+    audio_timebase: Fraction = default_timebase,
+) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
+    """
+    Reads a video from a file, returning both the video frames and the audio frames
+
+    Args:
+    filename (str): path to the video file
+    seek_frame_margin (double, optional): seeking frame in the stream is imprecise. Thus,
+        when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
+    read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
+    video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
+        the size of decoded frames:
+
+            - When video_width = 0, video_height = 0, video_min_dimension = 0,
+                and video_max_dimension = 0, keep the original frame resolution
+            - When video_width = 0, video_height = 0, video_min_dimension != 0,
+                and video_max_dimension = 0, keep the aspect ratio and resize the
+                frame so that shorter edge size is video_min_dimension
+            - When video_width = 0, video_height = 0, video_min_dimension = 0,
+                and video_max_dimension != 0, keep the aspect ratio and resize
+                the frame so that longer edge size is video_max_dimension
+            - When video_width = 0, video_height = 0, video_min_dimension != 0,
+                and video_max_dimension != 0, resize the frame so that shorter
+                edge size is video_min_dimension, and longer edge size is
+                video_max_dimension. The aspect ratio may not be preserved
+            - When video_width = 0, video_height != 0, video_min_dimension = 0,
+                and video_max_dimension = 0, keep the aspect ratio and resize
+                the frame so that frame video_height is $video_height
+            - When video_width != 0, video_height == 0, video_min_dimension = 0,
+                and video_max_dimension = 0, keep the aspect ratio and resize
+                the frame so that frame video_width is $video_width
+            - When video_width != 0, video_height != 0, video_min_dimension = 0,
+                and video_max_dimension = 0, resize the frame so that frame
+                video_width and  video_height are set to $video_width and
+                $video_height, respectively
+    video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
+    video_timebase (Fraction, optional): a Fraction rational number which denotes timebase in video stream
+    read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
+    audio_samples (int, optional): audio sampling rate
+    audio_channels (int optional): audio channels
+    audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
+    audio_timebase (Fraction, optional): a Fraction rational number which denotes time base in audio stream
+
+    Returns
+        vframes (Tensor[T, H, W, C]): the `T` video frames
+        aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
+            `K` is the number of audio_channels
+        info (Dict): metadata for the video and audio. Can contain the fields video_fps (float)
+            and audio_fps (int)
+    """
+    _validate_pts(video_pts_range)
+    _validate_pts(audio_pts_range)
+
+    result = torch.ops.video_reader.read_video_from_file(
+        filename,
+        seek_frame_margin,
+        0,  # getPtsOnly
+        read_video_stream,
+        video_width,
+        video_height,
+        video_min_dimension,
+        video_max_dimension,
+        video_pts_range[0],
+        video_pts_range[1],
+        video_timebase.numerator,
+        video_timebase.denominator,
+        read_audio_stream,
+        audio_samples,
+        audio_channels,
+        audio_pts_range[0],
+        audio_pts_range[1],
+        audio_timebase.numerator,
+        audio_timebase.denominator,
+    )
+    vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
+    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+    if aframes.numel() > 0:
+        # when audio stream is found
+        aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
+    return vframes, aframes, info
+
+
+def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]:
+    """
+    Decode all video- and audio frames in the video. Only pts
+    (presentation timestamp) is returned. The actual frame pixel data is not
+    copied. Thus, it is much faster than read_video(...)
+    """
+    result = torch.ops.video_reader.read_video_from_file(
+        filename,
+        0,  # seek_frame_margin
+        1,  # getPtsOnly
+        1,  # read_video_stream
+        0,  # video_width
+        0,  # video_height
+        0,  # video_min_dimension
+        0,  # video_max_dimension
+        0,  # video_start_pts
+        -1,  # video_end_pts
+        0,  # video_timebase_num
+        1,  # video_timebase_den
+        1,  # read_audio_stream
+        0,  # audio_samples
+        0,  # audio_channels
+        0,  # audio_start_pts
+        -1,  # audio_end_pts
+        0,  # audio_timebase_num
+        1,  # audio_timebase_den
+    )
+    _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
+    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+
+    vframe_pts = vframe_pts.numpy().tolist()
+    aframe_pts = aframe_pts.numpy().tolist()
+    return vframe_pts, aframe_pts, info
+
+
+def _probe_video_from_file(filename: str) -> VideoMetaData:
+    """
+    Probe a video file and return VideoMetaData with info about the video
+    """
+    result = torch.ops.video_reader.probe_video_from_file(filename)
+    vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
+    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+    return info
+
+
+def _read_video_from_memory(
+    video_data: torch.Tensor,
+    seek_frame_margin: float = 0.25,
+    read_video_stream: int = 1,
+    video_width: int = 0,
+    video_height: int = 0,
+    video_min_dimension: int = 0,
+    video_max_dimension: int = 0,
+    video_pts_range: Tuple[int, int] = (0, -1),
+    video_timebase_numerator: int = 0,
+    video_timebase_denominator: int = 1,
+    read_audio_stream: int = 1,
+    audio_samples: int = 0,
+    audio_channels: int = 0,
+    audio_pts_range: Tuple[int, int] = (0, -1),
+    audio_timebase_numerator: int = 0,
+    audio_timebase_denominator: int = 1,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Reads a video from memory, returning both the video frames as the audio frames
+    This function is torchscriptable.
+
+    Args:
+    video_data (data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes):
+        compressed video content stored in either 1) torch.Tensor 2) python bytes
+    seek_frame_margin (double, optional): seeking frame in the stream is imprecise.
+        Thus, when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
+    read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
+    video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
+        the size of decoded frames:
+
+            - When video_width = 0, video_height = 0, video_min_dimension = 0,
+                and video_max_dimension = 0, keep the original frame resolution
+            - When video_width = 0, video_height = 0, video_min_dimension != 0,
+                and video_max_dimension = 0, keep the aspect ratio and resize the
+                frame so that shorter edge size is video_min_dimension
+            - When video_width = 0, video_height = 0, video_min_dimension = 0,
+                and video_max_dimension != 0, keep the aspect ratio and resize
+                the frame so that longer edge size is video_max_dimension
+            - When video_width = 0, video_height = 0, video_min_dimension != 0,
+                and video_max_dimension != 0, resize the frame so that shorter
+                edge size is video_min_dimension, and longer edge size is
+                video_max_dimension. The aspect ratio may not be preserved
+            - When video_width = 0, video_height != 0, video_min_dimension = 0,
+                and video_max_dimension = 0, keep the aspect ratio and resize
+                the frame so that frame video_height is $video_height
+            - When video_width != 0, video_height == 0, video_min_dimension = 0,
+                and video_max_dimension = 0, keep the aspect ratio and resize
+                the frame so that frame video_width is $video_width
+            - When video_width != 0, video_height != 0, video_min_dimension = 0,
+                and video_max_dimension = 0, resize the frame so that frame
+                video_width and  video_height are set to $video_width and
+                $video_height, respectively
+    video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
+    video_timebase_numerator / video_timebase_denominator (float, optional): a rational
+        number which denotes timebase in video stream
+    read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
+    audio_samples (int, optional): audio sampling rate
+    audio_channels (int optional): audio audio_channels
+    audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
+    audio_timebase_numerator / audio_timebase_denominator (float, optional):
+        a rational number which denotes time base in audio stream
+
+    Returns:
+        vframes (Tensor[T, H, W, C]): the `T` video frames
+        aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
+            `K` is the number of channels
+    """
+
+    _validate_pts(video_pts_range)
+    _validate_pts(audio_pts_range)
+
+    if not isinstance(video_data, torch.Tensor):
+        with warnings.catch_warnings():
+            # Ignore the warning because we actually don't modify the buffer in this function
+            warnings.filterwarnings("ignore", message="The given buffer is not writable")
+            video_data = torch.frombuffer(video_data, dtype=torch.uint8)
+
+    result = torch.ops.video_reader.read_video_from_memory(
+        video_data,
+        seek_frame_margin,
+        0,  # getPtsOnly
+        read_video_stream,
+        video_width,
+        video_height,
+        video_min_dimension,
+        video_max_dimension,
+        video_pts_range[0],
+        video_pts_range[1],
+        video_timebase_numerator,
+        video_timebase_denominator,
+        read_audio_stream,
+        audio_samples,
+        audio_channels,
+        audio_pts_range[0],
+        audio_pts_range[1],
+        audio_timebase_numerator,
+        audio_timebase_denominator,
+    )
+
+    vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
+
+    if aframes.numel() > 0:
+        # when audio stream is found
+        aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
+
+    return vframes, aframes
+
+
+def _read_video_timestamps_from_memory(
+    video_data: torch.Tensor,
+) -> Tuple[List[int], List[int], VideoMetaData]:
+    """
+    Decode all frames in the video. Only pts (presentation timestamp) is returned.
+    The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
+    is much faster than read_video(...)
+    """
+    if not isinstance(video_data, torch.Tensor):
+        with warnings.catch_warnings():
+            # Ignore the warning because we actually don't modify the buffer in this function
+            warnings.filterwarnings("ignore", message="The given buffer is not writable")
+            video_data = torch.frombuffer(video_data, dtype=torch.uint8)
+    result = torch.ops.video_reader.read_video_from_memory(
+        video_data,
+        0,  # seek_frame_margin
+        1,  # getPtsOnly
+        1,  # read_video_stream
+        0,  # video_width
+        0,  # video_height
+        0,  # video_min_dimension
+        0,  # video_max_dimension
+        0,  # video_start_pts
+        -1,  # video_end_pts
+        0,  # video_timebase_num
+        1,  # video_timebase_den
+        1,  # read_audio_stream
+        0,  # audio_samples
+        0,  # audio_channels
+        0,  # audio_start_pts
+        -1,  # audio_end_pts
+        0,  # audio_timebase_num
+        1,  # audio_timebase_den
+    )
+    _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
+    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+
+    vframe_pts = vframe_pts.numpy().tolist()
+    aframe_pts = aframe_pts.numpy().tolist()
+    return vframe_pts, aframe_pts, info
+
+
+def _probe_video_from_memory(
+    video_data: torch.Tensor,
+) -> VideoMetaData:
+    """
+    Probe a video in memory and return VideoMetaData with info about the video
+    This function is torchscriptable
+    """
+    if not isinstance(video_data, torch.Tensor):
+        with warnings.catch_warnings():
+            # Ignore the warning because we actually don't modify the buffer in this function
+            warnings.filterwarnings("ignore", message="The given buffer is not writable")
+            video_data = torch.frombuffer(video_data, dtype=torch.uint8)
+    result = torch.ops.video_reader.probe_video_from_memory(video_data)
+    vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
+    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+    return info
+
+
+def _read_video(
+    filename: str,
+    start_pts: Union[float, Fraction] = 0,
+    end_pts: Optional[Union[float, Fraction]] = None,
+    pts_unit: str = "pts",
+) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
+    if end_pts is None:
+        end_pts = float("inf")
+
+    if pts_unit == "pts":
+        warnings.warn(
+            "The pts_unit 'pts' gives wrong results and will be removed in a "
+            + "follow-up version. Please use pts_unit 'sec'."
+        )
+
+    info = _probe_video_from_file(filename)
+
+    has_video = info.has_video
+    has_audio = info.has_audio
+
+    def get_pts(time_base):
+        start_offset = start_pts
+        end_offset = end_pts
+        if pts_unit == "sec":
+            start_offset = int(math.floor(start_pts * (1 / time_base)))
+            if end_offset != float("inf"):
+                end_offset = int(math.ceil(end_pts * (1 / time_base)))
+        if end_offset == float("inf"):
+            end_offset = -1
+        return start_offset, end_offset
+
+    video_pts_range = (0, -1)
+    video_timebase = default_timebase
+    if has_video:
+        video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
+        video_pts_range = get_pts(video_timebase)
+
+    audio_pts_range = (0, -1)
+    audio_timebase = default_timebase
+    if has_audio:
+        audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
+        audio_pts_range = get_pts(audio_timebase)
+
+    vframes, aframes, info = _read_video_from_file(
+        filename,
+        read_video_stream=True,
+        video_pts_range=video_pts_range,
+        video_timebase=video_timebase,
+        read_audio_stream=True,
+        audio_pts_range=audio_pts_range,
+        audio_timebase=audio_timebase,
+    )
+    _info = {}
+    if has_video:
+        _info["video_fps"] = info.video_fps
+    if has_audio:
+        _info["audio_fps"] = info.audio_sample_rate
+
+    return vframes, aframes, _info
+
+
+def _read_video_timestamps(
+    filename: str, pts_unit: str = "pts"
+) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]:
+    if pts_unit == "pts":
+        warnings.warn(
+            "The pts_unit 'pts' gives wrong results and will be removed in a "
+            + "follow-up version. Please use pts_unit 'sec'."
+        )
+
+    pts: Union[List[int], List[Fraction]]
+    pts, _, info = _read_video_timestamps_from_file(filename)
+
+    if pts_unit == "sec":
+        video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
+        pts = [x * video_time_base for x in pts]
+
+    video_fps = info.video_fps if info.has_video else None
+
+    return pts, video_fps

+ 264 - 0
libs/vision_libs/io/image.py

@@ -0,0 +1,264 @@
+from enum import Enum
+from warnings import warn
+
+import torch
+
+from ..extension import _load_library
+from ..utils import _log_api_usage_once
+
+
+try:
+    _load_library("image")
+except (ImportError, OSError) as e:
+    warn(
+        f"Failed to load image Python extension: '{e}'"
+        f"If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. "
+        f"Otherwise, there might be something wrong with your environment. "
+        f"Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?"
+    )
+
+
+class ImageReadMode(Enum):
+    """
+    Support for various modes while reading images.
+
+    Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
+    ``ImageReadMode.GRAY`` for converting to grayscale,
+    ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
+    ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
+    RGB with transparency.
+    """
+
+    UNCHANGED = 0
+    GRAY = 1
+    GRAY_ALPHA = 2
+    RGB = 3
+    RGB_ALPHA = 4
+
+
+def read_file(path: str) -> torch.Tensor:
+    """
+    Reads and outputs the bytes contents of a file as a uint8 Tensor
+    with one dimension.
+
+    Args:
+        path (str): the path to the file to be read
+
+    Returns:
+        data (Tensor)
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(read_file)
+    data = torch.ops.image.read_file(path)
+    return data
+
+
+def write_file(filename: str, data: torch.Tensor) -> None:
+    """
+    Writes the contents of an uint8 tensor with one dimension to a
+    file.
+
+    Args:
+        filename (str): the path to the file to be written
+        data (Tensor): the contents to be written to the output file
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(write_file)
+    torch.ops.image.write_file(filename, data)
+
+
+def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
+    """
+    Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
+    Optionally converts the image to the desired format.
+    The values of the output tensor are uint8 in [0, 255].
+
+    Args:
+        input (Tensor[1]): a one dimensional uint8 tensor containing
+            the raw bytes of the PNG image.
+        mode (ImageReadMode): the read mode used for optionally
+            converting the image. Default: ``ImageReadMode.UNCHANGED``.
+            See `ImageReadMode` class for more information on various
+            available modes.
+
+    Returns:
+        output (Tensor[image_channels, image_height, image_width])
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(decode_png)
+    output = torch.ops.image.decode_png(input, mode.value, False)
+    return output
+
+
+def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
+    """
+    Takes an input tensor in CHW layout and returns a buffer with the contents
+    of its corresponding PNG file.
+
+    Args:
+        input (Tensor[channels, image_height, image_width]): int8 image tensor of
+            ``c`` channels, where ``c`` must 3 or 1.
+        compression_level (int): Compression factor for the resulting file, it must be a number
+            between 0 and 9. Default: 6
+
+    Returns:
+        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
+            PNG file.
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(encode_png)
+    output = torch.ops.image.encode_png(input, compression_level)
+    return output
+
+
+def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
+    """
+    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
+    and saves it in a PNG file.
+
+    Args:
+        input (Tensor[channels, image_height, image_width]): int8 image tensor of
+            ``c`` channels, where ``c`` must be 1 or 3.
+        filename (str): Path to save the image.
+        compression_level (int): Compression factor for the resulting file, it must be a number
+            between 0 and 9. Default: 6
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(write_png)
+    output = encode_png(input, compression_level)
+    write_file(filename, output)
+
+
+def decode_jpeg(
+    input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu"
+) -> torch.Tensor:
+    """
+    Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
+    Optionally converts the image to the desired format.
+    The values of the output tensor are uint8 between 0 and 255.
+
+    Args:
+        input (Tensor[1]): a one dimensional uint8 tensor containing
+            the raw bytes of the JPEG image. This tensor must be on CPU,
+            regardless of the ``device`` parameter.
+        mode (ImageReadMode): the read mode used for optionally
+            converting the image. The supported modes are: ``ImageReadMode.UNCHANGED``,
+            ``ImageReadMode.GRAY`` and ``ImageReadMode.RGB``
+            Default: ``ImageReadMode.UNCHANGED``.
+            See ``ImageReadMode`` class for more information on various
+            available modes.
+        device (str or torch.device): The device on which the decoded image will
+            be stored. If a cuda device is specified, the image will be decoded
+            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
+            supported for CUDA version >= 10.1
+
+            .. betastatus:: device parameter
+
+            .. warning::
+                There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
+                Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
+
+    Returns:
+        output (Tensor[image_channels, image_height, image_width])
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(decode_jpeg)
+    device = torch.device(device)
+    if device.type == "cuda":
+        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
+    else:
+        output = torch.ops.image.decode_jpeg(input, mode.value)
+    return output
+
+
+def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
+    """
+    Takes an input tensor in CHW layout and returns a buffer with the contents
+    of its corresponding JPEG file.
+
+    Args:
+        input (Tensor[channels, image_height, image_width])): int8 image tensor of
+            ``c`` channels, where ``c`` must be 1 or 3.
+        quality (int): Quality of the resulting JPEG file, it must be a number between
+            1 and 100. Default: 75
+
+    Returns:
+        output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
+            JPEG file.
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(encode_jpeg)
+    if quality < 1 or quality > 100:
+        raise ValueError("Image quality should be a positive number between 1 and 100")
+
+    output = torch.ops.image.encode_jpeg(input, quality)
+    return output
+
+
+def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
+    """
+    Takes an input tensor in CHW layout and saves it in a JPEG file.
+
+    Args:
+        input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
+            channels, where ``c`` must be 1 or 3.
+        filename (str): Path to save the image.
+        quality (int): Quality of the resulting JPEG file, it must be a number
+            between 1 and 100. Default: 75
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(write_jpeg)
+    output = encode_jpeg(input, quality)
+    write_file(filename, output)
+
+
+def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
+    """
+    Detects whether an image is a JPEG or PNG and performs the appropriate
+    operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
+
+    Optionally converts the image to the desired format.
+    The values of the output tensor are uint8 in [0, 255].
+
+    Args:
+        input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
+            PNG or JPEG image.
+        mode (ImageReadMode): the read mode used for optionally converting the image.
+            Default: ``ImageReadMode.UNCHANGED``.
+            See ``ImageReadMode`` class for more information on various
+            available modes.
+
+    Returns:
+        output (Tensor[image_channels, image_height, image_width])
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(decode_image)
+    output = torch.ops.image.decode_image(input, mode.value)
+    return output
+
+
+def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
+    """
+    Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
+    Optionally converts the image to the desired format.
+    The values of the output tensor are uint8 in [0, 255].
+
+    Args:
+        path (str): path of the JPEG or PNG image.
+        mode (ImageReadMode): the read mode used for optionally converting the image.
+            Default: ``ImageReadMode.UNCHANGED``.
+            See ``ImageReadMode`` class for more information on various
+            available modes.
+
+    Returns:
+        output (Tensor[image_channels, image_height, image_width])
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(read_image)
+    data = read_file(path)
+    return decode_image(data, mode)
+
+
+def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
+    data = read_file(path)
+    return torch.ops.image.decode_png(data, mode.value, True)

+ 415 - 0
libs/vision_libs/io/video.py

@@ -0,0 +1,415 @@
+import gc
+import math
+import os
+import re
+import warnings
+from fractions import Fraction
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..utils import _log_api_usage_once
+from . import _video_opt
+
+try:
+    import av
+
+    av.logging.set_level(av.logging.ERROR)
+    if not hasattr(av.video.frame.VideoFrame, "pict_type"):
+        av = ImportError(
+            """\
+Your version of PyAV is too old for the necessary video operations in torchvision.
+If you are on Python 3.5, you will have to build from source (the conda-forge
+packages are not up-to-date).  See
+https://github.com/mikeboers/PyAV#installation for instructions on how to
+install PyAV on your system.
+"""
+        )
+except ImportError:
+    av = ImportError(
+        """\
+PyAV is not installed, and is necessary for the video operations in torchvision.
+See https://github.com/mikeboers/PyAV#installation for instructions on how to
+install PyAV on your system.
+"""
+    )
+
+
+def _check_av_available() -> None:
+    if isinstance(av, Exception):
+        raise av
+
+
+def _av_available() -> bool:
+    return not isinstance(av, Exception)
+
+
+# PyAV has some reference cycles
+_CALLED_TIMES = 0
+_GC_COLLECTION_INTERVAL = 10
+
+
+def write_video(
+    filename: str,
+    video_array: torch.Tensor,
+    fps: float,
+    video_codec: str = "libx264",
+    options: Optional[Dict[str, Any]] = None,
+    audio_array: Optional[torch.Tensor] = None,
+    audio_fps: Optional[float] = None,
+    audio_codec: Optional[str] = None,
+    audio_options: Optional[Dict[str, Any]] = None,
+) -> None:
+    """
+    Writes a 4d tensor in [T, H, W, C] format in a video file
+
+    Args:
+        filename (str): path where the video will be saved
+        video_array (Tensor[T, H, W, C]): tensor containing the individual frames,
+            as a uint8 tensor in [T, H, W, C] format
+        fps (Number): video frames per second
+        video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc.
+        options (Dict): dictionary containing options to be passed into the PyAV video stream
+        audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels
+            and N is the number of samples
+        audio_fps (Number): audio sample rate, typically 44100 or 48000
+        audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc.
+        audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(write_video)
+    _check_av_available()
+    video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
+
+    # PyAV does not support floating point numbers with decimal point
+    # and will throw OverflowException in case this is not the case
+    if isinstance(fps, float):
+        fps = np.round(fps)
+
+    with av.open(filename, mode="w") as container:
+        stream = container.add_stream(video_codec, rate=fps)
+        stream.width = video_array.shape[2]
+        stream.height = video_array.shape[1]
+        stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
+        stream.options = options or {}
+
+        if audio_array is not None:
+            audio_format_dtypes = {
+                "dbl": "<f8",
+                "dblp": "<f8",
+                "flt": "<f4",
+                "fltp": "<f4",
+                "s16": "<i2",
+                "s16p": "<i2",
+                "s32": "<i4",
+                "s32p": "<i4",
+                "u8": "u1",
+                "u8p": "u1",
+            }
+            a_stream = container.add_stream(audio_codec, rate=audio_fps)
+            a_stream.options = audio_options or {}
+
+            num_channels = audio_array.shape[0]
+            audio_layout = "stereo" if num_channels > 1 else "mono"
+            audio_sample_fmt = container.streams.audio[0].format.name
+
+            format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
+            audio_array = torch.as_tensor(audio_array).numpy().astype(format_dtype)
+
+            frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)
+
+            frame.sample_rate = audio_fps
+
+            for packet in a_stream.encode(frame):
+                container.mux(packet)
+
+            for packet in a_stream.encode():
+                container.mux(packet)
+
+        for img in video_array:
+            frame = av.VideoFrame.from_ndarray(img, format="rgb24")
+            frame.pict_type = "NONE"
+            for packet in stream.encode(frame):
+                container.mux(packet)
+
+        # Flush stream
+        for packet in stream.encode():
+            container.mux(packet)
+
+
+def _read_from_stream(
+    container: "av.container.Container",
+    start_offset: float,
+    end_offset: float,
+    pts_unit: str,
+    stream: "av.stream.Stream",
+    stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
+) -> List["av.frame.Frame"]:
+    global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
+    _CALLED_TIMES += 1
+    if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
+        gc.collect()
+
+    if pts_unit == "sec":
+        # TODO: we should change all of this from ground up to simply take
+        # sec and convert to MS in C++
+        start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
+        if end_offset != float("inf"):
+            end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
+    else:
+        warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
+
+    frames = {}
+    should_buffer = True
+    max_buffer_size = 5
+    if stream.type == "video":
+        # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
+        # so need to buffer some extra frames to sort everything
+        # properly
+        extradata = stream.codec_context.extradata
+        # overly complicated way of finding if `divx_packed` is set, following
+        # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
+        if extradata and b"DivX" in extradata:
+            # can't use regex directly because of some weird characters sometimes...
+            pos = extradata.find(b"DivX")
+            d = extradata[pos:]
+            o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
+            if o is None:
+                o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
+            if o is not None:
+                should_buffer = o.group(3) == b"p"
+    seek_offset = start_offset
+    # some files don't seek to the right location, so better be safe here
+    seek_offset = max(seek_offset - 1, 0)
+    if should_buffer:
+        # FIXME this is kind of a hack, but we will jump to the previous keyframe
+        # so this will be safe
+        seek_offset = max(seek_offset - max_buffer_size, 0)
+    try:
+        # TODO check if stream needs to always be the video stream here or not
+        container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
+    except av.AVError:
+        # TODO add some warnings in this case
+        # print("Corrupted file?", container.name)
+        return []
+    buffer_count = 0
+    try:
+        for _idx, frame in enumerate(container.decode(**stream_name)):
+            frames[frame.pts] = frame
+            if frame.pts >= end_offset:
+                if should_buffer and buffer_count < max_buffer_size:
+                    buffer_count += 1
+                    continue
+                break
+    except av.AVError:
+        # TODO add a warning
+        pass
+    # ensure that the results are sorted wrt the pts
+    result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
+    if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
+        # if there is no frame that exactly matches the pts of start_offset
+        # add the last frame smaller than start_offset, to guarantee that
+        # we will have all the necessary data. This is most useful for audio
+        preceding_frames = [i for i in frames if i < start_offset]
+        if len(preceding_frames) > 0:
+            first_frame_pts = max(preceding_frames)
+            result.insert(0, frames[first_frame_pts])
+    return result
+
+
+def _align_audio_frames(
+    aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float
+) -> torch.Tensor:
+    start, end = audio_frames[0].pts, audio_frames[-1].pts
+    total_aframes = aframes.shape[1]
+    step_per_aframe = (end - start + 1) / total_aframes
+    s_idx = 0
+    e_idx = total_aframes
+    if start < ref_start:
+        s_idx = int((ref_start - start) / step_per_aframe)
+    if end > ref_end:
+        e_idx = int((ref_end - end) / step_per_aframe)
+    return aframes[:, s_idx:e_idx]
+
+
+def read_video(
+    filename: str,
+    start_pts: Union[float, Fraction] = 0,
+    end_pts: Optional[Union[float, Fraction]] = None,
+    pts_unit: str = "pts",
+    output_format: str = "THWC",
+) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
+    """
+    Reads a video from a file, returning both the video frames and the audio frames
+
+    Args:
+        filename (str): path to the video file
+        start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
+            The start presentation time of the video
+        end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
+            The end presentation time
+        pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
+            either 'pts' or 'sec'. Defaults to 'pts'.
+        output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
+
+    Returns:
+        vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
+        aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
+        info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(read_video)
+
+    output_format = output_format.upper()
+    if output_format not in ("THWC", "TCHW"):
+        raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
+
+    from torchvision import get_video_backend
+
+    if not os.path.exists(filename):
+        raise RuntimeError(f"File not found: {filename}")
+
+    if get_video_backend() != "pyav":
+        vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
+    else:
+        _check_av_available()
+
+        if end_pts is None:
+            end_pts = float("inf")
+
+        if end_pts < start_pts:
+            raise ValueError(
+                f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
+            )
+
+        info = {}
+        video_frames = []
+        audio_frames = []
+        audio_timebase = _video_opt.default_timebase
+
+        try:
+            with av.open(filename, metadata_errors="ignore") as container:
+                if container.streams.audio:
+                    audio_timebase = container.streams.audio[0].time_base
+                if container.streams.video:
+                    video_frames = _read_from_stream(
+                        container,
+                        start_pts,
+                        end_pts,
+                        pts_unit,
+                        container.streams.video[0],
+                        {"video": 0},
+                    )
+                    video_fps = container.streams.video[0].average_rate
+                    # guard against potentially corrupted files
+                    if video_fps is not None:
+                        info["video_fps"] = float(video_fps)
+
+                if container.streams.audio:
+                    audio_frames = _read_from_stream(
+                        container,
+                        start_pts,
+                        end_pts,
+                        pts_unit,
+                        container.streams.audio[0],
+                        {"audio": 0},
+                    )
+                    info["audio_fps"] = container.streams.audio[0].rate
+
+        except av.AVError:
+            # TODO raise a warning?
+            pass
+
+        vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
+        aframes_list = [frame.to_ndarray() for frame in audio_frames]
+
+        if vframes_list:
+            vframes = torch.as_tensor(np.stack(vframes_list))
+        else:
+            vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
+
+        if aframes_list:
+            aframes = np.concatenate(aframes_list, 1)
+            aframes = torch.as_tensor(aframes)
+            if pts_unit == "sec":
+                start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
+                if end_pts != float("inf"):
+                    end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
+            aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
+        else:
+            aframes = torch.empty((1, 0), dtype=torch.float32)
+
+    if output_format == "TCHW":
+        # [T,H,W,C] --> [T,C,H,W]
+        vframes = vframes.permute(0, 3, 1, 2)
+
+    return vframes, aframes, info
+
+
+def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
+    extradata = container.streams[0].codec_context.extradata
+    if extradata is None:
+        return False
+    if b"Lavc" in extradata:
+        return True
+    return False
+
+
+def _decode_video_timestamps(container: "av.container.Container") -> List[int]:
+    if _can_read_timestamps_from_packets(container):
+        # fast path
+        return [x.pts for x in container.demux(video=0) if x.pts is not None]
+    else:
+        return [x.pts for x in container.decode(video=0) if x.pts is not None]
+
+
+def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
+    """
+    List the video frames timestamps.
+
+    Note that the function decodes the whole video frame-by-frame.
+
+    Args:
+        filename (str): path to the video file
+        pts_unit (str, optional): unit in which timestamp values will be returned
+            either 'pts' or 'sec'. Defaults to 'pts'.
+
+    Returns:
+        pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'):
+            presentation timestamps for each one of the frames in the video.
+        video_fps (float, optional): the frame rate for the video
+
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(read_video_timestamps)
+    from torchvision import get_video_backend
+
+    if get_video_backend() != "pyav":
+        return _video_opt._read_video_timestamps(filename, pts_unit)
+
+    _check_av_available()
+
+    video_fps = None
+    pts = []
+
+    try:
+        with av.open(filename, metadata_errors="ignore") as container:
+            if container.streams.video:
+                video_stream = container.streams.video[0]
+                video_time_base = video_stream.time_base
+                try:
+                    pts = _decode_video_timestamps(container)
+                except av.AVError:
+                    warnings.warn(f"Failed decoding frames for file {filename}")
+                video_fps = float(video_stream.average_rate)
+    except av.AVError as e:
+        msg = f"Failed to open container for {filename}; Caught error: {e}"
+        warnings.warn(msg, RuntimeWarning)
+
+    pts.sort()
+
+    if pts_unit == "sec":
+        pts = [x * video_time_base for x in pts]
+
+    return pts, video_fps

+ 286 - 0
libs/vision_libs/io/video_reader.py

@@ -0,0 +1,286 @@
+import io
+import warnings
+
+from typing import Any, Dict, Iterator
+
+import torch
+
+from ..utils import _log_api_usage_once
+
+from ._video_opt import _HAS_VIDEO_OPT
+
+if _HAS_VIDEO_OPT:
+
+    def _has_video_opt() -> bool:
+        return True
+
+else:
+
+    def _has_video_opt() -> bool:
+        return False
+
+
+try:
+    import av
+
+    av.logging.set_level(av.logging.ERROR)
+    if not hasattr(av.video.frame.VideoFrame, "pict_type"):
+        av = ImportError(
+            """\
+Your version of PyAV is too old for the necessary video operations in torchvision.
+If you are on Python 3.5, you will have to build from source (the conda-forge
+packages are not up-to-date).  See
+https://github.com/mikeboers/PyAV#installation for instructions on how to
+install PyAV on your system.
+"""
+        )
+except ImportError:
+    av = ImportError(
+        """\
+PyAV is not installed, and is necessary for the video operations in torchvision.
+See https://github.com/mikeboers/PyAV#installation for instructions on how to
+install PyAV on your system.
+"""
+    )
+
+
+class VideoReader:
+    """
+    Fine-grained video-reading API.
+    Supports frame-by-frame reading of various streams from a single video
+    container. Much like previous video_reader API it supports the following
+    backends: video_reader, pyav, and cuda.
+    Backends can be set via `torchvision.set_video_backend` function.
+
+    .. betastatus:: VideoReader class
+
+    Example:
+        The following examples creates a :mod:`VideoReader` object, seeks into 2s
+        point, and returns a single frame::
+
+            import torchvision
+            video_path = "path_to_a_test_video"
+            reader = torchvision.io.VideoReader(video_path, "video")
+            reader.seek(2.0)
+            frame = next(reader)
+
+        :mod:`VideoReader` implements the iterable API, which makes it suitable to
+        using it in conjunction with :mod:`itertools` for more advanced reading.
+        As such, we can use a :mod:`VideoReader` instance inside for loops::
+
+            reader.seek(2)
+            for frame in reader:
+                frames.append(frame['data'])
+            # additionally, `seek` implements a fluent API, so we can do
+            for frame in reader.seek(2):
+                frames.append(frame['data'])
+
+        With :mod:`itertools`, we can read all frames between 2 and 5 seconds with the
+        following code::
+
+            for frame in itertools.takewhile(lambda x: x['pts'] <= 5, reader.seek(2)):
+                frames.append(frame['data'])
+
+        and similarly, reading 10 frames after the 2s timestamp can be achieved
+        as follows::
+
+            for frame in itertools.islice(reader.seek(2), 10):
+                frames.append(frame['data'])
+
+    .. note::
+
+        Each stream descriptor consists of two parts: stream type (e.g. 'video') and
+        a unique stream id (which are determined by the video encoding).
+        In this way, if the video container contains multiple
+        streams of the same type, users can access the one they want.
+        If only stream type is passed, the decoder auto-detects first stream of that type.
+
+    Args:
+        src (string, bytes object, or tensor): The media source.
+            If string-type, it must be a file path supported by FFMPEG.
+            If bytes, should be an in-memory representation of a file supported by FFMPEG.
+            If Tensor, it is interpreted internally as byte buffer.
+            It must be one-dimensional, of type ``torch.uint8``.
+
+        stream (string, optional): descriptor of the required stream, followed by the stream id,
+            in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
+            Currently available options include ``['video', 'audio']``
+
+        num_threads (int, optional): number of threads used by the codec to decode video.
+            Default value (0) enables multithreading with codec-dependent heuristic. The performance
+            will depend on the version of FFMPEG codecs supported.
+    """
+
+    def __init__(
+        self,
+        src: str,
+        stream: str = "video",
+        num_threads: int = 0,
+    ) -> None:
+        _log_api_usage_once(self)
+        from .. import get_video_backend
+
+        self.backend = get_video_backend()
+        if isinstance(src, str):
+            if not src:
+                raise ValueError("src cannot be empty")
+        elif isinstance(src, bytes):
+            if self.backend in ["cuda"]:
+                raise RuntimeError(
+                    "VideoReader cannot be initialized from bytes object when using cuda or pyav backend."
+                )
+            elif self.backend == "pyav":
+                src = io.BytesIO(src)
+            else:
+                with warnings.catch_warnings():
+                    # Ignore the warning because we actually don't modify the buffer in this function
+                    warnings.filterwarnings("ignore", message="The given buffer is not writable")
+                    src = torch.frombuffer(src, dtype=torch.uint8)
+        elif isinstance(src, torch.Tensor):
+            if self.backend in ["cuda", "pyav"]:
+                raise RuntimeError(
+                    "VideoReader cannot be initialized from Tensor object when using cuda or pyav backend."
+                )
+        else:
+            raise ValueError(f"src must be either string, Tensor or bytes object. Got {type(src)}")
+
+        if self.backend == "cuda":
+            device = torch.device("cuda")
+            self._c = torch.classes.torchvision.GPUDecoder(src, device)
+
+        elif self.backend == "video_reader":
+            if isinstance(src, str):
+                self._c = torch.classes.torchvision.Video(src, stream, num_threads)
+            elif isinstance(src, torch.Tensor):
+                self._c = torch.classes.torchvision.Video("", "", 0)
+                self._c.init_from_memory(src, stream, num_threads)
+
+        elif self.backend == "pyav":
+            self.container = av.open(src, metadata_errors="ignore")
+            # TODO: load metadata
+            stream_type = stream.split(":")[0]
+            stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
+            self.pyav_stream = {stream_type: stream_id}
+            self._c = self.container.decode(**self.pyav_stream)
+
+            # TODO: add extradata exception
+
+        else:
+            raise RuntimeError("Unknown video backend: {}".format(self.backend))
+
+    def __next__(self) -> Dict[str, Any]:
+        """Decodes and returns the next frame of the current stream.
+        Frames are encoded as a dict with mandatory
+        data and pts fields, where data is a tensor, and pts is a
+        presentation timestamp of the frame expressed in seconds
+        as a float.
+
+        Returns:
+            (dict): a dictionary and containing decoded frame (``data``)
+            and corresponding timestamp (``pts``) in seconds
+
+        """
+        if self.backend == "cuda":
+            frame = self._c.next()
+            if frame.numel() == 0:
+                raise StopIteration
+            return {"data": frame, "pts": None}
+        elif self.backend == "video_reader":
+            frame, pts = self._c.next()
+        else:
+            try:
+                frame = next(self._c)
+                pts = float(frame.pts * frame.time_base)
+                if "video" in self.pyav_stream:
+                    frame = torch.tensor(frame.to_rgb().to_ndarray()).permute(2, 0, 1)
+                elif "audio" in self.pyav_stream:
+                    frame = torch.tensor(frame.to_ndarray()).permute(1, 0)
+                else:
+                    frame = None
+            except av.error.EOFError:
+                raise StopIteration
+
+        if frame.numel() == 0:
+            raise StopIteration
+
+        return {"data": frame, "pts": pts}
+
+    def __iter__(self) -> Iterator[Dict[str, Any]]:
+        return self
+
+    def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":
+        """Seek within current stream.
+
+        Args:
+            time_s (float): seek time in seconds
+            keyframes_only (bool): allow to seek only to keyframes
+
+        .. note::
+            Current implementation is the so-called precise seek. This
+            means following seek, call to :mod:`next()` will return the
+            frame with the exact timestamp if it exists or
+            the first frame with timestamp larger than ``time_s``.
+        """
+        if self.backend in ["cuda", "video_reader"]:
+            self._c.seek(time_s, keyframes_only)
+        else:
+            # handle special case as pyav doesn't catch it
+            if time_s < 0:
+                time_s = 0
+            temp_str = self.container.streams.get(**self.pyav_stream)[0]
+            offset = int(round(time_s / temp_str.time_base))
+            if not keyframes_only:
+                warnings.warn("Accurate seek is not implemented for pyav backend")
+            self.container.seek(offset, backward=True, any_frame=False, stream=temp_str)
+            self._c = self.container.decode(**self.pyav_stream)
+        return self
+
+    def get_metadata(self) -> Dict[str, Any]:
+        """Returns video metadata
+
+        Returns:
+            (dict): dictionary containing duration and frame rate for every stream
+        """
+        if self.backend == "pyav":
+            metadata = {}  # type:  Dict[str, Any]
+            for stream in self.container.streams:
+                if stream.type not in metadata:
+                    if stream.type == "video":
+                        rate_n = "fps"
+                    else:
+                        rate_n = "framerate"
+                    metadata[stream.type] = {rate_n: [], "duration": []}
+
+                rate = stream.average_rate if stream.average_rate is not None else stream.sample_rate
+
+                metadata[stream.type]["duration"].append(float(stream.duration * stream.time_base))
+                metadata[stream.type][rate_n].append(float(rate))
+            return metadata
+        return self._c.get_metadata()
+
+    def set_current_stream(self, stream: str) -> bool:
+        """Set current stream.
+        Explicitly define the stream we are operating on.
+
+        Args:
+            stream (string): descriptor of the required stream. Defaults to ``"video:0"``
+                Currently available stream types include ``['video', 'audio']``.
+                Each descriptor consists of two parts: stream type (e.g. 'video') and
+                a unique stream id (which are determined by video encoding).
+                In this way, if the video container contains multiple
+                streams of the same type, users can access the one they want.
+                If only stream type is passed, the decoder auto-detects first stream
+                of that type and returns it.
+
+        Returns:
+            (bool): True on success, False otherwise
+        """
+        if self.backend == "cuda":
+            warnings.warn("GPU decoding only works with video stream.")
+        if self.backend == "pyav":
+            stream_type = stream.split(":")[0]
+            stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
+            self.pyav_stream = {stream_type: stream_id}
+            self._c = self.container.decode(**self.pyav_stream)
+            return True
+        return self._c.set_current_stream(stream)

+ 23 - 0
libs/vision_libs/models/__init__.py

@@ -0,0 +1,23 @@
+from .alexnet import *
+from .convnext import *
+from .densenet import *
+from .efficientnet import *
+from .googlenet import *
+from .inception import *
+from .mnasnet import *
+from .mobilenet import *
+from .regnet import *
+from .resnet import *
+from .shufflenetv2 import *
+from .squeezenet import *
+from .vgg import *
+from .vision_transformer import *
+from .swin_transformer import *
+from .maxvit import *
+from . import detection, optical_flow, quantization, segmentation, video
+
+# The Weights and WeightsEnum are developer-facing utils that we make public for
+# downstream libs like torchgeo https://github.com/pytorch/vision/issues/7094
+# TODO: we could / should document them publicly, but it's not clear where, as
+# they're not intended for end users.
+from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models, Weights, WeightsEnum

+ 277 - 0
libs/vision_libs/models/_api.py

@@ -0,0 +1,277 @@
+import fnmatch
+import importlib
+import inspect
+import sys
+from dataclasses import dataclass
+from enum import Enum
+from functools import partial
+from inspect import signature
+from types import ModuleType
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
+
+from torch import nn
+
+from .._internally_replaced_utils import load_state_dict_from_url
+
+
+__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"]
+
+
+@dataclass
+class Weights:
+    """
+    This class is used to group important attributes associated with the pre-trained weights.
+
+    Args:
+        url (str): The location where we find the weights.
+        transforms (Callable): A callable that constructs the preprocessing method (or validation preset transforms)
+            needed to use the model. The reason we attach a constructor method rather than an already constructed
+            object is because the specific object might have memory and thus we want to delay initialization until
+            needed.
+        meta (Dict[str, Any]): Stores meta-data related to the weights of the model and its configuration. These can be
+            informative attributes (for example the number of parameters/flops, recipe link/methods used in training
+            etc), configuration parameters (for example the `num_classes`) needed to construct the model or important
+            meta-data (for example the `classes` of a classification model) needed to use the model.
+    """
+
+    url: str
+    transforms: Callable
+    meta: Dict[str, Any]
+
+    def __eq__(self, other: Any) -> bool:
+        # We need this custom implementation for correct deep-copy and deserialization behavior.
+        # TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
+        # involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
+        # defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
+        # for it, the check against the defined members would fail and effectively prevent the weights from being
+        # deep-copied or deserialized.
+        # See https://github.com/pytorch/vision/pull/7107 for details.
+        if not isinstance(other, Weights):
+            return NotImplemented
+
+        if self.url != other.url:
+            return False
+
+        if self.meta != other.meta:
+            return False
+
+        if isinstance(self.transforms, partial) and isinstance(other.transforms, partial):
+            return (
+                self.transforms.func == other.transforms.func
+                and self.transforms.args == other.transforms.args
+                and self.transforms.keywords == other.transforms.keywords
+            )
+        else:
+            return self.transforms == other.transforms
+
+
+class WeightsEnum(Enum):
+    """
+    This class is the parent class of all model weights. Each model building method receives an optional `weights`
+    parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
+    `Weights`.
+
+    Args:
+        value (Weights): The data class entry with the weight information.
+    """
+
+    @classmethod
+    def verify(cls, obj: Any) -> Any:
+        if obj is not None:
+            if type(obj) is str:
+                obj = cls[obj.replace(cls.__name__ + ".", "")]
+            elif not isinstance(obj, cls):
+                raise TypeError(
+                    f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
+                )
+        return obj
+
+    def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
+        return load_state_dict_from_url(self.url, *args, **kwargs)
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}.{self._name_}"
+
+    @property
+    def url(self):
+        return self.value.url
+
+    @property
+    def transforms(self):
+        return self.value.transforms
+
+    @property
+    def meta(self):
+        return self.value.meta
+
+
+def get_weight(name: str) -> WeightsEnum:
+    """
+    Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
+
+    Args:
+        name (str): The name of the weight enum entry.
+
+    Returns:
+        WeightsEnum: The requested weight enum.
+    """
+    try:
+        enum_name, value_name = name.split(".")
+    except ValueError:
+        raise ValueError(f"Invalid weight name provided: '{name}'.")
+
+    base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1])
+    base_module = importlib.import_module(base_module_name)
+    model_modules = [base_module] + [
+        x[1]
+        for x in inspect.getmembers(base_module, inspect.ismodule)
+        if x[1].__file__.endswith("__init__.py")  # type: ignore[union-attr]
+    ]
+
+    weights_enum = None
+    for m in model_modules:
+        potential_class = m.__dict__.get(enum_name, None)
+        if potential_class is not None and issubclass(potential_class, WeightsEnum):
+            weights_enum = potential_class
+            break
+
+    if weights_enum is None:
+        raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
+
+    return weights_enum[value_name]
+
+
+def get_model_weights(name: Union[Callable, str]) -> Type[WeightsEnum]:
+    """
+    Returns the weights enum class associated to the given model.
+
+    Args:
+        name (callable or str): The model builder function or the name under which it is registered.
+
+    Returns:
+        weights_enum (WeightsEnum): The weights enum class associated with the model.
+    """
+    model = get_model_builder(name) if isinstance(name, str) else name
+    return _get_enum_from_fn(model)
+
+
+def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
+    """
+    Internal method that gets the weight enum of a specific model builder method.
+
+    Args:
+        fn (Callable): The builder method used to create the model.
+    Returns:
+        WeightsEnum: The requested weight enum.
+    """
+    sig = signature(fn)
+    if "weights" not in sig.parameters:
+        raise ValueError("The method is missing the 'weights' argument.")
+
+    ann = signature(fn).parameters["weights"].annotation
+    weights_enum = None
+    if isinstance(ann, type) and issubclass(ann, WeightsEnum):
+        weights_enum = ann
+    else:
+        # handle cases like Union[Optional, T]
+        # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
+        for t in ann.__args__:  # type: ignore[union-attr]
+            if isinstance(t, type) and issubclass(t, WeightsEnum):
+                weights_enum = t
+                break
+
+    if weights_enum is None:
+        raise ValueError(
+            "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
+        )
+
+    return weights_enum
+
+
+M = TypeVar("M", bound=nn.Module)
+
+BUILTIN_MODELS = {}
+
+
+def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
+    def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
+        key = name if name is not None else fn.__name__
+        if key in BUILTIN_MODELS:
+            raise ValueError(f"An entry is already registered under the name '{key}'.")
+        BUILTIN_MODELS[key] = fn
+        return fn
+
+    return wrapper
+
+
+def list_models(
+    module: Optional[ModuleType] = None,
+    include: Union[Iterable[str], str, None] = None,
+    exclude: Union[Iterable[str], str, None] = None,
+) -> List[str]:
+    """
+    Returns a list with the names of registered models.
+
+    Args:
+        module (ModuleType, optional): The module from which we want to extract the available models.
+        include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models.
+            Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
+            wildcards. In case of many filters, the results is the union of individual filters.
+        exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models.
+            Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
+            wildcards. In case of many filters, the results is removal of all the models that match any individual filter.
+
+    Returns:
+        models (list): A list with the names of available models.
+    """
+    all_models = {
+        k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
+    }
+    if include:
+        models: Set[str] = set()
+        if isinstance(include, str):
+            include = [include]
+        for include_filter in include:
+            models = models | set(fnmatch.filter(all_models, include_filter))
+    else:
+        models = all_models
+
+    if exclude:
+        if isinstance(exclude, str):
+            exclude = [exclude]
+        for exclude_filter in exclude:
+            models = models - set(fnmatch.filter(all_models, exclude_filter))
+    return sorted(models)
+
+
+def get_model_builder(name: str) -> Callable[..., nn.Module]:
+    """
+    Gets the model name and returns the model builder method.
+
+    Args:
+        name (str): The name under which the model is registered.
+
+    Returns:
+        fn (Callable): The model builder method.
+    """
+    name = name.lower()
+    try:
+        fn = BUILTIN_MODELS[name]
+    except KeyError:
+        raise ValueError(f"Unknown model {name}")
+    return fn
+
+
+def get_model(name: str, **config: Any) -> nn.Module:
+    """
+    Gets the model name and configuration and returns an instantiated model.
+
+    Args:
+        name (str): The name under which the model is registered.
+        **config (Any): parameters passed to the model builder method.
+
+    Returns:
+        model (nn.Module): The initialized model.
+    """
+    fn = get_model_builder(name)
+    return fn(**config)

+ 1554 - 0
libs/vision_libs/models/_meta.py

@@ -0,0 +1,1554 @@
+"""
+This file is part of the private API. Please do not refer to any variables defined here directly as they will be
+removed on future versions without warning.
+"""
+
+# This will eventually be replaced with a call at torchvision.datasets.info("imagenet").categories
+_IMAGENET_CATEGORIES = [
+    "tench",
+    "goldfish",
+    "great white shark",
+    "tiger shark",
+    "hammerhead",
+    "electric ray",
+    "stingray",
+    "cock",
+    "hen",
+    "ostrich",
+    "brambling",
+    "goldfinch",
+    "house finch",
+    "junco",
+    "indigo bunting",
+    "robin",
+    "bulbul",
+    "jay",
+    "magpie",
+    "chickadee",
+    "water ouzel",
+    "kite",
+    "bald eagle",
+    "vulture",
+    "great grey owl",
+    "European fire salamander",
+    "common newt",
+    "eft",
+    "spotted salamander",
+    "axolotl",
+    "bullfrog",
+    "tree frog",
+    "tailed frog",
+    "loggerhead",
+    "leatherback turtle",
+    "mud turtle",
+    "terrapin",
+    "box turtle",
+    "banded gecko",
+    "common iguana",
+    "American chameleon",
+    "whiptail",
+    "agama",
+    "frilled lizard",
+    "alligator lizard",
+    "Gila monster",
+    "green lizard",
+    "African chameleon",
+    "Komodo dragon",
+    "African crocodile",
+    "American alligator",
+    "triceratops",
+    "thunder snake",
+    "ringneck snake",
+    "hognose snake",
+    "green snake",
+    "king snake",
+    "garter snake",
+    "water snake",
+    "vine snake",
+    "night snake",
+    "boa constrictor",
+    "rock python",
+    "Indian cobra",
+    "green mamba",
+    "sea snake",
+    "horned viper",
+    "diamondback",
+    "sidewinder",
+    "trilobite",
+    "harvestman",
+    "scorpion",
+    "black and gold garden spider",
+    "barn spider",
+    "garden spider",
+    "black widow",
+    "tarantula",
+    "wolf spider",
+    "tick",
+    "centipede",
+    "black grouse",
+    "ptarmigan",
+    "ruffed grouse",
+    "prairie chicken",
+    "peacock",
+    "quail",
+    "partridge",
+    "African grey",
+    "macaw",
+    "sulphur-crested cockatoo",
+    "lorikeet",
+    "coucal",
+    "bee eater",
+    "hornbill",
+    "hummingbird",
+    "jacamar",
+    "toucan",
+    "drake",
+    "red-breasted merganser",
+    "goose",
+    "black swan",
+    "tusker",
+    "echidna",
+    "platypus",
+    "wallaby",
+    "koala",
+    "wombat",
+    "jellyfish",
+    "sea anemone",
+    "brain coral",
+    "flatworm",
+    "nematode",
+    "conch",
+    "snail",
+    "slug",
+    "sea slug",
+    "chiton",
+    "chambered nautilus",
+    "Dungeness crab",
+    "rock crab",
+    "fiddler crab",
+    "king crab",
+    "American lobster",
+    "spiny lobster",
+    "crayfish",
+    "hermit crab",
+    "isopod",
+    "white stork",
+    "black stork",
+    "spoonbill",
+    "flamingo",
+    "little blue heron",
+    "American egret",
+    "bittern",
+    "crane bird",
+    "limpkin",
+    "European gallinule",
+    "American coot",
+    "bustard",
+    "ruddy turnstone",
+    "red-backed sandpiper",
+    "redshank",
+    "dowitcher",
+    "oystercatcher",
+    "pelican",
+    "king penguin",
+    "albatross",
+    "grey whale",
+    "killer whale",
+    "dugong",
+    "sea lion",
+    "Chihuahua",
+    "Japanese spaniel",
+    "Maltese dog",
+    "Pekinese",
+    "Shih-Tzu",
+    "Blenheim spaniel",
+    "papillon",
+    "toy terrier",
+    "Rhodesian ridgeback",
+    "Afghan hound",
+    "basset",
+    "beagle",
+    "bloodhound",
+    "bluetick",
+    "black-and-tan coonhound",
+    "Walker hound",
+    "English foxhound",
+    "redbone",
+    "borzoi",
+    "Irish wolfhound",
+    "Italian greyhound",
+    "whippet",
+    "Ibizan hound",
+    "Norwegian elkhound",
+    "otterhound",
+    "Saluki",
+    "Scottish deerhound",
+    "Weimaraner",
+    "Staffordshire bullterrier",
+    "American Staffordshire terrier",
+    "Bedlington terrier",
+    "Border terrier",
+    "Kerry blue terrier",
+    "Irish terrier",
+    "Norfolk terrier",
+    "Norwich terrier",
+    "Yorkshire terrier",
+    "wire-haired fox terrier",
+    "Lakeland terrier",
+    "Sealyham terrier",
+    "Airedale",
+    "cairn",
+    "Australian terrier",
+    "Dandie Dinmont",
+    "Boston bull",
+    "miniature schnauzer",
+    "giant schnauzer",
+    "standard schnauzer",
+    "Scotch terrier",
+    "Tibetan terrier",
+    "silky terrier",
+    "soft-coated wheaten terrier",
+    "West Highland white terrier",
+    "Lhasa",
+    "flat-coated retriever",
+    "curly-coated retriever",
+    "golden retriever",
+    "Labrador retriever",
+    "Chesapeake Bay retriever",
+    "German short-haired pointer",
+    "vizsla",
+    "English setter",
+    "Irish setter",
+    "Gordon setter",
+    "Brittany spaniel",
+    "clumber",
+    "English springer",
+    "Welsh springer spaniel",
+    "cocker spaniel",
+    "Sussex spaniel",
+    "Irish water spaniel",
+    "kuvasz",
+    "schipperke",
+    "groenendael",
+    "malinois",
+    "briard",
+    "kelpie",
+    "komondor",
+    "Old English sheepdog",
+    "Shetland sheepdog",
+    "collie",
+    "Border collie",
+    "Bouvier des Flandres",
+    "Rottweiler",
+    "German shepherd",
+    "Doberman",
+    "miniature pinscher",
+    "Greater Swiss Mountain dog",
+    "Bernese mountain dog",
+    "Appenzeller",
+    "EntleBucher",
+    "boxer",
+    "bull mastiff",
+    "Tibetan mastiff",
+    "French bulldog",
+    "Great Dane",
+    "Saint Bernard",
+    "Eskimo dog",
+    "malamute",
+    "Siberian husky",
+    "dalmatian",
+    "affenpinscher",
+    "basenji",
+    "pug",
+    "Leonberg",
+    "Newfoundland",
+    "Great Pyrenees",
+    "Samoyed",
+    "Pomeranian",
+    "chow",
+    "keeshond",
+    "Brabancon griffon",
+    "Pembroke",
+    "Cardigan",
+    "toy poodle",
+    "miniature poodle",
+    "standard poodle",
+    "Mexican hairless",
+    "timber wolf",
+    "white wolf",
+    "red wolf",
+    "coyote",
+    "dingo",
+    "dhole",
+    "African hunting dog",
+    "hyena",
+    "red fox",
+    "kit fox",
+    "Arctic fox",
+    "grey fox",
+    "tabby",
+    "tiger cat",
+    "Persian cat",
+    "Siamese cat",
+    "Egyptian cat",
+    "cougar",
+    "lynx",
+    "leopard",
+    "snow leopard",
+    "jaguar",
+    "lion",
+    "tiger",
+    "cheetah",
+    "brown bear",
+    "American black bear",
+    "ice bear",
+    "sloth bear",
+    "mongoose",
+    "meerkat",
+    "tiger beetle",
+    "ladybug",
+    "ground beetle",
+    "long-horned beetle",
+    "leaf beetle",
+    "dung beetle",
+    "rhinoceros beetle",
+    "weevil",
+    "fly",
+    "bee",
+    "ant",
+    "grasshopper",
+    "cricket",
+    "walking stick",
+    "cockroach",
+    "mantis",
+    "cicada",
+    "leafhopper",
+    "lacewing",
+    "dragonfly",
+    "damselfly",
+    "admiral",
+    "ringlet",
+    "monarch",
+    "cabbage butterfly",
+    "sulphur butterfly",
+    "lycaenid",
+    "starfish",
+    "sea urchin",
+    "sea cucumber",
+    "wood rabbit",
+    "hare",
+    "Angora",
+    "hamster",
+    "porcupine",
+    "fox squirrel",
+    "marmot",
+    "beaver",
+    "guinea pig",
+    "sorrel",
+    "zebra",
+    "hog",
+    "wild boar",
+    "warthog",
+    "hippopotamus",
+    "ox",
+    "water buffalo",
+    "bison",
+    "ram",
+    "bighorn",
+    "ibex",
+    "hartebeest",
+    "impala",
+    "gazelle",
+    "Arabian camel",
+    "llama",
+    "weasel",
+    "mink",
+    "polecat",
+    "black-footed ferret",
+    "otter",
+    "skunk",
+    "badger",
+    "armadillo",
+    "three-toed sloth",
+    "orangutan",
+    "gorilla",
+    "chimpanzee",
+    "gibbon",
+    "siamang",
+    "guenon",
+    "patas",
+    "baboon",
+    "macaque",
+    "langur",
+    "colobus",
+    "proboscis monkey",
+    "marmoset",
+    "capuchin",
+    "howler monkey",
+    "titi",
+    "spider monkey",
+    "squirrel monkey",
+    "Madagascar cat",
+    "indri",
+    "Indian elephant",
+    "African elephant",
+    "lesser panda",
+    "giant panda",
+    "barracouta",
+    "eel",
+    "coho",
+    "rock beauty",
+    "anemone fish",
+    "sturgeon",
+    "gar",
+    "lionfish",
+    "puffer",
+    "abacus",
+    "abaya",
+    "academic gown",
+    "accordion",
+    "acoustic guitar",
+    "aircraft carrier",
+    "airliner",
+    "airship",
+    "altar",
+    "ambulance",
+    "amphibian",
+    "analog clock",
+    "apiary",
+    "apron",
+    "ashcan",
+    "assault rifle",
+    "backpack",
+    "bakery",
+    "balance beam",
+    "balloon",
+    "ballpoint",
+    "Band Aid",
+    "banjo",
+    "bannister",
+    "barbell",
+    "barber chair",
+    "barbershop",
+    "barn",
+    "barometer",
+    "barrel",
+    "barrow",
+    "baseball",
+    "basketball",
+    "bassinet",
+    "bassoon",
+    "bathing cap",
+    "bath towel",
+    "bathtub",
+    "beach wagon",
+    "beacon",
+    "beaker",
+    "bearskin",
+    "beer bottle",
+    "beer glass",
+    "bell cote",
+    "bib",
+    "bicycle-built-for-two",
+    "bikini",
+    "binder",
+    "binoculars",
+    "birdhouse",
+    "boathouse",
+    "bobsled",
+    "bolo tie",
+    "bonnet",
+    "bookcase",
+    "bookshop",
+    "bottlecap",
+    "bow",
+    "bow tie",
+    "brass",
+    "brassiere",
+    "breakwater",
+    "breastplate",
+    "broom",
+    "bucket",
+    "buckle",
+    "bulletproof vest",
+    "bullet train",
+    "butcher shop",
+    "cab",
+    "caldron",
+    "candle",
+    "cannon",
+    "canoe",
+    "can opener",
+    "cardigan",
+    "car mirror",
+    "carousel",
+    "carpenter's kit",
+    "carton",
+    "car wheel",
+    "cash machine",
+    "cassette",
+    "cassette player",
+    "castle",
+    "catamaran",
+    "CD player",
+    "cello",
+    "cellular telephone",
+    "chain",
+    "chainlink fence",
+    "chain mail",
+    "chain saw",
+    "chest",
+    "chiffonier",
+    "chime",
+    "china cabinet",
+    "Christmas stocking",
+    "church",
+    "cinema",
+    "cleaver",
+    "cliff dwelling",
+    "cloak",
+    "clog",
+    "cocktail shaker",
+    "coffee mug",
+    "coffeepot",
+    "coil",
+    "combination lock",
+    "computer keyboard",
+    "confectionery",
+    "container ship",
+    "convertible",
+    "corkscrew",
+    "cornet",
+    "cowboy boot",
+    "cowboy hat",
+    "cradle",
+    "crane",
+    "crash helmet",
+    "crate",
+    "crib",
+    "Crock Pot",
+    "croquet ball",
+    "crutch",
+    "cuirass",
+    "dam",
+    "desk",
+    "desktop computer",
+    "dial telephone",
+    "diaper",
+    "digital clock",
+    "digital watch",
+    "dining table",
+    "dishrag",
+    "dishwasher",
+    "disk brake",
+    "dock",
+    "dogsled",
+    "dome",
+    "doormat",
+    "drilling platform",
+    "drum",
+    "drumstick",
+    "dumbbell",
+    "Dutch oven",
+    "electric fan",
+    "electric guitar",
+    "electric locomotive",
+    "entertainment center",
+    "envelope",
+    "espresso maker",
+    "face powder",
+    "feather boa",
+    "file",
+    "fireboat",
+    "fire engine",
+    "fire screen",
+    "flagpole",
+    "flute",
+    "folding chair",
+    "football helmet",
+    "forklift",
+    "fountain",
+    "fountain pen",
+    "four-poster",
+    "freight car",
+    "French horn",
+    "frying pan",
+    "fur coat",
+    "garbage truck",
+    "gasmask",
+    "gas pump",
+    "goblet",
+    "go-kart",
+    "golf ball",
+    "golfcart",
+    "gondola",
+    "gong",
+    "gown",
+    "grand piano",
+    "greenhouse",
+    "grille",
+    "grocery store",
+    "guillotine",
+    "hair slide",
+    "hair spray",
+    "half track",
+    "hammer",
+    "hamper",
+    "hand blower",
+    "hand-held computer",
+    "handkerchief",
+    "hard disc",
+    "harmonica",
+    "harp",
+    "harvester",
+    "hatchet",
+    "holster",
+    "home theater",
+    "honeycomb",
+    "hook",
+    "hoopskirt",
+    "horizontal bar",
+    "horse cart",
+    "hourglass",
+    "iPod",
+    "iron",
+    "jack-o'-lantern",
+    "jean",
+    "jeep",
+    "jersey",
+    "jigsaw puzzle",
+    "jinrikisha",
+    "joystick",
+    "kimono",
+    "knee pad",
+    "knot",
+    "lab coat",
+    "ladle",
+    "lampshade",
+    "laptop",
+    "lawn mower",
+    "lens cap",
+    "letter opener",
+    "library",
+    "lifeboat",
+    "lighter",
+    "limousine",
+    "liner",
+    "lipstick",
+    "Loafer",
+    "lotion",
+    "loudspeaker",
+    "loupe",
+    "lumbermill",
+    "magnetic compass",
+    "mailbag",
+    "mailbox",
+    "maillot",
+    "maillot tank suit",
+    "manhole cover",
+    "maraca",
+    "marimba",
+    "mask",
+    "matchstick",
+    "maypole",
+    "maze",
+    "measuring cup",
+    "medicine chest",
+    "megalith",
+    "microphone",
+    "microwave",
+    "military uniform",
+    "milk can",
+    "minibus",
+    "miniskirt",
+    "minivan",
+    "missile",
+    "mitten",
+    "mixing bowl",
+    "mobile home",
+    "Model T",
+    "modem",
+    "monastery",
+    "monitor",
+    "moped",
+    "mortar",
+    "mortarboard",
+    "mosque",
+    "mosquito net",
+    "motor scooter",
+    "mountain bike",
+    "mountain tent",
+    "mouse",
+    "mousetrap",
+    "moving van",
+    "muzzle",
+    "nail",
+    "neck brace",
+    "necklace",
+    "nipple",
+    "notebook",
+    "obelisk",
+    "oboe",
+    "ocarina",
+    "odometer",
+    "oil filter",
+    "organ",
+    "oscilloscope",
+    "overskirt",
+    "oxcart",
+    "oxygen mask",
+    "packet",
+    "paddle",
+    "paddlewheel",
+    "padlock",
+    "paintbrush",
+    "pajama",
+    "palace",
+    "panpipe",
+    "paper towel",
+    "parachute",
+    "parallel bars",
+    "park bench",
+    "parking meter",
+    "passenger car",
+    "patio",
+    "pay-phone",
+    "pedestal",
+    "pencil box",
+    "pencil sharpener",
+    "perfume",
+    "Petri dish",
+    "photocopier",
+    "pick",
+    "pickelhaube",
+    "picket fence",
+    "pickup",
+    "pier",
+    "piggy bank",
+    "pill bottle",
+    "pillow",
+    "ping-pong ball",
+    "pinwheel",
+    "pirate",
+    "pitcher",
+    "plane",
+    "planetarium",
+    "plastic bag",
+    "plate rack",
+    "plow",
+    "plunger",
+    "Polaroid camera",
+    "pole",
+    "police van",
+    "poncho",
+    "pool table",
+    "pop bottle",
+    "pot",
+    "potter's wheel",
+    "power drill",
+    "prayer rug",
+    "printer",
+    "prison",
+    "projectile",
+    "projector",
+    "puck",
+    "punching bag",
+    "purse",
+    "quill",
+    "quilt",
+    "racer",
+    "racket",
+    "radiator",
+    "radio",
+    "radio telescope",
+    "rain barrel",
+    "recreational vehicle",
+    "reel",
+    "reflex camera",
+    "refrigerator",
+    "remote control",
+    "restaurant",
+    "revolver",
+    "rifle",
+    "rocking chair",
+    "rotisserie",
+    "rubber eraser",
+    "rugby ball",
+    "rule",
+    "running shoe",
+    "safe",
+    "safety pin",
+    "saltshaker",
+    "sandal",
+    "sarong",
+    "sax",
+    "scabbard",
+    "scale",
+    "school bus",
+    "schooner",
+    "scoreboard",
+    "screen",
+    "screw",
+    "screwdriver",
+    "seat belt",
+    "sewing machine",
+    "shield",
+    "shoe shop",
+    "shoji",
+    "shopping basket",
+    "shopping cart",
+    "shovel",
+    "shower cap",
+    "shower curtain",
+    "ski",
+    "ski mask",
+    "sleeping bag",
+    "slide rule",
+    "sliding door",
+    "slot",
+    "snorkel",
+    "snowmobile",
+    "snowplow",
+    "soap dispenser",
+    "soccer ball",
+    "sock",
+    "solar dish",
+    "sombrero",
+    "soup bowl",
+    "space bar",
+    "space heater",
+    "space shuttle",
+    "spatula",
+    "speedboat",
+    "spider web",
+    "spindle",
+    "sports car",
+    "spotlight",
+    "stage",
+    "steam locomotive",
+    "steel arch bridge",
+    "steel drum",
+    "stethoscope",
+    "stole",
+    "stone wall",
+    "stopwatch",
+    "stove",
+    "strainer",
+    "streetcar",
+    "stretcher",
+    "studio couch",
+    "stupa",
+    "submarine",
+    "suit",
+    "sundial",
+    "sunglass",
+    "sunglasses",
+    "sunscreen",
+    "suspension bridge",
+    "swab",
+    "sweatshirt",
+    "swimming trunks",
+    "swing",
+    "switch",
+    "syringe",
+    "table lamp",
+    "tank",
+    "tape player",
+    "teapot",
+    "teddy",
+    "television",
+    "tennis ball",
+    "thatch",
+    "theater curtain",
+    "thimble",
+    "thresher",
+    "throne",
+    "tile roof",
+    "toaster",
+    "tobacco shop",
+    "toilet seat",
+    "torch",
+    "totem pole",
+    "tow truck",
+    "toyshop",
+    "tractor",
+    "trailer truck",
+    "tray",
+    "trench coat",
+    "tricycle",
+    "trimaran",
+    "tripod",
+    "triumphal arch",
+    "trolleybus",
+    "trombone",
+    "tub",
+    "turnstile",
+    "typewriter keyboard",
+    "umbrella",
+    "unicycle",
+    "upright",
+    "vacuum",
+    "vase",
+    "vault",
+    "velvet",
+    "vending machine",
+    "vestment",
+    "viaduct",
+    "violin",
+    "volleyball",
+    "waffle iron",
+    "wall clock",
+    "wallet",
+    "wardrobe",
+    "warplane",
+    "washbasin",
+    "washer",
+    "water bottle",
+    "water jug",
+    "water tower",
+    "whiskey jug",
+    "whistle",
+    "wig",
+    "window screen",
+    "window shade",
+    "Windsor tie",
+    "wine bottle",
+    "wing",
+    "wok",
+    "wooden spoon",
+    "wool",
+    "worm fence",
+    "wreck",
+    "yawl",
+    "yurt",
+    "web site",
+    "comic book",
+    "crossword puzzle",
+    "street sign",
+    "traffic light",
+    "book jacket",
+    "menu",
+    "plate",
+    "guacamole",
+    "consomme",
+    "hot pot",
+    "trifle",
+    "ice cream",
+    "ice lolly",
+    "French loaf",
+    "bagel",
+    "pretzel",
+    "cheeseburger",
+    "hotdog",
+    "mashed potato",
+    "head cabbage",
+    "broccoli",
+    "cauliflower",
+    "zucchini",
+    "spaghetti squash",
+    "acorn squash",
+    "butternut squash",
+    "cucumber",
+    "artichoke",
+    "bell pepper",
+    "cardoon",
+    "mushroom",
+    "Granny Smith",
+    "strawberry",
+    "orange",
+    "lemon",
+    "fig",
+    "pineapple",
+    "banana",
+    "jackfruit",
+    "custard apple",
+    "pomegranate",
+    "hay",
+    "carbonara",
+    "chocolate sauce",
+    "dough",
+    "meat loaf",
+    "pizza",
+    "potpie",
+    "burrito",
+    "red wine",
+    "espresso",
+    "cup",
+    "eggnog",
+    "alp",
+    "bubble",
+    "cliff",
+    "coral reef",
+    "geyser",
+    "lakeside",
+    "promontory",
+    "sandbar",
+    "seashore",
+    "valley",
+    "volcano",
+    "ballplayer",
+    "groom",
+    "scuba diver",
+    "rapeseed",
+    "daisy",
+    "yellow lady's slipper",
+    "corn",
+    "acorn",
+    "hip",
+    "buckeye",
+    "coral fungus",
+    "agaric",
+    "gyromitra",
+    "stinkhorn",
+    "earthstar",
+    "hen-of-the-woods",
+    "bolete",
+    "ear",
+    "toilet tissue",
+]
+
+# To be replaced with torchvision.datasets.info("coco").categories
+_COCO_CATEGORIES = [
+    "__background__",
+    "person",
+    "bicycle",
+    "car",
+    "motorcycle",
+    "airplane",
+    "bus",
+    "train",
+    "truck",
+    "boat",
+    "traffic light",
+    "fire hydrant",
+    "N/A",
+    "stop sign",
+    "parking meter",
+    "bench",
+    "bird",
+    "cat",
+    "dog",
+    "horse",
+    "sheep",
+    "cow",
+    "elephant",
+    "bear",
+    "zebra",
+    "giraffe",
+    "N/A",
+    "backpack",
+    "umbrella",
+    "N/A",
+    "N/A",
+    "handbag",
+    "tie",
+    "suitcase",
+    "frisbee",
+    "skis",
+    "snowboard",
+    "sports ball",
+    "kite",
+    "baseball bat",
+    "baseball glove",
+    "skateboard",
+    "surfboard",
+    "tennis racket",
+    "bottle",
+    "N/A",
+    "wine glass",
+    "cup",
+    "fork",
+    "knife",
+    "spoon",
+    "bowl",
+    "banana",
+    "apple",
+    "sandwich",
+    "orange",
+    "broccoli",
+    "carrot",
+    "hot dog",
+    "pizza",
+    "donut",
+    "cake",
+    "chair",
+    "couch",
+    "potted plant",
+    "bed",
+    "N/A",
+    "dining table",
+    "N/A",
+    "N/A",
+    "toilet",
+    "N/A",
+    "tv",
+    "laptop",
+    "mouse",
+    "remote",
+    "keyboard",
+    "cell phone",
+    "microwave",
+    "oven",
+    "toaster",
+    "sink",
+    "refrigerator",
+    "N/A",
+    "book",
+    "clock",
+    "vase",
+    "scissors",
+    "teddy bear",
+    "hair drier",
+    "toothbrush",
+]
+
+# To be replaced with torchvision.datasets.info("coco_kp")
+_COCO_PERSON_CATEGORIES = ["no person", "person"]
+_COCO_PERSON_KEYPOINT_NAMES = [
+    "nose",
+    "left_eye",
+    "right_eye",
+    "left_ear",
+    "right_ear",
+    "left_shoulder",
+    "right_shoulder",
+    "left_elbow",
+    "right_elbow",
+    "left_wrist",
+    "right_wrist",
+    "left_hip",
+    "right_hip",
+    "left_knee",
+    "right_knee",
+    "left_ankle",
+    "right_ankle",
+]
+
+# To be replaced with torchvision.datasets.info("voc").categories
+_VOC_CATEGORIES = [
+    "__background__",
+    "aeroplane",
+    "bicycle",
+    "bird",
+    "boat",
+    "bottle",
+    "bus",
+    "car",
+    "cat",
+    "chair",
+    "cow",
+    "diningtable",
+    "dog",
+    "horse",
+    "motorbike",
+    "person",
+    "pottedplant",
+    "sheep",
+    "sofa",
+    "train",
+    "tvmonitor",
+]
+
+# To be replaced with torchvision.datasets.info("kinetics400").categories
+_KINETICS400_CATEGORIES = [
+    "abseiling",
+    "air drumming",
+    "answering questions",
+    "applauding",
+    "applying cream",
+    "archery",
+    "arm wrestling",
+    "arranging flowers",
+    "assembling computer",
+    "auctioning",
+    "baby waking up",
+    "baking cookies",
+    "balloon blowing",
+    "bandaging",
+    "barbequing",
+    "bartending",
+    "beatboxing",
+    "bee keeping",
+    "belly dancing",
+    "bench pressing",
+    "bending back",
+    "bending metal",
+    "biking through snow",
+    "blasting sand",
+    "blowing glass",
+    "blowing leaves",
+    "blowing nose",
+    "blowing out candles",
+    "bobsledding",
+    "bookbinding",
+    "bouncing on trampoline",
+    "bowling",
+    "braiding hair",
+    "breading or breadcrumbing",
+    "breakdancing",
+    "brush painting",
+    "brushing hair",
+    "brushing teeth",
+    "building cabinet",
+    "building shed",
+    "bungee jumping",
+    "busking",
+    "canoeing or kayaking",
+    "capoeira",
+    "carrying baby",
+    "cartwheeling",
+    "carving pumpkin",
+    "catching fish",
+    "catching or throwing baseball",
+    "catching or throwing frisbee",
+    "catching or throwing softball",
+    "celebrating",
+    "changing oil",
+    "changing wheel",
+    "checking tires",
+    "cheerleading",
+    "chopping wood",
+    "clapping",
+    "clay pottery making",
+    "clean and jerk",
+    "cleaning floor",
+    "cleaning gutters",
+    "cleaning pool",
+    "cleaning shoes",
+    "cleaning toilet",
+    "cleaning windows",
+    "climbing a rope",
+    "climbing ladder",
+    "climbing tree",
+    "contact juggling",
+    "cooking chicken",
+    "cooking egg",
+    "cooking on campfire",
+    "cooking sausages",
+    "counting money",
+    "country line dancing",
+    "cracking neck",
+    "crawling baby",
+    "crossing river",
+    "crying",
+    "curling hair",
+    "cutting nails",
+    "cutting pineapple",
+    "cutting watermelon",
+    "dancing ballet",
+    "dancing charleston",
+    "dancing gangnam style",
+    "dancing macarena",
+    "deadlifting",
+    "decorating the christmas tree",
+    "digging",
+    "dining",
+    "disc golfing",
+    "diving cliff",
+    "dodgeball",
+    "doing aerobics",
+    "doing laundry",
+    "doing nails",
+    "drawing",
+    "dribbling basketball",
+    "drinking",
+    "drinking beer",
+    "drinking shots",
+    "driving car",
+    "driving tractor",
+    "drop kicking",
+    "drumming fingers",
+    "dunking basketball",
+    "dying hair",
+    "eating burger",
+    "eating cake",
+    "eating carrots",
+    "eating chips",
+    "eating doughnuts",
+    "eating hotdog",
+    "eating ice cream",
+    "eating spaghetti",
+    "eating watermelon",
+    "egg hunting",
+    "exercising arm",
+    "exercising with an exercise ball",
+    "extinguishing fire",
+    "faceplanting",
+    "feeding birds",
+    "feeding fish",
+    "feeding goats",
+    "filling eyebrows",
+    "finger snapping",
+    "fixing hair",
+    "flipping pancake",
+    "flying kite",
+    "folding clothes",
+    "folding napkins",
+    "folding paper",
+    "front raises",
+    "frying vegetables",
+    "garbage collecting",
+    "gargling",
+    "getting a haircut",
+    "getting a tattoo",
+    "giving or receiving award",
+    "golf chipping",
+    "golf driving",
+    "golf putting",
+    "grinding meat",
+    "grooming dog",
+    "grooming horse",
+    "gymnastics tumbling",
+    "hammer throw",
+    "headbanging",
+    "headbutting",
+    "high jump",
+    "high kick",
+    "hitting baseball",
+    "hockey stop",
+    "holding snake",
+    "hopscotch",
+    "hoverboarding",
+    "hugging",
+    "hula hooping",
+    "hurdling",
+    "hurling (sport)",
+    "ice climbing",
+    "ice fishing",
+    "ice skating",
+    "ironing",
+    "javelin throw",
+    "jetskiing",
+    "jogging",
+    "juggling balls",
+    "juggling fire",
+    "juggling soccer ball",
+    "jumping into pool",
+    "jumpstyle dancing",
+    "kicking field goal",
+    "kicking soccer ball",
+    "kissing",
+    "kitesurfing",
+    "knitting",
+    "krumping",
+    "laughing",
+    "laying bricks",
+    "long jump",
+    "lunge",
+    "making a cake",
+    "making a sandwich",
+    "making bed",
+    "making jewelry",
+    "making pizza",
+    "making snowman",
+    "making sushi",
+    "making tea",
+    "marching",
+    "massaging back",
+    "massaging feet",
+    "massaging legs",
+    "massaging person's head",
+    "milking cow",
+    "mopping floor",
+    "motorcycling",
+    "moving furniture",
+    "mowing lawn",
+    "news anchoring",
+    "opening bottle",
+    "opening present",
+    "paragliding",
+    "parasailing",
+    "parkour",
+    "passing American football (in game)",
+    "passing American football (not in game)",
+    "peeling apples",
+    "peeling potatoes",
+    "petting animal (not cat)",
+    "petting cat",
+    "picking fruit",
+    "planting trees",
+    "plastering",
+    "playing accordion",
+    "playing badminton",
+    "playing bagpipes",
+    "playing basketball",
+    "playing bass guitar",
+    "playing cards",
+    "playing cello",
+    "playing chess",
+    "playing clarinet",
+    "playing controller",
+    "playing cricket",
+    "playing cymbals",
+    "playing didgeridoo",
+    "playing drums",
+    "playing flute",
+    "playing guitar",
+    "playing harmonica",
+    "playing harp",
+    "playing ice hockey",
+    "playing keyboard",
+    "playing kickball",
+    "playing monopoly",
+    "playing organ",
+    "playing paintball",
+    "playing piano",
+    "playing poker",
+    "playing recorder",
+    "playing saxophone",
+    "playing squash or racquetball",
+    "playing tennis",
+    "playing trombone",
+    "playing trumpet",
+    "playing ukulele",
+    "playing violin",
+    "playing volleyball",
+    "playing xylophone",
+    "pole vault",
+    "presenting weather forecast",
+    "pull ups",
+    "pumping fist",
+    "pumping gas",
+    "punching bag",
+    "punching person (boxing)",
+    "push up",
+    "pushing car",
+    "pushing cart",
+    "pushing wheelchair",
+    "reading book",
+    "reading newspaper",
+    "recording music",
+    "riding a bike",
+    "riding camel",
+    "riding elephant",
+    "riding mechanical bull",
+    "riding mountain bike",
+    "riding mule",
+    "riding or walking with horse",
+    "riding scooter",
+    "riding unicycle",
+    "ripping paper",
+    "robot dancing",
+    "rock climbing",
+    "rock scissors paper",
+    "roller skating",
+    "running on treadmill",
+    "sailing",
+    "salsa dancing",
+    "sanding floor",
+    "scrambling eggs",
+    "scuba diving",
+    "setting table",
+    "shaking hands",
+    "shaking head",
+    "sharpening knives",
+    "sharpening pencil",
+    "shaving head",
+    "shaving legs",
+    "shearing sheep",
+    "shining shoes",
+    "shooting basketball",
+    "shooting goal (soccer)",
+    "shot put",
+    "shoveling snow",
+    "shredding paper",
+    "shuffling cards",
+    "side kick",
+    "sign language interpreting",
+    "singing",
+    "situp",
+    "skateboarding",
+    "ski jumping",
+    "skiing (not slalom or crosscountry)",
+    "skiing crosscountry",
+    "skiing slalom",
+    "skipping rope",
+    "skydiving",
+    "slacklining",
+    "slapping",
+    "sled dog racing",
+    "smoking",
+    "smoking hookah",
+    "snatch weight lifting",
+    "sneezing",
+    "sniffing",
+    "snorkeling",
+    "snowboarding",
+    "snowkiting",
+    "snowmobiling",
+    "somersaulting",
+    "spinning poi",
+    "spray painting",
+    "spraying",
+    "springboard diving",
+    "squat",
+    "sticking tongue out",
+    "stomping grapes",
+    "stretching arm",
+    "stretching leg",
+    "strumming guitar",
+    "surfing crowd",
+    "surfing water",
+    "sweeping floor",
+    "swimming backstroke",
+    "swimming breast stroke",
+    "swimming butterfly stroke",
+    "swing dancing",
+    "swinging legs",
+    "swinging on something",
+    "sword fighting",
+    "tai chi",
+    "taking a shower",
+    "tango dancing",
+    "tap dancing",
+    "tapping guitar",
+    "tapping pen",
+    "tasting beer",
+    "tasting food",
+    "testifying",
+    "texting",
+    "throwing axe",
+    "throwing ball",
+    "throwing discus",
+    "tickling",
+    "tobogganing",
+    "tossing coin",
+    "tossing salad",
+    "training dog",
+    "trapezing",
+    "trimming or shaving beard",
+    "trimming trees",
+    "triple jump",
+    "tying bow tie",
+    "tying knot (not on a tie)",
+    "tying tie",
+    "unboxing",
+    "unloading truck",
+    "using computer",
+    "using remote controller (not gaming)",
+    "using segway",
+    "vault",
+    "waiting in line",
+    "walking the dog",
+    "washing dishes",
+    "washing feet",
+    "washing hair",
+    "washing hands",
+    "water skiing",
+    "water sliding",
+    "watering plants",
+    "waxing back",
+    "waxing chest",
+    "waxing eyebrows",
+    "waxing legs",
+    "weaving basket",
+    "welding",
+    "whistling",
+    "windsurfing",
+    "wrapping present",
+    "wrestling",
+    "writing",
+    "yawning",
+    "yoga",
+    "zumba",
+]

+ 256 - 0
libs/vision_libs/models/_utils.py

@@ -0,0 +1,256 @@
+import functools
+import inspect
+import warnings
+from collections import OrderedDict
+from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union
+
+from torch import nn
+
+from .._utils import sequence_to_str
+from ._api import WeightsEnum
+
+
+class IntermediateLayerGetter(nn.ModuleDict):
+    """
+    Module wrapper that returns intermediate layers from a model
+
+    It has a strong assumption that the modules have been registered
+    into the model in the same order as they are used.
+    This means that one should **not** reuse the same nn.Module
+    twice in the forward if you want this to work.
+
+    Additionally, it is only able to query submodules that are directly
+    assigned to the model. So if `model` is passed, `model.feature1` can
+    be returned, but not `model.feature1.layer2`.
+
+    Args:
+        model (nn.Module): model on which we will extract the features
+        return_layers (Dict[name, new_name]): a dict containing the names
+            of the modules for which the activations will be returned as
+            the key of the dict, and the value of the dict is the name
+            of the returned activation (which the user can specify).
+
+    Examples::
+
+        >>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
+        >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
+        >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
+        >>>     {'layer1': 'feat1', 'layer3': 'feat2'})
+        >>> out = new_m(torch.rand(1, 3, 224, 224))
+        >>> print([(k, v.shape) for k, v in out.items()])
+        >>>     [('feat1', torch.Size([1, 64, 56, 56])),
+        >>>      ('feat2', torch.Size([1, 256, 14, 14]))]
+    """
+
+    _version = 2
+    __annotations__ = {
+        "return_layers": Dict[str, str],
+    }
+
+    def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
+        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
+            raise ValueError("return_layers are not present in model")
+        orig_return_layers = return_layers
+        return_layers = {str(k): str(v) for k, v in return_layers.items()}
+        layers = OrderedDict()
+        for name, module in model.named_children():
+            layers[name] = module
+            if name in return_layers:
+                del return_layers[name]
+            if not return_layers:
+                break
+
+        super().__init__(layers)
+        self.return_layers = orig_return_layers
+
+    def forward(self, x):
+        out = OrderedDict()
+        for name, module in self.items():
+            x = module(x)
+            if name in self.return_layers:
+                out_name = self.return_layers[name]
+                out[out_name] = x
+        return out
+
+
+def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
+    """
+    This function is taken from the original tf repo.
+    It ensures that all layers have a channel number that is divisible by 8
+    It can be seen here:
+    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+    """
+    if min_value is None:
+        min_value = divisor
+    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+    # Make sure that round down does not go down by more than 10%.
+    if new_v < 0.9 * v:
+        new_v += divisor
+    return new_v
+
+
+D = TypeVar("D")
+
+
+def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
+    """Decorates a function that uses keyword only parameters to also allow them being passed as positionals.
+
+    For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:
+
+    .. code::
+
+        def old_fn(foo, bar, baz=None):
+            ...
+
+        def new_fn(foo, *, bar, baz=None):
+            ...
+
+    Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
+    and at the same time warn the user of the deprecation, this decorator can be used:
+
+    .. code::
+
+        @kwonly_to_pos_or_kw
+        def new_fn(foo, *, bar, baz=None):
+            ...
+
+        new_fn("foo", "bar, "baz")
+    """
+    params = inspect.signature(fn).parameters
+
+    try:
+        keyword_only_start_idx = next(
+            idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
+        )
+    except StopIteration:
+        raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None
+
+    keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]
+
+    @functools.wraps(fn)
+    def wrapper(*args: Any, **kwargs: Any) -> D:
+        args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
+        if keyword_only_args:
+            keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
+            warnings.warn(
+                f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
+                f"parameter(s) is deprecated since 0.13 and may be removed in the future. Please use keyword parameter(s) "
+                f"instead."
+            )
+            kwargs.update(keyword_only_kwargs)
+
+        return fn(*args, **kwargs)
+
+    return wrapper
+
+
+W = TypeVar("W", bound=WeightsEnum)
+M = TypeVar("M", bound=nn.Module)
+V = TypeVar("V")
+
+
+def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
+    """Decorates a model builder with the new interface to make it compatible with the old.
+
+    In particular this handles two things:
+
+    1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
+        :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
+    2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
+        ``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
+
+    Args:
+        **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
+            name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
+            case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
+            the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
+            should be accessed with :meth:`~dict.get`.
+    """
+
+    def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
+        @kwonly_to_pos_or_kw
+        @functools.wraps(builder)
+        def inner_wrapper(*args: Any, **kwargs: Any) -> M:
+            for weights_param, (pretrained_param, default) in weights.items():  # type: ignore[union-attr]
+                # If neither the weights nor the pretrained parameter as passed, or the weights argument already use
+                # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
+                # weight argument, since it is a valid value.
+                sentinel = object()
+                weights_arg = kwargs.get(weights_param, sentinel)
+                if (
+                    (weights_param not in kwargs and pretrained_param not in kwargs)
+                    or isinstance(weights_arg, WeightsEnum)
+                    or (isinstance(weights_arg, str) and weights_arg != "legacy")
+                    or weights_arg is None
+                ):
+                    continue
+
+                # If the pretrained parameter was passed as positional argument, it is now mapped to
+                # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
+                # signature to infer the names of positionally passed arguments and thus has no knowledge that there
+                # used to be a pretrained parameter.
+                pretrained_positional = weights_arg is not sentinel
+                if pretrained_positional:
+                    # We put the pretrained argument under its legacy name in the keyword argument dictionary to have
+                    # unified access to the value if the default value is a callable.
+                    kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
+                else:
+                    pretrained_arg = kwargs[pretrained_param]
+
+                if pretrained_arg:
+                    default_weights_arg = default(kwargs) if callable(default) else default
+                    if not isinstance(default_weights_arg, WeightsEnum):
+                        raise ValueError(f"No weights available for model {builder.__name__}")
+                else:
+                    default_weights_arg = None
+
+                if not pretrained_positional:
+                    warnings.warn(
+                        f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
+                        f"please use '{weights_param}' instead."
+                    )
+
+                msg = (
+                    f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated since 0.13 and "
+                    f"may be removed in the future. "
+                    f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
+                )
+                if pretrained_arg:
+                    msg = (
+                        f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
+                        f"to get the most up-to-date weights."
+                    )
+                warnings.warn(msg)
+
+                del kwargs[pretrained_param]
+                kwargs[weights_param] = default_weights_arg
+
+            return builder(*args, **kwargs)
+
+        return inner_wrapper
+
+    return outer_wrapper
+
+
+def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
+    if param in kwargs:
+        if kwargs[param] != new_value:
+            raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.")
+    else:
+        kwargs[param] = new_value
+
+
+def _ovewrite_value_param(param: str, actual: Optional[V], expected: V) -> V:
+    if actual is not None:
+        if actual != expected:
+            raise ValueError(f"The parameter '{param}' expected value {expected} but got {actual} instead.")
+    return expected
+
+
+class _ModelURLs(dict):
+    def __getitem__(self, item):
+        warnings.warn(
+            "Accessing the model URLs via the internal dictionary of the module is deprecated since 0.13 and may "
+            "be removed in the future. Please access them via the appropriate Weights Enum instead."
+        )
+        return super().__getitem__(item)

+ 119 - 0
libs/vision_libs/models/alexnet.py

@@ -0,0 +1,119 @@
+from functools import partial
+from typing import Any, Optional
+
+import torch
+import torch.nn as nn
+
+from ..transforms._presets import ImageClassification
+from ..utils import _log_api_usage_once
+from ._api import register_model, Weights, WeightsEnum
+from ._meta import _IMAGENET_CATEGORIES
+from ._utils import _ovewrite_named_param, handle_legacy_interface
+
+
+__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
+
+
+class AlexNet(nn.Module):
+    def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
+        super().__init__()
+        _log_api_usage_once(self)
+        self.features = nn.Sequential(
+            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+            nn.Conv2d(64, 192, kernel_size=5, padding=2),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+            nn.Conv2d(192, 384, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(384, 256, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(256, 256, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+        )
+        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
+        self.classifier = nn.Sequential(
+            nn.Dropout(p=dropout),
+            nn.Linear(256 * 6 * 6, 4096),
+            nn.ReLU(inplace=True),
+            nn.Dropout(p=dropout),
+            nn.Linear(4096, 4096),
+            nn.ReLU(inplace=True),
+            nn.Linear(4096, num_classes),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.features(x)
+        x = self.avgpool(x)
+        x = torch.flatten(x, 1)
+        x = self.classifier(x)
+        return x
+
+
+class AlexNet_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            "num_params": 61100840,
+            "min_size": (63, 63),
+            "categories": _IMAGENET_CATEGORIES,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 56.522,
+                    "acc@5": 79.066,
+                }
+            },
+            "_ops": 0.714,
+            "_file_size": 233.087,
+            "_docs": """
+                These weights reproduce closely the results of the paper using a simplified training recipe.
+            """,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
+def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
+    """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__.
+
+    .. note::
+        AlexNet was originally introduced in the `ImageNet Classification with
+        Deep Convolutional Neural Networks
+        <https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html>`__
+        paper. Our implementation is based instead on the "One weird trick"
+        paper above.
+
+    Args:
+        weights (:class:`~torchvision.models.AlexNet_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.AlexNet_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.squeezenet.AlexNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.AlexNet_Weights
+        :members:
+    """
+
+    weights = AlexNet_Weights.verify(weights)
+
+    if weights is not None:
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+    model = AlexNet(**kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model

+ 414 - 0
libs/vision_libs/models/convnext.py

@@ -0,0 +1,414 @@
+from functools import partial
+from typing import Any, Callable, List, Optional, Sequence
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+from ..ops.misc import Conv2dNormActivation, Permute
+from ..ops.stochastic_depth import StochasticDepth
+from ..transforms._presets import ImageClassification
+from ..utils import _log_api_usage_once
+from ._api import register_model, Weights, WeightsEnum
+from ._meta import _IMAGENET_CATEGORIES
+from ._utils import _ovewrite_named_param, handle_legacy_interface
+
+
+__all__ = [
+    "ConvNeXt",
+    "ConvNeXt_Tiny_Weights",
+    "ConvNeXt_Small_Weights",
+    "ConvNeXt_Base_Weights",
+    "ConvNeXt_Large_Weights",
+    "convnext_tiny",
+    "convnext_small",
+    "convnext_base",
+    "convnext_large",
+]
+
+
+class LayerNorm2d(nn.LayerNorm):
+    def forward(self, x: Tensor) -> Tensor:
+        x = x.permute(0, 2, 3, 1)
+        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+        x = x.permute(0, 3, 1, 2)
+        return x
+
+
+class CNBlock(nn.Module):
+    def __init__(
+        self,
+        dim,
+        layer_scale: float,
+        stochastic_depth_prob: float,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+        if norm_layer is None:
+            norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+        self.block = nn.Sequential(
+            nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
+            Permute([0, 2, 3, 1]),
+            norm_layer(dim),
+            nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
+            nn.GELU(),
+            nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
+            Permute([0, 3, 1, 2]),
+        )
+        self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
+        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
+
+    def forward(self, input: Tensor) -> Tensor:
+        result = self.layer_scale * self.block(input)
+        result = self.stochastic_depth(result)
+        result += input
+        return result
+
+
+class CNBlockConfig:
+    # Stores information listed at Section 3 of the ConvNeXt paper
+    def __init__(
+        self,
+        input_channels: int,
+        out_channels: Optional[int],
+        num_layers: int,
+    ) -> None:
+        self.input_channels = input_channels
+        self.out_channels = out_channels
+        self.num_layers = num_layers
+
+    def __repr__(self) -> str:
+        s = self.__class__.__name__ + "("
+        s += "input_channels={input_channels}"
+        s += ", out_channels={out_channels}"
+        s += ", num_layers={num_layers}"
+        s += ")"
+        return s.format(**self.__dict__)
+
+
+class ConvNeXt(nn.Module):
+    def __init__(
+        self,
+        block_setting: List[CNBlockConfig],
+        stochastic_depth_prob: float = 0.0,
+        layer_scale: float = 1e-6,
+        num_classes: int = 1000,
+        block: Optional[Callable[..., nn.Module]] = None,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__()
+        _log_api_usage_once(self)
+
+        if not block_setting:
+            raise ValueError("The block_setting should not be empty")
+        elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
+            raise TypeError("The block_setting should be List[CNBlockConfig]")
+
+        if block is None:
+            block = CNBlock
+
+        if norm_layer is None:
+            norm_layer = partial(LayerNorm2d, eps=1e-6)
+
+        layers: List[nn.Module] = []
+
+        # Stem
+        firstconv_output_channels = block_setting[0].input_channels
+        layers.append(
+            Conv2dNormActivation(
+                3,
+                firstconv_output_channels,
+                kernel_size=4,
+                stride=4,
+                padding=0,
+                norm_layer=norm_layer,
+                activation_layer=None,
+                bias=True,
+            )
+        )
+
+        total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
+        stage_block_id = 0
+        for cnf in block_setting:
+            # Bottlenecks
+            stage: List[nn.Module] = []
+            for _ in range(cnf.num_layers):
+                # adjust stochastic depth probability based on the depth of the stage block
+                sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
+                stage.append(block(cnf.input_channels, layer_scale, sd_prob))
+                stage_block_id += 1
+            layers.append(nn.Sequential(*stage))
+            if cnf.out_channels is not None:
+                # Downsampling
+                layers.append(
+                    nn.Sequential(
+                        norm_layer(cnf.input_channels),
+                        nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
+                    )
+                )
+
+        self.features = nn.Sequential(*layers)
+        self.avgpool = nn.AdaptiveAvgPool2d(1)
+
+        lastblock = block_setting[-1]
+        lastconv_output_channels = (
+            lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
+        )
+        self.classifier = nn.Sequential(
+            norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
+        )
+
+        for m in self.modules():
+            if isinstance(m, (nn.Conv2d, nn.Linear)):
+                nn.init.trunc_normal_(m.weight, std=0.02)
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+
+    def _forward_impl(self, x: Tensor) -> Tensor:
+        x = self.features(x)
+        x = self.avgpool(x)
+        x = self.classifier(x)
+        return x
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self._forward_impl(x)
+
+
+def _convnext(
+    block_setting: List[CNBlockConfig],
+    stochastic_depth_prob: float,
+    weights: Optional[WeightsEnum],
+    progress: bool,
+    **kwargs: Any,
+) -> ConvNeXt:
+    if weights is not None:
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+    model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+_COMMON_META = {
+    "min_size": (32, 32),
+    "categories": _IMAGENET_CATEGORIES,
+    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
+    "_docs": """
+        These weights improve upon the results of the original paper by using a modified version of TorchVision's
+        `new training recipe
+        <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
+    """,
+}
+
+
+class ConvNeXt_Tiny_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
+        transforms=partial(ImageClassification, crop_size=224, resize_size=236),
+        meta={
+            **_COMMON_META,
+            "num_params": 28589128,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 82.520,
+                    "acc@5": 96.146,
+                }
+            },
+            "_ops": 4.456,
+            "_file_size": 109.119,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class ConvNeXt_Small_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
+        transforms=partial(ImageClassification, crop_size=224, resize_size=230),
+        meta={
+            **_COMMON_META,
+            "num_params": 50223688,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 83.616,
+                    "acc@5": 96.650,
+                }
+            },
+            "_ops": 8.684,
+            "_file_size": 191.703,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class ConvNeXt_Base_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
+        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+        meta={
+            **_COMMON_META,
+            "num_params": 88591464,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 84.062,
+                    "acc@5": 96.870,
+                }
+            },
+            "_ops": 15.355,
+            "_file_size": 338.064,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class ConvNeXt_Large_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
+        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+        meta={
+            **_COMMON_META,
+            "num_params": 197767336,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 84.414,
+                    "acc@5": 96.976,
+                }
+            },
+            "_ops": 34.361,
+            "_file_size": 754.537,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
+def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
+    """ConvNeXt Tiny model architecture from the
+    `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`, optional): The pretrained
+            weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`
+            below for more details and possible values. By default, no pre-trained weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.ConvNeXt_Tiny_Weights
+        :members:
+    """
+    weights = ConvNeXt_Tiny_Weights.verify(weights)
+
+    block_setting = [
+        CNBlockConfig(96, 192, 3),
+        CNBlockConfig(192, 384, 3),
+        CNBlockConfig(384, 768, 9),
+        CNBlockConfig(768, None, 3),
+    ]
+    stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
+    return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
+def convnext_small(
+    *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
+) -> ConvNeXt:
+    """ConvNeXt Small model architecture from the
+    `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`, optional): The pretrained
+            weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`
+            below for more details and possible values. By default, no pre-trained weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.ConvNeXt_Small_Weights
+        :members:
+    """
+    weights = ConvNeXt_Small_Weights.verify(weights)
+
+    block_setting = [
+        CNBlockConfig(96, 192, 3),
+        CNBlockConfig(192, 384, 3),
+        CNBlockConfig(384, 768, 27),
+        CNBlockConfig(768, None, 3),
+    ]
+    stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
+    return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
+def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
+    """ConvNeXt Base model architecture from the
+    `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`, optional): The pretrained
+            weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`
+            below for more details and possible values. By default, no pre-trained weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.ConvNeXt_Base_Weights
+        :members:
+    """
+    weights = ConvNeXt_Base_Weights.verify(weights)
+
+    block_setting = [
+        CNBlockConfig(128, 256, 3),
+        CNBlockConfig(256, 512, 3),
+        CNBlockConfig(512, 1024, 27),
+        CNBlockConfig(1024, None, 3),
+    ]
+    stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
+    return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
+def convnext_large(
+    *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
+) -> ConvNeXt:
+    """ConvNeXt Large model architecture from the
+    `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`, optional): The pretrained
+            weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`
+            below for more details and possible values. By default, no pre-trained weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.ConvNeXt_Large_Weights
+        :members:
+    """
+    weights = ConvNeXt_Large_Weights.verify(weights)
+
+    block_setting = [
+        CNBlockConfig(192, 384, 3),
+        CNBlockConfig(384, 768, 3),
+        CNBlockConfig(768, 1536, 27),
+        CNBlockConfig(1536, None, 3),
+    ]
+    stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
+    return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)

+ 448 - 0
libs/vision_libs/models/densenet.py

@@ -0,0 +1,448 @@
+import re
+from collections import OrderedDict
+from functools import partial
+from typing import Any, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from torch import Tensor
+
+from ..transforms._presets import ImageClassification
+from ..utils import _log_api_usage_once
+from ._api import register_model, Weights, WeightsEnum
+from ._meta import _IMAGENET_CATEGORIES
+from ._utils import _ovewrite_named_param, handle_legacy_interface
+
+__all__ = [
+    "DenseNet",
+    "DenseNet121_Weights",
+    "DenseNet161_Weights",
+    "DenseNet169_Weights",
+    "DenseNet201_Weights",
+    "densenet121",
+    "densenet161",
+    "densenet169",
+    "densenet201",
+]
+
+
+class _DenseLayer(nn.Module):
+    def __init__(
+        self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
+    ) -> None:
+        super().__init__()
+        self.norm1 = nn.BatchNorm2d(num_input_features)
+        self.relu1 = nn.ReLU(inplace=True)
+        self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
+
+        self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
+        self.relu2 = nn.ReLU(inplace=True)
+        self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
+
+        self.drop_rate = float(drop_rate)
+        self.memory_efficient = memory_efficient
+
+    def bn_function(self, inputs: List[Tensor]) -> Tensor:
+        concated_features = torch.cat(inputs, 1)
+        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
+        return bottleneck_output
+
+    # todo: rewrite when torchscript supports any
+    def any_requires_grad(self, input: List[Tensor]) -> bool:
+        for tensor in input:
+            if tensor.requires_grad:
+                return True
+        return False
+
+    @torch.jit.unused  # noqa: T484
+    def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
+        def closure(*inputs):
+            return self.bn_function(inputs)
+
+        return cp.checkpoint(closure, *input)
+
+    @torch.jit._overload_method  # noqa: F811
+    def forward(self, input: List[Tensor]) -> Tensor:  # noqa: F811
+        pass
+
+    @torch.jit._overload_method  # noqa: F811
+    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
+        pass
+
+    # torchscript does not yet support *args, so we overload method
+    # allowing it to take either a List[Tensor] or single Tensor
+    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
+        if isinstance(input, Tensor):
+            prev_features = [input]
+        else:
+            prev_features = input
+
+        if self.memory_efficient and self.any_requires_grad(prev_features):
+            if torch.jit.is_scripting():
+                raise Exception("Memory Efficient not supported in JIT")
+
+            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
+        else:
+            bottleneck_output = self.bn_function(prev_features)
+
+        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
+        if self.drop_rate > 0:
+            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
+        return new_features
+
+
+class _DenseBlock(nn.ModuleDict):
+    _version = 2
+
+    def __init__(
+        self,
+        num_layers: int,
+        num_input_features: int,
+        bn_size: int,
+        growth_rate: int,
+        drop_rate: float,
+        memory_efficient: bool = False,
+    ) -> None:
+        super().__init__()
+        for i in range(num_layers):
+            layer = _DenseLayer(
+                num_input_features + i * growth_rate,
+                growth_rate=growth_rate,
+                bn_size=bn_size,
+                drop_rate=drop_rate,
+                memory_efficient=memory_efficient,
+            )
+            self.add_module("denselayer%d" % (i + 1), layer)
+
+    def forward(self, init_features: Tensor) -> Tensor:
+        features = [init_features]
+        for name, layer in self.items():
+            new_features = layer(features)
+            features.append(new_features)
+        return torch.cat(features, 1)
+
+
+class _Transition(nn.Sequential):
+    def __init__(self, num_input_features: int, num_output_features: int) -> None:
+        super().__init__()
+        self.norm = nn.BatchNorm2d(num_input_features)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)
+        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
+
+
+class DenseNet(nn.Module):
+    r"""Densenet-BC model class, based on
+    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
+
+    Args:
+        growth_rate (int) - how many filters to add each layer (`k` in paper)
+        block_config (list of 4 ints) - how many layers in each pooling block
+        num_init_features (int) - the number of filters to learn in the first convolution layer
+        bn_size (int) - multiplicative factor for number of bottle neck layers
+          (i.e. bn_size * k features in the bottleneck layer)
+        drop_rate (float) - dropout rate after each dense layer
+        num_classes (int) - number of classification classes
+        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
+          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
+    """
+
+    def __init__(
+        self,
+        growth_rate: int = 32,
+        block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
+        num_init_features: int = 64,
+        bn_size: int = 4,
+        drop_rate: float = 0,
+        num_classes: int = 1000,
+        memory_efficient: bool = False,
+    ) -> None:
+
+        super().__init__()
+        _log_api_usage_once(self)
+
+        # First convolution
+        self.features = nn.Sequential(
+            OrderedDict(
+                [
+                    ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
+                    ("norm0", nn.BatchNorm2d(num_init_features)),
+                    ("relu0", nn.ReLU(inplace=True)),
+                    ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
+                ]
+            )
+        )
+
+        # Each denseblock
+        num_features = num_init_features
+        for i, num_layers in enumerate(block_config):
+            block = _DenseBlock(
+                num_layers=num_layers,
+                num_input_features=num_features,
+                bn_size=bn_size,
+                growth_rate=growth_rate,
+                drop_rate=drop_rate,
+                memory_efficient=memory_efficient,
+            )
+            self.features.add_module("denseblock%d" % (i + 1), block)
+            num_features = num_features + num_layers * growth_rate
+            if i != len(block_config) - 1:
+                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
+                self.features.add_module("transition%d" % (i + 1), trans)
+                num_features = num_features // 2
+
+        # Final batch norm
+        self.features.add_module("norm5", nn.BatchNorm2d(num_features))
+
+        # Linear layer
+        self.classifier = nn.Linear(num_features, num_classes)
+
+        # Official init from torch repo.
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x: Tensor) -> Tensor:
+        features = self.features(x)
+        out = F.relu(features, inplace=True)
+        out = F.adaptive_avg_pool2d(out, (1, 1))
+        out = torch.flatten(out, 1)
+        out = self.classifier(out)
+        return out
+
+
+def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
+    # '.'s are no longer allowed in module names, but previous _DenseLayer
+    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
+    # They are also in the checkpoints in model_urls. This pattern is used
+    # to find such keys.
+    pattern = re.compile(
+        r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
+    )
+
+    state_dict = weights.get_state_dict(progress=progress, check_hash=True)
+    for key in list(state_dict.keys()):
+        res = pattern.match(key)
+        if res:
+            new_key = res.group(1) + res.group(2)
+            state_dict[new_key] = state_dict[key]
+            del state_dict[key]
+    model.load_state_dict(state_dict)
+
+
+def _densenet(
+    growth_rate: int,
+    block_config: Tuple[int, int, int, int],
+    num_init_features: int,
+    weights: Optional[WeightsEnum],
+    progress: bool,
+    **kwargs: Any,
+) -> DenseNet:
+    if weights is not None:
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
+
+    if weights is not None:
+        _load_state_dict(model=model, weights=weights, progress=progress)
+
+    return model
+
+
+_COMMON_META = {
+    "min_size": (29, 29),
+    "categories": _IMAGENET_CATEGORIES,
+    "recipe": "https://github.com/pytorch/vision/pull/116",
+    "_docs": """These weights are ported from LuaTorch.""",
+}
+
+
+class DenseNet121_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "num_params": 7978856,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 74.434,
+                    "acc@5": 91.972,
+                }
+            },
+            "_ops": 2.834,
+            "_file_size": 30.845,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class DenseNet161_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "num_params": 28681000,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 77.138,
+                    "acc@5": 93.560,
+                }
+            },
+            "_ops": 7.728,
+            "_file_size": 110.369,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class DenseNet169_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "num_params": 14149480,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 75.600,
+                    "acc@5": 92.806,
+                }
+            },
+            "_ops": 3.36,
+            "_file_size": 54.708,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class DenseNet201_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/densenet201-c1103571.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "num_params": 20013928,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 76.896,
+                    "acc@5": 93.370,
+                }
+            },
+            "_ops": 4.291,
+            "_file_size": 77.373,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
+def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
+    r"""Densenet-121 model from
+    `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.DenseNet121_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.DenseNet121_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.DenseNet121_Weights
+        :members:
+    """
+    weights = DenseNet121_Weights.verify(weights)
+
+    return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
+def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
+    r"""Densenet-161 model from
+    `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.DenseNet161_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.DenseNet161_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.DenseNet161_Weights
+        :members:
+    """
+    weights = DenseNet161_Weights.verify(weights)
+
+    return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
+def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
+    r"""Densenet-169 model from
+    `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.DenseNet169_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.DenseNet169_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.DenseNet169_Weights
+        :members:
+    """
+    weights = DenseNet169_Weights.verify(weights)
+
+    return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
+def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
+    r"""Densenet-201 model from
+    `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.DenseNet201_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.DenseNet201_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.DenseNet201_Weights
+        :members:
+    """
+    weights = DenseNet201_Weights.verify(weights)
+
+    return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)

+ 7 - 0
libs/vision_libs/models/detection/__init__.py

@@ -0,0 +1,7 @@
+from .faster_rcnn import *
+from .fcos import *
+from .keypoint_rcnn import *
+from .mask_rcnn import *
+from .retinanet import *
+from .ssd import *
+from .ssdlite import *

+ 540 - 0
libs/vision_libs/models/detection/_utils.py

@@ -0,0 +1,540 @@
+import math
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
+
+
+class BalancedPositiveNegativeSampler:
+    """
+    This class samples batches, ensuring that they contain a fixed proportion of positives
+    """
+
+    def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
+        """
+        Args:
+            batch_size_per_image (int): number of elements to be selected per image
+            positive_fraction (float): percentage of positive elements per batch
+        """
+        self.batch_size_per_image = batch_size_per_image
+        self.positive_fraction = positive_fraction
+
+    def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
+        """
+        Args:
+            matched_idxs: list of tensors containing -1, 0 or positive values.
+                Each tensor corresponds to a specific image.
+                -1 values are ignored, 0 are considered as negatives and > 0 as
+                positives.
+
+        Returns:
+            pos_idx (list[tensor])
+            neg_idx (list[tensor])
+
+        Returns two lists of binary masks for each image.
+        The first list contains the positive elements that were selected,
+        and the second list the negative example.
+        """
+        pos_idx = []
+        neg_idx = []
+        for matched_idxs_per_image in matched_idxs:
+            positive = torch.where(matched_idxs_per_image >= 1)[0]
+            negative = torch.where(matched_idxs_per_image == 0)[0]
+
+            num_pos = int(self.batch_size_per_image * self.positive_fraction)
+            # protect against not enough positive examples
+            num_pos = min(positive.numel(), num_pos)
+            num_neg = self.batch_size_per_image - num_pos
+            # protect against not enough negative examples
+            num_neg = min(negative.numel(), num_neg)
+
+            # randomly select positive and negative examples
+            perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
+            perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
+
+            pos_idx_per_image = positive[perm1]
+            neg_idx_per_image = negative[perm2]
+
+            # create binary mask from indices
+            pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+            neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+
+            pos_idx_per_image_mask[pos_idx_per_image] = 1
+            neg_idx_per_image_mask[neg_idx_per_image] = 1
+
+            pos_idx.append(pos_idx_per_image_mask)
+            neg_idx.append(neg_idx_per_image_mask)
+
+        return pos_idx, neg_idx
+
+
+@torch.jit._script_if_tracing
+def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
+    """
+    Encode a set of proposals with respect to some
+    reference boxes
+
+    Args:
+        reference_boxes (Tensor): reference boxes
+        proposals (Tensor): boxes to be encoded
+        weights (Tensor[4]): the weights for ``(x, y, w, h)``
+    """
+
+    # perform some unpacking to make it JIT-fusion friendly
+    wx = weights[0]
+    wy = weights[1]
+    ww = weights[2]
+    wh = weights[3]
+
+    proposals_x1 = proposals[:, 0].unsqueeze(1)
+    proposals_y1 = proposals[:, 1].unsqueeze(1)
+    proposals_x2 = proposals[:, 2].unsqueeze(1)
+    proposals_y2 = proposals[:, 3].unsqueeze(1)
+
+    reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
+    reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
+    reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
+    reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
+
+    # implementation starts here
+    ex_widths = proposals_x2 - proposals_x1
+    ex_heights = proposals_y2 - proposals_y1
+    ex_ctr_x = proposals_x1 + 0.5 * ex_widths
+    ex_ctr_y = proposals_y1 + 0.5 * ex_heights
+
+    gt_widths = reference_boxes_x2 - reference_boxes_x1
+    gt_heights = reference_boxes_y2 - reference_boxes_y1
+    gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
+    gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
+
+    targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
+    targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
+    targets_dw = ww * torch.log(gt_widths / ex_widths)
+    targets_dh = wh * torch.log(gt_heights / ex_heights)
+
+    targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
+    return targets
+
+
+class BoxCoder:
+    """
+    This class encodes and decodes a set of bounding boxes into
+    the representation used for training the regressors.
+    """
+
+    def __init__(
+        self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
+    ) -> None:
+        """
+        Args:
+            weights (4-element tuple)
+            bbox_xform_clip (float)
+        """
+        self.weights = weights
+        self.bbox_xform_clip = bbox_xform_clip
+
+    def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
+        boxes_per_image = [len(b) for b in reference_boxes]
+        reference_boxes = torch.cat(reference_boxes, dim=0)
+        proposals = torch.cat(proposals, dim=0)
+        targets = self.encode_single(reference_boxes, proposals)
+        return targets.split(boxes_per_image, 0)
+
+    def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+        """
+        Encode a set of proposals with respect to some
+        reference boxes
+
+        Args:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+        """
+        dtype = reference_boxes.dtype
+        device = reference_boxes.device
+        weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
+        targets = encode_boxes(reference_boxes, proposals, weights)
+
+        return targets
+
+    def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
+        torch._assert(
+            isinstance(boxes, (list, tuple)),
+            "This function expects boxes of type list or tuple.",
+        )
+        torch._assert(
+            isinstance(rel_codes, torch.Tensor),
+            "This function expects rel_codes of type torch.Tensor.",
+        )
+        boxes_per_image = [b.size(0) for b in boxes]
+        concat_boxes = torch.cat(boxes, dim=0)
+        box_sum = 0
+        for val in boxes_per_image:
+            box_sum += val
+        if box_sum > 0:
+            rel_codes = rel_codes.reshape(box_sum, -1)
+        pred_boxes = self.decode_single(rel_codes, concat_boxes)
+        if box_sum > 0:
+            pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
+        return pred_boxes
+
+    def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Args:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+        """
+
+        boxes = boxes.to(rel_codes.dtype)
+
+        widths = boxes[:, 2] - boxes[:, 0]
+        heights = boxes[:, 3] - boxes[:, 1]
+        ctr_x = boxes[:, 0] + 0.5 * widths
+        ctr_y = boxes[:, 1] + 0.5 * heights
+
+        wx, wy, ww, wh = self.weights
+        dx = rel_codes[:, 0::4] / wx
+        dy = rel_codes[:, 1::4] / wy
+        dw = rel_codes[:, 2::4] / ww
+        dh = rel_codes[:, 3::4] / wh
+
+        # Prevent sending too large values into torch.exp()
+        dw = torch.clamp(dw, max=self.bbox_xform_clip)
+        dh = torch.clamp(dh, max=self.bbox_xform_clip)
+
+        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
+        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
+        pred_w = torch.exp(dw) * widths[:, None]
+        pred_h = torch.exp(dh) * heights[:, None]
+
+        # Distance from center to box's corner.
+        c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
+        c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
+
+        pred_boxes1 = pred_ctr_x - c_to_c_w
+        pred_boxes2 = pred_ctr_y - c_to_c_h
+        pred_boxes3 = pred_ctr_x + c_to_c_w
+        pred_boxes4 = pred_ctr_y + c_to_c_h
+        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
+        return pred_boxes
+
+
+class BoxLinearCoder:
+    """
+    The linear box-to-box transform defined in FCOS. The transformation is parameterized
+    by the distance from the center of (square) src box to 4 edges of the target box.
+    """
+
+    def __init__(self, normalize_by_size: bool = True) -> None:
+        """
+        Args:
+            normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
+        """
+        self.normalize_by_size = normalize_by_size
+
+    def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+        """
+        Encode a set of proposals with respect to some reference boxes
+
+        Args:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+
+        Returns:
+            Tensor: the encoded relative box offsets that can be used to
+            decode the boxes.
+
+        """
+
+        # get the center of reference_boxes
+        reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
+        reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
+
+        # get box regression transformation deltas
+        target_l = reference_boxes_ctr_x - proposals[..., 0]
+        target_t = reference_boxes_ctr_y - proposals[..., 1]
+        target_r = proposals[..., 2] - reference_boxes_ctr_x
+        target_b = proposals[..., 3] - reference_boxes_ctr_y
+
+        targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
+
+        if self.normalize_by_size:
+            reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
+            reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
+            reference_boxes_size = torch.stack(
+                (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
+            )
+            targets = targets / reference_boxes_size
+        return targets
+
+    def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Args:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+
+        Returns:
+            Tensor: the predicted boxes with the encoded relative box offsets.
+
+        .. note::
+            This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
+
+        """
+
+        boxes = boxes.to(dtype=rel_codes.dtype)
+
+        ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
+        ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
+
+        if self.normalize_by_size:
+            boxes_w = boxes[..., 2] - boxes[..., 0]
+            boxes_h = boxes[..., 3] - boxes[..., 1]
+
+            list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
+            rel_codes = rel_codes * list_box_size
+
+        pred_boxes1 = ctr_x - rel_codes[..., 0]
+        pred_boxes2 = ctr_y - rel_codes[..., 1]
+        pred_boxes3 = ctr_x + rel_codes[..., 2]
+        pred_boxes4 = ctr_y + rel_codes[..., 3]
+
+        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
+        return pred_boxes
+
+
+class Matcher:
+    """
+    This class assigns to each predicted "element" (e.g., a box) a ground-truth
+    element. Each predicted element will have exactly zero or one matches; each
+    ground-truth element may be assigned to zero or more predicted elements.
+
+    Matching is based on the MxN match_quality_matrix, that characterizes how well
+    each (ground-truth, predicted)-pair match. For example, if the elements are
+    boxes, the matrix may contain box IoU overlap values.
+
+    The matcher returns a tensor of size N containing the index of the ground-truth
+    element m that matches to prediction n. If there is no match, a negative value
+    is returned.
+    """
+
+    BELOW_LOW_THRESHOLD = -1
+    BETWEEN_THRESHOLDS = -2
+
+    __annotations__ = {
+        "BELOW_LOW_THRESHOLD": int,
+        "BETWEEN_THRESHOLDS": int,
+    }
+
+    def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
+        """
+        Args:
+            high_threshold (float): quality values greater than or equal to
+                this value are candidate matches.
+            low_threshold (float): a lower quality threshold used to stratify
+                matches into three levels:
+                1) matches >= high_threshold
+                2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
+                3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
+            allow_low_quality_matches (bool): if True, produce additional matches
+                for predictions that have only low-quality match candidates. See
+                set_low_quality_matches_ for more details.
+        """
+        self.BELOW_LOW_THRESHOLD = -1
+        self.BETWEEN_THRESHOLDS = -2
+        torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
+        self.high_threshold = high_threshold
+        self.low_threshold = low_threshold
+        self.allow_low_quality_matches = allow_low_quality_matches
+
+    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+        """
+        Args:
+            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
+            pairwise quality between M ground-truth elements and N predicted elements.
+
+        Returns:
+            matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
+            [0, M - 1] or a negative value indicating that prediction i could not
+            be matched.
+        """
+        if match_quality_matrix.numel() == 0:
+            # empty targets or proposals not supported during training
+            if match_quality_matrix.shape[0] == 0:
+                raise ValueError("No ground-truth boxes available for one of the images during training")
+            else:
+                raise ValueError("No proposal boxes available for one of the images during training")
+
+        # match_quality_matrix is M (gt) x N (predicted)
+        # Max over gt elements (dim 0) to find best gt candidate for each prediction
+        matched_vals, matches = match_quality_matrix.max(dim=0)
+        if self.allow_low_quality_matches:
+            all_matches = matches.clone()
+        else:
+            all_matches = None  # type: ignore[assignment]
+
+        # Assign candidate matches with low quality to negative (unassigned) values
+        below_low_threshold = matched_vals < self.low_threshold
+        between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
+        matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
+        matches[between_thresholds] = self.BETWEEN_THRESHOLDS
+
+        if self.allow_low_quality_matches:
+            if all_matches is None:
+                torch._assert(False, "all_matches should not be None")
+            else:
+                self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
+
+        return matches
+
+    def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
+        """
+        Produce additional matches for predictions that have only low-quality matches.
+        Specifically, for each ground-truth find the set of predictions that have
+        maximum overlap with it (including ties); for each prediction in that set, if
+        it is unmatched, then match it to the ground-truth with which it has the highest
+        quality value.
+        """
+        # For each gt, find the prediction with which it has the highest quality
+        highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
+        # Find the highest quality match available, even if it is low, including ties
+        gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
+        # Example gt_pred_pairs_of_highest_quality:
+        # (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]),
+        #  tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
+        # Each element in the first tensor is a gt index, and each element in second tensor is a prediction index
+        # Note how gt items 1, 2, 3, and 5 each have two ties
+
+        pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
+        matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
+
+
+class SSDMatcher(Matcher):
+    def __init__(self, threshold: float) -> None:
+        super().__init__(threshold, threshold, allow_low_quality_matches=False)
+
+    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+        matches = super().__call__(match_quality_matrix)
+
+        # For each gt, find the prediction with which it has the highest quality
+        _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
+        matches[highest_quality_pred_foreach_gt] = torch.arange(
+            highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
+        )
+
+        return matches
+
+
+def overwrite_eps(model: nn.Module, eps: float) -> None:
+    """
+    This method overwrites the default eps values of all the
+    FrozenBatchNorm2d layers of the model with the provided value.
+    This is necessary to address the BC-breaking change introduced
+    by the bug-fix at pytorch/vision#2933. The overwrite is applied
+    only when the pretrained weights are loaded to maintain compatibility
+    with previous versions.
+
+    Args:
+        model (nn.Module): The model on which we perform the overwrite.
+        eps (float): The new value of eps.
+    """
+    for module in model.modules():
+        if isinstance(module, FrozenBatchNorm2d):
+            module.eps = eps
+
+
+def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
+    """
+    This method retrieves the number of output channels of a specific model.
+
+    Args:
+        model (nn.Module): The model for which we estimate the out_channels.
+            It should return a single Tensor or an OrderedDict[Tensor].
+        size (Tuple[int, int]): The size (wxh) of the input.
+
+    Returns:
+        out_channels (List[int]): A list of the output channels of the model.
+    """
+    in_training = model.training
+    model.eval()
+
+    with torch.no_grad():
+        # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
+        device = next(model.parameters()).device
+        tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
+        features = model(tmp_img)
+        if isinstance(features, torch.Tensor):
+            features = OrderedDict([("0", features)])
+        out_channels = [x.size(1) for x in features.values()]
+
+    if in_training:
+        model.train()
+
+    return out_channels
+
+
+@torch.jit.unused
+def _fake_cast_onnx(v: Tensor) -> int:
+    return v  # type: ignore[return-value]
+
+
+def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
+    """
+    ONNX spec requires the k-value to be less than or equal to the number of inputs along
+    provided dim. Certain models use the number of elements along a particular axis instead of K
+    if K exceeds the number of elements along that axis. Previously, python's min() function was
+    used to determine whether to use the provided k-value or the specified dim axis value.
+
+    However, in cases where the model is being exported in tracing mode, python min() is
+    static causing the model to be traced incorrectly and eventually fail at the topk node.
+    In order to avoid this situation, in tracing mode, torch.min() is used instead.
+
+    Args:
+        input (Tensor): The original input tensor.
+        orig_kval (int): The provided k-value.
+        axis(int): Axis along which we retrieve the input size.
+
+    Returns:
+        min_kval (int): Appropriately selected k-value.
+    """
+    if not torch.jit.is_tracing():
+        return min(orig_kval, input.size(axis))
+    axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
+    min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
+    return _fake_cast_onnx(min_kval)
+
+
+def _box_loss(
+    type: str,
+    box_coder: BoxCoder,
+    anchors_per_image: Tensor,
+    matched_gt_boxes_per_image: Tensor,
+    bbox_regression_per_image: Tensor,
+    cnf: Optional[Dict[str, float]] = None,
+) -> Tensor:
+    torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
+
+    if type == "l1":
+        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+        return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
+    elif type == "smooth_l1":
+        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+        beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
+        return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
+    else:
+        bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
+        eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
+        if type == "ciou":
+            return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+        if type == "diou":
+            return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+        # otherwise giou
+        return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)

+ 268 - 0
libs/vision_libs/models/detection/anchor_utils.py

@@ -0,0 +1,268 @@
+import math
+from typing import List, Optional
+
+import torch
+from torch import nn, Tensor
+
+from .image_list import ImageList
+
+
+class AnchorGenerator(nn.Module):
+    """
+    Module that generates anchors for a set of feature maps and
+    image sizes.
+
+    The module support computing anchors at multiple sizes and aspect ratios
+    per feature map. This module assumes aspect ratio = height / width for
+    each anchor.
+
+    sizes and aspect_ratios should have the same number of elements, and it should
+    correspond to the number of feature maps.
+
+    sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
+    and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
+    per spatial location for feature map i.
+
+    Args:
+        sizes (Tuple[Tuple[int]]):
+        aspect_ratios (Tuple[Tuple[float]]):
+    """
+
+    __annotations__ = {
+        "cell_anchors": List[torch.Tensor],
+    }
+
+    def __init__(
+        self,
+        sizes=((128, 256, 512),),
+        aspect_ratios=((0.5, 1.0, 2.0),),
+    ):
+        super().__init__()
+
+        if not isinstance(sizes[0], (list, tuple)):
+            # TODO change this
+            sizes = tuple((s,) for s in sizes)
+        if not isinstance(aspect_ratios[0], (list, tuple)):
+            aspect_ratios = (aspect_ratios,) * len(sizes)
+
+        self.sizes = sizes
+        self.aspect_ratios = aspect_ratios
+        self.cell_anchors = [
+            self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios)
+        ]
+
+    # TODO: https://github.com/pytorch/pytorch/issues/26792
+    # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
+    # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
+    # This method assumes aspect ratio = height / width for an anchor.
+    def generate_anchors(
+        self,
+        scales: List[int],
+        aspect_ratios: List[float],
+        dtype: torch.dtype = torch.float32,
+        device: torch.device = torch.device("cpu"),
+    ) -> Tensor:
+        scales = torch.as_tensor(scales, dtype=dtype, device=device)
+        aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
+        h_ratios = torch.sqrt(aspect_ratios)
+        w_ratios = 1 / h_ratios
+
+        ws = (w_ratios[:, None] * scales[None, :]).view(-1)
+        hs = (h_ratios[:, None] * scales[None, :]).view(-1)
+
+        base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
+        return base_anchors.round()
+
+    def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
+        self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
+
+    def num_anchors_per_location(self) -> List[int]:
+        return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
+
+    # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
+    # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
+    def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
+        anchors = []
+        cell_anchors = self.cell_anchors
+        torch._assert(cell_anchors is not None, "cell_anchors should not be None")
+        torch._assert(
+            len(grid_sizes) == len(strides) == len(cell_anchors),
+            "Anchors should be Tuple[Tuple[int]] because each feature "
+            "map could potentially have different sizes and aspect ratios. "
+            "There needs to be a match between the number of "
+            "feature maps passed and the number of sizes / aspect ratios specified.",
+        )
+
+        for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
+            grid_height, grid_width = size
+            stride_height, stride_width = stride
+            device = base_anchors.device
+
+            # For output anchor, compute [x_center, y_center, x_center, y_center]
+            shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
+            shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
+            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
+            shift_x = shift_x.reshape(-1)
+            shift_y = shift_y.reshape(-1)
+            shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
+
+            # For every (base anchor, output anchor) pair,
+            # offset each zero-centered base anchor by the center of the output anchor.
+            anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
+
+        return anchors
+
+    def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
+        grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
+        image_size = image_list.tensors.shape[-2:]
+        dtype, device = feature_maps[0].dtype, feature_maps[0].device
+        strides = [
+            [
+                torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]),
+                torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]),
+            ]
+            for g in grid_sizes
+        ]
+        self.set_cell_anchors(dtype, device)
+        anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
+        anchors: List[List[torch.Tensor]] = []
+        for _ in range(len(image_list.image_sizes)):
+            anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
+            anchors.append(anchors_in_image)
+        anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
+        return anchors
+
+
+class DefaultBoxGenerator(nn.Module):
+    """
+    This module generates the default boxes of SSD for a set of feature maps and image sizes.
+
+    Args:
+        aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map.
+        min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation
+            of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
+        max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}`  of the default boxes used in the estimation
+            of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
+        scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
+            the ``min_ratio`` and ``max_ratio`` parameters.
+        steps (List[int]], optional): It's a hyper-parameter that affects the tiling of default boxes. If not provided
+            it will be estimated from the data.
+        clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
+            is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
+    """
+
+    def __init__(
+        self,
+        aspect_ratios: List[List[int]],
+        min_ratio: float = 0.15,
+        max_ratio: float = 0.9,
+        scales: Optional[List[float]] = None,
+        steps: Optional[List[int]] = None,
+        clip: bool = True,
+    ):
+        super().__init__()
+        if steps is not None and len(aspect_ratios) != len(steps):
+            raise ValueError("aspect_ratios and steps should have the same length")
+        self.aspect_ratios = aspect_ratios
+        self.steps = steps
+        self.clip = clip
+        num_outputs = len(aspect_ratios)
+
+        # Estimation of default boxes scales
+        if scales is None:
+            if num_outputs > 1:
+                range_ratio = max_ratio - min_ratio
+                self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)]
+                self.scales.append(1.0)
+            else:
+                self.scales = [min_ratio, max_ratio]
+        else:
+            self.scales = scales
+
+        self._wh_pairs = self._generate_wh_pairs(num_outputs)
+
+    def _generate_wh_pairs(
+        self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu")
+    ) -> List[Tensor]:
+        _wh_pairs: List[Tensor] = []
+        for k in range(num_outputs):
+            # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
+            s_k = self.scales[k]
+            s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
+            wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]]
+
+            # Adding 2 pairs for each aspect ratio of the feature map k
+            for ar in self.aspect_ratios[k]:
+                sq_ar = math.sqrt(ar)
+                w = self.scales[k] * sq_ar
+                h = self.scales[k] / sq_ar
+                wh_pairs.extend([[w, h], [h, w]])
+
+            _wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
+        return _wh_pairs
+
+    def num_anchors_per_location(self) -> List[int]:
+        # Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
+        return [2 + 2 * len(r) for r in self.aspect_ratios]
+
+    # Default Boxes calculation based on page 6 of SSD paper
+    def _grid_default_boxes(
+        self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32
+    ) -> Tensor:
+        default_boxes = []
+        for k, f_k in enumerate(grid_sizes):
+            # Now add the default boxes for each width-height pair
+            if self.steps is not None:
+                x_f_k = image_size[1] / self.steps[k]
+                y_f_k = image_size[0] / self.steps[k]
+            else:
+                y_f_k, x_f_k = f_k
+
+            shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
+            shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
+            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
+            shift_x = shift_x.reshape(-1)
+            shift_y = shift_y.reshape(-1)
+
+            shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2)
+            # Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
+            _wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k]
+            wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1)
+
+            default_box = torch.cat((shifts, wh_pairs), dim=1)
+
+            default_boxes.append(default_box)
+
+        return torch.cat(default_boxes, dim=0)
+
+    def __repr__(self) -> str:
+        s = (
+            f"{self.__class__.__name__}("
+            f"aspect_ratios={self.aspect_ratios}"
+            f", clip={self.clip}"
+            f", scales={self.scales}"
+            f", steps={self.steps}"
+            ")"
+        )
+        return s
+
+    def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
+        grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
+        image_size = image_list.tensors.shape[-2:]
+        dtype, device = feature_maps[0].dtype, feature_maps[0].device
+        default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
+        default_boxes = default_boxes.to(device)
+
+        dboxes = []
+        x_y_size = torch.tensor([image_size[1], image_size[0]], device=default_boxes.device)
+        for _ in image_list.image_sizes:
+            dboxes_in_image = default_boxes
+            dboxes_in_image = torch.cat(
+                [
+                    (dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
+                    (dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
+                ],
+                -1,
+            )
+            dboxes.append(dboxes_in_image)
+        return dboxes

+ 244 - 0
libs/vision_libs/models/detection/backbone_utils.py

@@ -0,0 +1,244 @@
+import warnings
+from typing import Callable, Dict, List, Optional, Union
+
+from torch import nn, Tensor
+from torchvision.ops import misc as misc_nn_ops
+from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
+
+from .. import mobilenet, resnet
+from .._api import _get_enum_from_fn, WeightsEnum
+from .._utils import handle_legacy_interface, IntermediateLayerGetter
+
+
+class BackboneWithFPN(nn.Module):
+    """
+    Adds a FPN on top of a model.
+    Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
+    extract a submodel that returns the feature maps specified in return_layers.
+    The same limitations of IntermediateLayerGetter apply here.
+    Args:
+        backbone (nn.Module)
+        return_layers (Dict[name, new_name]): a dict containing the names
+            of the modules for which the activations will be returned as
+            the key of the dict, and the value of the dict is the name
+            of the returned activation (which the user can specify).
+        in_channels_list (List[int]): number of channels for each feature map
+            that is returned, in the order they are present in the OrderedDict
+        out_channels (int): number of channels in the FPN.
+        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+    Attributes:
+        out_channels (int): the number of channels in the FPN
+    """
+
+    def __init__(
+        self,
+        backbone: nn.Module,
+        return_layers: Dict[str, str],
+        in_channels_list: List[int],
+        out_channels: int,
+        extra_blocks: Optional[ExtraFPNBlock] = None,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+
+        if extra_blocks is None:
+            extra_blocks = LastLevelMaxPool()
+
+        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+        self.fpn = FeaturePyramidNetwork(
+            in_channels_list=in_channels_list,
+            out_channels=out_channels,
+            extra_blocks=extra_blocks,
+            norm_layer=norm_layer,
+        )
+        self.out_channels = out_channels
+
+    def forward(self, x: Tensor) -> Dict[str, Tensor]:
+        x = self.body(x)
+        x = self.fpn(x)
+        return x
+
+
+@handle_legacy_interface(
+    weights=(
+        "pretrained",
+        lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
+    ),
+)
+def resnet_fpn_backbone(
+    *,
+    backbone_name: str,
+    weights: Optional[WeightsEnum],
+    norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
+    trainable_layers: int = 3,
+    returned_layers: Optional[List[int]] = None,
+    extra_blocks: Optional[ExtraFPNBlock] = None,
+) -> BackboneWithFPN:
+    """
+    Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
+
+    Examples::
+
+        >>> import torch
+        >>> from torchvision.models import ResNet50_Weights
+        >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
+        >>> backbone = resnet_fpn_backbone(backbone_name='resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3)
+        >>> # get some dummy image
+        >>> x = torch.rand(1,3,64,64)
+        >>> # compute the output
+        >>> output = backbone(x)
+        >>> print([(k, v.shape) for k, v in output.items()])
+        >>> # returns
+        >>>   [('0', torch.Size([1, 256, 16, 16])),
+        >>>    ('1', torch.Size([1, 256, 8, 8])),
+        >>>    ('2', torch.Size([1, 256, 4, 4])),
+        >>>    ('3', torch.Size([1, 256, 2, 2])),
+        >>>    ('pool', torch.Size([1, 256, 1, 1]))]
+
+    Args:
+        backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50',
+             'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
+        weights (WeightsEnum, optional): The pretrained weights for the model
+        norm_layer (callable): it is recommended to use the default value. For details visit:
+            (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
+        trainable_layers (int): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
+        returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
+            By default, all layers are returned.
+        extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
+            be performed. It is expected to take the fpn features, the original
+            features and the names of the original features as input, and returns
+            a new list of feature maps and their corresponding names. By
+            default, a ``LastLevelMaxPool`` is used.
+    """
+    backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
+    return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
+
+
+def _resnet_fpn_extractor(
+    backbone: resnet.ResNet,
+    trainable_layers: int,
+    returned_layers: Optional[List[int]] = None,
+    extra_blocks: Optional[ExtraFPNBlock] = None,
+    norm_layer: Optional[Callable[..., nn.Module]] = None,
+) -> BackboneWithFPN:
+
+    # select layers that won't be frozen
+    if trainable_layers < 0 or trainable_layers > 5:
+        raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
+    layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
+    if trainable_layers == 5:
+        layers_to_train.append("bn1")
+    for name, parameter in backbone.named_parameters():
+        if all([not name.startswith(layer) for layer in layers_to_train]):
+            parameter.requires_grad_(False)
+
+    if extra_blocks is None:
+        extra_blocks = LastLevelMaxPool()
+
+    if returned_layers is None:
+        returned_layers = [1, 2, 3, 4]
+    if min(returned_layers) <= 0 or max(returned_layers) >= 5:
+        raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
+    return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
+
+    in_channels_stage2 = backbone.inplanes // 8
+    in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
+    out_channels = 256
+    return BackboneWithFPN(
+        backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
+    )
+
+
+def _validate_trainable_layers(
+    is_trained: bool,
+    trainable_backbone_layers: Optional[int],
+    max_value: int,
+    default_value: int,
+) -> int:
+    # don't freeze any layers if pretrained model or backbone is not used
+    if not is_trained:
+        if trainable_backbone_layers is not None:
+            warnings.warn(
+                "Changing trainable_backbone_layers has no effect if "
+                "neither pretrained nor pretrained_backbone have been set to True, "
+                f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
+            )
+        trainable_backbone_layers = max_value
+
+    # by default freeze first blocks
+    if trainable_backbone_layers is None:
+        trainable_backbone_layers = default_value
+    if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
+        raise ValueError(
+            f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
+        )
+    return trainable_backbone_layers
+
+
+@handle_legacy_interface(
+    weights=(
+        "pretrained",
+        lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
+    ),
+)
+def mobilenet_backbone(
+    *,
+    backbone_name: str,
+    weights: Optional[WeightsEnum],
+    fpn: bool,
+    norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
+    trainable_layers: int = 2,
+    returned_layers: Optional[List[int]] = None,
+    extra_blocks: Optional[ExtraFPNBlock] = None,
+) -> nn.Module:
+    backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
+    return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)
+
+
+def _mobilenet_extractor(
+    backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
+    fpn: bool,
+    trainable_layers: int,
+    returned_layers: Optional[List[int]] = None,
+    extra_blocks: Optional[ExtraFPNBlock] = None,
+    norm_layer: Optional[Callable[..., nn.Module]] = None,
+) -> nn.Module:
+    backbone = backbone.features
+    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
+    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
+    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
+    num_stages = len(stage_indices)
+
+    # find the index of the layer from which we won't freeze
+    if trainable_layers < 0 or trainable_layers > num_stages:
+        raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
+    freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
+
+    for b in backbone[:freeze_before]:
+        for parameter in b.parameters():
+            parameter.requires_grad_(False)
+
+    out_channels = 256
+    if fpn:
+        if extra_blocks is None:
+            extra_blocks = LastLevelMaxPool()
+
+        if returned_layers is None:
+            returned_layers = [num_stages - 2, num_stages - 1]
+        if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
+            raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
+        return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
+
+        in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
+        return BackboneWithFPN(
+            backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
+        )
+    else:
+        m = nn.Sequential(
+            backbone,
+            # depthwise linear combination of channels to reduce their size
+            nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
+        )
+        m.out_channels = out_channels  # type: ignore[assignment]
+        return m

+ 843 - 0
libs/vision_libs/models/detection/faster_rcnn.py

@@ -0,0 +1,843 @@
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from ...ops import misc as misc_nn_ops
+from ...transforms._presets import ObjectDetection
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
+from ..resnet import resnet50, ResNet50_Weights
+from ._utils import overwrite_eps
+from .anchor_utils import AnchorGenerator
+from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
+from .generalized_rcnn import GeneralizedRCNN
+from .roi_heads import RoIHeads
+from .rpn import RegionProposalNetwork, RPNHead
+from .transform import GeneralizedRCNNTransform
+
+
+__all__ = [
+    "FasterRCNN",
+    "FasterRCNN_ResNet50_FPN_Weights",
+    "FasterRCNN_ResNet50_FPN_V2_Weights",
+    "FasterRCNN_MobileNet_V3_Large_FPN_Weights",
+    "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
+    "fasterrcnn_resnet50_fpn",
+    "fasterrcnn_resnet50_fpn_v2",
+    "fasterrcnn_mobilenet_v3_large_fpn",
+    "fasterrcnn_mobilenet_v3_large_320_fpn",
+]
+
+
+def _default_anchorgen():
+    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
+    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+    return AnchorGenerator(anchor_sizes, aspect_ratios)
+
+
+class FasterRCNN(GeneralizedRCNN):
+    """
+    Implements Faster R-CNN.
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+    The model returns a Dict[Tensor] during training, containing the classification and regression
+    losses for both the RPN and the R-CNN.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows:
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each image
+        - scores (Tensor[N]): the scores or each prediction
+
+    Args:
+        backbone (nn.Module): the network used to compute the features for the model.
+            It should contain an out_channels attribute, which indicates the number of output
+            channels that each feature map has (and it should be the same for all feature maps).
+            The backbone should return a single Tensor or and OrderedDict[Tensor].
+        num_classes (int): number of output classes of the model (including the background).
+            If box_predictor is specified, num_classes should be None.
+        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+        image_mean (Tuple[float, float, float]): mean values used for input normalization.
+            They are generally the mean values of the dataset on which the backbone has been trained
+            on
+        image_std (Tuple[float, float, float]): std values used for input normalization.
+            They are generally the std values of the dataset on which the backbone has been trained on
+        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+            maps.
+        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+            considered as positive during training of the RPN.
+        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+            considered as negative during training of the RPN.
+        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+            for computing the loss
+        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+            of the RPN
+        rpn_score_thresh (float): during inference, only return proposals with a classification score
+            greater than rpn_score_thresh
+        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+            the locations indicated by the bounding boxes
+        box_head (nn.Module): module that takes the cropped feature maps as input
+        box_predictor (nn.Module): module that takes the output of box_head and returns the
+            classification logits and box regression deltas.
+        box_score_thresh (float): during inference, only return proposals with a classification score
+            greater than box_score_thresh
+        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+        box_detections_per_img (int): maximum number of detections per image, for all classes.
+        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+            considered as positive during training of the classification head
+        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+            considered as negative during training of the classification head
+        box_batch_size_per_image (int): number of proposals that are sampled during training of the
+            classification head
+        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+            of the classification head
+        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+            bounding boxes
+
+    Example::
+
+        >>> import torch
+        >>> import torchvision
+        >>> from torchvision.models.detection import FasterRCNN
+        >>> from torchvision.models.detection.rpn import AnchorGenerator
+        >>> # load a pre-trained model for classification and return
+        >>> # only the features
+        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+        >>> # FasterRCNN needs to know the number of
+        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+        >>> # so we need to add it here
+        >>> backbone.out_channels = 1280
+        >>>
+        >>> # let's make the RPN generate 5 x 3 anchors per spatial
+        >>> # location, with 5 different sizes and 3 different aspect
+        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+        >>> # map could potentially have different sizes and
+        >>> # aspect ratios
+        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
+        >>>
+        >>> # let's define what are the feature maps that we will
+        >>> # use to perform the region of interest cropping, as well as
+        >>> # the size of the crop after rescaling.
+        >>> # if your backbone returns a Tensor, featmap_names is expected to
+        >>> # be ['0']. More generally, the backbone should return an
+        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+        >>> # feature maps to use.
+        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                 output_size=7,
+        >>>                                                 sampling_ratio=2)
+        >>>
+        >>> # put the pieces together inside a FasterRCNN model
+        >>> model = FasterRCNN(backbone,
+        >>>                    num_classes=2,
+        >>>                    rpn_anchor_generator=anchor_generator,
+        >>>                    box_roi_pool=roi_pooler)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+    """
+
+    def __init__(
+        self,
+        backbone,
+        num_classes=None,
+        # transform parameters
+        min_size=512,
+        max_size=1333,
+        image_mean=None,
+        image_std=None,
+        # RPN parameters
+        rpn_anchor_generator=None,
+        rpn_head=None,
+        rpn_pre_nms_top_n_train=2000,
+        rpn_pre_nms_top_n_test=1000,
+        rpn_post_nms_top_n_train=2000,
+        rpn_post_nms_top_n_test=1000,
+        rpn_nms_thresh=0.7,
+        rpn_fg_iou_thresh=0.7,
+        rpn_bg_iou_thresh=0.3,
+        rpn_batch_size_per_image=256,
+        rpn_positive_fraction=0.5,
+        rpn_score_thresh=0.0,
+        # Box parameters
+        box_roi_pool=None,
+        box_head=None,
+        box_predictor=None,
+        box_score_thresh=0.05,
+        box_nms_thresh=0.5,
+        box_detections_per_img=100,
+        box_fg_iou_thresh=0.5,
+        box_bg_iou_thresh=0.5,
+        box_batch_size_per_image=512,
+        box_positive_fraction=0.25,
+        bbox_reg_weights=None,
+        **kwargs,
+    ):
+
+        if not hasattr(backbone, "out_channels"):
+            raise ValueError(
+                "backbone should contain an attribute out_channels "
+                "specifying the number of output channels (assumed to be the "
+                "same for all the levels)"
+            )
+
+        if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
+            raise TypeError(
+                f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
+            )
+        if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
+            raise TypeError(
+                f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
+            )
+
+        if num_classes is not None:
+            if box_predictor is not None:
+                raise ValueError("num_classes should be None when box_predictor is specified")
+        else:
+            if box_predictor is None:
+                raise ValueError("num_classes should not be None when box_predictor is not specified")
+
+        out_channels = backbone.out_channels
+
+        if rpn_anchor_generator is None:
+            rpn_anchor_generator = _default_anchorgen()
+        if rpn_head is None:
+            rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
+
+        rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
+        rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
+
+        rpn = RegionProposalNetwork(
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_pre_nms_top_n,
+            rpn_post_nms_top_n,
+            rpn_nms_thresh,
+            score_thresh=rpn_score_thresh,
+        )
+
+        if box_roi_pool is None:
+            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
+
+        if box_head is None:
+            resolution = box_roi_pool.output_size[0]
+            representation_size = 1024
+            box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
+
+        if box_predictor is None:
+            representation_size = 1024
+            box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
+        roi_heads = RoIHeads(
+            # Box
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+        )
+
+        if image_mean is None:
+            image_mean = [0.485, 0.456, 0.406]
+        if image_std is None:
+            image_std = [0.229, 0.224, 0.225]
+        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+
+        super().__init__(backbone, rpn, roi_heads, transform)
+
+
+class TwoMLPHead(nn.Module):
+    """
+    Standard heads for FPN-based models
+
+    Args:
+        in_channels (int): number of input channels
+        representation_size (int): size of the intermediate representation
+    """
+
+    def __init__(self, in_channels, representation_size):
+        super().__init__()
+
+        self.fc6 = nn.Linear(in_channels, representation_size)
+        self.fc7 = nn.Linear(representation_size, representation_size)
+
+    def forward(self, x):
+        x = x.flatten(start_dim=1)
+
+        x = F.relu(self.fc6(x))
+        x = F.relu(self.fc7(x))
+
+        return x
+
+
+class FastRCNNConvFCHead(nn.Sequential):
+    def __init__(
+        self,
+        input_size: Tuple[int, int, int],
+        conv_layers: List[int],
+        fc_layers: List[int],
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ):
+        """
+        Args:
+            input_size (Tuple[int, int, int]): the input size in CHW format.
+            conv_layers (list): feature dimensions of each Convolution layer
+            fc_layers (list): feature dimensions of each FCN layer
+            norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+        """
+        in_channels, in_height, in_width = input_size
+
+        blocks = []
+        previous_channels = in_channels
+        for current_channels in conv_layers:
+            blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
+            previous_channels = current_channels
+        blocks.append(nn.Flatten())
+        previous_channels = previous_channels * in_height * in_width
+        for current_channels in fc_layers:
+            blocks.append(nn.Linear(previous_channels, current_channels))
+            blocks.append(nn.ReLU(inplace=True))
+            previous_channels = current_channels
+
+        super().__init__(*blocks)
+        for layer in self.modules():
+            if isinstance(layer, nn.Conv2d):
+                nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
+                if layer.bias is not None:
+                    nn.init.zeros_(layer.bias)
+
+
+class FastRCNNPredictor(nn.Module):
+    """
+    Standard classification + bounding box regression layers
+    for Fast R-CNN.
+
+    Args:
+        in_channels (int): number of input channels
+        num_classes (int): number of output classes (including background)
+    """
+
+    def __init__(self, in_channels, num_classes):
+        super().__init__()
+        self.cls_score = nn.Linear(in_channels, num_classes)
+        self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
+
+    def forward(self, x):
+        if x.dim() == 4:
+            torch._assert(
+                list(x.shape[2:]) == [1, 1],
+                f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
+            )
+        x = x.flatten(start_dim=1)
+        scores = self.cls_score(x)
+        bbox_deltas = self.bbox_pred(x)
+
+        return scores, bbox_deltas
+
+
+_COMMON_META = {
+    "categories": _COCO_CATEGORIES,
+    "min_size": (1, 1),
+}
+
+
+class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 41755286,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 37.0,
+                }
+            },
+            "_ops": 134.38,
+            "_file_size": 159.743,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 43712278,
+            "recipe": "https://github.com/pytorch/vision/pull/5763",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 46.7,
+                }
+            },
+            "_ops": 280.371,
+            "_file_size": 167.104,
+            "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 19386354,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 32.8,
+                }
+            },
+            "_ops": 4.494,
+            "_file_size": 74.239,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 19386354,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 22.8,
+                }
+            },
+            "_ops": 0.719,
+            "_file_size": 74.239,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_resnet50_fpn(
+    *,
+    weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FasterRCNN:
+    """
+    Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
+    Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
+    paper.
+
+    .. betastatus:: detection module
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and a targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detections:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each detection
+        - scores (``Tensor[N]``): the scores of each detection
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> # For training
+        >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
+        >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
+        >>> labels = torch.randint(1, 91, (4, 11))
+        >>> images = list(image for image in images)
+        >>> targets = []
+        >>> for i in range(len(images)):
+        >>>     d = {}
+        >>>     d['boxes'] = boxes[i]
+        >>>     d['labels'] = labels[i]
+        >>>     targets.append(d)
+        >>> output = model(images, targets)
+        >>> # For inference
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_resnet50_fpn_v2(
+    *,
+    weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = None,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FasterRCNN:
+    """
+    Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
+    Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
+        :members:
+    """
+    weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+
+    backbone = resnet50(weights=weights_backbone, progress=progress)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
+    rpn_anchor_generator = _default_anchorgen()
+    rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
+    box_head = FastRCNNConvFCHead(
+        (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
+    )
+    model = FasterRCNN(
+        backbone,
+        num_classes=num_classes,
+        rpn_anchor_generator=rpn_anchor_generator,
+        rpn_head=rpn_head,
+        box_head=box_head,
+        **kwargs,
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+def _fasterrcnn_mobilenet_v3_large_fpn(
+    *,
+    weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
+    progress: bool,
+    num_classes: Optional[int],
+    weights_backbone: Optional[MobileNet_V3_Large_Weights],
+    trainable_backbone_layers: Optional[int],
+    **kwargs: Any,
+) -> FasterRCNN:
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
+    anchor_sizes = (
+        (
+            32,
+            64,
+            128,
+            256,
+            512,
+        ),
+    ) * 3
+    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+    model = FasterRCNN(
+        backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_mobilenet_v3_large_320_fpn(
+    *,
+    weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FasterRCNN:
+    """
+    Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Example::
+
+        >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
+        :members:
+    """
+    weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
+    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
+
+    defaults = {
+        "min_size": 320,
+        "max_size": 640,
+        "rpn_pre_nms_top_n_test": 150,
+        "rpn_post_nms_top_n_test": 150,
+        "rpn_score_thresh": 0.05,
+    }
+
+    kwargs = {**defaults, **kwargs}
+    return _fasterrcnn_mobilenet_v3_large_fpn(
+        weights=weights,
+        progress=progress,
+        num_classes=num_classes,
+        weights_backbone=weights_backbone,
+        trainable_backbone_layers=trainable_backbone_layers,
+        **kwargs,
+    )
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+)
+def fasterrcnn_mobilenet_v3_large_fpn(
+    *,
+    weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FasterRCNN:
+    """
+    Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Example::
+
+        >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
+        :members:
+    """
+    weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
+    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
+
+    defaults = {
+        "rpn_score_thresh": 0.05,
+    }
+
+    kwargs = {**defaults, **kwargs}
+    return _fasterrcnn_mobilenet_v3_large_fpn(
+        weights=weights,
+        progress=progress,
+        num_classes=num_classes,
+        weights_backbone=weights_backbone,
+        trainable_backbone_layers=trainable_backbone_layers,
+        **kwargs,
+    )

+ 771 - 0
libs/vision_libs/models/detection/fcos.py

@@ -0,0 +1,771 @@
+import math
+import warnings
+from collections import OrderedDict
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+
+from ...ops import boxes as box_ops, generalized_box_iou_loss, misc as misc_nn_ops, sigmoid_focal_loss
+from ...ops.feature_pyramid_network import LastLevelP6P7
+from ...transforms._presets import ObjectDetection
+from ...utils import _log_api_usage_once
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..resnet import resnet50, ResNet50_Weights
+from . import _utils as det_utils
+from .anchor_utils import AnchorGenerator
+from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from .transform import GeneralizedRCNNTransform
+
+
+__all__ = [
+    "FCOS",
+    "FCOS_ResNet50_FPN_Weights",
+    "fcos_resnet50_fpn",
+]
+
+
+class FCOSHead(nn.Module):
+    """
+    A regression and classification head for use in FCOS.
+
+    Args:
+        in_channels (int): number of channels of the input feature
+        num_anchors (int): number of anchors to be predicted
+        num_classes (int): number of classes to be predicted
+        num_convs (Optional[int]): number of conv layer of head. Default: 4.
+    """
+
+    __annotations__ = {
+        "box_coder": det_utils.BoxLinearCoder,
+    }
+
+    def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4) -> None:
+        super().__init__()
+        self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
+        self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs)
+        self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs)
+
+    def compute_loss(
+        self,
+        targets: List[Dict[str, Tensor]],
+        head_outputs: Dict[str, Tensor],
+        anchors: List[Tensor],
+        matched_idxs: List[Tensor],
+    ) -> Dict[str, Tensor]:
+
+        cls_logits = head_outputs["cls_logits"]  # [N, HWA, C]
+        bbox_regression = head_outputs["bbox_regression"]  # [N, HWA, 4]
+        bbox_ctrness = head_outputs["bbox_ctrness"]  # [N, HWA, 1]
+
+        all_gt_classes_targets = []
+        all_gt_boxes_targets = []
+        for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
+            if len(targets_per_image["labels"]) == 0:
+                gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),))
+                gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4))
+            else:
+                gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
+                gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
+            gt_classes_targets[matched_idxs_per_image < 0] = -1  # background
+            all_gt_classes_targets.append(gt_classes_targets)
+            all_gt_boxes_targets.append(gt_boxes_targets)
+
+        # List[Tensor] to Tensor conversion of  `all_gt_boxes_target`, `all_gt_classes_targets` and `anchors`
+        all_gt_boxes_targets, all_gt_classes_targets, anchors = (
+            torch.stack(all_gt_boxes_targets),
+            torch.stack(all_gt_classes_targets),
+            torch.stack(anchors),
+        )
+
+        # compute foregroud
+        foregroud_mask = all_gt_classes_targets >= 0
+        num_foreground = foregroud_mask.sum().item()
+
+        # classification loss
+        gt_classes_targets = torch.zeros_like(cls_logits)
+        gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0
+        loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")
+
+        # amp issue: pred_boxes need to convert float
+        pred_boxes = self.box_coder.decode(bbox_regression, anchors)
+
+        # regression loss: GIoU loss
+        loss_bbox_reg = generalized_box_iou_loss(
+            pred_boxes[foregroud_mask],
+            all_gt_boxes_targets[foregroud_mask],
+            reduction="sum",
+        )
+
+        # ctrness loss
+
+        bbox_reg_targets = self.box_coder.encode(anchors, all_gt_boxes_targets)
+
+        if len(bbox_reg_targets) == 0:
+            gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
+        else:
+            left_right = bbox_reg_targets[:, :, [0, 2]]
+            top_bottom = bbox_reg_targets[:, :, [1, 3]]
+            gt_ctrness_targets = torch.sqrt(
+                (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
+                * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
+            )
+        pred_centerness = bbox_ctrness.squeeze(dim=2)
+        loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
+            pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"
+        )
+
+        return {
+            "classification": loss_cls / max(1, num_foreground),
+            "bbox_regression": loss_bbox_reg / max(1, num_foreground),
+            "bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground),
+        }
+
+    def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
+        cls_logits = self.classification_head(x)
+        bbox_regression, bbox_ctrness = self.regression_head(x)
+        return {
+            "cls_logits": cls_logits,
+            "bbox_regression": bbox_regression,
+            "bbox_ctrness": bbox_ctrness,
+        }
+
+
+class FCOSClassificationHead(nn.Module):
+    """
+    A classification head for use in FCOS.
+
+    Args:
+        in_channels (int): number of channels of the input feature.
+        num_anchors (int): number of anchors to be predicted.
+        num_classes (int): number of classes to be predicted.
+        num_convs (Optional[int]): number of conv layer. Default: 4.
+        prior_probability (Optional[float]): probability of prior. Default: 0.01.
+        norm_layer: Module specifying the normalization layer to use.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        num_anchors: int,
+        num_classes: int,
+        num_convs: int = 4,
+        prior_probability: float = 0.01,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+
+        self.num_classes = num_classes
+        self.num_anchors = num_anchors
+
+        if norm_layer is None:
+            norm_layer = partial(nn.GroupNorm, 32)
+
+        conv = []
+        for _ in range(num_convs):
+            conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
+            conv.append(norm_layer(in_channels))
+            conv.append(nn.ReLU())
+        self.conv = nn.Sequential(*conv)
+
+        for layer in self.conv.children():
+            if isinstance(layer, nn.Conv2d):
+                torch.nn.init.normal_(layer.weight, std=0.01)
+                torch.nn.init.constant_(layer.bias, 0)
+
+        self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
+        torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
+        torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
+
+    def forward(self, x: List[Tensor]) -> Tensor:
+        all_cls_logits = []
+
+        for features in x:
+            cls_logits = self.conv(features)
+            cls_logits = self.cls_logits(cls_logits)
+
+            # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
+            N, _, H, W = cls_logits.shape
+            cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
+            cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
+            cls_logits = cls_logits.reshape(N, -1, self.num_classes)  # Size=(N, HWA, 4)
+
+            all_cls_logits.append(cls_logits)
+
+        return torch.cat(all_cls_logits, dim=1)
+
+
+class FCOSRegressionHead(nn.Module):
+    """
+    A regression head for use in FCOS, which combines regression branch and center-ness branch.
+    This can obtain better performance.
+
+    Reference: `FCOS: A simple and strong anchor-free object detector <https://arxiv.org/abs/2006.09214>`_.
+
+    Args:
+        in_channels (int): number of channels of the input feature
+        num_anchors (int): number of anchors to be predicted
+        num_convs (Optional[int]): number of conv layer. Default: 4.
+        norm_layer: Module specifying the normalization layer to use.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        num_anchors: int,
+        num_convs: int = 4,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ):
+        super().__init__()
+
+        if norm_layer is None:
+            norm_layer = partial(nn.GroupNorm, 32)
+
+        conv = []
+        for _ in range(num_convs):
+            conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
+            conv.append(norm_layer(in_channels))
+            conv.append(nn.ReLU())
+        self.conv = nn.Sequential(*conv)
+
+        self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
+        self.bbox_ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1)
+        for layer in [self.bbox_reg, self.bbox_ctrness]:
+            torch.nn.init.normal_(layer.weight, std=0.01)
+            torch.nn.init.zeros_(layer.bias)
+
+        for layer in self.conv.children():
+            if isinstance(layer, nn.Conv2d):
+                torch.nn.init.normal_(layer.weight, std=0.01)
+                torch.nn.init.zeros_(layer.bias)
+
+    def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]:
+        all_bbox_regression = []
+        all_bbox_ctrness = []
+
+        for features in x:
+            bbox_feature = self.conv(features)
+            bbox_regression = nn.functional.relu(self.bbox_reg(bbox_feature))
+            bbox_ctrness = self.bbox_ctrness(bbox_feature)
+
+            # permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
+            N, _, H, W = bbox_regression.shape
+            bbox_regression = bbox_regression.view(N, -1, 4, H, W)
+            bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
+            bbox_regression = bbox_regression.reshape(N, -1, 4)  # Size=(N, HWA, 4)
+            all_bbox_regression.append(bbox_regression)
+
+            # permute bbox ctrness output from (N, 1 * A, H, W) to (N, HWA, 1).
+            bbox_ctrness = bbox_ctrness.view(N, -1, 1, H, W)
+            bbox_ctrness = bbox_ctrness.permute(0, 3, 4, 1, 2)
+            bbox_ctrness = bbox_ctrness.reshape(N, -1, 1)
+            all_bbox_ctrness.append(bbox_ctrness)
+
+        return torch.cat(all_bbox_regression, dim=1), torch.cat(all_bbox_ctrness, dim=1)
+
+
+class FCOS(nn.Module):
+    """
+    Implements FCOS.
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+    The model returns a Dict[Tensor] during training, containing the classification, regression
+    and centerness losses.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows:
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each image
+        - scores (Tensor[N]): the scores for each prediction
+
+    Args:
+        backbone (nn.Module): the network used to compute the features for the model.
+            It should contain an out_channels attribute, which indicates the number of output
+            channels that each feature map has (and it should be the same for all feature maps).
+            The backbone should return a single Tensor or an OrderedDict[Tensor].
+        num_classes (int): number of output classes of the model (including the background).
+        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+        image_mean (Tuple[float, float, float]): mean values used for input normalization.
+            They are generally the mean values of the dataset on which the backbone has been trained
+            on
+        image_std (Tuple[float, float, float]): std values used for input normalization.
+            They are generally the std values of the dataset on which the backbone has been trained on
+        anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+            maps. For FCOS, only set one anchor for per position of each level, the width and height equal to
+            the stride of feature map, and set aspect ratio = 1.0, so the center of anchor is equivalent to the point
+            in FCOS paper.
+        head (nn.Module): Module run on top of the feature pyramid.
+            Defaults to a module containing a classification and regression module.
+        center_sampling_radius (int): radius of the "center" of a groundtruth box,
+            within which all anchor points are labeled positive.
+        score_thresh (float): Score threshold used for postprocessing the detections.
+        nms_thresh (float): NMS threshold used for postprocessing the detections.
+        detections_per_img (int): Number of best detections to keep after NMS.
+        topk_candidates (int): Number of best detections to keep before NMS.
+
+    Example:
+
+        >>> import torch
+        >>> import torchvision
+        >>> from torchvision.models.detection import FCOS
+        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
+        >>> # load a pre-trained model for classification and return
+        >>> # only the features
+        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+        >>> # FCOS needs to know the number of
+        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+        >>> # so we need to add it here
+        >>> backbone.out_channels = 1280
+        >>>
+        >>> # let's make the network generate 5 x 3 anchors per spatial
+        >>> # location, with 5 different sizes and 3 different aspect
+        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+        >>> # map could potentially have different sizes and
+        >>> # aspect ratios
+        >>> anchor_generator = AnchorGenerator(
+        >>>     sizes=((8,), (16,), (32,), (64,), (128,)),
+        >>>     aspect_ratios=((1.0,),)
+        >>> )
+        >>>
+        >>> # put the pieces together inside a FCOS model
+        >>> model = FCOS(
+        >>>     backbone,
+        >>>     num_classes=80,
+        >>>     anchor_generator=anchor_generator,
+        >>> )
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+    """
+
+    __annotations__ = {
+        "box_coder": det_utils.BoxLinearCoder,
+    }
+
+    def __init__(
+        self,
+        backbone: nn.Module,
+        num_classes: int,
+        # transform parameters
+        min_size: int = 800,
+        max_size: int = 1333,
+        image_mean: Optional[List[float]] = None,
+        image_std: Optional[List[float]] = None,
+        # Anchor parameters
+        anchor_generator: Optional[AnchorGenerator] = None,
+        head: Optional[nn.Module] = None,
+        center_sampling_radius: float = 1.5,
+        score_thresh: float = 0.2,
+        nms_thresh: float = 0.6,
+        detections_per_img: int = 100,
+        topk_candidates: int = 1000,
+        **kwargs,
+    ):
+        super().__init__()
+        _log_api_usage_once(self)
+
+        if not hasattr(backbone, "out_channels"):
+            raise ValueError(
+                "backbone should contain an attribute out_channels "
+                "specifying the number of output channels (assumed to be the "
+                "same for all the levels)"
+            )
+        self.backbone = backbone
+
+        if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
+            raise TypeError(
+                f"anchor_generator should be of type AnchorGenerator or None, instead  got {type(anchor_generator)}"
+            )
+
+        if anchor_generator is None:
+            anchor_sizes = ((8,), (16,), (32,), (64,), (128,))  # equal to strides of multi-level feature map
+            aspect_ratios = ((1.0,),) * len(anchor_sizes)  # set only one anchor
+            anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
+        self.anchor_generator = anchor_generator
+        if self.anchor_generator.num_anchors_per_location()[0] != 1:
+            raise ValueError(
+                f"anchor_generator.num_anchors_per_location()[0] should be 1 instead of {anchor_generator.num_anchors_per_location()[0]}"
+            )
+
+        if head is None:
+            head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
+        self.head = head
+
+        self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
+
+        if image_mean is None:
+            image_mean = [0.485, 0.456, 0.406]
+        if image_std is None:
+            image_std = [0.229, 0.224, 0.225]
+        self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+
+        self.center_sampling_radius = center_sampling_radius
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+        self.topk_candidates = topk_candidates
+
+        # used only on torchscript mode
+        self._has_warned = False
+
+    @torch.jit.unused
+    def eager_outputs(
+        self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
+    ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
+        if self.training:
+            return losses
+
+        return detections
+
+    def compute_loss(
+        self,
+        targets: List[Dict[str, Tensor]],
+        head_outputs: Dict[str, Tensor],
+        anchors: List[Tensor],
+        num_anchors_per_level: List[int],
+    ) -> Dict[str, Tensor]:
+        matched_idxs = []
+        for anchors_per_image, targets_per_image in zip(anchors, targets):
+            if targets_per_image["boxes"].numel() == 0:
+                matched_idxs.append(
+                    torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
+                )
+                continue
+
+            gt_boxes = targets_per_image["boxes"]
+            gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2  # Nx2
+            anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2  # N
+            anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0]
+            # center sampling: anchor point must be close enough to gt center.
+            pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max(
+                dim=2
+            ).values < self.center_sampling_radius * anchor_sizes[:, None]
+            # compute pairwise distance between N points and M boxes
+            x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1)  # (N, 1)
+            x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2)  # (1, M)
+            pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2)  # (N, M)
+
+            # anchor point must be inside gt
+            pairwise_match &= pairwise_dist.min(dim=2).values > 0
+
+            # each anchor is only responsible for certain scale range.
+            lower_bound = anchor_sizes * 4
+            lower_bound[: num_anchors_per_level[0]] = 0
+            upper_bound = anchor_sizes * 8
+            upper_bound[-num_anchors_per_level[-1] :] = float("inf")
+            pairwise_dist = pairwise_dist.max(dim=2).values
+            pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (pairwise_dist < upper_bound[:, None])
+
+            # match the GT box with minimum area, if there are multiple GT matches
+            gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1])  # N
+            pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
+            min_values, matched_idx = pairwise_match.max(dim=1)  # R, per-anchor match
+            matched_idx[min_values < 1e-5] = -1  # unmatched anchors are assigned -1
+
+            matched_idxs.append(matched_idx)
+
+        return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
+
+    def postprocess_detections(
+        self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]]
+    ) -> List[Dict[str, Tensor]]:
+        class_logits = head_outputs["cls_logits"]
+        box_regression = head_outputs["bbox_regression"]
+        box_ctrness = head_outputs["bbox_ctrness"]
+
+        num_images = len(image_shapes)
+
+        detections: List[Dict[str, Tensor]] = []
+
+        for index in range(num_images):
+            box_regression_per_image = [br[index] for br in box_regression]
+            logits_per_image = [cl[index] for cl in class_logits]
+            box_ctrness_per_image = [bc[index] for bc in box_ctrness]
+            anchors_per_image, image_shape = anchors[index], image_shapes[index]
+
+            image_boxes = []
+            image_scores = []
+            image_labels = []
+
+            for box_regression_per_level, logits_per_level, box_ctrness_per_level, anchors_per_level in zip(
+                box_regression_per_image, logits_per_image, box_ctrness_per_image, anchors_per_image
+            ):
+                num_classes = logits_per_level.shape[-1]
+
+                # remove low scoring boxes
+                scores_per_level = torch.sqrt(
+                    torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level)
+                ).flatten()
+                keep_idxs = scores_per_level > self.score_thresh
+                scores_per_level = scores_per_level[keep_idxs]
+                topk_idxs = torch.where(keep_idxs)[0]
+
+                # keep only topk scoring predictions
+                num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
+                scores_per_level, idxs = scores_per_level.topk(num_topk)
+                topk_idxs = topk_idxs[idxs]
+
+                anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
+                labels_per_level = topk_idxs % num_classes
+
+                boxes_per_level = self.box_coder.decode(
+                    box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
+                )
+                boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
+
+                image_boxes.append(boxes_per_level)
+                image_scores.append(scores_per_level)
+                image_labels.append(labels_per_level)
+
+            image_boxes = torch.cat(image_boxes, dim=0)
+            image_scores = torch.cat(image_scores, dim=0)
+            image_labels = torch.cat(image_labels, dim=0)
+
+            # non-maximum suppression
+            keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
+            keep = keep[: self.detections_per_img]
+
+            detections.append(
+                {
+                    "boxes": image_boxes[keep],
+                    "scores": image_scores[keep],
+                    "labels": image_labels[keep],
+                }
+            )
+
+        return detections
+
+    def forward(
+        self,
+        images: List[Tensor],
+        targets: Optional[List[Dict[str, Tensor]]] = None,
+    ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
+        """
+        Args:
+            images (list[Tensor]): images to be processed
+            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
+
+        Returns:
+            result (list[BoxList] or dict[Tensor]): the output from the model.
+                During training, it returns a dict[Tensor] which contains the losses.
+                During testing, it returns list[BoxList] contains additional fields
+                like `scores`, `labels` and `mask` (for Mask R-CNN models).
+        """
+        if self.training:
+
+            if targets is None:
+                torch._assert(False, "targets should not be none when in training mode")
+            else:
+                for target in targets:
+                    boxes = target["boxes"]
+                    torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
+                    torch._assert(
+                        len(boxes.shape) == 2 and boxes.shape[-1] == 4,
+                        f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
+                    )
+
+        original_image_sizes: List[Tuple[int, int]] = []
+        for img in images:
+            val = img.shape[-2:]
+            torch._assert(
+                len(val) == 2,
+                f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
+            )
+            original_image_sizes.append((val[0], val[1]))
+
+        # transform the input
+        images, targets = self.transform(images, targets)
+
+        # Check for degenerate boxes
+        if targets is not None:
+            for target_idx, target in enumerate(targets):
+                boxes = target["boxes"]
+                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
+                if degenerate_boxes.any():
+                    # print the first degenerate box
+                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
+                    degen_bb: List[float] = boxes[bb_idx].tolist()
+                    torch._assert(
+                        False,
+                        f"All bounding boxes should have positive height and width. Found invalid box {degen_bb} for target at index {target_idx}.",
+                    )
+
+        # get the features from the backbone
+        features = self.backbone(images.tensors)
+        if isinstance(features, torch.Tensor):
+            features = OrderedDict([("0", features)])
+
+        features = list(features.values())
+
+        # compute the fcos heads outputs using the features
+        head_outputs = self.head(features)
+
+        # create the set of anchors
+        anchors = self.anchor_generator(images, features)
+        # recover level sizes
+        num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
+
+        losses = {}
+        detections: List[Dict[str, Tensor]] = []
+        if self.training:
+            if targets is None:
+                torch._assert(False, "targets should not be none when in training mode")
+            else:
+                # compute the losses
+                losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level)
+        else:
+            # split outputs per level
+            split_head_outputs: Dict[str, List[Tensor]] = {}
+            for k in head_outputs:
+                split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
+            split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
+
+            # compute the detections
+            detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
+            detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
+
+        if torch.jit.is_scripting():
+            if not self._has_warned:
+                warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting")
+                self._has_warned = True
+            return losses, detections
+        return self.eager_outputs(losses, detections)
+
+
+class FCOS_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
+        transforms=ObjectDetection,
+        meta={
+            "num_params": 32269600,
+            "categories": _COCO_CATEGORIES,
+            "min_size": (1, 1),
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 39.2,
+                }
+            },
+            "_ops": 128.207,
+            "_file_size": 123.608,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def fcos_resnet50_fpn(
+    *,
+    weights: Optional[FCOS_ResNet50_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> FCOS:
+    """
+    Constructs a FCOS model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `FCOS: Fully Convolutional One-Stage Object Detection <https://arxiv.org/abs/1904.01355>`_.
+               `FCOS: A simple and strong anchor-free object detector <https://arxiv.org/abs/2006.09214>`_.
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detections:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each detection
+        - scores (``Tensor[N]``): the scores of each detection
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Example:
+
+        >>> model = torchvision.models.detection.fcos_resnet50_fpn(weights=FCOS_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FCOS_ResNet50_FPN_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
+            the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting
+            from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3. Default: None
+        **kwargs: parameters passed to the ``torchvision.models.detection.FCOS``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/fcos.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FCOS_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = FCOS_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(
+        backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
+    )
+    model = FCOS(backbone, num_classes, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model

+ 118 - 0
libs/vision_libs/models/detection/generalized_rcnn.py

@@ -0,0 +1,118 @@
+"""
+Implements the Generalized R-CNN framework
+"""
+
+import warnings
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn, Tensor
+
+from ...utils import _log_api_usage_once
+
+
+class GeneralizedRCNN(nn.Module):
+    """
+    Main class for Generalized R-CNN.
+
+    Args:
+        backbone (nn.Module):
+        rpn (nn.Module):
+        roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
+            detections / masks from it.
+        transform (nn.Module): performs the data transformation from the inputs to feed into
+            the model
+    """
+
+    def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
+        super().__init__()
+        _log_api_usage_once(self)
+        self.transform = transform
+        self.backbone = backbone
+        self.rpn = rpn
+        self.roi_heads = roi_heads
+        # used only on torchscript mode
+        self._has_warned = False
+
+    @torch.jit.unused
+    def eager_outputs(self, losses, detections):
+        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
+        if self.training:
+            return losses
+
+        return detections
+
+    def forward(self, images, targets=None):
+        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
+        """
+        Args:
+            images (list[Tensor]): images to be processed
+            targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
+
+        Returns:
+            result (list[BoxList] or dict[Tensor]): the output from the model.
+                During training, it returns a dict[Tensor] which contains the losses.
+                During testing, it returns list[BoxList] contains additional fields
+                like `scores`, `labels` and `mask` (for Mask R-CNN models).
+
+        """
+        if self.training:
+            if targets is None:
+                torch._assert(False, "targets should not be none when in training mode")
+            else:
+                for target in targets:
+                    boxes = target["boxes"]
+                    if isinstance(boxes, torch.Tensor):
+                        torch._assert(
+                            len(boxes.shape) == 2 and boxes.shape[-1] == 4,
+                            f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
+                        )
+                    else:
+                        torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
+
+        original_image_sizes: List[Tuple[int, int]] = []
+        for img in images:
+            val = img.shape[-2:]
+            torch._assert(
+                len(val) == 2,
+                f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
+            )
+            original_image_sizes.append((val[0], val[1]))
+
+        images, targets = self.transform(images, targets)
+
+        # Check for degenerate boxes
+        # TODO: Move this to a function
+        if targets is not None:
+            for target_idx, target in enumerate(targets):
+                boxes = target["boxes"]
+                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
+                if degenerate_boxes.any():
+                    # print the first degenerate box
+                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
+                    degen_bb: List[float] = boxes[bb_idx].tolist()
+                    torch._assert(
+                        False,
+                        "All bounding boxes should have positive height and width."
+                        f" Found invalid box {degen_bb} for target at index {target_idx}.",
+                    )
+
+        features = self.backbone(images.tensors)
+        if isinstance(features, torch.Tensor):
+            features = OrderedDict([("0", features)])
+        proposals, proposal_losses = self.rpn(images, features, targets)
+        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
+        detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)  # type: ignore[operator]
+
+        losses = {}
+        losses.update(detector_losses)
+        losses.update(proposal_losses)
+
+        if torch.jit.is_scripting():
+            if not self._has_warned:
+                warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
+                self._has_warned = True
+            return losses, detections
+        else:
+            return self.eager_outputs(losses, detections)

+ 25 - 0
libs/vision_libs/models/detection/image_list.py

@@ -0,0 +1,25 @@
+from typing import List, Tuple
+
+import torch
+from torch import Tensor
+
+
+class ImageList:
+    """
+    Structure that holds a list of images (of possibly
+    varying sizes) as a single tensor.
+    This works by padding the images to the same size,
+    and storing in a field the original sizes of each image
+
+    Args:
+        tensors (tensor): Tensor containing images.
+        image_sizes (list[tuple[int, int]]): List of Tuples each containing size of images.
+    """
+
+    def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]) -> None:
+        self.tensors = tensors
+        self.image_sizes = image_sizes
+
+    def to(self, device: torch.device) -> "ImageList":
+        cast_tensor = self.tensors.to(device)
+        return ImageList(cast_tensor, self.image_sizes)

+ 473 - 0
libs/vision_libs/models/detection/keypoint_rcnn.py

@@ -0,0 +1,473 @@
+from typing import Any, Optional
+
+import torch
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from ...ops import misc as misc_nn_ops
+from ...transforms._presets import ObjectDetection
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..resnet import resnet50, ResNet50_Weights
+from ._utils import overwrite_eps
+from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from .faster_rcnn import FasterRCNN
+
+
+__all__ = [
+    "KeypointRCNN",
+    "KeypointRCNN_ResNet50_FPN_Weights",
+    "keypointrcnn_resnet50_fpn",
+]
+
+
+class KeypointRCNN(FasterRCNN):
+    """
+    Implements Keypoint R-CNN.
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+        - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
+          format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
+
+    The model returns a Dict[Tensor] during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each image
+        - scores (Tensor[N]): the scores or each prediction
+        - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
+
+    Args:
+        backbone (nn.Module): the network used to compute the features for the model.
+            It should contain an out_channels attribute, which indicates the number of output
+            channels that each feature map has (and it should be the same for all feature maps).
+            The backbone should return a single Tensor or and OrderedDict[Tensor].
+        num_classes (int): number of output classes of the model (including the background).
+            If box_predictor is specified, num_classes should be None.
+        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+        image_mean (Tuple[float, float, float]): mean values used for input normalization.
+            They are generally the mean values of the dataset on which the backbone has been trained
+            on
+        image_std (Tuple[float, float, float]): std values used for input normalization.
+            They are generally the std values of the dataset on which the backbone has been trained on
+        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+            maps.
+        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+            considered as positive during training of the RPN.
+        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+            considered as negative during training of the RPN.
+        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+            for computing the loss
+        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+            of the RPN
+        rpn_score_thresh (float): during inference, only return proposals with a classification score
+            greater than rpn_score_thresh
+        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+            the locations indicated by the bounding boxes
+        box_head (nn.Module): module that takes the cropped feature maps as input
+        box_predictor (nn.Module): module that takes the output of box_head and returns the
+            classification logits and box regression deltas.
+        box_score_thresh (float): during inference, only return proposals with a classification score
+            greater than box_score_thresh
+        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+        box_detections_per_img (int): maximum number of detections per image, for all classes.
+        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+            considered as positive during training of the classification head
+        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+            considered as negative during training of the classification head
+        box_batch_size_per_image (int): number of proposals that are sampled during training of the
+            classification head
+        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+            of the classification head
+        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+            bounding boxes
+        keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+             the locations indicated by the bounding boxes, which will be used for the keypoint head.
+        keypoint_head (nn.Module): module that takes the cropped feature maps as input
+        keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
+            heatmap logits
+
+    Example::
+
+        >>> import torch
+        >>> import torchvision
+        >>> from torchvision.models.detection import KeypointRCNN
+        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
+        >>>
+        >>> # load a pre-trained model for classification and return
+        >>> # only the features
+        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+        >>> # KeypointRCNN needs to know the number of
+        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+        >>> # so we need to add it here
+        >>> backbone.out_channels = 1280
+        >>>
+        >>> # let's make the RPN generate 5 x 3 anchors per spatial
+        >>> # location, with 5 different sizes and 3 different aspect
+        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+        >>> # map could potentially have different sizes and
+        >>> # aspect ratios
+        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
+        >>>
+        >>> # let's define what are the feature maps that we will
+        >>> # use to perform the region of interest cropping, as well as
+        >>> # the size of the crop after rescaling.
+        >>> # if your backbone returns a Tensor, featmap_names is expected to
+        >>> # be ['0']. More generally, the backbone should return an
+        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+        >>> # feature maps to use.
+        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                 output_size=7,
+        >>>                                                 sampling_ratio=2)
+        >>>
+        >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                          output_size=14,
+        >>>                                                          sampling_ratio=2)
+        >>> # put the pieces together inside a KeypointRCNN model
+        >>> model = KeypointRCNN(backbone,
+        >>>                      num_classes=2,
+        >>>                      rpn_anchor_generator=anchor_generator,
+        >>>                      box_roi_pool=roi_pooler,
+        >>>                      keypoint_roi_pool=keypoint_roi_pooler)
+        >>> model.eval()
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+    """
+
+    def __init__(
+        self,
+        backbone,
+        num_classes=None,
+        # transform parameters
+        min_size=None,
+        max_size=1333,
+        image_mean=None,
+        image_std=None,
+        # RPN parameters
+        rpn_anchor_generator=None,
+        rpn_head=None,
+        rpn_pre_nms_top_n_train=2000,
+        rpn_pre_nms_top_n_test=1000,
+        rpn_post_nms_top_n_train=2000,
+        rpn_post_nms_top_n_test=1000,
+        rpn_nms_thresh=0.7,
+        rpn_fg_iou_thresh=0.7,
+        rpn_bg_iou_thresh=0.3,
+        rpn_batch_size_per_image=256,
+        rpn_positive_fraction=0.5,
+        rpn_score_thresh=0.0,
+        # Box parameters
+        box_roi_pool=None,
+        box_head=None,
+        box_predictor=None,
+        box_score_thresh=0.05,
+        box_nms_thresh=0.5,
+        box_detections_per_img=100,
+        box_fg_iou_thresh=0.5,
+        box_bg_iou_thresh=0.5,
+        box_batch_size_per_image=512,
+        box_positive_fraction=0.25,
+        bbox_reg_weights=None,
+        # keypoint parameters
+        keypoint_roi_pool=None,
+        keypoint_head=None,
+        keypoint_predictor=None,
+        num_keypoints=None,
+        **kwargs,
+    ):
+
+        if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
+            raise TypeError(
+                "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
+            )
+        if min_size is None:
+            min_size = (640, 672, 704, 736, 768, 800)
+
+        if num_keypoints is not None:
+            if keypoint_predictor is not None:
+                raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
+        else:
+            num_keypoints = 17
+
+        out_channels = backbone.out_channels
+
+        if keypoint_roi_pool is None:
+            keypoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+        if keypoint_head is None:
+            keypoint_layers = tuple(512 for _ in range(8))
+            keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers)
+
+        if keypoint_predictor is None:
+            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
+            keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints)
+
+        super().__init__(
+            backbone,
+            num_classes,
+            # transform parameters
+            min_size,
+            max_size,
+            image_mean,
+            image_std,
+            # RPN-specific parameters
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_pre_nms_top_n_train,
+            rpn_pre_nms_top_n_test,
+            rpn_post_nms_top_n_train,
+            rpn_post_nms_top_n_test,
+            rpn_nms_thresh,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_score_thresh,
+            # Box parameters
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            **kwargs,
+        )
+
+        self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
+        self.roi_heads.keypoint_head = keypoint_head
+        self.roi_heads.keypoint_predictor = keypoint_predictor
+
+
+class KeypointRCNNHeads(nn.Sequential):
+    def __init__(self, in_channels, layers):
+        d = []
+        next_feature = in_channels
+        for out_channels in layers:
+            d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
+            d.append(nn.ReLU(inplace=True))
+            next_feature = out_channels
+        super().__init__(*d)
+        for m in self.children():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                nn.init.constant_(m.bias, 0)
+
+
+class KeypointRCNNPredictor(nn.Module):
+    def __init__(self, in_channels, num_keypoints):
+        super().__init__()
+        input_features = in_channels
+        deconv_kernel = 4
+        self.kps_score_lowres = nn.ConvTranspose2d(
+            input_features,
+            num_keypoints,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+        self.out_channels = num_keypoints
+
+    def forward(self, x):
+        x = self.kps_score_lowres(x)
+        return torch.nn.functional.interpolate(
+            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        )
+
+
+_COMMON_META = {
+    "categories": _COCO_PERSON_CATEGORIES,
+    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
+    "min_size": (1, 1),
+}
+
+
+class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_LEGACY = Weights(
+        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 59137258,
+            "recipe": "https://github.com/pytorch/vision/issues/1606",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 50.6,
+                    "kp_map": 61.1,
+                }
+            },
+            "_ops": 133.924,
+            "_file_size": 226.054,
+            "_docs": """
+                These weights were produced by following a similar training recipe as on the paper but use a checkpoint
+                from an early epoch.
+            """,
+        },
+    )
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 59137258,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 54.6,
+                    "kp_map": 65.0,
+                }
+            },
+            "_ops": 137.42,
+            "_file_size": 226.054,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=(
+        "pretrained",
+        lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY
+        if kwargs["pretrained"] == "legacy"
+        else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1,
+    ),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def keypointrcnn_resnet50_fpn(
+    *,
+    weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    num_keypoints: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> KeypointRCNN:
+    """
+    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
+          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detected instances:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each instance
+        - scores (``Tensor[N]``): the scores or each instance
+        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr
+        num_classes (int, optional): number of output classes of the model (including the background)
+        num_keypoints (int, optional): number of keypoints
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 3.
+
+    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
+    else:
+        if num_classes is None:
+            num_classes = 2
+        if num_keypoints is None:
+            num_keypoints = 17
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model

+ 587 - 0
libs/vision_libs/models/detection/mask_rcnn.py

@@ -0,0 +1,587 @@
+from collections import OrderedDict
+from typing import Any, Callable, Optional
+
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from ...ops import misc as misc_nn_ops
+from ...transforms._presets import ObjectDetection
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..resnet import resnet50, ResNet50_Weights
+from ._utils import overwrite_eps
+from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPNHead
+
+
+__all__ = [
+    "MaskRCNN",
+    "MaskRCNN_ResNet50_FPN_Weights",
+    "MaskRCNN_ResNet50_FPN_V2_Weights",
+    "maskrcnn_resnet50_fpn",
+    "maskrcnn_resnet50_fpn_v2",
+]
+
+
+class MaskRCNN(FasterRCNN):
+    """
+    Implements Mask R-CNN.
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+        - masks (UInt8Tensor[N, H, W]): the segmentation binary masks for each instance
+
+    The model returns a Dict[Tensor] during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the mask loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows:
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each image
+        - scores (Tensor[N]): the scores or each prediction
+        - masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range. In order to
+          obtain the final segmentation masks, the soft masks can be thresholded, generally
+          with a value of 0.5 (mask >= 0.5)
+
+    Args:
+        backbone (nn.Module): the network used to compute the features for the model.
+            It should contain an out_channels attribute, which indicates the number of output
+            channels that each feature map has (and it should be the same for all feature maps).
+            The backbone should return a single Tensor or and OrderedDict[Tensor].
+        num_classes (int): number of output classes of the model (including the background).
+            If box_predictor is specified, num_classes should be None.
+        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+        image_mean (Tuple[float, float, float]): mean values used for input normalization.
+            They are generally the mean values of the dataset on which the backbone has been trained
+            on
+        image_std (Tuple[float, float, float]): std values used for input normalization.
+            They are generally the std values of the dataset on which the backbone has been trained on
+        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+            maps.
+        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+            considered as positive during training of the RPN.
+        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+            considered as negative during training of the RPN.
+        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+            for computing the loss
+        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+            of the RPN
+        rpn_score_thresh (float): during inference, only return proposals with a classification score
+            greater than rpn_score_thresh
+        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+            the locations indicated by the bounding boxes
+        box_head (nn.Module): module that takes the cropped feature maps as input
+        box_predictor (nn.Module): module that takes the output of box_head and returns the
+            classification logits and box regression deltas.
+        box_score_thresh (float): during inference, only return proposals with a classification score
+            greater than box_score_thresh
+        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+        box_detections_per_img (int): maximum number of detections per image, for all classes.
+        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+            considered as positive during training of the classification head
+        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+            considered as negative during training of the classification head
+        box_batch_size_per_image (int): number of proposals that are sampled during training of the
+            classification head
+        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+            of the classification head
+        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+            bounding boxes
+        mask_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+             the locations indicated by the bounding boxes, which will be used for the mask head.
+        mask_head (nn.Module): module that takes the cropped feature maps as input
+        mask_predictor (nn.Module): module that takes the output of the mask_head and returns the
+            segmentation mask logits
+
+    Example::
+
+        >>> import torch
+        >>> import torchvision
+        >>> from torchvision.models.detection import MaskRCNN
+        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
+        >>>
+        >>> # load a pre-trained model for classification and return
+        >>> # only the features
+        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+        >>> # MaskRCNN needs to know the number of
+        >>> # output channels in a backbone. For mobilenet_v2, it's 1280
+        >>> # so we need to add it here,
+        >>> backbone.out_channels = 1280
+        >>>
+        >>> # let's make the RPN generate 5 x 3 anchors per spatial
+        >>> # location, with 5 different sizes and 3 different aspect
+        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+        >>> # map could potentially have different sizes and
+        >>> # aspect ratios
+        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
+        >>>
+        >>> # let's define what are the feature maps that we will
+        >>> # use to perform the region of interest cropping, as well as
+        >>> # the size of the crop after rescaling.
+        >>> # if your backbone returns a Tensor, featmap_names is expected to
+        >>> # be ['0']. More generally, the backbone should return an
+        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+        >>> # feature maps to use.
+        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                 output_size=7,
+        >>>                                                 sampling_ratio=2)
+        >>>
+        >>> mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                      output_size=14,
+        >>>                                                      sampling_ratio=2)
+        >>> # put the pieces together inside a MaskRCNN model
+        >>> model = MaskRCNN(backbone,
+        >>>                  num_classes=2,
+        >>>                  rpn_anchor_generator=anchor_generator,
+        >>>                  box_roi_pool=roi_pooler,
+        >>>                  mask_roi_pool=mask_roi_pooler)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+    """
+
+    def __init__(
+        self,
+        backbone,
+        num_classes=None,
+        # transform parameters
+        min_size=800,
+        max_size=1333,
+        image_mean=None,
+        image_std=None,
+        # RPN parameters
+        rpn_anchor_generator=None,
+        rpn_head=None,
+        rpn_pre_nms_top_n_train=2000,
+        rpn_pre_nms_top_n_test=1000,
+        rpn_post_nms_top_n_train=2000,
+        rpn_post_nms_top_n_test=1000,
+        rpn_nms_thresh=0.7,
+        rpn_fg_iou_thresh=0.7,
+        rpn_bg_iou_thresh=0.3,
+        rpn_batch_size_per_image=256,
+        rpn_positive_fraction=0.5,
+        rpn_score_thresh=0.0,
+        # Box parameters
+        box_roi_pool=None,
+        box_head=None,
+        box_predictor=None,
+        box_score_thresh=0.05,
+        box_nms_thresh=0.5,
+        box_detections_per_img=100,
+        box_fg_iou_thresh=0.5,
+        box_bg_iou_thresh=0.5,
+        box_batch_size_per_image=512,
+        box_positive_fraction=0.25,
+        bbox_reg_weights=None,
+        # Mask parameters
+        mask_roi_pool=None,
+        mask_head=None,
+        mask_predictor=None,
+        **kwargs,
+    ):
+
+        if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))):
+            raise TypeError(
+                f"mask_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(mask_roi_pool)}"
+            )
+
+        if num_classes is not None:
+            if mask_predictor is not None:
+                raise ValueError("num_classes should be None when mask_predictor is specified")
+
+        out_channels = backbone.out_channels
+
+        if mask_roi_pool is None:
+            mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+        if mask_head is None:
+            mask_layers = (256, 256, 256, 256)
+            mask_dilation = 1
+            mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation)
+
+        if mask_predictor is None:
+            mask_predictor_in_channels = 256  # == mask_layers[-1]
+            mask_dim_reduced = 256
+            mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes)
+
+        super().__init__(
+            backbone,
+            num_classes,
+            # transform parameters
+            min_size,
+            max_size,
+            image_mean,
+            image_std,
+            # RPN-specific parameters
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_pre_nms_top_n_train,
+            rpn_pre_nms_top_n_test,
+            rpn_post_nms_top_n_train,
+            rpn_post_nms_top_n_test,
+            rpn_nms_thresh,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_score_thresh,
+            # Box parameters
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            **kwargs,
+        )
+
+        self.roi_heads.mask_roi_pool = mask_roi_pool
+        self.roi_heads.mask_head = mask_head
+        self.roi_heads.mask_predictor = mask_predictor
+
+
+class MaskRCNNHeads(nn.Sequential):
+    _version = 2
+
+    def __init__(self, in_channels, layers, dilation, norm_layer: Optional[Callable[..., nn.Module]] = None):
+        """
+        Args:
+            in_channels (int): number of input channels
+            layers (list): feature dimensions of each FCN layer
+            dilation (int): dilation rate of kernel
+            norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+        """
+        blocks = []
+        next_feature = in_channels
+        for layer_features in layers:
+            blocks.append(
+                misc_nn_ops.Conv2dNormActivation(
+                    next_feature,
+                    layer_features,
+                    kernel_size=3,
+                    stride=1,
+                    padding=dilation,
+                    dilation=dilation,
+                    norm_layer=norm_layer,
+                )
+            )
+            next_feature = layer_features
+
+        super().__init__(*blocks)
+        for layer in self.modules():
+            if isinstance(layer, nn.Conv2d):
+                nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
+                if layer.bias is not None:
+                    nn.init.zeros_(layer.bias)
+
+    def _load_from_state_dict(
+        self,
+        state_dict,
+        prefix,
+        local_metadata,
+        strict,
+        missing_keys,
+        unexpected_keys,
+        error_msgs,
+    ):
+        version = local_metadata.get("version", None)
+
+        if version is None or version < 2:
+            num_blocks = len(self)
+            for i in range(num_blocks):
+                for type in ["weight", "bias"]:
+                    old_key = f"{prefix}mask_fcn{i+1}.{type}"
+                    new_key = f"{prefix}{i}.0.{type}"
+                    if old_key in state_dict:
+                        state_dict[new_key] = state_dict.pop(old_key)
+
+        super()._load_from_state_dict(
+            state_dict,
+            prefix,
+            local_metadata,
+            strict,
+            missing_keys,
+            unexpected_keys,
+            error_msgs,
+        )
+
+
+class MaskRCNNPredictor(nn.Sequential):
+    def __init__(self, in_channels, dim_reduced, num_classes):
+        super().__init__(
+            OrderedDict(
+                [
+                    ("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)),
+                    ("relu", nn.ReLU(inplace=True)),
+                    ("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)),
+                ]
+            )
+        )
+
+        for name, param in self.named_parameters():
+            if "weight" in name:
+                nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
+            # elif "bias" in name:
+            #     nn.init.constant_(param, 0)
+
+
+_COMMON_META = {
+    "categories": _COCO_CATEGORIES,
+    "min_size": (1, 1),
+}
+
+
+class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 44401393,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 37.9,
+                    "mask_map": 34.6,
+                }
+            },
+            "_ops": 134.38,
+            "_file_size": 169.84,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class MaskRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 46359409,
+            "recipe": "https://github.com/pytorch/vision/pull/5773",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 47.4,
+                    "mask_map": 41.8,
+                }
+            },
+            "_ops": 333.577,
+            "_file_size": 177.219,
+            "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def maskrcnn_resnet50_fpn(
+    *,
+    weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> MaskRCNN:
+    """Mask R-CNN model with a ResNet-50-FPN backbone from the `Mask R-CNN
+    <https://arxiv.org/abs/1703.06870>`_ paper.
+
+    .. betastatus:: detection module
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+        - masks (``UInt8Tensor[N, H, W]``): the segmentation binary masks for each instance
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the mask loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detected instances:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each instance
+        - scores (``Tensor[N]``): the scores or each instance
+        - masks (``UInt8Tensor[N, 1, H, W]``): the predicted masks for each instance, in ``0-1`` range. In order to
+          obtain the final segmentation masks, the soft masks can be thresholded, generally
+          with a value of 0.5 (``mask >= 0.5``)
+
+    For more details on the output and on how to plot the masks, you may refer to :ref:`instance_seg_output`.
+
+    Mask R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/mask_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def maskrcnn_resnet50_fpn_v2(
+    *,
+    weights: Optional[MaskRCNN_ResNet50_FPN_V2_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = None,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> MaskRCNN:
+    """Improved Mask R-CNN model with a ResNet-50-FPN backbone from the `Benchmarking Detection Transfer
+    Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`_ paper.
+
+    .. betastatus:: detection module
+
+    :func:`~torchvision.models.detection.maskrcnn_resnet50_fpn` for more details.
+
+    Args:
+        weights (:class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.mask_rcnn.MaskRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/mask_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.MaskRCNN_ResNet50_FPN_V2_Weights
+        :members:
+    """
+    weights = MaskRCNN_ResNet50_FPN_V2_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+
+    backbone = resnet50(weights=weights_backbone, progress=progress)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
+    rpn_anchor_generator = _default_anchorgen()
+    rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
+    box_head = FastRCNNConvFCHead(
+        (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
+    )
+    mask_head = MaskRCNNHeads(backbone.out_channels, [256, 256, 256, 256], 1, norm_layer=nn.BatchNorm2d)
+    model = MaskRCNN(
+        backbone,
+        num_classes=num_classes,
+        rpn_anchor_generator=rpn_anchor_generator,
+        rpn_head=rpn_head,
+        box_head=box_head,
+        mask_head=mask_head,
+        **kwargs,
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model

+ 899 - 0
libs/vision_libs/models/detection/retinanet.py

@@ -0,0 +1,899 @@
+import math
+import warnings
+from collections import OrderedDict
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+
+from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss
+from ...ops.feature_pyramid_network import LastLevelP6P7
+from ...transforms._presets import ObjectDetection
+from ...utils import _log_api_usage_once
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..resnet import resnet50, ResNet50_Weights
+from . import _utils as det_utils
+from ._utils import _box_loss, overwrite_eps
+from .anchor_utils import AnchorGenerator
+from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from .transform import GeneralizedRCNNTransform
+
+
+__all__ = [
+    "RetinaNet",
+    "RetinaNet_ResNet50_FPN_Weights",
+    "RetinaNet_ResNet50_FPN_V2_Weights",
+    "retinanet_resnet50_fpn",
+    "retinanet_resnet50_fpn_v2",
+]
+
+
+def _sum(x: List[Tensor]) -> Tensor:
+    res = x[0]
+    for i in x[1:]:
+        res = res + i
+    return res
+
+
+def _v1_to_v2_weights(state_dict, prefix):
+    for i in range(4):
+        for type in ["weight", "bias"]:
+            old_key = f"{prefix}conv.{2*i}.{type}"
+            new_key = f"{prefix}conv.{i}.0.{type}"
+            if old_key in state_dict:
+                state_dict[new_key] = state_dict.pop(old_key)
+
+
+def _default_anchorgen():
+    anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
+    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+    anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
+    return anchor_generator
+
+
+class RetinaNetHead(nn.Module):
+    """
+    A regression and classification head for use in RetinaNet.
+
+    Args:
+        in_channels (int): number of channels of the input feature
+        num_anchors (int): number of anchors to be predicted
+        num_classes (int): number of classes to be predicted
+        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+    """
+
+    def __init__(self, in_channels, num_anchors, num_classes, norm_layer: Optional[Callable[..., nn.Module]] = None):
+        super().__init__()
+        self.classification_head = RetinaNetClassificationHead(
+            in_channels, num_anchors, num_classes, norm_layer=norm_layer
+        )
+        self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors, norm_layer=norm_layer)
+
+    def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
+        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
+        return {
+            "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
+            "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
+        }
+
+    def forward(self, x):
+        # type: (List[Tensor]) -> Dict[str, Tensor]
+        return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)}
+
+
+class RetinaNetClassificationHead(nn.Module):
+    """
+    A classification head for use in RetinaNet.
+
+    Args:
+        in_channels (int): number of channels of the input feature
+        num_anchors (int): number of anchors to be predicted
+        num_classes (int): number of classes to be predicted
+        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+    """
+
+    _version = 2
+
+    def __init__(
+        self,
+        in_channels,
+        num_anchors,
+        num_classes,
+        prior_probability=0.01,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ):
+        super().__init__()
+
+        conv = []
+        for _ in range(4):
+            conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
+        self.conv = nn.Sequential(*conv)
+
+        for layer in self.conv.modules():
+            if isinstance(layer, nn.Conv2d):
+                torch.nn.init.normal_(layer.weight, std=0.01)
+                if layer.bias is not None:
+                    torch.nn.init.constant_(layer.bias, 0)
+
+        self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
+        torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
+        torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))
+
+        self.num_classes = num_classes
+        self.num_anchors = num_anchors
+
+        # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
+        # TorchScript doesn't support class attributes.
+        # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
+        self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS
+
+    def _load_from_state_dict(
+        self,
+        state_dict,
+        prefix,
+        local_metadata,
+        strict,
+        missing_keys,
+        unexpected_keys,
+        error_msgs,
+    ):
+        version = local_metadata.get("version", None)
+
+        if version is None or version < 2:
+            _v1_to_v2_weights(state_dict, prefix)
+
+        super()._load_from_state_dict(
+            state_dict,
+            prefix,
+            local_metadata,
+            strict,
+            missing_keys,
+            unexpected_keys,
+            error_msgs,
+        )
+
+    def compute_loss(self, targets, head_outputs, matched_idxs):
+        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
+        losses = []
+
+        cls_logits = head_outputs["cls_logits"]
+
+        for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
+            # determine only the foreground
+            foreground_idxs_per_image = matched_idxs_per_image >= 0
+            num_foreground = foreground_idxs_per_image.sum()
+
+            # create the target classification
+            gt_classes_target = torch.zeros_like(cls_logits_per_image)
+            gt_classes_target[
+                foreground_idxs_per_image,
+                targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
+            ] = 1.0
+
+            # find indices for which anchors should be ignored
+            valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
+
+            # compute the classification loss
+            losses.append(
+                sigmoid_focal_loss(
+                    cls_logits_per_image[valid_idxs_per_image],
+                    gt_classes_target[valid_idxs_per_image],
+                    reduction="sum",
+                )
+                / max(1, num_foreground)
+            )
+
+        return _sum(losses) / len(targets)
+
+    def forward(self, x):
+        # type: (List[Tensor]) -> Tensor
+        all_cls_logits = []
+
+        for features in x:
+            cls_logits = self.conv(features)
+            cls_logits = self.cls_logits(cls_logits)
+
+            # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
+            N, _, H, W = cls_logits.shape
+            cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
+            cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
+            cls_logits = cls_logits.reshape(N, -1, self.num_classes)  # Size=(N, HWA, 4)
+
+            all_cls_logits.append(cls_logits)
+
+        return torch.cat(all_cls_logits, dim=1)
+
+
+class RetinaNetRegressionHead(nn.Module):
+    """
+    A regression head for use in RetinaNet.
+
+    Args:
+        in_channels (int): number of channels of the input feature
+        num_anchors (int): number of anchors to be predicted
+        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+    """
+
+    _version = 2
+
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+    }
+
+    def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None):
+        super().__init__()
+
+        conv = []
+        for _ in range(4):
+            conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
+        self.conv = nn.Sequential(*conv)
+
+        self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
+        torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
+        torch.nn.init.zeros_(self.bbox_reg.bias)
+
+        for layer in self.conv.modules():
+            if isinstance(layer, nn.Conv2d):
+                torch.nn.init.normal_(layer.weight, std=0.01)
+                if layer.bias is not None:
+                    torch.nn.init.zeros_(layer.bias)
+
+        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
+        self._loss_type = "l1"
+
+    def _load_from_state_dict(
+        self,
+        state_dict,
+        prefix,
+        local_metadata,
+        strict,
+        missing_keys,
+        unexpected_keys,
+        error_msgs,
+    ):
+        version = local_metadata.get("version", None)
+
+        if version is None or version < 2:
+            _v1_to_v2_weights(state_dict, prefix)
+
+        super()._load_from_state_dict(
+            state_dict,
+            prefix,
+            local_metadata,
+            strict,
+            missing_keys,
+            unexpected_keys,
+            error_msgs,
+        )
+
+    def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
+        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
+        losses = []
+
+        bbox_regression = head_outputs["bbox_regression"]
+
+        for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(
+            targets, bbox_regression, anchors, matched_idxs
+        ):
+            # determine only the foreground indices, ignore the rest
+            foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
+            num_foreground = foreground_idxs_per_image.numel()
+
+            # select only the foreground boxes
+            matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
+            bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
+            anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
+
+            # compute the loss
+            losses.append(
+                _box_loss(
+                    self._loss_type,
+                    self.box_coder,
+                    anchors_per_image,
+                    matched_gt_boxes_per_image,
+                    bbox_regression_per_image,
+                )
+                / max(1, num_foreground)
+            )
+
+        return _sum(losses) / max(1, len(targets))
+
+    def forward(self, x):
+        # type: (List[Tensor]) -> Tensor
+        all_bbox_regression = []
+
+        for features in x:
+            bbox_regression = self.conv(features)
+            bbox_regression = self.bbox_reg(bbox_regression)
+
+            # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
+            N, _, H, W = bbox_regression.shape
+            bbox_regression = bbox_regression.view(N, -1, 4, H, W)
+            bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
+            bbox_regression = bbox_regression.reshape(N, -1, 4)  # Size=(N, HWA, 4)
+
+            all_bbox_regression.append(bbox_regression)
+
+        return torch.cat(all_bbox_regression, dim=1)
+
+
+class RetinaNet(nn.Module):
+    """
+    Implements RetinaNet.
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+    The model returns a Dict[Tensor] during training, containing the classification and regression
+    losses.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows:
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each image
+        - scores (Tensor[N]): the scores for each prediction
+
+    Args:
+        backbone (nn.Module): the network used to compute the features for the model.
+            It should contain an out_channels attribute, which indicates the number of output
+            channels that each feature map has (and it should be the same for all feature maps).
+            The backbone should return a single Tensor or an OrderedDict[Tensor].
+        num_classes (int): number of output classes of the model (including the background).
+        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+        image_mean (Tuple[float, float, float]): mean values used for input normalization.
+            They are generally the mean values of the dataset on which the backbone has been trained
+            on
+        image_std (Tuple[float, float, float]): std values used for input normalization.
+            They are generally the std values of the dataset on which the backbone has been trained on
+        anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+            maps.
+        head (nn.Module): Module run on top of the feature pyramid.
+            Defaults to a module containing a classification and regression module.
+        score_thresh (float): Score threshold used for postprocessing the detections.
+        nms_thresh (float): NMS threshold used for postprocessing the detections.
+        detections_per_img (int): Number of best detections to keep after NMS.
+        fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+            considered as positive during training.
+        bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+            considered as negative during training.
+        topk_candidates (int): Number of best detections to keep before NMS.
+
+    Example:
+
+        >>> import torch
+        >>> import torchvision
+        >>> from torchvision.models.detection import RetinaNet
+        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
+        >>> # load a pre-trained model for classification and return
+        >>> # only the features
+        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+        >>> # RetinaNet needs to know the number of
+        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+        >>> # so we need to add it here
+        >>> backbone.out_channels = 1280
+        >>>
+        >>> # let's make the network generate 5 x 3 anchors per spatial
+        >>> # location, with 5 different sizes and 3 different aspect
+        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+        >>> # map could potentially have different sizes and
+        >>> # aspect ratios
+        >>> anchor_generator = AnchorGenerator(
+        >>>     sizes=((32, 64, 128, 256, 512),),
+        >>>     aspect_ratios=((0.5, 1.0, 2.0),)
+        >>> )
+        >>>
+        >>> # put the pieces together inside a RetinaNet model
+        >>> model = RetinaNet(backbone,
+        >>>                   num_classes=2,
+        >>>                   anchor_generator=anchor_generator)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+    """
+
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+    }
+
+    def __init__(
+        self,
+        backbone,
+        num_classes,
+        # transform parameters
+        min_size=800,
+        max_size=1333,
+        image_mean=None,
+        image_std=None,
+        # Anchor parameters
+        anchor_generator=None,
+        head=None,
+        proposal_matcher=None,
+        score_thresh=0.05,
+        nms_thresh=0.5,
+        detections_per_img=300,
+        fg_iou_thresh=0.5,
+        bg_iou_thresh=0.4,
+        topk_candidates=1000,
+        **kwargs,
+    ):
+        super().__init__()
+        _log_api_usage_once(self)
+
+        if not hasattr(backbone, "out_channels"):
+            raise ValueError(
+                "backbone should contain an attribute out_channels "
+                "specifying the number of output channels (assumed to be the "
+                "same for all the levels)"
+            )
+        self.backbone = backbone
+
+        if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
+            raise TypeError(
+                f"anchor_generator should be of type AnchorGenerator or None instead of {type(anchor_generator)}"
+            )
+
+        if anchor_generator is None:
+            anchor_generator = _default_anchorgen()
+        self.anchor_generator = anchor_generator
+
+        if head is None:
+            head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
+        self.head = head
+
+        if proposal_matcher is None:
+            proposal_matcher = det_utils.Matcher(
+                fg_iou_thresh,
+                bg_iou_thresh,
+                allow_low_quality_matches=True,
+            )
+        self.proposal_matcher = proposal_matcher
+
+        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
+
+        if image_mean is None:
+            image_mean = [0.485, 0.456, 0.406]
+        if image_std is None:
+            image_std = [0.229, 0.224, 0.225]
+        self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+        self.topk_candidates = topk_candidates
+
+        # used only on torchscript mode
+        self._has_warned = False
+
+    @torch.jit.unused
+    def eager_outputs(self, losses, detections):
+        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
+        if self.training:
+            return losses
+
+        return detections
+
+    def compute_loss(self, targets, head_outputs, anchors):
+        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
+        matched_idxs = []
+        for anchors_per_image, targets_per_image in zip(anchors, targets):
+            if targets_per_image["boxes"].numel() == 0:
+                matched_idxs.append(
+                    torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
+                )
+                continue
+
+            match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
+            matched_idxs.append(self.proposal_matcher(match_quality_matrix))
+
+        return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
+
+    def postprocess_detections(self, head_outputs, anchors, image_shapes):
+        # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
+        class_logits = head_outputs["cls_logits"]
+        box_regression = head_outputs["bbox_regression"]
+
+        num_images = len(image_shapes)
+
+        detections: List[Dict[str, Tensor]] = []
+
+        for index in range(num_images):
+            box_regression_per_image = [br[index] for br in box_regression]
+            logits_per_image = [cl[index] for cl in class_logits]
+            anchors_per_image, image_shape = anchors[index], image_shapes[index]
+
+            image_boxes = []
+            image_scores = []
+            image_labels = []
+
+            for box_regression_per_level, logits_per_level, anchors_per_level in zip(
+                box_regression_per_image, logits_per_image, anchors_per_image
+            ):
+                num_classes = logits_per_level.shape[-1]
+
+                # remove low scoring boxes
+                scores_per_level = torch.sigmoid(logits_per_level).flatten()
+                keep_idxs = scores_per_level > self.score_thresh
+                scores_per_level = scores_per_level[keep_idxs]
+                topk_idxs = torch.where(keep_idxs)[0]
+
+                # keep only topk scoring predictions
+                num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
+                scores_per_level, idxs = scores_per_level.topk(num_topk)
+                topk_idxs = topk_idxs[idxs]
+
+                anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
+                labels_per_level = topk_idxs % num_classes
+
+                boxes_per_level = self.box_coder.decode_single(
+                    box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
+                )
+                boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
+
+                image_boxes.append(boxes_per_level)
+                image_scores.append(scores_per_level)
+                image_labels.append(labels_per_level)
+
+            image_boxes = torch.cat(image_boxes, dim=0)
+            image_scores = torch.cat(image_scores, dim=0)
+            image_labels = torch.cat(image_labels, dim=0)
+
+            # non-maximum suppression
+            keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
+            keep = keep[: self.detections_per_img]
+
+            detections.append(
+                {
+                    "boxes": image_boxes[keep],
+                    "scores": image_scores[keep],
+                    "labels": image_labels[keep],
+                }
+            )
+
+        return detections
+
+    def forward(self, images, targets=None):
+        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
+        """
+        Args:
+            images (list[Tensor]): images to be processed
+            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
+
+        Returns:
+            result (list[BoxList] or dict[Tensor]): the output from the model.
+                During training, it returns a dict[Tensor] which contains the losses.
+                During testing, it returns list[BoxList] contains additional fields
+                like `scores`, `labels` and `mask` (for Mask R-CNN models).
+
+        """
+        if self.training:
+            if targets is None:
+                torch._assert(False, "targets should not be none when in training mode")
+            else:
+                for target in targets:
+                    boxes = target["boxes"]
+                    torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
+                    torch._assert(
+                        len(boxes.shape) == 2 and boxes.shape[-1] == 4,
+                        "Expected target boxes to be a tensor of shape [N, 4].",
+                    )
+
+        # get the original image sizes
+        original_image_sizes: List[Tuple[int, int]] = []
+        for img in images:
+            val = img.shape[-2:]
+            torch._assert(
+                len(val) == 2,
+                f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
+            )
+            original_image_sizes.append((val[0], val[1]))
+
+        # transform the input
+        images, targets = self.transform(images, targets)
+
+        # Check for degenerate boxes
+        # TODO: Move this to a function
+        if targets is not None:
+            for target_idx, target in enumerate(targets):
+                boxes = target["boxes"]
+                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
+                if degenerate_boxes.any():
+                    # print the first degenerate box
+                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
+                    degen_bb: List[float] = boxes[bb_idx].tolist()
+                    torch._assert(
+                        False,
+                        "All bounding boxes should have positive height and width."
+                        f" Found invalid box {degen_bb} for target at index {target_idx}.",
+                    )
+
+        # get the features from the backbone
+        features = self.backbone(images.tensors)
+        if isinstance(features, torch.Tensor):
+            features = OrderedDict([("0", features)])
+
+        # TODO: Do we want a list or a dict?
+        features = list(features.values())
+
+        # compute the retinanet heads outputs using the features
+        head_outputs = self.head(features)
+
+        # create the set of anchors
+        anchors = self.anchor_generator(images, features)
+
+        losses = {}
+        detections: List[Dict[str, Tensor]] = []
+        if self.training:
+            if targets is None:
+                torch._assert(False, "targets should not be none when in training mode")
+            else:
+                # compute the losses
+                losses = self.compute_loss(targets, head_outputs, anchors)
+        else:
+            # recover level sizes
+            num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
+            HW = 0
+            for v in num_anchors_per_level:
+                HW += v
+            HWA = head_outputs["cls_logits"].size(1)
+            A = HWA // HW
+            num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
+
+            # split outputs per level
+            split_head_outputs: Dict[str, List[Tensor]] = {}
+            for k in head_outputs:
+                split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
+            split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
+
+            # compute the detections
+            detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
+            detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
+
+        if torch.jit.is_scripting():
+            if not self._has_warned:
+                warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
+                self._has_warned = True
+            return losses, detections
+        return self.eager_outputs(losses, detections)
+
+
+_COMMON_META = {
+    "categories": _COCO_CATEGORIES,
+    "min_size": (1, 1),
+}
+
+
+class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 34014999,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 36.4,
+                }
+            },
+            "_ops": 151.54,
+            "_file_size": 130.267,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 38198935,
+            "recipe": "https://github.com/pytorch/vision/pull/5756",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 41.5,
+                }
+            },
+            "_ops": 152.238,
+            "_file_size": 146.037,
+            "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def retinanet_resnet50_fpn(
+    *,
+    weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> RetinaNet:
+    """
+    Constructs a RetinaNet model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detections:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each detection
+        - scores (``Tensor[N]``): the scores of each detection
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Example::
+
+        >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
+            the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    # skip P2 because it generates too many anchors (according to their paper)
+    backbone = _resnet_fpn_extractor(
+        backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
+    )
+    model = RetinaNet(backbone, num_classes, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def retinanet_resnet50_fpn_v2(
+    *,
+    weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[ResNet50_Weights] = None,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> RetinaNet:
+    """
+    Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
+    <https://arxiv.org/abs/1912.02424>`_.
+
+    :func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.
+
+    Args:
+        weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
+            the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
+        :members:
+    """
+    weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+
+    backbone = resnet50(weights=weights_backbone, progress=progress)
+    backbone = _resnet_fpn_extractor(
+        backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(2048, 256)
+    )
+    anchor_generator = _default_anchorgen()
+    head = RetinaNetHead(
+        backbone.out_channels,
+        anchor_generator.num_anchors_per_location()[0],
+        num_classes,
+        norm_layer=partial(nn.GroupNorm, 32),
+    )
+    head.regression_head._loss_type = "giou"
+    model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model

+ 876 - 0
libs/vision_libs/models/detection/roi_heads.py

@@ -0,0 +1,876 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from torchvision.ops import boxes as box_ops, roi_align
+
+from . import _utils as det_utils
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+    # type: (Tensor, List[Tensor]) -> List[Tensor]
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    Args:
+        x (Tensor): the mask logits
+        labels (list[BoxList]): bounding boxes that are used as
+            reference, one for ech image
+
+    Returns:
+        results (list[BoxList]): one BoxList for each image, containing
+            the extra field mask
+    """
+    mask_prob = x.sigmoid()
+
+    # select masks corresponding to the predicted classes
+    num_masks = x.shape[0]
+    boxes_per_image = [label.shape[0] for label in labels]
+    labels = torch.cat(labels)
+    index = torch.arange(num_masks, device=labels.device)
+    mask_prob = mask_prob[index, labels][:, None]
+    mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+    return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+    # type: (Tensor, Tensor, Tensor, int) -> Tensor
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+    """
+    matched_idxs = matched_idxs.to(boxes)
+    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+    gt_masks = gt_masks[:, None].to(rois)
+    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    """
+    Args:
+        proposals (list[BoxList])
+        mask_logits (Tensor)
+        targets (list[BoxList])
+
+    Return:
+        mask_loss (Tensor): scalar tensor containing the loss
+    """
+
+    discretization_size = mask_logits.shape[-1]
+    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+    mask_targets = [
+        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+    ]
+
+    labels = torch.cat(labels, dim=0)
+    mask_targets = torch.cat(mask_targets, dim=0)
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it separately
+    if mask_targets.numel() == 0:
+        return mask_logits.sum() * 0
+
+    mask_loss = F.binary_cross_entropy_with_logits(
+        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+    )
+    return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+    width_correction = widths_i / roi_map_width
+    height_correction = heights_i / roi_map_height
+
+    roi_map = F.interpolate(
+        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+    )[:, 0]
+
+    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+    x_int = pos % w
+    y_int = (pos - x_int) // w
+
+    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+        dtype=torch.float32
+    )
+    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+        dtype=torch.float32
+    )
+
+    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+    xy_preds_i = torch.stack(
+        [
+            xy_preds_i_0.to(dtype=torch.float32),
+            xy_preds_i_1.to(dtype=torch.float32),
+            xy_preds_i_2.to(dtype=torch.float32),
+        ],
+        0,
+    )
+
+    # TODO: simplify when indexing without rank will be supported by ONNX
+    base = num_keypoints * num_keypoints + num_keypoints + 1
+    ind = torch.arange(num_keypoints)
+    ind = ind.to(dtype=torch.int64) * base
+    end_scores_i = (
+        roi_map.index_select(1, y_int.to(dtype=torch.int64))
+        .index_select(2, x_int.to(dtype=torch.int64))
+        .view(-1)
+        .index_select(0, ind.to(dtype=torch.int64))
+    )
+
+    return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+    for i in range(int(rois.size(0))):
+        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+        )
+        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+        end_scores = torch.cat(
+            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+        )
+    return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    if torchvision._is_tracing():
+        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+            maps,
+            rois,
+            widths_ceil,
+            heights_ceil,
+            widths,
+            heights,
+            offset_x,
+            offset_y,
+            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+        )
+        return xy_preds.permute(0, 2, 1), end_scores
+
+    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        x_int = pos % w
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = keypoint_logits.shape
+    if H != W:
+        raise ValueError(
+            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+        kp = gt_kp_in_image[midx]
+        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+        heatmaps.append(heatmaps_per_image.view(-1))
+        valid.append(valid_per_image.view(-1))
+
+    keypoint_targets = torch.cat(heatmaps, dim=0)
+    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    valid = torch.where(valid)[0]
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    if keypoint_targets.numel() == 0 or len(valid) == 0:
+        return keypoint_logits.sum() * 0
+
+    keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+    return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    kp_probs = []
+    kp_scores = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, boxes):
+        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_probs.append(kp_prob)
+        kp_scores.append(scores)
+
+    return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half = w_half.to(dtype=torch.float32) * scale
+    h_half = h_half.to(dtype=torch.float32) * scale
+
+    boxes_exp0 = x_c - w_half
+    boxes_exp1 = y_c - h_half
+    boxes_exp2 = x_c + w_half
+    boxes_exp3 = y_c + h_half
+    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+    return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    if torchvision._is_tracing():
+        return _onnx_expand_boxes(boxes, scale)
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+    # type: (int, int) -> float
+    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+    # type: (Tensor, int) -> Tuple[Tensor, float]
+    M = mask.shape[-1]
+    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
+        scale = expand_masks_tracing_scale(M, padding)
+    else:
+        scale = float(M + 2 * padding) / M
+    padded_mask = F.pad(mask, (padding,) * 4)
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+    # type: (Tensor, Tensor, int, int) -> Tensor
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+    return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+    one = torch.ones(1, dtype=torch.int64)
+    zero = torch.zeros(1, dtype=torch.int64)
+
+    w = box[2] - box[0] + one
+    h = box[3] - box[1] + one
+    w = torch.max(torch.cat((w, one)))
+    h = torch.max(torch.cat((h, one)))
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+
+    # TODO : replace below with a dynamic padding when support is added in ONNX
+
+    # pad y
+    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+    # pad x
+    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+    return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+    res_append = torch.zeros(0, im_h, im_w)
+    for i in range(masks.size(0)):
+        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+        mask_res = mask_res.unsqueeze(0)
+        res_append = torch.cat((res_append, mask_res))
+    return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+    masks, scale = expand_masks(masks, padding=padding)
+    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+    im_h, im_w = img_shape
+
+    if torchvision._is_tracing():
+        return _onnx_paste_masks_in_image_loop(
+            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+        )[:, None]
+    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+    if len(res) > 0:
+        ret = torch.stack(res, dim=0)[:, None]
+    else:
+        ret = masks.new_empty((0, 1, im_h, im_w))
+    return ret
+
+
+class RoIHeads(nn.Module):
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+        self,
+        box_roi_pool,
+        box_head,
+        box_predictor,
+        # Faster R-CNN training
+        fg_iou_thresh,
+        bg_iou_thresh,
+        batch_size_per_image,
+        positive_fraction,
+        bbox_reg_weights,
+        # Faster R-CNN inference
+        score_thresh,
+        nms_thresh,
+        detections_per_img,
+        # Mask
+        mask_roi_pool=None,
+        mask_head=None,
+        mask_predictor=None,
+        keypoint_roi_pool=None,
+        keypoint_head=None,
+        keypoint_predictor=None,
+    ):
+        super().__init__()
+
+        self.box_similarity = box_ops.box_iou
+        # assign ground-truth boxes for each proposal
+        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+        if bbox_reg_weights is None:
+            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+        self.box_roi_pool = box_roi_pool
+        self.box_head = box_head
+        self.box_predictor = box_predictor
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+
+        self.mask_roi_pool = mask_roi_pool
+        self.mask_head = mask_head
+        self.mask_predictor = mask_predictor
+
+        self.keypoint_roi_pool = keypoint_roi_pool
+        self.keypoint_head = keypoint_head
+        self.keypoint_predictor = keypoint_predictor
+
+    def has_mask(self):
+        if self.mask_roi_pool is None:
+            return False
+        if self.mask_head is None:
+            return False
+        if self.mask_predictor is None:
+            return False
+        return True
+
+    def has_keypoint(self):
+        if self.keypoint_roi_pool is None:
+            return False
+        if self.keypoint_head is None:
+            return False
+        if self.keypoint_predictor is None:
+            return False
+        return True
+
+    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+        matched_idxs = []
+        labels = []
+        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+            if gt_boxes_in_image.numel() == 0:
+                # Background image
+                device = proposals_in_image.device
+                clamped_matched_idxs_in_image = torch.zeros(
+                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+                )
+                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+            else:
+                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+                labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+                # Label background (below the low threshold)
+                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_in_image[bg_inds] = 0
+
+                # Label ignore proposals (between low and high thresholds)
+                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            matched_idxs.append(clamped_matched_idxs_in_image)
+            labels.append(labels_in_image)
+        return matched_idxs, labels
+
+    def subsample(self, labels):
+        # type: (List[Tensor]) -> List[Tensor]
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_inds = []
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+            sampled_inds.append(img_sampled_inds)
+        return sampled_inds
+
+    def add_gt_proposals(self, proposals, gt_boxes):
+        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+        return proposals
+
+    def check_targets(self, targets):
+        # type: (Optional[List[Dict[str, Tensor]]]) -> None
+        if targets is None:
+            raise ValueError("targets should not be None")
+        if not all(["boxes" in t for t in targets]):
+            raise ValueError("Every element of targets should have a boxes key")
+        if not all(["labels" in t for t in targets]):
+            raise ValueError("Every element of targets should have a labels key")
+        if self.has_mask():
+            if not all(["masks" in t for t in targets]):
+                raise ValueError("Every element of targets should have a masks key")
+
+    def select_training_samples(
+        self,
+        proposals,  # type: List[Tensor]
+        targets,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+        self.check_targets(targets)
+        if targets is None:
+            raise ValueError("targets should not be None")
+        dtype = proposals[0].dtype
+        device = proposals[0].device
+
+        gt_boxes = [t["boxes"].to(dtype) for t in targets]
+        gt_labels = [t["labels"] for t in targets]
+
+        # append ground-truth bboxes to propos
+        proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+        # get matching gt indices for each proposal
+        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+        # sample a fixed proportion of positive-negative proposals
+        sampled_inds = self.subsample(labels)
+        matched_gt_boxes = []
+        num_images = len(proposals)
+        for img_id in range(num_images):
+            img_sampled_inds = sampled_inds[img_id]
+            proposals[img_id] = proposals[img_id][img_sampled_inds]
+            labels[img_id] = labels[img_id][img_sampled_inds]
+            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+            gt_boxes_in_image = gt_boxes[img_id]
+            if gt_boxes_in_image.numel() == 0:
+                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+        return proposals, matched_idxs, labels, regression_targets
+
+    def postprocess_detections(
+        self,
+        class_logits,  # type: Tensor
+        box_regression,  # type: Tensor
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+        device = class_logits.device
+        num_classes = class_logits.shape[-1]
+
+        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+        pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+        pred_scores = F.softmax(class_logits, -1)
+
+        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+        pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+        all_boxes = []
+        all_scores = []
+        all_labels = []
+        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            # create labels for each prediction
+            labels = torch.arange(num_classes, device=device)
+            labels = labels.view(1, -1).expand_as(scores)
+
+            # remove predictions with the background label
+            boxes = boxes[:, 1:]
+            scores = scores[:, 1:]
+            labels = labels[:, 1:]
+
+            # batch everything, by making every class prediction be a separate instance
+            boxes = boxes.reshape(-1, 4)
+            scores = scores.reshape(-1)
+            labels = labels.reshape(-1)
+
+            # remove low scoring boxes
+            inds = torch.where(scores > self.score_thresh)[0]
+            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+            # remove empty boxes
+            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            # non-maximum suppression, independently done per class
+            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+            # keep only topk scoring predictions
+            keep = keep[: self.detections_per_img]
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            all_boxes.append(boxes)
+            all_scores.append(scores)
+            all_labels.append(labels)
+
+        return all_boxes, all_scores, all_labels
+
+    def forward(
+        self,
+        features,  # type: Dict[str, Tensor]
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+        """
+        Args:
+            features (List[Tensor])
+            proposals (List[Tensor[N, 4]])
+            image_shapes (List[Tuple[H, W]])
+            targets (List[Dict])
+        """
+        if targets is not None:
+            for t in targets:
+                # TODO: https://github.com/pytorch/pytorch/issues/26731
+                floating_point_types = (torch.float, torch.double, torch.half)
+                if not t["boxes"].dtype in floating_point_types:
+                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+                if not t["labels"].dtype == torch.int64:
+                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+                if self.has_keypoint():
+                    if not t["keypoints"].dtype == torch.float32:
+                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+        if self.training:
+            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+        else:
+            labels = None
+            regression_targets = None
+            matched_idxs = None
+
+        box_features = self.box_roi_pool(features, proposals, image_shapes)
+        box_features = self.box_head(box_features)
+        class_logits, box_regression = self.box_predictor(box_features)
+
+        result: List[Dict[str, torch.Tensor]] = []
+        losses = {}
+        if self.training:
+            if labels is None:
+                raise ValueError("labels cannot be None")
+            if regression_targets is None:
+                raise ValueError("regression_targets cannot be None")
+            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        else:
+            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+            num_images = len(boxes)
+            for i in range(num_images):
+                result.append(
+                    {
+                        "boxes": boxes[i],
+                        "labels": labels[i],
+                        "scores": scores[i],
+                    }
+                )
+
+        if self.has_mask():
+            mask_proposals = [p["boxes"] for p in result]
+            if self.training:
+                if matched_idxs is None:
+                    raise ValueError("if in training, matched_idxs should not be None")
+
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                mask_proposals = []
+                pos_matched_idxs = []
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    mask_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            if self.mask_roi_pool is not None:
+                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+                mask_features = self.mask_head(mask_features)
+                mask_logits = self.mask_predictor(mask_features)
+            else:
+                raise Exception("Expected mask_roi_pool to be not None")
+
+            loss_mask = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None or mask_logits is None:
+                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+                gt_masks = [t["masks"] for t in targets]
+                gt_labels = [t["labels"] for t in targets]
+                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+                loss_mask = {"loss_mask": rcnn_loss_mask}
+            else:
+                labels = [r["labels"] for r in result]
+                masks_probs = maskrcnn_inference(mask_logits, labels)
+                for mask_prob, r in zip(masks_probs, result):
+                    r["masks"] = mask_prob
+
+            losses.update(loss_mask)
+
+        # keep none checks in if conditional so torchscript will conditionally
+        # compile each branch
+        if (
+            self.keypoint_roi_pool is not None
+            and self.keypoint_head is not None
+            and self.keypoint_predictor is not None
+        ):
+            keypoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                keypoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    keypoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+            keypoint_features = self.keypoint_head(keypoint_features)
+            keypoint_logits = self.keypoint_predictor(keypoint_features)
+
+            loss_keypoint = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+            else:
+                if keypoint_logits is None or keypoint_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
+
+                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
+
+        return result, losses

+ 387 - 0
libs/vision_libs/models/detection/rpn.py

@@ -0,0 +1,387 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+from torchvision.ops import boxes as box_ops, Conv2dNormActivation
+
+from . import _utils as det_utils
+
+# Import AnchorGenerator to keep compatibility.
+from .anchor_utils import AnchorGenerator  # noqa: 401
+from .image_list import ImageList
+
+
+class RPNHead(nn.Module):
+    """
+    Adds a simple RPN Head with classification and regression heads
+
+    Args:
+        in_channels (int): number of channels of the input feature
+        num_anchors (int): number of anchors to be predicted
+        conv_depth (int, optional): number of convolutions
+    """
+
+    _version = 2
+
+    def __init__(self, in_channels: int, num_anchors: int, conv_depth=1) -> None:
+        super().__init__()
+        convs = []
+        for _ in range(conv_depth):
+            convs.append(Conv2dNormActivation(in_channels, in_channels, kernel_size=3, norm_layer=None))
+        self.conv = nn.Sequential(*convs)
+        self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
+        self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
+
+        for layer in self.modules():
+            if isinstance(layer, nn.Conv2d):
+                torch.nn.init.normal_(layer.weight, std=0.01)  # type: ignore[arg-type]
+                if layer.bias is not None:
+                    torch.nn.init.constant_(layer.bias, 0)  # type: ignore[arg-type]
+
+    def _load_from_state_dict(
+        self,
+        state_dict,
+        prefix,
+        local_metadata,
+        strict,
+        missing_keys,
+        unexpected_keys,
+        error_msgs,
+    ):
+        version = local_metadata.get("version", None)
+
+        if version is None or version < 2:
+            for type in ["weight", "bias"]:
+                old_key = f"{prefix}conv.{type}"
+                new_key = f"{prefix}conv.0.0.{type}"
+                if old_key in state_dict:
+                    state_dict[new_key] = state_dict.pop(old_key)
+
+        super()._load_from_state_dict(
+            state_dict,
+            prefix,
+            local_metadata,
+            strict,
+            missing_keys,
+            unexpected_keys,
+            error_msgs,
+        )
+
+    def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
+        logits = []
+        bbox_reg = []
+        for feature in x:
+            t = self.conv(feature)
+            logits.append(self.cls_logits(t))
+            bbox_reg.append(self.bbox_pred(t))
+        return logits, bbox_reg
+
+
+def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
+    layer = layer.view(N, -1, C, H, W)
+    layer = layer.permute(0, 3, 4, 1, 2)
+    layer = layer.reshape(N, -1, C)
+    return layer
+
+
+def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]:
+    box_cls_flattened = []
+    box_regression_flattened = []
+    # for each feature level, permute the outputs to make them be in the
+    # same format as the labels. Note that the labels are computed for
+    # all feature levels concatenated, so we keep the same representation
+    # for the objectness and the box_regression
+    for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
+        N, AxC, H, W = box_cls_per_level.shape
+        Ax4 = box_regression_per_level.shape[1]
+        A = Ax4 // 4
+        C = AxC // A
+        box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
+        box_cls_flattened.append(box_cls_per_level)
+
+        box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
+        box_regression_flattened.append(box_regression_per_level)
+    # concatenate on the first dimension (representing the feature levels), to
+    # take into account the way the labels were generated (with all feature maps
+    # being concatenated as well)
+    box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
+    box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
+    return box_cls, box_regression
+
+
+class RegionProposalNetwork(torch.nn.Module):
+    """
+    Implements Region Proposal Network (RPN).
+
+    Args:
+        anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+            maps.
+        head (nn.Module): module that computes the objectness and regression deltas
+        fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+            considered as positive during training of the RPN.
+        bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+            considered as negative during training of the RPN.
+        batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+            for computing the loss
+        positive_fraction (float): proportion of positive anchors in a mini-batch during training
+            of the RPN
+        pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
+            contain two fields: training and testing, to allow for different values depending
+            on training or evaluation
+        post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
+            contain two fields: training and testing, to allow for different values depending
+            on training or evaluation
+        nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+
+    """
+
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+        self,
+        anchor_generator: AnchorGenerator,
+        head: nn.Module,
+        # Faster-RCNN Training
+        fg_iou_thresh: float,
+        bg_iou_thresh: float,
+        batch_size_per_image: int,
+        positive_fraction: float,
+        # Faster-RCNN Inference
+        pre_nms_top_n: Dict[str, int],
+        post_nms_top_n: Dict[str, int],
+        nms_thresh: float,
+        score_thresh: float = 0.0,
+    ) -> None:
+        super().__init__()
+        self.anchor_generator = anchor_generator
+        self.head = head
+        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
+
+        # used during training
+        self.box_similarity = box_ops.box_iou
+
+        self.proposal_matcher = det_utils.Matcher(
+            fg_iou_thresh,
+            bg_iou_thresh,
+            allow_low_quality_matches=True,
+        )
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+        # used during testing
+        self._pre_nms_top_n = pre_nms_top_n
+        self._post_nms_top_n = post_nms_top_n
+        self.nms_thresh = nms_thresh
+        self.score_thresh = score_thresh
+        self.min_size = 1e-3
+
+    def pre_nms_top_n(self) -> int:
+        if self.training:
+            return self._pre_nms_top_n["training"]
+        return self._pre_nms_top_n["testing"]
+
+    def post_nms_top_n(self) -> int:
+        if self.training:
+            return self._post_nms_top_n["training"]
+        return self._post_nms_top_n["testing"]
+
+    def assign_targets_to_anchors(
+        self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]
+    ) -> Tuple[List[Tensor], List[Tensor]]:
+
+        labels = []
+        matched_gt_boxes = []
+        for anchors_per_image, targets_per_image in zip(anchors, targets):
+            gt_boxes = targets_per_image["boxes"]
+
+            if gt_boxes.numel() == 0:
+                # Background image (negative example)
+                device = anchors_per_image.device
+                matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
+                labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
+            else:
+                match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
+                matched_idxs = self.proposal_matcher(match_quality_matrix)
+                # get the targets corresponding GT for each proposal
+                # NB: need to clamp the indices because we can have a single
+                # GT in the image, and matched_idxs can be -2, which goes
+                # out of bounds
+                matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
+
+                labels_per_image = matched_idxs >= 0
+                labels_per_image = labels_per_image.to(dtype=torch.float32)
+
+                # Background (negative examples)
+                bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_per_image[bg_indices] = 0.0
+
+                # discard indices that are between thresholds
+                inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_per_image[inds_to_discard] = -1.0
+
+            labels.append(labels_per_image)
+            matched_gt_boxes.append(matched_gt_boxes_per_image)
+        return labels, matched_gt_boxes
+
+    def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor:
+        r = []
+        offset = 0
+        for ob in objectness.split(num_anchors_per_level, 1):
+            num_anchors = ob.shape[1]
+            pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
+            _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
+            r.append(top_n_idx + offset)
+            offset += num_anchors
+        return torch.cat(r, dim=1)
+
+    def filter_proposals(
+        self,
+        proposals: Tensor,
+        objectness: Tensor,
+        image_shapes: List[Tuple[int, int]],
+        num_anchors_per_level: List[int],
+    ) -> Tuple[List[Tensor], List[Tensor]]:
+
+        num_images = proposals.shape[0]
+        device = proposals.device
+        # do not backprop through objectness
+        objectness = objectness.detach()
+        objectness = objectness.reshape(num_images, -1)
+
+        levels = [
+            torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)
+        ]
+        levels = torch.cat(levels, 0)
+        levels = levels.reshape(1, -1).expand_as(objectness)
+
+        # select top_n boxes independently per level before applying nms
+        top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
+
+        image_range = torch.arange(num_images, device=device)
+        batch_idx = image_range[:, None]
+
+        objectness = objectness[batch_idx, top_n_idx]
+        levels = levels[batch_idx, top_n_idx]
+        proposals = proposals[batch_idx, top_n_idx]
+
+        objectness_prob = torch.sigmoid(objectness)
+
+        final_boxes = []
+        final_scores = []
+        for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
+
+            # remove small boxes
+            keep = box_ops.remove_small_boxes(boxes, self.min_size)
+            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
+
+            # remove low scoring boxes
+            # use >= for Backwards compatibility
+            keep = torch.where(scores >= self.score_thresh)[0]
+            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
+
+            # non-maximum suppression, independently done per level
+            keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
+
+            # keep only topk scoring predictions
+            keep = keep[: self.post_nms_top_n()]
+            boxes, scores = boxes[keep], scores[keep]
+
+            final_boxes.append(boxes)
+            final_scores.append(scores)
+        return final_boxes, final_scores
+
+    def compute_loss(
+        self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]
+    ) -> Tuple[Tensor, Tensor]:
+        """
+        Args:
+            objectness (Tensor)
+            pred_bbox_deltas (Tensor)
+            labels (List[Tensor])
+            regression_targets (List[Tensor])
+
+        Returns:
+            objectness_loss (Tensor)
+            box_loss (Tensor)
+        """
+
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
+        sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
+
+        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
+
+        objectness = objectness.flatten()
+
+        labels = torch.cat(labels, dim=0)
+        regression_targets = torch.cat(regression_targets, dim=0)
+
+        box_loss = F.smooth_l1_loss(
+            pred_bbox_deltas[sampled_pos_inds],
+            regression_targets[sampled_pos_inds],
+            beta=1 / 9,
+            reduction="sum",
+        ) / (sampled_inds.numel())
+
+        objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])
+
+        return objectness_loss, box_loss
+
+    def forward(
+        self,
+        images: ImageList,
+        features: Dict[str, Tensor],
+        targets: Optional[List[Dict[str, Tensor]]] = None,
+    ) -> Tuple[List[Tensor], Dict[str, Tensor]]:
+
+        """
+        Args:
+            images (ImageList): images for which we want to compute the predictions
+            features (Dict[str, Tensor]): features computed from the images that are
+                used for computing the predictions. Each tensor in the list
+                correspond to different feature levels
+            targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
+                If provided, each element in the dict should contain a field `boxes`,
+                with the locations of the ground-truth boxes.
+
+        Returns:
+            boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
+                image.
+            losses (Dict[str, Tensor]): the losses for the model during training. During
+                testing, it is an empty dict.
+        """
+        # RPN uses all feature maps that are available
+        features = list(features.values())
+        objectness, pred_bbox_deltas = self.head(features)
+        anchors = self.anchor_generator(images, features)
+
+        num_images = len(anchors)
+        num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
+        num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
+        objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
+        # apply pred_bbox_deltas to anchors to obtain the decoded proposals
+        # note that we detach the deltas because Faster R-CNN do not backprop through
+        # the proposals
+        proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
+        proposals = proposals.view(num_images, -1, 4)
+        boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
+
+        losses = {}
+        if self.training:
+            if targets is None:
+                raise ValueError("targets should not be None")
+            labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
+            regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
+            loss_objectness, loss_rpn_box_reg = self.compute_loss(
+                objectness, pred_bbox_deltas, labels, regression_targets
+            )
+            losses = {
+                "loss_objectness": loss_objectness,
+                "loss_rpn_box_reg": loss_rpn_box_reg,
+            }
+        return boxes, losses

+ 682 - 0
libs/vision_libs/models/detection/ssd.py

@@ -0,0 +1,682 @@
+import warnings
+from collections import OrderedDict
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+from ...ops import boxes as box_ops
+from ...transforms._presets import ObjectDetection
+from ...utils import _log_api_usage_once
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..vgg import VGG, vgg16, VGG16_Weights
+from . import _utils as det_utils
+from .anchor_utils import DefaultBoxGenerator
+from .backbone_utils import _validate_trainable_layers
+from .transform import GeneralizedRCNNTransform
+
+
+__all__ = [
+    "SSD300_VGG16_Weights",
+    "ssd300_vgg16",
+]
+
+
+class SSD300_VGG16_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
+        transforms=ObjectDetection,
+        meta={
+            "num_params": 35641826,
+            "categories": _COCO_CATEGORIES,
+            "min_size": (1, 1),
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 25.1,
+                }
+            },
+            "_ops": 34.858,
+            "_file_size": 135.988,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+def _xavier_init(conv: nn.Module):
+    for layer in conv.modules():
+        if isinstance(layer, nn.Conv2d):
+            torch.nn.init.xavier_uniform_(layer.weight)
+            if layer.bias is not None:
+                torch.nn.init.constant_(layer.bias, 0.0)
+
+
+class SSDHead(nn.Module):
+    def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
+        super().__init__()
+        self.classification_head = SSDClassificationHead(in_channels, num_anchors, num_classes)
+        self.regression_head = SSDRegressionHead(in_channels, num_anchors)
+
+    def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
+        return {
+            "bbox_regression": self.regression_head(x),
+            "cls_logits": self.classification_head(x),
+        }
+
+
+class SSDScoringHead(nn.Module):
+    def __init__(self, module_list: nn.ModuleList, num_columns: int):
+        super().__init__()
+        self.module_list = module_list
+        self.num_columns = num_columns
+
+    def _get_result_from_module_list(self, x: Tensor, idx: int) -> Tensor:
+        """
+        This is equivalent to self.module_list[idx](x),
+        but torchscript doesn't support this yet
+        """
+        num_blocks = len(self.module_list)
+        if idx < 0:
+            idx += num_blocks
+        out = x
+        for i, module in enumerate(self.module_list):
+            if i == idx:
+                out = module(x)
+        return out
+
+    def forward(self, x: List[Tensor]) -> Tensor:
+        all_results = []
+
+        for i, features in enumerate(x):
+            results = self._get_result_from_module_list(features, i)
+
+            # Permute output from (N, A * K, H, W) to (N, HWA, K).
+            N, _, H, W = results.shape
+            results = results.view(N, -1, self.num_columns, H, W)
+            results = results.permute(0, 3, 4, 1, 2)
+            results = results.reshape(N, -1, self.num_columns)  # Size=(N, HWA, K)
+
+            all_results.append(results)
+
+        return torch.cat(all_results, dim=1)
+
+
+class SSDClassificationHead(SSDScoringHead):
+    def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
+        cls_logits = nn.ModuleList()
+        for channels, anchors in zip(in_channels, num_anchors):
+            cls_logits.append(nn.Conv2d(channels, num_classes * anchors, kernel_size=3, padding=1))
+        _xavier_init(cls_logits)
+        super().__init__(cls_logits, num_classes)
+
+
+class SSDRegressionHead(SSDScoringHead):
+    def __init__(self, in_channels: List[int], num_anchors: List[int]):
+        bbox_reg = nn.ModuleList()
+        for channels, anchors in zip(in_channels, num_anchors):
+            bbox_reg.append(nn.Conv2d(channels, 4 * anchors, kernel_size=3, padding=1))
+        _xavier_init(bbox_reg)
+        super().__init__(bbox_reg, 4)
+
+
+class SSD(nn.Module):
+    """
+    Implements SSD architecture from `"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes, but they will be resized
+    to a fixed size before passing it to the backbone.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+    The model returns a Dict[Tensor] during training, containing the classification and regression
+    losses.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows, where ``N`` is the number of detections:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each detection
+        - scores (Tensor[N]): the scores for each detection
+
+    Args:
+        backbone (nn.Module): the network used to compute the features for the model.
+            It should contain an out_channels attribute with the list of the output channels of
+            each feature map. The backbone should return a single Tensor or an OrderedDict[Tensor].
+        anchor_generator (DefaultBoxGenerator): module that generates the default boxes for a
+            set of feature maps.
+        size (Tuple[int, int]): the width and height to which images will be rescaled before feeding them
+            to the backbone.
+        num_classes (int): number of output classes of the model (including the background).
+        image_mean (Tuple[float, float, float]): mean values used for input normalization.
+            They are generally the mean values of the dataset on which the backbone has been trained
+            on
+        image_std (Tuple[float, float, float]): std values used for input normalization.
+            They are generally the std values of the dataset on which the backbone has been trained on
+        head (nn.Module, optional): Module run on top of the backbone features. Defaults to a module containing
+            a classification and regression module.
+        score_thresh (float): Score threshold used for postprocessing the detections.
+        nms_thresh (float): NMS threshold used for postprocessing the detections.
+        detections_per_img (int): Number of best detections to keep after NMS.
+        iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+            considered as positive during training.
+        topk_candidates (int): Number of best detections to keep before NMS.
+        positive_fraction (float): a number between 0 and 1 which indicates the proportion of positive
+            proposals used during the training of the classification head. It is used to estimate the negative to
+            positive ratio.
+    """
+
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+    }
+
+    def __init__(
+        self,
+        backbone: nn.Module,
+        anchor_generator: DefaultBoxGenerator,
+        size: Tuple[int, int],
+        num_classes: int,
+        image_mean: Optional[List[float]] = None,
+        image_std: Optional[List[float]] = None,
+        head: Optional[nn.Module] = None,
+        score_thresh: float = 0.01,
+        nms_thresh: float = 0.45,
+        detections_per_img: int = 200,
+        iou_thresh: float = 0.5,
+        topk_candidates: int = 400,
+        positive_fraction: float = 0.25,
+        **kwargs: Any,
+    ):
+        super().__init__()
+        _log_api_usage_once(self)
+
+        self.backbone = backbone
+
+        self.anchor_generator = anchor_generator
+
+        self.box_coder = det_utils.BoxCoder(weights=(10.0, 10.0, 5.0, 5.0))
+
+        if head is None:
+            if hasattr(backbone, "out_channels"):
+                out_channels = backbone.out_channels
+            else:
+                out_channels = det_utils.retrieve_out_channels(backbone, size)
+
+            if len(out_channels) != len(anchor_generator.aspect_ratios):
+                raise ValueError(
+                    f"The length of the output channels from the backbone ({len(out_channels)}) do not match the length of the anchor generator aspect ratios ({len(anchor_generator.aspect_ratios)})"
+                )
+
+            num_anchors = self.anchor_generator.num_anchors_per_location()
+            head = SSDHead(out_channels, num_anchors, num_classes)
+        self.head = head
+
+        self.proposal_matcher = det_utils.SSDMatcher(iou_thresh)
+
+        if image_mean is None:
+            image_mean = [0.485, 0.456, 0.406]
+        if image_std is None:
+            image_std = [0.229, 0.224, 0.225]
+        self.transform = GeneralizedRCNNTransform(
+            min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size, **kwargs
+        )
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+        self.topk_candidates = topk_candidates
+        self.neg_to_pos_ratio = (1.0 - positive_fraction) / positive_fraction
+
+        # used only on torchscript mode
+        self._has_warned = False
+
+    @torch.jit.unused
+    def eager_outputs(
+        self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
+    ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
+        if self.training:
+            return losses
+
+        return detections
+
+    def compute_loss(
+        self,
+        targets: List[Dict[str, Tensor]],
+        head_outputs: Dict[str, Tensor],
+        anchors: List[Tensor],
+        matched_idxs: List[Tensor],
+    ) -> Dict[str, Tensor]:
+        bbox_regression = head_outputs["bbox_regression"]
+        cls_logits = head_outputs["cls_logits"]
+
+        # Match original targets with default boxes
+        num_foreground = 0
+        bbox_loss = []
+        cls_targets = []
+        for (
+            targets_per_image,
+            bbox_regression_per_image,
+            cls_logits_per_image,
+            anchors_per_image,
+            matched_idxs_per_image,
+        ) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs):
+            # produce the matching between boxes and targets
+            foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
+            foreground_matched_idxs_per_image = matched_idxs_per_image[foreground_idxs_per_image]
+            num_foreground += foreground_matched_idxs_per_image.numel()
+
+            # Calculate regression loss
+            matched_gt_boxes_per_image = targets_per_image["boxes"][foreground_matched_idxs_per_image]
+            bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
+            anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
+            target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+            bbox_loss.append(
+                torch.nn.functional.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
+            )
+
+            # Estimate ground truth for class targets
+            gt_classes_target = torch.zeros(
+                (cls_logits_per_image.size(0),),
+                dtype=targets_per_image["labels"].dtype,
+                device=targets_per_image["labels"].device,
+            )
+            gt_classes_target[foreground_idxs_per_image] = targets_per_image["labels"][
+                foreground_matched_idxs_per_image
+            ]
+            cls_targets.append(gt_classes_target)
+
+        bbox_loss = torch.stack(bbox_loss)
+        cls_targets = torch.stack(cls_targets)
+
+        # Calculate classification loss
+        num_classes = cls_logits.size(-1)
+        cls_loss = F.cross_entropy(cls_logits.view(-1, num_classes), cls_targets.view(-1), reduction="none").view(
+            cls_targets.size()
+        )
+
+        # Hard Negative Sampling
+        foreground_idxs = cls_targets > 0
+        num_negative = self.neg_to_pos_ratio * foreground_idxs.sum(1, keepdim=True)
+        # num_negative[num_negative < self.neg_to_pos_ratio] = self.neg_to_pos_ratio
+        negative_loss = cls_loss.clone()
+        negative_loss[foreground_idxs] = -float("inf")  # use -inf to detect positive values that creeped in the sample
+        values, idx = negative_loss.sort(1, descending=True)
+        # background_idxs = torch.logical_and(idx.sort(1)[1] < num_negative, torch.isfinite(values))
+        background_idxs = idx.sort(1)[1] < num_negative
+
+        N = max(1, num_foreground)
+        return {
+            "bbox_regression": bbox_loss.sum() / N,
+            "classification": (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N,
+        }
+
+    def forward(
+        self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
+    ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
+        if self.training:
+            if targets is None:
+                torch._assert(False, "targets should not be none when in training mode")
+            else:
+                for target in targets:
+                    boxes = target["boxes"]
+                    if isinstance(boxes, torch.Tensor):
+                        torch._assert(
+                            len(boxes.shape) == 2 and boxes.shape[-1] == 4,
+                            f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
+                        )
+                    else:
+                        torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
+
+        # get the original image sizes
+        original_image_sizes: List[Tuple[int, int]] = []
+        for img in images:
+            val = img.shape[-2:]
+            torch._assert(
+                len(val) == 2,
+                f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
+            )
+            original_image_sizes.append((val[0], val[1]))
+
+        # transform the input
+        images, targets = self.transform(images, targets)
+
+        # Check for degenerate boxes
+        if targets is not None:
+            for target_idx, target in enumerate(targets):
+                boxes = target["boxes"]
+                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
+                if degenerate_boxes.any():
+                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
+                    degen_bb: List[float] = boxes[bb_idx].tolist()
+                    torch._assert(
+                        False,
+                        "All bounding boxes should have positive height and width."
+                        f" Found invalid box {degen_bb} for target at index {target_idx}.",
+                    )
+
+        # get the features from the backbone
+        features = self.backbone(images.tensors)
+        if isinstance(features, torch.Tensor):
+            features = OrderedDict([("0", features)])
+
+        features = list(features.values())
+
+        # compute the ssd heads outputs using the features
+        head_outputs = self.head(features)
+
+        # create the set of anchors
+        anchors = self.anchor_generator(images, features)
+
+        losses = {}
+        detections: List[Dict[str, Tensor]] = []
+        if self.training:
+            matched_idxs = []
+            if targets is None:
+                torch._assert(False, "targets should not be none when in training mode")
+            else:
+                for anchors_per_image, targets_per_image in zip(anchors, targets):
+                    if targets_per_image["boxes"].numel() == 0:
+                        matched_idxs.append(
+                            torch.full(
+                                (anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device
+                            )
+                        )
+                        continue
+
+                    match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
+                    matched_idxs.append(self.proposal_matcher(match_quality_matrix))
+
+                losses = self.compute_loss(targets, head_outputs, anchors, matched_idxs)
+        else:
+            detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes)
+            detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
+
+        if torch.jit.is_scripting():
+            if not self._has_warned:
+                warnings.warn("SSD always returns a (Losses, Detections) tuple in scripting")
+                self._has_warned = True
+            return losses, detections
+        return self.eager_outputs(losses, detections)
+
+    def postprocess_detections(
+        self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor], image_shapes: List[Tuple[int, int]]
+    ) -> List[Dict[str, Tensor]]:
+        bbox_regression = head_outputs["bbox_regression"]
+        pred_scores = F.softmax(head_outputs["cls_logits"], dim=-1)
+
+        num_classes = pred_scores.size(-1)
+        device = pred_scores.device
+
+        detections: List[Dict[str, Tensor]] = []
+
+        for boxes, scores, anchors, image_shape in zip(bbox_regression, pred_scores, image_anchors, image_shapes):
+            boxes = self.box_coder.decode_single(boxes, anchors)
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            image_boxes = []
+            image_scores = []
+            image_labels = []
+            for label in range(1, num_classes):
+                score = scores[:, label]
+
+                keep_idxs = score > self.score_thresh
+                score = score[keep_idxs]
+                box = boxes[keep_idxs]
+
+                # keep only topk scoring predictions
+                num_topk = det_utils._topk_min(score, self.topk_candidates, 0)
+                score, idxs = score.topk(num_topk)
+                box = box[idxs]
+
+                image_boxes.append(box)
+                image_scores.append(score)
+                image_labels.append(torch.full_like(score, fill_value=label, dtype=torch.int64, device=device))
+
+            image_boxes = torch.cat(image_boxes, dim=0)
+            image_scores = torch.cat(image_scores, dim=0)
+            image_labels = torch.cat(image_labels, dim=0)
+
+            # non-maximum suppression
+            keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
+            keep = keep[: self.detections_per_img]
+
+            detections.append(
+                {
+                    "boxes": image_boxes[keep],
+                    "scores": image_scores[keep],
+                    "labels": image_labels[keep],
+                }
+            )
+        return detections
+
+
+class SSDFeatureExtractorVGG(nn.Module):
+    def __init__(self, backbone: nn.Module, highres: bool):
+        super().__init__()
+
+        _, _, maxpool3_pos, maxpool4_pos, _ = (i for i, layer in enumerate(backbone) if isinstance(layer, nn.MaxPool2d))
+
+        # Patch ceil_mode for maxpool3 to get the same WxH output sizes as the paper
+        backbone[maxpool3_pos].ceil_mode = True
+
+        # parameters used for L2 regularization + rescaling
+        self.scale_weight = nn.Parameter(torch.ones(512) * 20)
+
+        # Multiple Feature maps - page 4, Fig 2 of SSD paper
+        self.features = nn.Sequential(*backbone[:maxpool4_pos])  # until conv4_3
+
+        # SSD300 case - page 4, Fig 2 of SSD paper
+        extra = nn.ModuleList(
+            [
+                nn.Sequential(
+                    nn.Conv2d(1024, 256, kernel_size=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2),  # conv8_2
+                    nn.ReLU(inplace=True),
+                ),
+                nn.Sequential(
+                    nn.Conv2d(512, 128, kernel_size=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),  # conv9_2
+                    nn.ReLU(inplace=True),
+                ),
+                nn.Sequential(
+                    nn.Conv2d(256, 128, kernel_size=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(128, 256, kernel_size=3),  # conv10_2
+                    nn.ReLU(inplace=True),
+                ),
+                nn.Sequential(
+                    nn.Conv2d(256, 128, kernel_size=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(128, 256, kernel_size=3),  # conv11_2
+                    nn.ReLU(inplace=True),
+                ),
+            ]
+        )
+        if highres:
+            # Additional layers for the SSD512 case. See page 11, footernote 5.
+            extra.append(
+                nn.Sequential(
+                    nn.Conv2d(256, 128, kernel_size=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(128, 256, kernel_size=4),  # conv12_2
+                    nn.ReLU(inplace=True),
+                )
+            )
+        _xavier_init(extra)
+
+        fc = nn.Sequential(
+            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=False),  # add modified maxpool5
+            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6),  # FC6 with atrous
+            nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1),  # FC7
+            nn.ReLU(inplace=True),
+        )
+        _xavier_init(fc)
+        extra.insert(
+            0,
+            nn.Sequential(
+                *backbone[maxpool4_pos:-1],  # until conv5_3, skip maxpool5
+                fc,
+            ),
+        )
+        self.extra = extra
+
+    def forward(self, x: Tensor) -> Dict[str, Tensor]:
+        # L2 regularization + Rescaling of 1st block's feature map
+        x = self.features(x)
+        rescaled = self.scale_weight.view(1, -1, 1, 1) * F.normalize(x)
+        output = [rescaled]
+
+        # Calculating Feature maps for the rest blocks
+        for block in self.extra:
+            x = block(x)
+            output.append(x)
+
+        return OrderedDict([(str(i), v) for i, v in enumerate(output)])
+
+
+def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int):
+    backbone = backbone.features
+    # Gather the indices of maxpools. These are the locations of output blocks.
+    stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1]
+    num_stages = len(stage_indices)
+
+    # find the index of the layer from which we won't freeze
+    torch._assert(
+        0 <= trainable_layers <= num_stages,
+        f"trainable_layers should be in the range [0, {num_stages}]. Instead got {trainable_layers}",
+    )
+    freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
+
+    for b in backbone[:freeze_before]:
+        for parameter in b.parameters():
+            parameter.requires_grad_(False)
+
+    return SSDFeatureExtractorVGG(backbone, highres)
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", SSD300_VGG16_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES),
+)
+def ssd300_vgg16(
+    *,
+    weights: Optional[SSD300_VGG16_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[VGG16_Weights] = VGG16_Weights.IMAGENET1K_FEATURES,
+    trainable_backbone_layers: Optional[int] = None,
+    **kwargs: Any,
+) -> SSD:
+    """The SSD300 model is based on the `SSD: Single Shot MultiBox Detector
+    <https://arxiv.org/abs/1512.02325>`_ paper.
+
+    .. betastatus:: detection module
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes, but they will be resized
+    to a fixed size before passing it to the backbone.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+    The model returns a Dict[Tensor] during training, containing the classification and regression
+    losses.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows, where ``N`` is the number of detections:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each detection
+        - scores (Tensor[N]): the scores for each detection
+
+    Example:
+
+        >>> model = torchvision.models.detection.ssd300_vgg16(weights=SSD300_VGG16_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 300), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.SSD300_VGG16_Weights`, optional): The pretrained
+                weights to use. See
+                :class:`~torchvision.models.detection.SSD300_VGG16_Weights`
+                below for more details, and possible values. By default, no
+                pre-trained weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr
+            Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.VGG16_Weights`, optional): The pretrained weights for the
+            backbone
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 4.
+        **kwargs: parameters passed to the ``torchvision.models.detection.SSD``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/ssd.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.SSD300_VGG16_Weights
+        :members:
+    """
+    weights = SSD300_VGG16_Weights.verify(weights)
+    weights_backbone = VGG16_Weights.verify(weights_backbone)
+
+    if "size" in kwargs:
+        warnings.warn("The size of the model is already fixed; ignoring the parameter.")
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    trainable_backbone_layers = _validate_trainable_layers(
+        weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4
+    )
+
+    # Use custom backbones more appropriate for SSD
+    backbone = vgg16(weights=weights_backbone, progress=progress)
+    backbone = _vgg_extractor(backbone, False, trainable_backbone_layers)
+    anchor_generator = DefaultBoxGenerator(
+        [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
+        scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
+        steps=[8, 16, 32, 64, 100, 300],
+    )
+
+    defaults = {
+        # Rescale the input in a way compatible to the backbone
+        "image_mean": [0.48235, 0.45882, 0.40784],
+        "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0],  # undo the 0-1 scaling of toTensor
+    }
+    kwargs: Any = {**defaults, **kwargs}
+    model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model

+ 331 - 0
libs/vision_libs/models/detection/ssdlite.py

@@ -0,0 +1,331 @@
+import warnings
+from collections import OrderedDict
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from torch import nn, Tensor
+
+from ...ops.misc import Conv2dNormActivation
+from ...transforms._presets import ObjectDetection
+from ...utils import _log_api_usage_once
+from .. import mobilenet
+from .._api import register_model, Weights, WeightsEnum
+from .._meta import _COCO_CATEGORIES
+from .._utils import _ovewrite_value_param, handle_legacy_interface
+from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
+from . import _utils as det_utils
+from .anchor_utils import DefaultBoxGenerator
+from .backbone_utils import _validate_trainable_layers
+from .ssd import SSD, SSDScoringHead
+
+
+__all__ = [
+    "SSDLite320_MobileNet_V3_Large_Weights",
+    "ssdlite320_mobilenet_v3_large",
+]
+
+
+# Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper
+def _prediction_block(
+    in_channels: int, out_channels: int, kernel_size: int, norm_layer: Callable[..., nn.Module]
+) -> nn.Sequential:
+    return nn.Sequential(
+        # 3x3 depthwise with stride 1 and padding 1
+        Conv2dNormActivation(
+            in_channels,
+            in_channels,
+            kernel_size=kernel_size,
+            groups=in_channels,
+            norm_layer=norm_layer,
+            activation_layer=nn.ReLU6,
+        ),
+        # 1x1 projetion to output channels
+        nn.Conv2d(in_channels, out_channels, 1),
+    )
+
+
+def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
+    activation = nn.ReLU6
+    intermediate_channels = out_channels // 2
+    return nn.Sequential(
+        # 1x1 projection to half output channels
+        Conv2dNormActivation(
+            in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
+        ),
+        # 3x3 depthwise with stride 2 and padding 1
+        Conv2dNormActivation(
+            intermediate_channels,
+            intermediate_channels,
+            kernel_size=3,
+            stride=2,
+            groups=intermediate_channels,
+            norm_layer=norm_layer,
+            activation_layer=activation,
+        ),
+        # 1x1 projetion to output channels
+        Conv2dNormActivation(
+            intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
+        ),
+    )
+
+
+def _normal_init(conv: nn.Module):
+    for layer in conv.modules():
+        if isinstance(layer, nn.Conv2d):
+            torch.nn.init.normal_(layer.weight, mean=0.0, std=0.03)
+            if layer.bias is not None:
+                torch.nn.init.constant_(layer.bias, 0.0)
+
+
+class SSDLiteHead(nn.Module):
+    def __init__(
+        self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module]
+    ):
+        super().__init__()
+        self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)
+        self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer)
+
+    def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
+        return {
+            "bbox_regression": self.regression_head(x),
+            "cls_logits": self.classification_head(x),
+        }
+
+
+class SSDLiteClassificationHead(SSDScoringHead):
+    def __init__(
+        self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module]
+    ):
+        cls_logits = nn.ModuleList()
+        for channels, anchors in zip(in_channels, num_anchors):
+            cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer))
+        _normal_init(cls_logits)
+        super().__init__(cls_logits, num_classes)
+
+
+class SSDLiteRegressionHead(SSDScoringHead):
+    def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: Callable[..., nn.Module]):
+        bbox_reg = nn.ModuleList()
+        for channels, anchors in zip(in_channels, num_anchors):
+            bbox_reg.append(_prediction_block(channels, 4 * anchors, 3, norm_layer))
+        _normal_init(bbox_reg)
+        super().__init__(bbox_reg, 4)
+
+
+class SSDLiteFeatureExtractorMobileNet(nn.Module):
+    def __init__(
+        self,
+        backbone: nn.Module,
+        c4_pos: int,
+        norm_layer: Callable[..., nn.Module],
+        width_mult: float = 1.0,
+        min_depth: int = 16,
+    ):
+        super().__init__()
+        _log_api_usage_once(self)
+
+        if backbone[c4_pos].use_res_connect:
+            raise ValueError("backbone[c4_pos].use_res_connect should be False")
+
+        self.features = nn.Sequential(
+            # As described in section 6.3 of MobileNetV3 paper
+            nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]),  # from start until C4 expansion layer
+            nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1 :]),  # from C4 depthwise until end
+        )
+
+        get_depth = lambda d: max(min_depth, int(d * width_mult))  # noqa: E731
+        extra = nn.ModuleList(
+            [
+                _extra_block(backbone[-1].out_channels, get_depth(512), norm_layer),
+                _extra_block(get_depth(512), get_depth(256), norm_layer),
+                _extra_block(get_depth(256), get_depth(256), norm_layer),
+                _extra_block(get_depth(256), get_depth(128), norm_layer),
+            ]
+        )
+        _normal_init(extra)
+
+        self.extra = extra
+
+    def forward(self, x: Tensor) -> Dict[str, Tensor]:
+        # Get feature maps from backbone and extra. Can't be refactored due to JIT limitations.
+        output = []
+        for block in self.features:
+            x = block(x)
+            output.append(x)
+
+        for block in self.extra:
+            x = block(x)
+            output.append(x)
+
+        return OrderedDict([(str(i), v) for i, v in enumerate(output)])
+
+
+def _mobilenet_extractor(
+    backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
+    trainable_layers: int,
+    norm_layer: Callable[..., nn.Module],
+):
+    backbone = backbone.features
+    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
+    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
+    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
+    num_stages = len(stage_indices)
+
+    # find the index of the layer from which we won't freeze
+    if not 0 <= trainable_layers <= num_stages:
+        raise ValueError("trainable_layers should be in the range [0, {num_stages}], instead got {trainable_layers}")
+    freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
+
+    for b in backbone[:freeze_before]:
+        for parameter in b.parameters():
+            parameter.requires_grad_(False)
+
+    return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer)
+
+
+class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
+        transforms=ObjectDetection,
+        meta={
+            "num_params": 3440060,
+            "categories": _COCO_CATEGORIES,
+            "min_size": (1, 1),
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 21.3,
+                }
+            },
+            "_ops": 0.583,
+            "_file_size": 13.418,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1),
+    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
+)
+def ssdlite320_mobilenet_v3_large(
+    *,
+    weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None,
+    progress: bool = True,
+    num_classes: Optional[int] = None,
+    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
+    trainable_backbone_layers: Optional[int] = None,
+    norm_layer: Optional[Callable[..., nn.Module]] = None,
+    **kwargs: Any,
+) -> SSD:
+    """SSDlite model architecture with input size 320x320 and a MobileNetV3 Large backbone, as
+    described at `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`__ and
+    `MobileNetV2: Inverted Residuals and Linear Bottlenecks <https://arxiv.org/abs/1801.04381>`__.
+
+    .. betastatus:: detection module
+
+    See :func:`~torchvision.models.detection.ssd300_vgg16` for more details.
+
+    Example:
+
+        >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(weights=SSDLite320_MobileNet_V3_Large_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model
+            (including the background).
+        weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained
+            weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers
+            starting from final block. Valid values are between 0 and 6, with 6 meaning all
+            backbone layers are trainable. If ``None`` is passed (the default) this value is
+            set to 6.
+        norm_layer (callable, optional): Module specifying the normalization layer to use.
+        **kwargs: parameters passed to the ``torchvision.models.detection.ssd.SSD``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/ssd.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights
+        :members:
+    """
+
+    weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights)
+    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
+
+    if "size" in kwargs:
+        warnings.warn("The size of the model is already fixed; ignoring the parameter.")
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    trainable_backbone_layers = _validate_trainable_layers(
+        weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6
+    )
+
+    # Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper.
+    reduce_tail = weights_backbone is None
+
+    if norm_layer is None:
+        norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
+
+    backbone = mobilenet_v3_large(
+        weights=weights_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs
+    )
+    if weights_backbone is None:
+        # Change the default initialization scheme if not pretrained
+        _normal_init(backbone)
+    backbone = _mobilenet_extractor(
+        backbone,
+        trainable_backbone_layers,
+        norm_layer,
+    )
+
+    size = (320, 320)
+    anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
+    out_channels = det_utils.retrieve_out_channels(backbone, size)
+    num_anchors = anchor_generator.num_anchors_per_location()
+    if len(out_channels) != len(anchor_generator.aspect_ratios):
+        raise ValueError(
+            f"The length of the output channels from the backbone {len(out_channels)} do not match the length of the anchor generator aspect ratios {len(anchor_generator.aspect_ratios)}"
+        )
+
+    defaults = {
+        "score_thresh": 0.001,
+        "nms_thresh": 0.55,
+        "detections_per_img": 300,
+        "topk_candidates": 300,
+        # Rescale the input in a way compatible to the backbone:
+        # The following mean/std rescale the data from [0, 1] to [-1, 1]
+        "image_mean": [0.5, 0.5, 0.5],
+        "image_std": [0.5, 0.5, 0.5],
+    }
+    kwargs: Any = {**defaults, **kwargs}
+    model = SSD(
+        backbone,
+        anchor_generator,
+        size,
+        num_classes,
+        head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer),
+        **kwargs,
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model

+ 319 - 0
libs/vision_libs/models/detection/transform.py

@@ -0,0 +1,319 @@
+import math
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torchvision
+from torch import nn, Tensor
+
+from .image_list import ImageList
+from .roi_heads import paste_masks_in_image
+
+
+@torch.jit.unused
+def _get_shape_onnx(image: Tensor) -> Tensor:
+    from torch.onnx import operators
+
+    return operators.shape_as_tensor(image)[-2:]
+
+
+@torch.jit.unused
+def _fake_cast_onnx(v: Tensor) -> float:
+    # ONNX requires a tensor but here we fake its type for JIT.
+    return v
+
+
+def _resize_image_and_masks(
+    image: Tensor,
+    self_min_size: int,
+    self_max_size: int,
+    target: Optional[Dict[str, Tensor]] = None,
+    fixed_size: Optional[Tuple[int, int]] = None,
+) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+    if torchvision._is_tracing():
+        im_shape = _get_shape_onnx(image)
+    elif torch.jit.is_scripting():
+        im_shape = torch.tensor(image.shape[-2:])
+    else:
+        im_shape = image.shape[-2:]
+
+    size: Optional[List[int]] = None
+    scale_factor: Optional[float] = None
+    recompute_scale_factor: Optional[bool] = None
+    if fixed_size is not None:
+        size = [fixed_size[1], fixed_size[0]]
+    else:
+        if torch.jit.is_scripting() or torchvision._is_tracing():
+            min_size = torch.min(im_shape).to(dtype=torch.float32)
+            max_size = torch.max(im_shape).to(dtype=torch.float32)
+            self_min_size_f = float(self_min_size)
+            self_max_size_f = float(self_max_size)
+            scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
+
+            if torchvision._is_tracing():
+                scale_factor = _fake_cast_onnx(scale)
+            else:
+                scale_factor = scale.item()
+
+        else:
+            # Do it the normal way
+            min_size = min(im_shape)
+            max_size = max(im_shape)
+            scale_factor = min(self_min_size / min_size, self_max_size / max_size)
+
+        recompute_scale_factor = True
+
+    image = torch.nn.functional.interpolate(
+        image[None],
+        size=size,
+        scale_factor=scale_factor,
+        mode="bilinear",
+        recompute_scale_factor=recompute_scale_factor,
+        align_corners=False,
+    )[0]
+
+    if target is None:
+        return image, target
+
+    if "masks" in target:
+        mask = target["masks"]
+        mask = torch.nn.functional.interpolate(
+            mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
+        )[:, 0].byte()
+        target["masks"] = mask
+    return image, target
+
+
+class GeneralizedRCNNTransform(nn.Module):
+    """
+    Performs input / target transformation before feeding the data to a GeneralizedRCNN
+    model.
+
+    The transformations it performs are:
+        - input normalization (mean subtraction and std division)
+        - input / target resizing to match min_size / max_size
+
+    It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
+    """
+
+    def __init__(
+        self,
+        min_size: int,
+        max_size: int,
+        image_mean: List[float],
+        image_std: List[float],
+        size_divisible: int = 32,
+        fixed_size: Optional[Tuple[int, int]] = None,
+        **kwargs: Any,
+    ):
+        super().__init__()
+        if not isinstance(min_size, (list, tuple)):
+            min_size = (min_size,)
+        self.min_size = min_size
+        self.max_size = max_size
+        self.image_mean = image_mean
+        self.image_std = image_std
+        self.size_divisible = size_divisible
+        self.fixed_size = fixed_size
+        self._skip_resize = kwargs.pop("_skip_resize", False)
+
+    def forward(
+        self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
+    ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
+        images = [img for img in images]
+        if targets is not None:
+            # make a copy of targets to avoid modifying it in-place
+            # once torchscript supports dict comprehension
+            # this can be simplified as follows
+            # targets = [{k: v for k,v in t.items()} for t in targets]
+            targets_copy: List[Dict[str, Tensor]] = []
+            for t in targets:
+                data: Dict[str, Tensor] = {}
+                for k, v in t.items():
+                    data[k] = v
+                targets_copy.append(data)
+            targets = targets_copy
+        for i in range(len(images)):
+            image = images[i]
+            target_index = targets[i] if targets is not None else None
+
+            if image.dim() != 3:
+                raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
+            image = self.normalize(image)
+            image, target_index = self.resize(image, target_index)
+            images[i] = image
+            if targets is not None and target_index is not None:
+                targets[i] = target_index
+
+        image_sizes = [img.shape[-2:] for img in images]
+        images = self.batch_images(images, size_divisible=self.size_divisible)
+        image_sizes_list: List[Tuple[int, int]] = []
+        for image_size in image_sizes:
+            torch._assert(
+                len(image_size) == 2,
+                f"Input tensors expected to have in the last two elements H and W, instead got {image_size}",
+            )
+            image_sizes_list.append((image_size[0], image_size[1]))
+
+        image_list = ImageList(images, image_sizes_list)
+        return image_list, targets
+
+    def normalize(self, image: Tensor) -> Tensor:
+        if not image.is_floating_point():
+            raise TypeError(
+                f"Expected input images to be of floating type (in range [0, 1]), "
+                f"but found type {image.dtype} instead"
+            )
+        dtype, device = image.dtype, image.device
+        mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
+        std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
+        return (image - mean[:, None, None]) / std[:, None, None]
+
+    def torch_choice(self, k: List[int]) -> int:
+        """
+        Implements `random.choice` via torch ops, so it can be compiled with
+        TorchScript and we use PyTorch's RNG (not native RNG)
+        """
+        index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
+        return k[index]
+
+    def resize(
+        self,
+        image: Tensor,
+        target: Optional[Dict[str, Tensor]] = None,
+    ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+        h, w = image.shape[-2:]
+        if self.training:
+            if self._skip_resize:
+                return image, target
+            size = self.torch_choice(self.min_size)
+        else:
+            size = self.min_size[-1]
+        image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
+
+        if target is None:
+            return image, target
+
+        bbox = target["boxes"]
+        bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
+        target["boxes"] = bbox
+
+        if "keypoints" in target:
+            keypoints = target["keypoints"]
+            keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
+            target["keypoints"] = keypoints
+        return image, target
+
+    # _onnx_batch_images() is an implementation of
+    # batch_images() that is supported by ONNX tracing.
+    @torch.jit.unused
+    def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
+        max_size = []
+        for i in range(images[0].dim()):
+            max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
+            max_size.append(max_size_i)
+        stride = size_divisible
+        max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
+        max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
+        max_size = tuple(max_size)
+
+        # work around for
+        # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+        # which is not yet supported in onnx
+        padded_imgs = []
+        for img in images:
+            padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+            padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+            padded_imgs.append(padded_img)
+
+        return torch.stack(padded_imgs)
+
+    def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
+        maxes = the_list[0]
+        for sublist in the_list[1:]:
+            for index, item in enumerate(sublist):
+                maxes[index] = max(maxes[index], item)
+        return maxes
+
+    def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
+        if torchvision._is_tracing():
+            # batch_images() does not export well to ONNX
+            # call _onnx_batch_images() instead
+            return self._onnx_batch_images(images, size_divisible)
+
+        max_size = self.max_by_axis([list(img.shape) for img in images])
+        stride = float(size_divisible)
+        max_size = list(max_size)
+        max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
+        max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
+
+        batch_shape = [len(images)] + max_size
+        batched_imgs = images[0].new_full(batch_shape, 0)
+        for i in range(batched_imgs.shape[0]):
+            img = images[i]
+            batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+
+        return batched_imgs
+
+    def postprocess(
+        self,
+        result: List[Dict[str, Tensor]],
+        image_shapes: List[Tuple[int, int]],
+        original_image_sizes: List[Tuple[int, int]],
+    ) -> List[Dict[str, Tensor]]:
+        if self.training:
+            return result
+        for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
+            boxes = pred["boxes"]
+            boxes = resize_boxes(boxes, im_s, o_im_s)
+            result[i]["boxes"] = boxes
+            if "masks" in pred:
+                masks = pred["masks"]
+                masks = paste_masks_in_image(masks, boxes, o_im_s)
+                result[i]["masks"] = masks
+            if "keypoints" in pred:
+                keypoints = pred["keypoints"]
+                keypoints = resize_keypoints(keypoints, im_s, o_im_s)
+                result[i]["keypoints"] = keypoints
+        return result
+
+    def __repr__(self) -> str:
+        format_string = f"{self.__class__.__name__}("
+        _indent = "\n    "
+        format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})"
+        format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')"
+        format_string += "\n)"
+        return format_string
+
+
+def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
+    ratios = [
+        torch.tensor(s, dtype=torch.float32, device=keypoints.device)
+        / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
+        for s, s_orig in zip(new_size, original_size)
+    ]
+    ratio_h, ratio_w = ratios
+    resized_data = keypoints.clone()
+    if torch._C._get_tracing_state():
+        resized_data_0 = resized_data[:, :, 0] * ratio_w
+        resized_data_1 = resized_data[:, :, 1] * ratio_h
+        resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2)
+    else:
+        resized_data[..., 0] *= ratio_w
+        resized_data[..., 1] *= ratio_h
+    return resized_data
+
+
+def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
+    ratios = [
+        torch.tensor(s, dtype=torch.float32, device=boxes.device)
+        / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
+        for s, s_orig in zip(new_size, original_size)
+    ]
+    ratio_height, ratio_width = ratios
+    xmin, ymin, xmax, ymax = boxes.unbind(1)
+
+    xmin = xmin * ratio_width
+    xmax = xmax * ratio_width
+    ymin = ymin * ratio_height
+    ymax = ymax * ratio_height
+    return torch.stack((xmin, ymin, xmax, ymax), dim=1)

+ 1131 - 0
libs/vision_libs/models/efficientnet.py

@@ -0,0 +1,1131 @@
+import copy
+import math
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
+
+import torch
+from torch import nn, Tensor
+from torchvision.ops import StochasticDepth
+
+from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
+from ..transforms._presets import ImageClassification, InterpolationMode
+from ..utils import _log_api_usage_once
+from ._api import register_model, Weights, WeightsEnum
+from ._meta import _IMAGENET_CATEGORIES
+from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
+
+
+__all__ = [
+    "EfficientNet",
+    "EfficientNet_B0_Weights",
+    "EfficientNet_B1_Weights",
+    "EfficientNet_B2_Weights",
+    "EfficientNet_B3_Weights",
+    "EfficientNet_B4_Weights",
+    "EfficientNet_B5_Weights",
+    "EfficientNet_B6_Weights",
+    "EfficientNet_B7_Weights",
+    "EfficientNet_V2_S_Weights",
+    "EfficientNet_V2_M_Weights",
+    "EfficientNet_V2_L_Weights",
+    "efficientnet_b0",
+    "efficientnet_b1",
+    "efficientnet_b2",
+    "efficientnet_b3",
+    "efficientnet_b4",
+    "efficientnet_b5",
+    "efficientnet_b6",
+    "efficientnet_b7",
+    "efficientnet_v2_s",
+    "efficientnet_v2_m",
+    "efficientnet_v2_l",
+]
+
+
+@dataclass
+class _MBConvConfig:
+    expand_ratio: float
+    kernel: int
+    stride: int
+    input_channels: int
+    out_channels: int
+    num_layers: int
+    block: Callable[..., nn.Module]
+
+    @staticmethod
+    def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
+        return _make_divisible(channels * width_mult, 8, min_value)
+
+
+class MBConvConfig(_MBConvConfig):
+    # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
+    def __init__(
+        self,
+        expand_ratio: float,
+        kernel: int,
+        stride: int,
+        input_channels: int,
+        out_channels: int,
+        num_layers: int,
+        width_mult: float = 1.0,
+        depth_mult: float = 1.0,
+        block: Optional[Callable[..., nn.Module]] = None,
+    ) -> None:
+        input_channels = self.adjust_channels(input_channels, width_mult)
+        out_channels = self.adjust_channels(out_channels, width_mult)
+        num_layers = self.adjust_depth(num_layers, depth_mult)
+        if block is None:
+            block = MBConv
+        super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
+
+    @staticmethod
+    def adjust_depth(num_layers: int, depth_mult: float):
+        return int(math.ceil(num_layers * depth_mult))
+
+
+class FusedMBConvConfig(_MBConvConfig):
+    # Stores information listed at Table 4 of the EfficientNetV2 paper
+    def __init__(
+        self,
+        expand_ratio: float,
+        kernel: int,
+        stride: int,
+        input_channels: int,
+        out_channels: int,
+        num_layers: int,
+        block: Optional[Callable[..., nn.Module]] = None,
+    ) -> None:
+        if block is None:
+            block = FusedMBConv
+        super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
+
+
+class MBConv(nn.Module):
+    def __init__(
+        self,
+        cnf: MBConvConfig,
+        stochastic_depth_prob: float,
+        norm_layer: Callable[..., nn.Module],
+        se_layer: Callable[..., nn.Module] = SqueezeExcitation,
+    ) -> None:
+        super().__init__()
+
+        if not (1 <= cnf.stride <= 2):
+            raise ValueError("illegal stride value")
+
+        self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
+
+        layers: List[nn.Module] = []
+        activation_layer = nn.SiLU
+
+        # expand
+        expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
+        if expanded_channels != cnf.input_channels:
+            layers.append(
+                Conv2dNormActivation(
+                    cnf.input_channels,
+                    expanded_channels,
+                    kernel_size=1,
+                    norm_layer=norm_layer,
+                    activation_layer=activation_layer,
+                )
+            )
+
+        # depthwise
+        layers.append(
+            Conv2dNormActivation(
+                expanded_channels,
+                expanded_channels,
+                kernel_size=cnf.kernel,
+                stride=cnf.stride,
+                groups=expanded_channels,
+                norm_layer=norm_layer,
+                activation_layer=activation_layer,
+            )
+        )
+
+        # squeeze and excitation
+        squeeze_channels = max(1, cnf.input_channels // 4)
+        layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
+
+        # project
+        layers.append(
+            Conv2dNormActivation(
+                expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
+            )
+        )
+
+        self.block = nn.Sequential(*layers)
+        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
+        self.out_channels = cnf.out_channels
+
+    def forward(self, input: Tensor) -> Tensor:
+        result = self.block(input)
+        if self.use_res_connect:
+            result = self.stochastic_depth(result)
+            result += input
+        return result
+
+
+class FusedMBConv(nn.Module):
+    def __init__(
+        self,
+        cnf: FusedMBConvConfig,
+        stochastic_depth_prob: float,
+        norm_layer: Callable[..., nn.Module],
+    ) -> None:
+        super().__init__()
+
+        if not (1 <= cnf.stride <= 2):
+            raise ValueError("illegal stride value")
+
+        self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
+
+        layers: List[nn.Module] = []
+        activation_layer = nn.SiLU
+
+        expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
+        if expanded_channels != cnf.input_channels:
+            # fused expand
+            layers.append(
+                Conv2dNormActivation(
+                    cnf.input_channels,
+                    expanded_channels,
+                    kernel_size=cnf.kernel,
+                    stride=cnf.stride,
+                    norm_layer=norm_layer,
+                    activation_layer=activation_layer,
+                )
+            )
+
+            # project
+            layers.append(
+                Conv2dNormActivation(
+                    expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
+                )
+            )
+        else:
+            layers.append(
+                Conv2dNormActivation(
+                    cnf.input_channels,
+                    cnf.out_channels,
+                    kernel_size=cnf.kernel,
+                    stride=cnf.stride,
+                    norm_layer=norm_layer,
+                    activation_layer=activation_layer,
+                )
+            )
+
+        self.block = nn.Sequential(*layers)
+        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
+        self.out_channels = cnf.out_channels
+
+    def forward(self, input: Tensor) -> Tensor:
+        result = self.block(input)
+        if self.use_res_connect:
+            result = self.stochastic_depth(result)
+            result += input
+        return result
+
+
+class EfficientNet(nn.Module):
+    def __init__(
+        self,
+        inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
+        dropout: float,
+        stochastic_depth_prob: float = 0.2,
+        num_classes: int = 1000,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+        last_channel: Optional[int] = None,
+    ) -> None:
+        """
+        EfficientNet V1 and V2 main class
+
+        Args:
+            inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
+            dropout (float): The droupout probability
+            stochastic_depth_prob (float): The stochastic depth probability
+            num_classes (int): Number of classes
+            norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
+            last_channel (int): The number of channels on the penultimate layer
+        """
+        super().__init__()
+        _log_api_usage_once(self)
+
+        if not inverted_residual_setting:
+            raise ValueError("The inverted_residual_setting should not be empty")
+        elif not (
+            isinstance(inverted_residual_setting, Sequence)
+            and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
+        ):
+            raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")
+
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+
+        layers: List[nn.Module] = []
+
+        # building first layer
+        firstconv_output_channels = inverted_residual_setting[0].input_channels
+        layers.append(
+            Conv2dNormActivation(
+                3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
+            )
+        )
+
+        # building inverted residual blocks
+        total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
+        stage_block_id = 0
+        for cnf in inverted_residual_setting:
+            stage: List[nn.Module] = []
+            for _ in range(cnf.num_layers):
+                # copy to avoid modifications. shallow copy is enough
+                block_cnf = copy.copy(cnf)
+
+                # overwrite info if not the first conv in the stage
+                if stage:
+                    block_cnf.input_channels = block_cnf.out_channels
+                    block_cnf.stride = 1
+
+                # adjust stochastic depth probability based on the depth of the stage block
+                sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks
+
+                stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
+                stage_block_id += 1
+
+            layers.append(nn.Sequential(*stage))
+
+        # building last several layers
+        lastconv_input_channels = inverted_residual_setting[-1].out_channels
+        lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
+        layers.append(
+            Conv2dNormActivation(
+                lastconv_input_channels,
+                lastconv_output_channels,
+                kernel_size=1,
+                norm_layer=norm_layer,
+                activation_layer=nn.SiLU,
+            )
+        )
+
+        self.features = nn.Sequential(*layers)
+        self.avgpool = nn.AdaptiveAvgPool2d(1)
+        self.classifier = nn.Sequential(
+            nn.Dropout(p=dropout, inplace=True),
+            nn.Linear(lastconv_output_channels, num_classes),
+        )
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out")
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.ones_(m.weight)
+                nn.init.zeros_(m.bias)
+            elif isinstance(m, nn.Linear):
+                init_range = 1.0 / math.sqrt(m.out_features)
+                nn.init.uniform_(m.weight, -init_range, init_range)
+                nn.init.zeros_(m.bias)
+
+    def _forward_impl(self, x: Tensor) -> Tensor:
+        x = self.features(x)
+
+        x = self.avgpool(x)
+        x = torch.flatten(x, 1)
+
+        x = self.classifier(x)
+
+        return x
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self._forward_impl(x)
+
+
+def _efficientnet(
+    inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
+    dropout: float,
+    last_channel: Optional[int],
+    weights: Optional[WeightsEnum],
+    progress: bool,
+    **kwargs: Any,
+) -> EfficientNet:
+    if weights is not None:
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+    model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+def _efficientnet_conf(
+    arch: str,
+    **kwargs: Any,
+) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
+    inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
+    if arch.startswith("efficientnet_b"):
+        bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
+        inverted_residual_setting = [
+            bneck_conf(1, 3, 1, 32, 16, 1),
+            bneck_conf(6, 3, 2, 16, 24, 2),
+            bneck_conf(6, 5, 2, 24, 40, 2),
+            bneck_conf(6, 3, 2, 40, 80, 3),
+            bneck_conf(6, 5, 1, 80, 112, 3),
+            bneck_conf(6, 5, 2, 112, 192, 4),
+            bneck_conf(6, 3, 1, 192, 320, 1),
+        ]
+        last_channel = None
+    elif arch.startswith("efficientnet_v2_s"):
+        inverted_residual_setting = [
+            FusedMBConvConfig(1, 3, 1, 24, 24, 2),
+            FusedMBConvConfig(4, 3, 2, 24, 48, 4),
+            FusedMBConvConfig(4, 3, 2, 48, 64, 4),
+            MBConvConfig(4, 3, 2, 64, 128, 6),
+            MBConvConfig(6, 3, 1, 128, 160, 9),
+            MBConvConfig(6, 3, 2, 160, 256, 15),
+        ]
+        last_channel = 1280
+    elif arch.startswith("efficientnet_v2_m"):
+        inverted_residual_setting = [
+            FusedMBConvConfig(1, 3, 1, 24, 24, 3),
+            FusedMBConvConfig(4, 3, 2, 24, 48, 5),
+            FusedMBConvConfig(4, 3, 2, 48, 80, 5),
+            MBConvConfig(4, 3, 2, 80, 160, 7),
+            MBConvConfig(6, 3, 1, 160, 176, 14),
+            MBConvConfig(6, 3, 2, 176, 304, 18),
+            MBConvConfig(6, 3, 1, 304, 512, 5),
+        ]
+        last_channel = 1280
+    elif arch.startswith("efficientnet_v2_l"):
+        inverted_residual_setting = [
+            FusedMBConvConfig(1, 3, 1, 32, 32, 4),
+            FusedMBConvConfig(4, 3, 2, 32, 64, 7),
+            FusedMBConvConfig(4, 3, 2, 64, 96, 7),
+            MBConvConfig(4, 3, 2, 96, 192, 10),
+            MBConvConfig(6, 3, 1, 192, 224, 19),
+            MBConvConfig(6, 3, 2, 224, 384, 25),
+            MBConvConfig(6, 3, 1, 384, 640, 7),
+        ]
+        last_channel = 1280
+    else:
+        raise ValueError(f"Unsupported model type {arch}")
+
+    return inverted_residual_setting, last_channel
+
+
+_COMMON_META: Dict[str, Any] = {
+    "categories": _IMAGENET_CATEGORIES,
+}
+
+
+_COMMON_META_V1 = {
+    **_COMMON_META,
+    "min_size": (1, 1),
+    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v1",
+}
+
+
+_COMMON_META_V2 = {
+    **_COMMON_META,
+    "min_size": (33, 33),
+    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v2",
+}
+
+
+class EfficientNet_B0_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        # Weights ported from https://github.com/rwightman/pytorch-image-models/
+        url="https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth",
+        transforms=partial(
+            ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
+        ),
+        meta={
+            **_COMMON_META_V1,
+            "num_params": 5288548,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 77.692,
+                    "acc@5": 93.532,
+                }
+            },
+            "_ops": 0.386,
+            "_file_size": 20.451,
+            "_docs": """These weights are ported from the original paper.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_B1_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        # Weights ported from https://github.com/rwightman/pytorch-image-models/
+        url="https://download.pytorch.org/models/efficientnet_b1_rwightman-bac287d4.pth",
+        transforms=partial(
+            ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
+        ),
+        meta={
+            **_COMMON_META_V1,
+            "num_params": 7794184,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 78.642,
+                    "acc@5": 94.186,
+                }
+            },
+            "_ops": 0.687,
+            "_file_size": 30.134,
+            "_docs": """These weights are ported from the original paper.""",
+        },
+    )
+    IMAGENET1K_V2 = Weights(
+        url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
+        transforms=partial(
+            ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR
+        ),
+        meta={
+            **_COMMON_META_V1,
+            "num_params": 7794184,
+            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 79.838,
+                    "acc@5": 94.934,
+                }
+            },
+            "_ops": 0.687,
+            "_file_size": 30.136,
+            "_docs": """
+                These weights improve upon the results of the original paper by using a modified version of TorchVision's
+                `new training recipe
+                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
+            """,
+        },
+    )
+    DEFAULT = IMAGENET1K_V2
+
+
+class EfficientNet_B2_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        # Weights ported from https://github.com/rwightman/pytorch-image-models/
+        url="https://download.pytorch.org/models/efficientnet_b2_rwightman-c35c1473.pth",
+        transforms=partial(
+            ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
+        ),
+        meta={
+            **_COMMON_META_V1,
+            "num_params": 9109994,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 80.608,
+                    "acc@5": 95.310,
+                }
+            },
+            "_ops": 1.088,
+            "_file_size": 35.174,
+            "_docs": """These weights are ported from the original paper.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_B3_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        # Weights ported from https://github.com/rwightman/pytorch-image-models/
+        url="https://download.pytorch.org/models/efficientnet_b3_rwightman-b3899882.pth",
+        transforms=partial(
+            ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
+        ),
+        meta={
+            **_COMMON_META_V1,
+            "num_params": 12233232,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 82.008,
+                    "acc@5": 96.054,
+                }
+            },
+            "_ops": 1.827,
+            "_file_size": 47.184,
+            "_docs": """These weights are ported from the original paper.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_B4_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        # Weights ported from https://github.com/rwightman/pytorch-image-models/
+        url="https://download.pytorch.org/models/efficientnet_b4_rwightman-23ab8bcd.pth",
+        transforms=partial(
+            ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
+        ),
+        meta={
+            **_COMMON_META_V1,
+            "num_params": 19341616,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 83.384,
+                    "acc@5": 96.594,
+                }
+            },
+            "_ops": 4.394,
+            "_file_size": 74.489,
+            "_docs": """These weights are ported from the original paper.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_B5_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
+        url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-1a07897c.pth",
+        transforms=partial(
+            ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
+        ),
+        meta={
+            **_COMMON_META_V1,
+            "num_params": 30389784,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 83.444,
+                    "acc@5": 96.628,
+                }
+            },
+            "_ops": 10.266,
+            "_file_size": 116.864,
+            "_docs": """These weights are ported from the original paper.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_B6_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
+        url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-24a108a5.pth",
+        transforms=partial(
+            ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
+        ),
+        meta={
+            **_COMMON_META_V1,
+            "num_params": 43040704,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 84.008,
+                    "acc@5": 96.916,
+                }
+            },
+            "_ops": 19.068,
+            "_file_size": 165.362,
+            "_docs": """These weights are ported from the original paper.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_B7_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
+        url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-c5b4e57e.pth",
+        transforms=partial(
+            ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
+        ),
+        meta={
+            **_COMMON_META_V1,
+            "num_params": 66347960,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 84.122,
+                    "acc@5": 96.908,
+                }
+            },
+            "_ops": 37.746,
+            "_file_size": 254.675,
+            "_docs": """These weights are ported from the original paper.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_V2_S_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
+        transforms=partial(
+            ImageClassification,
+            crop_size=384,
+            resize_size=384,
+            interpolation=InterpolationMode.BILINEAR,
+        ),
+        meta={
+            **_COMMON_META_V2,
+            "num_params": 21458488,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 84.228,
+                    "acc@5": 96.878,
+                }
+            },
+            "_ops": 8.366,
+            "_file_size": 82.704,
+            "_docs": """
+                These weights improve upon the results of the original paper by using a modified version of TorchVision's
+                `new training recipe
+                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
+            """,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_V2_M_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
+        transforms=partial(
+            ImageClassification,
+            crop_size=480,
+            resize_size=480,
+            interpolation=InterpolationMode.BILINEAR,
+        ),
+        meta={
+            **_COMMON_META_V2,
+            "num_params": 54139356,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 85.112,
+                    "acc@5": 97.156,
+                }
+            },
+            "_ops": 24.582,
+            "_file_size": 208.01,
+            "_docs": """
+                These weights improve upon the results of the original paper by using a modified version of TorchVision's
+                `new training recipe
+                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
+            """,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class EfficientNet_V2_L_Weights(WeightsEnum):
+    # Weights ported from https://github.com/google/automl/tree/master/efficientnetv2
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
+        transforms=partial(
+            ImageClassification,
+            crop_size=480,
+            resize_size=480,
+            interpolation=InterpolationMode.BICUBIC,
+            mean=(0.5, 0.5, 0.5),
+            std=(0.5, 0.5, 0.5),
+        ),
+        meta={
+            **_COMMON_META_V2,
+            "num_params": 118515272,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 85.808,
+                    "acc@5": 97.788,
+                }
+            },
+            "_ops": 56.08,
+            "_file_size": 454.573,
+            "_docs": """These weights are ported from the original paper.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
+def efficientnet_b0(
+    *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+    """EfficientNet B0 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.EfficientNet_B0_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.EfficientNet_B0_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
+            for more details about this class.
+    .. autoclass:: torchvision.models.EfficientNet_B0_Weights
+        :members:
+    """
+    weights = EfficientNet_B0_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
+    return _efficientnet(
+        inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
+    )
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
+def efficientnet_b1(
+    *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+    """EfficientNet B1 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.EfficientNet_B1_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.EfficientNet_B1_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
+            for more details about this class.
+    .. autoclass:: torchvision.models.EfficientNet_B1_Weights
+        :members:
+    """
+    weights = EfficientNet_B1_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
+    return _efficientnet(
+        inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
+    )
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
+def efficientnet_b2(
+    *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+    """EfficientNet B2 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.EfficientNet_B2_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.EfficientNet_B2_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
+            for more details about this class.
+    .. autoclass:: torchvision.models.EfficientNet_B2_Weights
+        :members:
+    """
+    weights = EfficientNet_B2_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
+    return _efficientnet(
+        inverted_residual_setting, kwargs.pop("dropout", 0.3), last_channel, weights, progress, **kwargs
+    )
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
+def efficientnet_b3(
+    *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+    """EfficientNet B3 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.EfficientNet_B3_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.EfficientNet_B3_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
+            for more details about this class.
+    .. autoclass:: torchvision.models.EfficientNet_B3_Weights
+        :members:
+    """
+    weights = EfficientNet_B3_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
+    return _efficientnet(
+        inverted_residual_setting,
+        kwargs.pop("dropout", 0.3),
+        last_channel,
+        weights,
+        progress,
+        **kwargs,
+    )
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
+def efficientnet_b4(
+    *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+    """EfficientNet B4 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.EfficientNet_B4_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.EfficientNet_B4_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
+            for more details about this class.
+    .. autoclass:: torchvision.models.EfficientNet_B4_Weights
+        :members:
+    """
+    weights = EfficientNet_B4_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
+    return _efficientnet(
+        inverted_residual_setting,
+        kwargs.pop("dropout", 0.4),
+        last_channel,
+        weights,
+        progress,
+        **kwargs,
+    )
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
+def efficientnet_b5(
+    *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+    """EfficientNet B5 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.EfficientNet_B5_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.EfficientNet_B5_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
+            for more details about this class.
+    .. autoclass:: torchvision.models.EfficientNet_B5_Weights
+        :members:
+    """
+    weights = EfficientNet_B5_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
+    return _efficientnet(
+        inverted_residual_setting,
+        kwargs.pop("dropout", 0.4),
+        last_channel,
+        weights,
+        progress,
+        norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
+        **kwargs,
+    )
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
+def efficientnet_b6(
+    *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+    """EfficientNet B6 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.EfficientNet_B6_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.EfficientNet_B6_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
+            for more details about this class.
+    .. autoclass:: torchvision.models.EfficientNet_B6_Weights
+        :members:
+    """
+    weights = EfficientNet_B6_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
+    return _efficientnet(
+        inverted_residual_setting,
+        kwargs.pop("dropout", 0.5),
+        last_channel,
+        weights,
+        progress,
+        norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
+        **kwargs,
+    )
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
+def efficientnet_b7(
+    *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+    """EfficientNet B7 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
+    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.EfficientNet_B7_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.EfficientNet_B7_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
+            for more details about this class.
+    .. autoclass:: torchvision.models.EfficientNet_B7_Weights
+        :members:
+    """
+    weights = EfficientNet_B7_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
+    return _efficientnet(
+        inverted_residual_setting,
+        kwargs.pop("dropout", 0.5),
+        last_channel,
+        weights,
+        progress,
+        norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
+        **kwargs,
+    )
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
+def efficientnet_v2_s(
+    *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+    """
+    Constructs an EfficientNetV2-S architecture from
+    `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.EfficientNet_V2_S_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.EfficientNet_V2_S_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
+            for more details about this class.
+    .. autoclass:: torchvision.models.EfficientNet_V2_S_Weights
+        :members:
+    """
+    weights = EfficientNet_V2_S_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
+    return _efficientnet(
+        inverted_residual_setting,
+        kwargs.pop("dropout", 0.2),
+        last_channel,
+        weights,
+        progress,
+        norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
+        **kwargs,
+    )
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
+def efficientnet_v2_m(
+    *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+    """
+    Constructs an EfficientNetV2-M architecture from
+    `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.EfficientNet_V2_M_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.EfficientNet_V2_M_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
+            for more details about this class.
+    .. autoclass:: torchvision.models.EfficientNet_V2_M_Weights
+        :members:
+    """
+    weights = EfficientNet_V2_M_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
+    return _efficientnet(
+        inverted_residual_setting,
+        kwargs.pop("dropout", 0.3),
+        last_channel,
+        weights,
+        progress,
+        norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
+        **kwargs,
+    )
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
+def efficientnet_v2_l(
+    *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
+) -> EfficientNet:
+    """
+    Constructs an EfficientNetV2-L architecture from
+    `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.EfficientNet_V2_L_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.EfficientNet_V2_L_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
+            for more details about this class.
+    .. autoclass:: torchvision.models.EfficientNet_V2_L_Weights
+        :members:
+    """
+    weights = EfficientNet_V2_L_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
+    return _efficientnet(
+        inverted_residual_setting,
+        kwargs.pop("dropout", 0.4),
+        last_channel,
+        weights,
+        progress,
+        norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
+        **kwargs,
+    )

+ 563 - 0
libs/vision_libs/models/feature_extraction.py

@@ -0,0 +1,563 @@
+import inspect
+import math
+import re
+import warnings
+from collections import OrderedDict
+from copy import deepcopy
+from itertools import chain
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torchvision
+from torch import fx, nn
+from torch.fx.graph_module import _copy_attr
+
+
+__all__ = ["create_feature_extractor", "get_graph_node_names"]
+
+
+class LeafModuleAwareTracer(fx.Tracer):
+    """
+    An fx.Tracer that allows the user to specify a set of leaf modules, i.e.
+    modules that are not to be traced through. The resulting graph ends up
+    having single nodes referencing calls to the leaf modules' forward methods.
+    """
+
+    def __init__(self, *args, **kwargs):
+        self.leaf_modules = {}
+        if "leaf_modules" in kwargs:
+            leaf_modules = kwargs.pop("leaf_modules")
+            self.leaf_modules = leaf_modules
+        super().__init__(*args, **kwargs)
+
+    def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool:
+        if isinstance(m, tuple(self.leaf_modules)):
+            return True
+        return super().is_leaf_module(m, module_qualname)
+
+
+class NodePathTracer(LeafModuleAwareTracer):
+    """
+    NodePathTracer is an FX tracer that, for each operation, also records the
+    name of the Node from which the operation originated. A node name here is
+    a `.` separated path walking the hierarchy from top level module down to
+    leaf operation or leaf module. The name of the top level module is not
+    included as part of the node name. For example, if we trace a module whose
+    forward method applies a ReLU module, the name for that node will simply
+    be 'relu'.
+
+    Some notes on the specifics:
+        - Nodes are recorded to `self.node_to_qualname` which is a dictionary
+          mapping a given Node object to its node name.
+        - Nodes are recorded in the order which they are executed during
+          tracing.
+        - When a duplicate node name is encountered, a suffix of the form
+          _{int} is added. The counter starts from 1.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # Track the qualified name of the Node being traced
+        self.current_module_qualname = ""
+        # A map from FX Node to the qualified name\#
+        # NOTE: This is loosely like the "qualified name" mentioned in the
+        # torch.fx docs https://pytorch.org/docs/stable/fx.html but adapted
+        # for the purposes of the torchvision feature extractor
+        self.node_to_qualname = OrderedDict()
+
+    def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs):
+        """
+        Override of `fx.Tracer.call_module`
+        This override:
+        1) Stores away the qualified name of the caller for restoration later
+        2) Adds the qualified name of the caller to
+           `current_module_qualname` for retrieval by `create_proxy`
+        3) Once a leaf module is reached, calls `create_proxy`
+        4) Restores the caller's qualified name into current_module_qualname
+        """
+        old_qualname = self.current_module_qualname
+        try:
+            module_qualname = self.path_of_module(m)
+            self.current_module_qualname = module_qualname
+            if not self.is_leaf_module(m, module_qualname):
+                out = forward(*args, **kwargs)
+                return out
+            return self.create_proxy("call_module", module_qualname, args, kwargs)
+        finally:
+            self.current_module_qualname = old_qualname
+
+    def create_proxy(
+        self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None, *_
+    ) -> fx.proxy.Proxy:
+        """
+        Override of `Tracer.create_proxy`. This override intercepts the recording
+        of every operation and stores away the current traced module's qualified
+        name in `node_to_qualname`
+        """
+        proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
+        self.node_to_qualname[proxy.node] = self._get_node_qualname(self.current_module_qualname, proxy.node)
+        return proxy
+
+    def _get_node_qualname(self, module_qualname: str, node: fx.node.Node) -> str:
+        node_qualname = module_qualname
+
+        if node.op != "call_module":
+            # In this case module_qualname from torch.fx doesn't go all the
+            # way to the leaf function/op, so we need to append it
+            if len(node_qualname) > 0:
+                # Only append '.' if we are deeper than the top level module
+                node_qualname += "."
+            node_qualname += str(node)
+
+        # Now we need to add an _{index} postfix on any repeated node names
+        # For modules we do this from scratch
+        # But for anything else, torch.fx already has a globally scoped
+        # _{index} postfix. But we want it locally (relative to direct parent)
+        # scoped. So first we need to undo the torch.fx postfix
+        if re.match(r".+_[0-9]+$", node_qualname) is not None:
+            node_qualname = node_qualname.rsplit("_", 1)[0]
+
+        # ... and now we add on our own postfix
+        for existing_qualname in reversed(self.node_to_qualname.values()):
+            # Check to see if existing_qualname is of the form
+            # {node_qualname} or {node_qualname}_{int}
+            if re.match(rf"{node_qualname}(_[0-9]+)?$", existing_qualname) is not None:
+                postfix = existing_qualname.replace(node_qualname, "")
+                if len(postfix):
+                    # existing_qualname is of the form {node_qualname}_{int}
+                    next_index = int(postfix[1:]) + 1
+                else:
+                    # existing_qualname is of the form {node_qualname}
+                    next_index = 1
+                node_qualname += f"_{next_index}"
+                break
+
+        return node_qualname
+
+
+def _is_subseq(x, y):
+    """Check if y is a subsequence of x
+    https://stackoverflow.com/a/24017747/4391249
+    """
+    iter_x = iter(x)
+    return all(any(x_item == y_item for x_item in iter_x) for y_item in y)
+
+
+def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
+    """
+    Utility function for warning the user if there are differences between
+    the train graph nodes and the eval graph nodes.
+    """
+    train_nodes = list(train_tracer.node_to_qualname.values())
+    eval_nodes = list(eval_tracer.node_to_qualname.values())
+
+    if len(train_nodes) == len(eval_nodes) and all(t == e for t, e in zip(train_nodes, eval_nodes)):
+        return
+
+    suggestion_msg = (
+        "When choosing nodes for feature extraction, you may need to specify "
+        "output nodes for train and eval mode separately."
+    )
+
+    if _is_subseq(train_nodes, eval_nodes):
+        msg = (
+            "NOTE: The nodes obtained by tracing the model in eval mode "
+            "are a subsequence of those obtained in train mode. "
+        )
+    elif _is_subseq(eval_nodes, train_nodes):
+        msg = (
+            "NOTE: The nodes obtained by tracing the model in train mode "
+            "are a subsequence of those obtained in eval mode. "
+        )
+    else:
+        msg = "The nodes obtained by tracing the model in train mode are different to those obtained in eval mode. "
+    warnings.warn(msg + suggestion_msg)
+
+
+def _get_leaf_modules_for_ops() -> List[type]:
+    members = inspect.getmembers(torchvision.ops)
+    result = []
+    for _, obj in members:
+        if inspect.isclass(obj) and issubclass(obj, torch.nn.Module):
+            result.append(obj)
+    return result
+
+
+def _set_default_tracer_kwargs(original_tr_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
+    default_autowrap_modules = (math, torchvision.ops)
+    default_leaf_modules = _get_leaf_modules_for_ops()
+    result_tracer_kwargs = {} if original_tr_kwargs is None else original_tr_kwargs
+    result_tracer_kwargs["autowrap_modules"] = (
+        tuple(set(result_tracer_kwargs["autowrap_modules"] + default_autowrap_modules))
+        if "autowrap_modules" in result_tracer_kwargs
+        else default_autowrap_modules
+    )
+    result_tracer_kwargs["leaf_modules"] = (
+        list(set(result_tracer_kwargs["leaf_modules"] + default_leaf_modules))
+        if "leaf_modules" in result_tracer_kwargs
+        else default_leaf_modules
+    )
+    return result_tracer_kwargs
+
+
+def get_graph_node_names(
+    model: nn.Module,
+    tracer_kwargs: Optional[Dict[str, Any]] = None,
+    suppress_diff_warning: bool = False,
+) -> Tuple[List[str], List[str]]:
+    """
+    Dev utility to return node names in order of execution. See note on node
+    names under :func:`create_feature_extractor`. Useful for seeing which node
+    names are available for feature extraction. There are two reasons that
+    node names can't easily be read directly from the code for a model:
+
+        1. Not all submodules are traced through. Modules from ``torch.nn`` all
+           fall within this category.
+        2. Nodes representing the repeated application of the same operation
+           or leaf module get a ``_{counter}`` postfix.
+
+    The model is traced twice: once in train mode, and once in eval mode. Both
+    sets of node names are returned.
+
+    For more details on the node naming conventions used here, please see the
+    :ref:`relevant subheading <about-node-names>` in the
+    `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
+
+    Args:
+        model (nn.Module): model for which we'd like to print node names
+        tracer_kwargs (dict, optional): a dictionary of keyword arguments for
+            ``NodePathTracer`` (they are eventually passed onto
+            `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
+            By default, it will be set to wrap and make leaf nodes all torchvision ops:
+            {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
+            WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
+            provided dictionary.
+
+        suppress_diff_warning (bool, optional): whether to suppress a warning
+            when there are discrepancies between the train and eval version of
+            the graph. Defaults to False.
+
+    Returns:
+        tuple(list, list): a list of node names from tracing the model in
+        train mode, and another from tracing the model in eval mode.
+
+    Examples::
+
+        >>> model = torchvision.models.resnet18()
+        >>> train_nodes, eval_nodes = get_graph_node_names(model)
+    """
+    tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
+    is_training = model.training
+    train_tracer = NodePathTracer(**tracer_kwargs)
+    train_tracer.trace(model.train())
+    eval_tracer = NodePathTracer(**tracer_kwargs)
+    eval_tracer.trace(model.eval())
+    train_nodes = list(train_tracer.node_to_qualname.values())
+    eval_nodes = list(eval_tracer.node_to_qualname.values())
+    if not suppress_diff_warning:
+        _warn_graph_differences(train_tracer, eval_tracer)
+    # Restore training state
+    model.train(is_training)
+    return train_nodes, eval_nodes
+
+
+class DualGraphModule(fx.GraphModule):
+    """
+    A derivative of `fx.GraphModule`. Differs in the following ways:
+    - Requires a train and eval version of the underlying graph
+    - Copies submodules according to the nodes of both train and eval graphs.
+    - Calling train(mode) switches between train graph and eval graph.
+    """
+
+    def __init__(
+        self, root: torch.nn.Module, train_graph: fx.Graph, eval_graph: fx.Graph, class_name: str = "GraphModule"
+    ):
+        """
+        Args:
+            root (nn.Module): module from which the copied module hierarchy is
+                built
+            train_graph (fx.Graph): the graph that should be used in train mode
+            eval_graph (fx.Graph): the graph that should be used in eval mode
+        """
+        super(fx.GraphModule, self).__init__()
+
+        self.__class__.__name__ = class_name
+
+        self.train_graph = train_graph
+        self.eval_graph = eval_graph
+
+        # Copy all get_attr and call_module ops (indicated by BOTH train and
+        # eval graphs)
+        for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)):
+            if node.op in ["get_attr", "call_module"]:
+                if not isinstance(node.target, str):
+                    raise TypeError(f"node.target should be of type str instead of {type(node.target)}")
+                _copy_attr(root, self, node.target)
+
+        # train mode by default
+        self.train()
+        self.graph = train_graph
+
+        # (borrowed from fx.GraphModule):
+        # Store the Tracer class responsible for creating a Graph separately as part of the
+        # GraphModule state, except when the Tracer is defined in a local namespace.
+        # Locally defined Tracers are not pickleable. This is needed because torch.package will
+        # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
+        # to re-create the Graph during deserialization.
+        if self.eval_graph._tracer_cls != self.train_graph._tracer_cls:
+            raise TypeError(
+                f"Train mode and eval mode should use the same tracer class. Instead got {self.eval_graph._tracer_cls} for eval vs {self.train_graph._tracer_cls} for train"
+            )
+        self._tracer_cls = None
+        if self.graph._tracer_cls and "<locals>" not in self.graph._tracer_cls.__qualname__:
+            self._tracer_cls = self.graph._tracer_cls
+
+    def train(self, mode=True):
+        """
+        Swap out the graph depending on the selected training mode.
+        NOTE this should be safe when calling model.eval() because that just
+        calls this with mode == False.
+        """
+        # NOTE: Only set self.graph if the current graph is not the desired
+        # one. This saves us from recompiling the graph where not necessary.
+        if mode and not self.training:
+            self.graph = self.train_graph
+        elif not mode and self.training:
+            self.graph = self.eval_graph
+        return super().train(mode=mode)
+
+
+def create_feature_extractor(
+    model: nn.Module,
+    return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
+    train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
+    eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
+    tracer_kwargs: Optional[Dict[str, Any]] = None,
+    suppress_diff_warning: bool = False,
+) -> fx.GraphModule:
+    """
+    Creates a new graph module that returns intermediate nodes from a given
+    model as dictionary with user specified keys as strings, and the requested
+    outputs as values. This is achieved by re-writing the computation graph of
+    the model via FX to return the desired nodes as outputs. All unused nodes
+    are removed, together with their corresponding parameters.
+
+    Desired output nodes must be specified as a ``.`` separated
+    path walking the module hierarchy from top level module down to leaf
+    operation or leaf module. For more details on the node naming conventions
+    used here, please see the :ref:`relevant subheading <about-node-names>`
+    in the `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
+
+    Not all models will be FX traceable, although with some massaging they can
+    be made to cooperate. Here's a (not exhaustive) list of tips:
+
+        - If you don't need to trace through a particular, problematic
+          sub-module, turn it into a "leaf module" by passing a list of
+          ``leaf_modules`` as one of the ``tracer_kwargs`` (see example below).
+          It will not be traced through, but rather, the resulting graph will
+          hold a reference to that module's forward method.
+        - Likewise, you may turn functions into leaf functions by passing a
+          list of ``autowrap_functions`` as one of the ``tracer_kwargs`` (see
+          example below).
+        - Some inbuilt Python functions can be problematic. For instance,
+          ``int`` will raise an error during tracing. You may wrap them in your
+          own function and then pass that in ``autowrap_functions`` as one of
+          the ``tracer_kwargs``.
+
+    For further information on FX see the
+    `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_.
+
+    Args:
+        model (nn.Module): model on which we will extract the features
+        return_nodes (list or dict, optional): either a ``List`` or a ``Dict``
+            containing the names (or partial names - see note above)
+            of the nodes for which the activations will be returned. If it is
+            a ``Dict``, the keys are the node names, and the values
+            are the user-specified keys for the graph module's returned
+            dictionary. If it is a ``List``, it is treated as a ``Dict`` mapping
+            node specification strings directly to output names. In the case
+            that ``train_return_nodes`` and ``eval_return_nodes`` are specified,
+            this should not be specified.
+        train_return_nodes (list or dict, optional): similar to
+            ``return_nodes``. This can be used if the return nodes
+            for train mode are different than those from eval mode.
+            If this is specified, ``eval_return_nodes`` must also be specified,
+            and ``return_nodes`` should not be specified.
+        eval_return_nodes (list or dict, optional): similar to
+            ``return_nodes``. This can be used if the return nodes
+            for train mode are different than those from eval mode.
+            If this is specified, ``train_return_nodes`` must also be specified,
+            and `return_nodes` should not be specified.
+        tracer_kwargs (dict, optional): a dictionary of keyword arguments for
+            ``NodePathTracer`` (which passes them onto it's parent class
+            `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
+            By default, it will be set to wrap and make leaf nodes all torchvision ops:
+            {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
+            WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
+            provided dictionary.
+        suppress_diff_warning (bool, optional): whether to suppress a warning
+            when there are discrepancies between the train and eval version of
+            the graph. Defaults to False.
+
+    Examples::
+
+        >>> # Feature extraction with resnet
+        >>> model = torchvision.models.resnet18()
+        >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
+        >>> model = create_feature_extractor(
+        >>>     model, {'layer1': 'feat1', 'layer3': 'feat2'})
+        >>> out = model(torch.rand(1, 3, 224, 224))
+        >>> print([(k, v.shape) for k, v in out.items()])
+        >>>     [('feat1', torch.Size([1, 64, 56, 56])),
+        >>>      ('feat2', torch.Size([1, 256, 14, 14]))]
+
+        >>> # Specifying leaf modules and leaf functions
+        >>> def leaf_function(x):
+        >>>     # This would raise a TypeError if traced through
+        >>>     return int(x)
+        >>>
+        >>> class LeafModule(torch.nn.Module):
+        >>>     def forward(self, x):
+        >>>         # This would raise a TypeError if traced through
+        >>>         int(x.shape[0])
+        >>>         return torch.nn.functional.relu(x + 4)
+        >>>
+        >>> class MyModule(torch.nn.Module):
+        >>>     def __init__(self):
+        >>>         super().__init__()
+        >>>         self.conv = torch.nn.Conv2d(3, 1, 3)
+        >>>         self.leaf_module = LeafModule()
+        >>>
+        >>>     def forward(self, x):
+        >>>         leaf_function(x.shape[0])
+        >>>         x = self.conv(x)
+        >>>         return self.leaf_module(x)
+        >>>
+        >>> model = create_feature_extractor(
+        >>>     MyModule(), return_nodes=['leaf_module'],
+        >>>     tracer_kwargs={'leaf_modules': [LeafModule],
+        >>>                    'autowrap_functions': [leaf_function]})
+
+    """
+    tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
+    is_training = model.training
+
+    if all(arg is None for arg in [return_nodes, train_return_nodes, eval_return_nodes]):
+
+        raise ValueError(
+            "Either `return_nodes` or `train_return_nodes` and `eval_return_nodes` together, should be specified"
+        )
+
+    if (train_return_nodes is None) ^ (eval_return_nodes is None):
+        raise ValueError(
+            "If any of `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified"
+        )
+
+    if not ((return_nodes is None) ^ (train_return_nodes is None)):
+        raise ValueError("If `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified")
+
+    # Put *_return_nodes into Dict[str, str] format
+    def to_strdict(n) -> Dict[str, str]:
+        if isinstance(n, list):
+            return {str(i): str(i) for i in n}
+        return {str(k): str(v) for k, v in n.items()}
+
+    if train_return_nodes is None:
+        return_nodes = to_strdict(return_nodes)
+        train_return_nodes = deepcopy(return_nodes)
+        eval_return_nodes = deepcopy(return_nodes)
+    else:
+        train_return_nodes = to_strdict(train_return_nodes)
+        eval_return_nodes = to_strdict(eval_return_nodes)
+
+    # Repeat the tracing and graph rewriting for train and eval mode
+    tracers = {}
+    graphs = {}
+    mode_return_nodes: Dict[str, Dict[str, str]] = {"train": train_return_nodes, "eval": eval_return_nodes}
+    for mode in ["train", "eval"]:
+        if mode == "train":
+            model.train()
+        elif mode == "eval":
+            model.eval()
+
+        # Instantiate our NodePathTracer and use that to trace the model
+        tracer = NodePathTracer(**tracer_kwargs)
+        graph = tracer.trace(model)
+
+        name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__
+        graph_module = fx.GraphModule(tracer.root, graph, name)
+
+        available_nodes = list(tracer.node_to_qualname.values())
+        # FIXME We don't know if we should expect this to happen
+        if len(set(available_nodes)) != len(available_nodes):
+            raise ValueError(
+                "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
+            )
+        # Check that all outputs in return_nodes are present in the model
+        for query in mode_return_nodes[mode].keys():
+            # To check if a query is available we need to check that at least
+            # one of the available names starts with it up to a .
+            if not any([re.match(rf"^{query}(\.|$)", n) is not None for n in available_nodes]):
+                raise ValueError(
+                    f"node: '{query}' is not present in model. Hint: use "
+                    "`get_graph_node_names` to make sure the "
+                    "`return_nodes` you specified are present. It may even "
+                    "be that you need to specify `train_return_nodes` and "
+                    "`eval_return_nodes` separately."
+                )
+
+        # Remove existing output nodes (train mode)
+        orig_output_nodes = []
+        for n in reversed(graph_module.graph.nodes):
+            if n.op == "output":
+                orig_output_nodes.append(n)
+        if not orig_output_nodes:
+            raise ValueError("No output nodes found in graph_module.graph.nodes")
+
+        for n in orig_output_nodes:
+            graph_module.graph.erase_node(n)
+
+        # Find nodes corresponding to return_nodes and make them into output_nodes
+        nodes = [n for n in graph_module.graph.nodes]
+        output_nodes = OrderedDict()
+        for n in reversed(nodes):
+            module_qualname = tracer.node_to_qualname.get(n)
+            if module_qualname is None:
+                # NOTE - Know cases where this happens:
+                # - Node representing creation of a tensor constant - probably
+                #   not interesting as a return node
+                # - When packing outputs into a named tuple like in InceptionV3
+                continue
+            for query in mode_return_nodes[mode]:
+                depth = query.count(".")
+                if ".".join(module_qualname.split(".")[: depth + 1]) == query:
+                    output_nodes[mode_return_nodes[mode][query]] = n
+                    mode_return_nodes[mode].pop(query)
+                    break
+        output_nodes = OrderedDict(reversed(list(output_nodes.items())))
+
+        # And add them in the end of the graph
+        with graph_module.graph.inserting_after(nodes[-1]):
+            graph_module.graph.output(output_nodes)
+
+        # Remove unused modules / parameters
+        graph_module.graph.eliminate_dead_code()
+        graph_module.recompile()
+
+        # Keep track of the tracer and graph, so we can choose the main one
+        tracers[mode] = tracer
+        graphs[mode] = graph
+
+    # Warn user if there are any discrepancies between the graphs of the
+    # train and eval modes
+    if not suppress_diff_warning:
+        _warn_graph_differences(tracers["train"], tracers["eval"])
+
+    # Build the final graph module
+    graph_module = DualGraphModule(model, graphs["train"], graphs["eval"], class_name=name)
+
+    # Restore original training mode
+    model.train(is_training)
+    graph_module.train(is_training)
+
+    return graph_module

+ 345 - 0
libs/vision_libs/models/googlenet.py

@@ -0,0 +1,345 @@
+import warnings
+from collections import namedtuple
+from functools import partial
+from typing import Any, Callable, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from ..transforms._presets import ImageClassification
+from ..utils import _log_api_usage_once
+from ._api import register_model, Weights, WeightsEnum
+from ._meta import _IMAGENET_CATEGORIES
+from ._utils import _ovewrite_named_param, handle_legacy_interface
+
+
+__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"]
+
+
+GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
+GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]}
+
+# Script annotations failed with _GoogleNetOutputs = namedtuple ...
+# _GoogLeNetOutputs set here for backwards compat
+_GoogLeNetOutputs = GoogLeNetOutputs
+
+
+class GoogLeNet(nn.Module):
+    __constants__ = ["aux_logits", "transform_input"]
+
+    def __init__(
+        self,
+        num_classes: int = 1000,
+        aux_logits: bool = True,
+        transform_input: bool = False,
+        init_weights: Optional[bool] = None,
+        blocks: Optional[List[Callable[..., nn.Module]]] = None,
+        dropout: float = 0.2,
+        dropout_aux: float = 0.7,
+    ) -> None:
+        super().__init__()
+        _log_api_usage_once(self)
+        if blocks is None:
+            blocks = [BasicConv2d, Inception, InceptionAux]
+        if init_weights is None:
+            warnings.warn(
+                "The default weight initialization of GoogleNet will be changed in future releases of "
+                "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
+                " due to scipy/scipy#11299), please set init_weights=True.",
+                FutureWarning,
+            )
+            init_weights = True
+        if len(blocks) != 3:
+            raise ValueError(f"blocks length should be 3 instead of {len(blocks)}")
+        conv_block = blocks[0]
+        inception_block = blocks[1]
+        inception_aux_block = blocks[2]
+
+        self.aux_logits = aux_logits
+        self.transform_input = transform_input
+
+        self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
+        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
+        self.conv2 = conv_block(64, 64, kernel_size=1)
+        self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
+        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
+
+        self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
+        self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
+        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
+
+        self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
+        self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
+        self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
+        self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
+        self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
+        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
+        self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)
+
+        if aux_logits:
+            self.aux1 = inception_aux_block(512, num_classes, dropout=dropout_aux)
+            self.aux2 = inception_aux_block(528, num_classes, dropout=dropout_aux)
+        else:
+            self.aux1 = None  # type: ignore[assignment]
+            self.aux2 = None  # type: ignore[assignment]
+
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.dropout = nn.Dropout(p=dropout)
+        self.fc = nn.Linear(1024, num_classes)
+
+        if init_weights:
+            for m in self.modules():
+                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
+                    torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
+                elif isinstance(m, nn.BatchNorm2d):
+                    nn.init.constant_(m.weight, 1)
+                    nn.init.constant_(m.bias, 0)
+
+    def _transform_input(self, x: Tensor) -> Tensor:
+        if self.transform_input:
+            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
+            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
+            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
+            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
+        return x
+
+    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+        # N x 3 x 224 x 224
+        x = self.conv1(x)
+        # N x 64 x 112 x 112
+        x = self.maxpool1(x)
+        # N x 64 x 56 x 56
+        x = self.conv2(x)
+        # N x 64 x 56 x 56
+        x = self.conv3(x)
+        # N x 192 x 56 x 56
+        x = self.maxpool2(x)
+
+        # N x 192 x 28 x 28
+        x = self.inception3a(x)
+        # N x 256 x 28 x 28
+        x = self.inception3b(x)
+        # N x 480 x 28 x 28
+        x = self.maxpool3(x)
+        # N x 480 x 14 x 14
+        x = self.inception4a(x)
+        # N x 512 x 14 x 14
+        aux1: Optional[Tensor] = None
+        if self.aux1 is not None:
+            if self.training:
+                aux1 = self.aux1(x)
+
+        x = self.inception4b(x)
+        # N x 512 x 14 x 14
+        x = self.inception4c(x)
+        # N x 512 x 14 x 14
+        x = self.inception4d(x)
+        # N x 528 x 14 x 14
+        aux2: Optional[Tensor] = None
+        if self.aux2 is not None:
+            if self.training:
+                aux2 = self.aux2(x)
+
+        x = self.inception4e(x)
+        # N x 832 x 14 x 14
+        x = self.maxpool4(x)
+        # N x 832 x 7 x 7
+        x = self.inception5a(x)
+        # N x 832 x 7 x 7
+        x = self.inception5b(x)
+        # N x 1024 x 7 x 7
+
+        x = self.avgpool(x)
+        # N x 1024 x 1 x 1
+        x = torch.flatten(x, 1)
+        # N x 1024
+        x = self.dropout(x)
+        x = self.fc(x)
+        # N x 1000 (num_classes)
+        return x, aux2, aux1
+
+    @torch.jit.unused
+    def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
+        if self.training and self.aux_logits:
+            return _GoogLeNetOutputs(x, aux2, aux1)
+        else:
+            return x  # type: ignore[return-value]
+
+    def forward(self, x: Tensor) -> GoogLeNetOutputs:
+        x = self._transform_input(x)
+        x, aux1, aux2 = self._forward(x)
+        aux_defined = self.training and self.aux_logits
+        if torch.jit.is_scripting():
+            if not aux_defined:
+                warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
+            return GoogLeNetOutputs(x, aux2, aux1)
+        else:
+            return self.eager_outputs(x, aux2, aux1)
+
+
+class Inception(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        ch1x1: int,
+        ch3x3red: int,
+        ch3x3: int,
+        ch5x5red: int,
+        ch5x5: int,
+        pool_proj: int,
+        conv_block: Optional[Callable[..., nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
+
+        self.branch2 = nn.Sequential(
+            conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
+        )
+
+        self.branch3 = nn.Sequential(
+            conv_block(in_channels, ch5x5red, kernel_size=1),
+            # Here, kernel_size=3 instead of kernel_size=5 is a known bug.
+            # Please see https://github.com/pytorch/vision/issues/906 for details.
+            conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1),
+        )
+
+        self.branch4 = nn.Sequential(
+            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
+            conv_block(in_channels, pool_proj, kernel_size=1),
+        )
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch1 = self.branch1(x)
+        branch2 = self.branch2(x)
+        branch3 = self.branch3(x)
+        branch4 = self.branch4(x)
+
+        outputs = [branch1, branch2, branch3, branch4]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionAux(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        num_classes: int,
+        conv_block: Optional[Callable[..., nn.Module]] = None,
+        dropout: float = 0.7,
+    ) -> None:
+        super().__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.conv = conv_block(in_channels, 128, kernel_size=1)
+
+        self.fc1 = nn.Linear(2048, 1024)
+        self.fc2 = nn.Linear(1024, num_classes)
+        self.dropout = nn.Dropout(p=dropout)
+
+    def forward(self, x: Tensor) -> Tensor:
+        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
+        x = F.adaptive_avg_pool2d(x, (4, 4))
+        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
+        x = self.conv(x)
+        # N x 128 x 4 x 4
+        x = torch.flatten(x, 1)
+        # N x 2048
+        x = F.relu(self.fc1(x), inplace=True)
+        # N x 1024
+        x = self.dropout(x)
+        # N x 1024
+        x = self.fc2(x)
+        # N x 1000 (num_classes)
+
+        return x
+
+
+class BasicConv2d(nn.Module):
+    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
+        super().__init__()
+        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
+        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.conv(x)
+        x = self.bn(x)
+        return F.relu(x, inplace=True)
+
+
+class GoogLeNet_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/googlenet-1378be20.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            "num_params": 6624904,
+            "min_size": (15, 15),
+            "categories": _IMAGENET_CATEGORIES,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet",
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 69.778,
+                    "acc@5": 89.530,
+                }
+            },
+            "_ops": 1.498,
+            "_file_size": 49.731,
+            "_docs": """These weights are ported from the original paper.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1))
+def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
+    """GoogLeNet (Inception v1) model architecture from
+    `Going Deeper with Convolutions <http://arxiv.org/abs/1409.4842>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.GoogLeNet_Weights`, optional): The
+            pretrained weights for the model. See
+            :class:`~torchvision.models.GoogLeNet_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.GoogLeNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/googlenet.py>`_
+            for more details about this class.
+    .. autoclass:: torchvision.models.GoogLeNet_Weights
+        :members:
+    """
+    weights = GoogLeNet_Weights.verify(weights)
+
+    original_aux_logits = kwargs.get("aux_logits", False)
+    if weights is not None:
+        if "transform_input" not in kwargs:
+            _ovewrite_named_param(kwargs, "transform_input", True)
+        _ovewrite_named_param(kwargs, "aux_logits", True)
+        _ovewrite_named_param(kwargs, "init_weights", False)
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+    model = GoogLeNet(**kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if not original_aux_logits:
+            model.aux_logits = False
+            model.aux1 = None  # type: ignore[assignment]
+            model.aux2 = None  # type: ignore[assignment]
+        else:
+            warnings.warn(
+                "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
+            )
+
+    return model

+ 478 - 0
libs/vision_libs/models/inception.py

@@ -0,0 +1,478 @@
+import warnings
+from collections import namedtuple
+from functools import partial
+from typing import Any, Callable, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+from ..transforms._presets import ImageClassification
+from ..utils import _log_api_usage_once
+from ._api import register_model, Weights, WeightsEnum
+from ._meta import _IMAGENET_CATEGORIES
+from ._utils import _ovewrite_named_param, handle_legacy_interface
+
+
+__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"]
+
+
+InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"])
+InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]}
+
+# Script annotations failed with _GoogleNetOutputs = namedtuple ...
+# _InceptionOutputs set here for backwards compat
+_InceptionOutputs = InceptionOutputs
+
+
+class Inception3(nn.Module):
+    def __init__(
+        self,
+        num_classes: int = 1000,
+        aux_logits: bool = True,
+        transform_input: bool = False,
+        inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
+        init_weights: Optional[bool] = None,
+        dropout: float = 0.5,
+    ) -> None:
+        super().__init__()
+        _log_api_usage_once(self)
+        if inception_blocks is None:
+            inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]
+        if init_weights is None:
+            warnings.warn(
+                "The default weight initialization of inception_v3 will be changed in future releases of "
+                "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
+                " due to scipy/scipy#11299), please set init_weights=True.",
+                FutureWarning,
+            )
+            init_weights = True
+        if len(inception_blocks) != 7:
+            raise ValueError(f"length of inception_blocks should be 7 instead of {len(inception_blocks)}")
+        conv_block = inception_blocks[0]
+        inception_a = inception_blocks[1]
+        inception_b = inception_blocks[2]
+        inception_c = inception_blocks[3]
+        inception_d = inception_blocks[4]
+        inception_e = inception_blocks[5]
+        inception_aux = inception_blocks[6]
+
+        self.aux_logits = aux_logits
+        self.transform_input = transform_input
+        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
+        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
+        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
+        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
+        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
+        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
+        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
+        self.Mixed_5b = inception_a(192, pool_features=32)
+        self.Mixed_5c = inception_a(256, pool_features=64)
+        self.Mixed_5d = inception_a(288, pool_features=64)
+        self.Mixed_6a = inception_b(288)
+        self.Mixed_6b = inception_c(768, channels_7x7=128)
+        self.Mixed_6c = inception_c(768, channels_7x7=160)
+        self.Mixed_6d = inception_c(768, channels_7x7=160)
+        self.Mixed_6e = inception_c(768, channels_7x7=192)
+        self.AuxLogits: Optional[nn.Module] = None
+        if aux_logits:
+            self.AuxLogits = inception_aux(768, num_classes)
+        self.Mixed_7a = inception_d(768)
+        self.Mixed_7b = inception_e(1280)
+        self.Mixed_7c = inception_e(2048)
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.dropout = nn.Dropout(p=dropout)
+        self.fc = nn.Linear(2048, num_classes)
+        if init_weights:
+            for m in self.modules():
+                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
+                    stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1  # type: ignore
+                    torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
+                elif isinstance(m, nn.BatchNorm2d):
+                    nn.init.constant_(m.weight, 1)
+                    nn.init.constant_(m.bias, 0)
+
+    def _transform_input(self, x: Tensor) -> Tensor:
+        if self.transform_input:
+            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
+            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
+            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
+            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
+        return x
+
+    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
+        # N x 3 x 299 x 299
+        x = self.Conv2d_1a_3x3(x)
+        # N x 32 x 149 x 149
+        x = self.Conv2d_2a_3x3(x)
+        # N x 32 x 147 x 147
+        x = self.Conv2d_2b_3x3(x)
+        # N x 64 x 147 x 147
+        x = self.maxpool1(x)
+        # N x 64 x 73 x 73
+        x = self.Conv2d_3b_1x1(x)
+        # N x 80 x 73 x 73
+        x = self.Conv2d_4a_3x3(x)
+        # N x 192 x 71 x 71
+        x = self.maxpool2(x)
+        # N x 192 x 35 x 35
+        x = self.Mixed_5b(x)
+        # N x 256 x 35 x 35
+        x = self.Mixed_5c(x)
+        # N x 288 x 35 x 35
+        x = self.Mixed_5d(x)
+        # N x 288 x 35 x 35
+        x = self.Mixed_6a(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_6b(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_6c(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_6d(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_6e(x)
+        # N x 768 x 17 x 17
+        aux: Optional[Tensor] = None
+        if self.AuxLogits is not None:
+            if self.training:
+                aux = self.AuxLogits(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_7a(x)
+        # N x 1280 x 8 x 8
+        x = self.Mixed_7b(x)
+        # N x 2048 x 8 x 8
+        x = self.Mixed_7c(x)
+        # N x 2048 x 8 x 8
+        # Adaptive average pooling
+        x = self.avgpool(x)
+        # N x 2048 x 1 x 1
+        x = self.dropout(x)
+        # N x 2048 x 1 x 1
+        x = torch.flatten(x, 1)
+        # N x 2048
+        x = self.fc(x)
+        # N x 1000 (num_classes)
+        return x, aux
+
+    @torch.jit.unused
+    def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
+        if self.training and self.aux_logits:
+            return InceptionOutputs(x, aux)
+        else:
+            return x  # type: ignore[return-value]
+
+    def forward(self, x: Tensor) -> InceptionOutputs:
+        x = self._transform_input(x)
+        x, aux = self._forward(x)
+        aux_defined = self.training and self.aux_logits
+        if torch.jit.is_scripting():
+            if not aux_defined:
+                warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
+            return InceptionOutputs(x, aux)
+        else:
+            return self.eager_outputs(x, aux)
+
+
+class InceptionA(nn.Module):
+    def __init__(
+        self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super().__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
+
+        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
+        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
+
+        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
+        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
+        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
+
+        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch1x1 = self.branch1x1(x)
+
+        branch5x5 = self.branch5x5_1(x)
+        branch5x5 = self.branch5x5_2(branch5x5)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionB(nn.Module):
+    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
+        super().__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
+
+        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
+        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
+        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch3x3 = self.branch3x3(x)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
+
+        outputs = [branch3x3, branch3x3dbl, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionC(nn.Module):
+    def __init__(
+        self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super().__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
+
+        c7 = channels_7x7
+        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
+        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
+        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
+
+        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
+        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
+        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
+        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
+        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
+
+        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch1x1 = self.branch1x1(x)
+
+        branch7x7 = self.branch7x7_1(x)
+        branch7x7 = self.branch7x7_2(branch7x7)
+        branch7x7 = self.branch7x7_3(branch7x7)
+
+        branch7x7dbl = self.branch7x7dbl_1(x)
+        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionD(nn.Module):
+    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
+        super().__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
+        self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
+
+        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
+        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
+        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
+        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = self.branch3x3_2(branch3x3)
+
+        branch7x7x3 = self.branch7x7x3_1(x)
+        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
+        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
+        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
+
+        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
+        outputs = [branch3x3, branch7x7x3, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionE(nn.Module):
+    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
+        super().__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
+
+        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
+        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
+        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
+
+        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
+        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
+        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
+        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
+
+        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionAux(nn.Module):
+    def __init__(
+        self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super().__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.conv0 = conv_block(in_channels, 128, kernel_size=1)
+        self.conv1 = conv_block(128, 768, kernel_size=5)
+        self.conv1.stddev = 0.01  # type: ignore[assignment]
+        self.fc = nn.Linear(768, num_classes)
+        self.fc.stddev = 0.001  # type: ignore[assignment]
+
+    def forward(self, x: Tensor) -> Tensor:
+        # N x 768 x 17 x 17
+        x = F.avg_pool2d(x, kernel_size=5, stride=3)
+        # N x 768 x 5 x 5
+        x = self.conv0(x)
+        # N x 128 x 5 x 5
+        x = self.conv1(x)
+        # N x 768 x 1 x 1
+        # Adaptive average pooling
+        x = F.adaptive_avg_pool2d(x, (1, 1))
+        # N x 768 x 1 x 1
+        x = torch.flatten(x, 1)
+        # N x 768
+        x = self.fc(x)
+        # N x 1000
+        return x
+
+
+class BasicConv2d(nn.Module):
+    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
+        super().__init__()
+        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
+        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.conv(x)
+        x = self.bn(x)
+        return F.relu(x, inplace=True)
+
+
+class Inception_V3_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
+        transforms=partial(ImageClassification, crop_size=299, resize_size=342),
+        meta={
+            "num_params": 27161264,
+            "min_size": (75, 75),
+            "categories": _IMAGENET_CATEGORIES,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3",
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 77.294,
+                    "acc@5": 93.450,
+                }
+            },
+            "_ops": 5.713,
+            "_file_size": 103.903,
+            "_docs": """These weights are ported from the original paper.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1))
+def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
+    """
+    Inception v3 model architecture from
+    `Rethinking the Inception Architecture for Computer Vision <http://arxiv.org/abs/1512.00567>`_.
+
+    .. note::
+        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
+        N x 3 x 299 x 299, so ensure your images are sized accordingly.
+
+    Args:
+        weights (:class:`~torchvision.models.Inception_V3_Weights`, optional): The
+            pretrained weights for the model. See
+            :class:`~torchvision.models.Inception_V3_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.Inception3``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.Inception_V3_Weights
+        :members:
+    """
+    weights = Inception_V3_Weights.verify(weights)
+
+    original_aux_logits = kwargs.get("aux_logits", True)
+    if weights is not None:
+        if "transform_input" not in kwargs:
+            _ovewrite_named_param(kwargs, "transform_input", True)
+        _ovewrite_named_param(kwargs, "aux_logits", True)
+        _ovewrite_named_param(kwargs, "init_weights", False)
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+    model = Inception3(**kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if not original_aux_logits:
+            model.aux_logits = False
+            model.AuxLogits = None
+
+    return model

+ 832 - 0
libs/vision_libs/models/maxvit.py

@@ -0,0 +1,832 @@
+import math
+from collections import OrderedDict
+from functools import partial
+from typing import Any, Callable, List, Optional, Sequence, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+from torchvision.models._api import register_model, Weights, WeightsEnum
+from torchvision.models._meta import _IMAGENET_CATEGORIES
+from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface
+from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation
+from torchvision.ops.stochastic_depth import StochasticDepth
+from torchvision.transforms._presets import ImageClassification, InterpolationMode
+from torchvision.utils import _log_api_usage_once
+
+__all__ = [
+    "MaxVit",
+    "MaxVit_T_Weights",
+    "maxvit_t",
+]
+
+
+def _get_conv_output_shape(input_size: Tuple[int, int], kernel_size: int, stride: int, padding: int) -> Tuple[int, int]:
+    return (
+        (input_size[0] - kernel_size + 2 * padding) // stride + 1,
+        (input_size[1] - kernel_size + 2 * padding) // stride + 1,
+    )
+
+
+def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List[Tuple[int, int]]:
+    """Util function to check that the input size is correct for a MaxVit configuration."""
+    shapes = []
+    block_input_shape = _get_conv_output_shape(input_size, 3, 2, 1)
+    for _ in range(n_blocks):
+        block_input_shape = _get_conv_output_shape(block_input_shape, 3, 2, 1)
+        shapes.append(block_input_shape)
+    return shapes
+
+
+def _get_relative_position_index(height: int, width: int) -> torch.Tensor:
+    coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)]))
+    coords_flat = torch.flatten(coords, 1)
+    relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]
+    relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+    relative_coords[:, :, 0] += height - 1
+    relative_coords[:, :, 1] += width - 1
+    relative_coords[:, :, 0] *= 2 * width - 1
+    return relative_coords.sum(-1)
+
+
+class MBConv(nn.Module):
+    """MBConv: Mobile Inverted Residual Bottleneck.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        expansion_ratio (float): Expansion ratio in the bottleneck.
+        squeeze_ratio (float): Squeeze ratio in the SE Layer.
+        stride (int): Stride of the depthwise convolution.
+        activation_layer (Callable[..., nn.Module]): Activation function.
+        norm_layer (Callable[..., nn.Module]): Normalization function.
+        p_stochastic_dropout (float): Probability of stochastic depth.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        expansion_ratio: float,
+        squeeze_ratio: float,
+        stride: int,
+        activation_layer: Callable[..., nn.Module],
+        norm_layer: Callable[..., nn.Module],
+        p_stochastic_dropout: float = 0.0,
+    ) -> None:
+        super().__init__()
+
+        proj: Sequence[nn.Module]
+        self.proj: nn.Module
+
+        should_proj = stride != 1 or in_channels != out_channels
+        if should_proj:
+            proj = [nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=True)]
+            if stride == 2:
+                proj = [nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)] + proj  # type: ignore
+            self.proj = nn.Sequential(*proj)
+        else:
+            self.proj = nn.Identity()  # type: ignore
+
+        mid_channels = int(out_channels * expansion_ratio)
+        sqz_channels = int(out_channels * squeeze_ratio)
+
+        if p_stochastic_dropout:
+            self.stochastic_depth = StochasticDepth(p_stochastic_dropout, mode="row")  # type: ignore
+        else:
+            self.stochastic_depth = nn.Identity()  # type: ignore
+
+        _layers = OrderedDict()
+        _layers["pre_norm"] = norm_layer(in_channels)
+        _layers["conv_a"] = Conv2dNormActivation(
+            in_channels,
+            mid_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            activation_layer=activation_layer,
+            norm_layer=norm_layer,
+            inplace=None,
+        )
+        _layers["conv_b"] = Conv2dNormActivation(
+            mid_channels,
+            mid_channels,
+            kernel_size=3,
+            stride=stride,
+            padding=1,
+            activation_layer=activation_layer,
+            norm_layer=norm_layer,
+            groups=mid_channels,
+            inplace=None,
+        )
+        _layers["squeeze_excitation"] = SqueezeExcitation(mid_channels, sqz_channels, activation=nn.SiLU)
+        _layers["conv_c"] = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, bias=True)
+
+        self.layers = nn.Sequential(_layers)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor with expected layout of [B, C, H, W].
+        Returns:
+            Tensor: Output tensor with expected layout of [B, C, H / stride, W / stride].
+        """
+        res = self.proj(x)
+        x = self.stochastic_depth(self.layers(x))
+        return res + x
+
+
+class RelativePositionalMultiHeadAttention(nn.Module):
+    """Relative Positional Multi-Head Attention.
+
+    Args:
+        feat_dim (int): Number of input features.
+        head_dim (int): Number of features per head.
+        max_seq_len (int): Maximum sequence length.
+    """
+
+    def __init__(
+        self,
+        feat_dim: int,
+        head_dim: int,
+        max_seq_len: int,
+    ) -> None:
+        super().__init__()
+
+        if feat_dim % head_dim != 0:
+            raise ValueError(f"feat_dim: {feat_dim} must be divisible by head_dim: {head_dim}")
+
+        self.n_heads = feat_dim // head_dim
+        self.head_dim = head_dim
+        self.size = int(math.sqrt(max_seq_len))
+        self.max_seq_len = max_seq_len
+
+        self.to_qkv = nn.Linear(feat_dim, self.n_heads * self.head_dim * 3)
+        self.scale_factor = feat_dim**-0.5
+
+        self.merge = nn.Linear(self.head_dim * self.n_heads, feat_dim)
+        self.relative_position_bias_table = nn.parameter.Parameter(
+            torch.empty(((2 * self.size - 1) * (2 * self.size - 1), self.n_heads), dtype=torch.float32),
+        )
+
+        self.register_buffer("relative_position_index", _get_relative_position_index(self.size, self.size))
+        # initialize with truncated normal the bias
+        torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
+
+    def get_relative_positional_bias(self) -> torch.Tensor:
+        bias_index = self.relative_position_index.view(-1)  # type: ignore
+        relative_bias = self.relative_position_bias_table[bias_index].view(self.max_seq_len, self.max_seq_len, -1)  # type: ignore
+        relative_bias = relative_bias.permute(2, 0, 1).contiguous()
+        return relative_bias.unsqueeze(0)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor with expected layout of [B, G, P, D].
+        Returns:
+            Tensor: Output tensor with expected layout of [B, G, P, D].
+        """
+        B, G, P, D = x.shape
+        H, DH = self.n_heads, self.head_dim
+
+        qkv = self.to_qkv(x)
+        q, k, v = torch.chunk(qkv, 3, dim=-1)
+
+        q = q.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
+        k = k.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
+        v = v.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
+
+        k = k * self.scale_factor
+        dot_prod = torch.einsum("B G H I D, B G H J D -> B G H I J", q, k)
+        pos_bias = self.get_relative_positional_bias()
+
+        dot_prod = F.softmax(dot_prod + pos_bias, dim=-1)
+
+        out = torch.einsum("B G H I J, B G H J D -> B G H I D", dot_prod, v)
+        out = out.permute(0, 1, 3, 2, 4).reshape(B, G, P, D)
+
+        out = self.merge(out)
+        return out
+
+
+class SwapAxes(nn.Module):
+    """Permute the axes of a tensor."""
+
+    def __init__(self, a: int, b: int) -> None:
+        super().__init__()
+        self.a = a
+        self.b = b
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        res = torch.swapaxes(x, self.a, self.b)
+        return res
+
+
+class WindowPartition(nn.Module):
+    """
+    Partition the input tensor into non-overlapping windows.
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+
+    def forward(self, x: Tensor, p: int) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor with expected layout of [B, C, H, W].
+            p (int): Number of partitions.
+        Returns:
+            Tensor: Output tensor with expected layout of [B, H/P, W/P, P*P, C].
+        """
+        B, C, H, W = x.shape
+        P = p
+        # chunk up H and W dimensions
+        x = x.reshape(B, C, H // P, P, W // P, P)
+        x = x.permute(0, 2, 4, 3, 5, 1)
+        # colapse P * P dimension
+        x = x.reshape(B, (H // P) * (W // P), P * P, C)
+        return x
+
+
+class WindowDepartition(nn.Module):
+    """
+    Departition the input tensor of non-overlapping windows into a feature volume of layout [B, C, H, W].
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+
+    def forward(self, x: Tensor, p: int, h_partitions: int, w_partitions: int) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor with expected layout of [B, (H/P * W/P), P*P, C].
+            p (int): Number of partitions.
+            h_partitions (int): Number of vertical partitions.
+            w_partitions (int): Number of horizontal partitions.
+        Returns:
+            Tensor: Output tensor with expected layout of [B, C, H, W].
+        """
+        B, G, PP, C = x.shape
+        P = p
+        HP, WP = h_partitions, w_partitions
+        # split P * P dimension into 2 P tile dimensionsa
+        x = x.reshape(B, HP, WP, P, P, C)
+        # permute into B, C, HP, P, WP, P
+        x = x.permute(0, 5, 1, 3, 2, 4)
+        # reshape into B, C, H, W
+        x = x.reshape(B, C, HP * P, WP * P)
+        return x
+
+
+class PartitionAttentionLayer(nn.Module):
+    """
+    Layer for partitioning the input tensor into non-overlapping windows and applying attention to each window.
+
+    Args:
+        in_channels (int): Number of input channels.
+        head_dim (int): Dimension of each attention head.
+        partition_size (int): Size of the partitions.
+        partition_type (str): Type of partitioning to use. Can be either "grid" or "window".
+        grid_size (Tuple[int, int]): Size of the grid to partition the input tensor into.
+        mlp_ratio (int): Ratio of the  feature size expansion in the MLP layer.
+        activation_layer (Callable[..., nn.Module]): Activation function to use.
+        norm_layer (Callable[..., nn.Module]): Normalization function to use.
+        attention_dropout (float): Dropout probability for the attention layer.
+        mlp_dropout (float): Dropout probability for the MLP layer.
+        p_stochastic_dropout (float): Probability of dropping out a partition.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        head_dim: int,
+        # partitioning parameters
+        partition_size: int,
+        partition_type: str,
+        # grid size needs to be known at initialization time
+        # because we need to know hamy relative offsets there are in the grid
+        grid_size: Tuple[int, int],
+        mlp_ratio: int,
+        activation_layer: Callable[..., nn.Module],
+        norm_layer: Callable[..., nn.Module],
+        attention_dropout: float,
+        mlp_dropout: float,
+        p_stochastic_dropout: float,
+    ) -> None:
+        super().__init__()
+
+        self.n_heads = in_channels // head_dim
+        self.head_dim = head_dim
+        self.n_partitions = grid_size[0] // partition_size
+        self.partition_type = partition_type
+        self.grid_size = grid_size
+
+        if partition_type not in ["grid", "window"]:
+            raise ValueError("partition_type must be either 'grid' or 'window'")
+
+        if partition_type == "window":
+            self.p, self.g = partition_size, self.n_partitions
+        else:
+            self.p, self.g = self.n_partitions, partition_size
+
+        self.partition_op = WindowPartition()
+        self.departition_op = WindowDepartition()
+        self.partition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity()
+        self.departition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity()
+
+        self.attn_layer = nn.Sequential(
+            norm_layer(in_channels),
+            # it's always going to be partition_size ** 2 because
+            # of the axis swap in the case of grid partitioning
+            RelativePositionalMultiHeadAttention(in_channels, head_dim, partition_size**2),
+            nn.Dropout(attention_dropout),
+        )
+
+        # pre-normalization similar to transformer layers
+        self.mlp_layer = nn.Sequential(
+            nn.LayerNorm(in_channels),
+            nn.Linear(in_channels, in_channels * mlp_ratio),
+            activation_layer(),
+            nn.Linear(in_channels * mlp_ratio, in_channels),
+            nn.Dropout(mlp_dropout),
+        )
+
+        # layer scale factors
+        self.stochastic_dropout = StochasticDepth(p_stochastic_dropout, mode="row")
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor with expected layout of [B, C, H, W].
+        Returns:
+            Tensor: Output tensor with expected layout of [B, C, H, W].
+        """
+
+        # Undefined behavior if H or W are not divisible by p
+        # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766
+        gh, gw = self.grid_size[0] // self.p, self.grid_size[1] // self.p
+        torch._assert(
+            self.grid_size[0] % self.p == 0 and self.grid_size[1] % self.p == 0,
+            "Grid size must be divisible by partition size. Got grid size of {} and partition size of {}".format(
+                self.grid_size, self.p
+            ),
+        )
+
+        x = self.partition_op(x, self.p)
+        x = self.partition_swap(x)
+        x = x + self.stochastic_dropout(self.attn_layer(x))
+        x = x + self.stochastic_dropout(self.mlp_layer(x))
+        x = self.departition_swap(x)
+        x = self.departition_op(x, self.p, gh, gw)
+
+        return x
+
+
+class MaxVitLayer(nn.Module):
+    """
+    MaxVit layer consisting of a MBConv layer followed by a PartitionAttentionLayer with `window` and a PartitionAttentionLayer with `grid`.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        expansion_ratio (float): Expansion ratio in the bottleneck.
+        squeeze_ratio (float): Squeeze ratio in the SE Layer.
+        stride (int): Stride of the depthwise convolution.
+        activation_layer (Callable[..., nn.Module]): Activation function.
+        norm_layer (Callable[..., nn.Module]): Normalization function.
+        head_dim (int): Dimension of the attention heads.
+        mlp_ratio (int): Ratio of the MLP layer.
+        mlp_dropout (float): Dropout probability for the MLP layer.
+        attention_dropout (float): Dropout probability for the attention layer.
+        p_stochastic_dropout (float): Probability of stochastic depth.
+        partition_size (int): Size of the partitions.
+        grid_size (Tuple[int, int]): Size of the input feature grid.
+    """
+
+    def __init__(
+        self,
+        # conv parameters
+        in_channels: int,
+        out_channels: int,
+        squeeze_ratio: float,
+        expansion_ratio: float,
+        stride: int,
+        # conv + transformer parameters
+        norm_layer: Callable[..., nn.Module],
+        activation_layer: Callable[..., nn.Module],
+        # transformer parameters
+        head_dim: int,
+        mlp_ratio: int,
+        mlp_dropout: float,
+        attention_dropout: float,
+        p_stochastic_dropout: float,
+        # partitioning parameters
+        partition_size: int,
+        grid_size: Tuple[int, int],
+    ) -> None:
+        super().__init__()
+
+        layers: OrderedDict = OrderedDict()
+
+        # convolutional layer
+        layers["MBconv"] = MBConv(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            expansion_ratio=expansion_ratio,
+            squeeze_ratio=squeeze_ratio,
+            stride=stride,
+            activation_layer=activation_layer,
+            norm_layer=norm_layer,
+            p_stochastic_dropout=p_stochastic_dropout,
+        )
+        # attention layers, block -> grid
+        layers["window_attention"] = PartitionAttentionLayer(
+            in_channels=out_channels,
+            head_dim=head_dim,
+            partition_size=partition_size,
+            partition_type="window",
+            grid_size=grid_size,
+            mlp_ratio=mlp_ratio,
+            activation_layer=activation_layer,
+            norm_layer=nn.LayerNorm,
+            attention_dropout=attention_dropout,
+            mlp_dropout=mlp_dropout,
+            p_stochastic_dropout=p_stochastic_dropout,
+        )
+        layers["grid_attention"] = PartitionAttentionLayer(
+            in_channels=out_channels,
+            head_dim=head_dim,
+            partition_size=partition_size,
+            partition_type="grid",
+            grid_size=grid_size,
+            mlp_ratio=mlp_ratio,
+            activation_layer=activation_layer,
+            norm_layer=nn.LayerNorm,
+            attention_dropout=attention_dropout,
+            mlp_dropout=mlp_dropout,
+            p_stochastic_dropout=p_stochastic_dropout,
+        )
+        self.layers = nn.Sequential(layers)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor of shape (B, C, H, W).
+        Returns:
+            Tensor: Output tensor of shape (B, C, H, W).
+        """
+        x = self.layers(x)
+        return x
+
+
+class MaxVitBlock(nn.Module):
+    """
+    A MaxVit block consisting of `n_layers` MaxVit layers.
+
+     Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        expansion_ratio (float): Expansion ratio in the bottleneck.
+        squeeze_ratio (float): Squeeze ratio in the SE Layer.
+        activation_layer (Callable[..., nn.Module]): Activation function.
+        norm_layer (Callable[..., nn.Module]): Normalization function.
+        head_dim (int): Dimension of the attention heads.
+        mlp_ratio (int): Ratio of the MLP layer.
+        mlp_dropout (float): Dropout probability for the MLP layer.
+        attention_dropout (float): Dropout probability for the attention layer.
+        p_stochastic_dropout (float): Probability of stochastic depth.
+        partition_size (int): Size of the partitions.
+        input_grid_size (Tuple[int, int]): Size of the input feature grid.
+        n_layers (int): Number of layers in the block.
+        p_stochastic (List[float]): List of probabilities for stochastic depth for each layer.
+    """
+
+    def __init__(
+        self,
+        # conv parameters
+        in_channels: int,
+        out_channels: int,
+        squeeze_ratio: float,
+        expansion_ratio: float,
+        # conv + transformer parameters
+        norm_layer: Callable[..., nn.Module],
+        activation_layer: Callable[..., nn.Module],
+        # transformer parameters
+        head_dim: int,
+        mlp_ratio: int,
+        mlp_dropout: float,
+        attention_dropout: float,
+        # partitioning parameters
+        partition_size: int,
+        input_grid_size: Tuple[int, int],
+        # number of layers
+        n_layers: int,
+        p_stochastic: List[float],
+    ) -> None:
+        super().__init__()
+        if not len(p_stochastic) == n_layers:
+            raise ValueError(f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}.")
+
+        self.layers = nn.ModuleList()
+        # account for the first stride of the first layer
+        self.grid_size = _get_conv_output_shape(input_grid_size, kernel_size=3, stride=2, padding=1)
+
+        for idx, p in enumerate(p_stochastic):
+            stride = 2 if idx == 0 else 1
+            self.layers += [
+                MaxVitLayer(
+                    in_channels=in_channels if idx == 0 else out_channels,
+                    out_channels=out_channels,
+                    squeeze_ratio=squeeze_ratio,
+                    expansion_ratio=expansion_ratio,
+                    stride=stride,
+                    norm_layer=norm_layer,
+                    activation_layer=activation_layer,
+                    head_dim=head_dim,
+                    mlp_ratio=mlp_ratio,
+                    mlp_dropout=mlp_dropout,
+                    attention_dropout=attention_dropout,
+                    partition_size=partition_size,
+                    grid_size=self.grid_size,
+                    p_stochastic_dropout=p,
+                ),
+            ]
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor of shape (B, C, H, W).
+        Returns:
+            Tensor: Output tensor of shape (B, C, H, W).
+        """
+        for layer in self.layers:
+            x = layer(x)
+        return x
+
+
+class MaxVit(nn.Module):
+    """
+    Implements MaxVit Transformer from the `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_ paper.
+    Args:
+        input_size (Tuple[int, int]): Size of the input image.
+        stem_channels (int): Number of channels in the stem.
+        partition_size (int): Size of the partitions.
+        block_channels (List[int]): Number of channels in each block.
+        block_layers (List[int]): Number of layers in each block.
+        stochastic_depth_prob (float): Probability of stochastic depth. Expands to a list of probabilities for each layer that scales linearly to the specified value.
+        squeeze_ratio (float): Squeeze ratio in the SE Layer. Default: 0.25.
+        expansion_ratio (float): Expansion ratio in the MBConv bottleneck. Default: 4.
+        norm_layer (Callable[..., nn.Module]): Normalization function. Default: None (setting to None will produce a `BatchNorm2d(eps=1e-3, momentum=0.99)`).
+        activation_layer (Callable[..., nn.Module]): Activation function Default: nn.GELU.
+        head_dim (int): Dimension of the attention heads.
+        mlp_ratio (int): Expansion ratio of the MLP layer. Default: 4.
+        mlp_dropout (float): Dropout probability for the MLP layer. Default: 0.0.
+        attention_dropout (float): Dropout probability for the attention layer. Default: 0.0.
+        num_classes (int): Number of classes. Default: 1000.
+    """
+
+    def __init__(
+        self,
+        # input size parameters
+        input_size: Tuple[int, int],
+        # stem and task parameters
+        stem_channels: int,
+        # partitioning parameters
+        partition_size: int,
+        # block parameters
+        block_channels: List[int],
+        block_layers: List[int],
+        # attention head dimensions
+        head_dim: int,
+        stochastic_depth_prob: float,
+        # conv + transformer parameters
+        # norm_layer is applied only to the conv layers
+        # activation_layer is applied both to conv and transformer layers
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+        activation_layer: Callable[..., nn.Module] = nn.GELU,
+        # conv parameters
+        squeeze_ratio: float = 0.25,
+        expansion_ratio: float = 4,
+        # transformer parameters
+        mlp_ratio: int = 4,
+        mlp_dropout: float = 0.0,
+        attention_dropout: float = 0.0,
+        # task parameters
+        num_classes: int = 1000,
+    ) -> None:
+        super().__init__()
+        _log_api_usage_once(self)
+
+        input_channels = 3
+
+        # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1029-L1030
+        # for the exact parameters used in batchnorm
+        if norm_layer is None:
+            norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.99)
+
+        # Make sure input size will be divisible by the partition size in all blocks
+        # Undefined behavior if H or W are not divisible by p
+        # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766
+        block_input_sizes = _make_block_input_shapes(input_size, len(block_channels))
+        for idx, block_input_size in enumerate(block_input_sizes):
+            if block_input_size[0] % partition_size != 0 or block_input_size[1] % partition_size != 0:
+                raise ValueError(
+                    f"Input size {block_input_size} of block {idx} is not divisible by partition size {partition_size}. "
+                    f"Consider changing the partition size or the input size.\n"
+                    f"Current configuration yields the following block input sizes: {block_input_sizes}."
+                )
+
+        # stem
+        self.stem = nn.Sequential(
+            Conv2dNormActivation(
+                input_channels,
+                stem_channels,
+                3,
+                stride=2,
+                norm_layer=norm_layer,
+                activation_layer=activation_layer,
+                bias=False,
+                inplace=None,
+            ),
+            Conv2dNormActivation(
+                stem_channels, stem_channels, 3, stride=1, norm_layer=None, activation_layer=None, bias=True
+            ),
+        )
+
+        # account for stem stride
+        input_size = _get_conv_output_shape(input_size, kernel_size=3, stride=2, padding=1)
+        self.partition_size = partition_size
+
+        # blocks
+        self.blocks = nn.ModuleList()
+        in_channels = [stem_channels] + block_channels[:-1]
+        out_channels = block_channels
+
+        # precompute the stochastich depth probabilities from 0 to stochastic_depth_prob
+        # since we have N blocks with L layers, we will have N * L probabilities uniformly distributed
+        # over the range [0, stochastic_depth_prob]
+        p_stochastic = np.linspace(0, stochastic_depth_prob, sum(block_layers)).tolist()
+
+        p_idx = 0
+        for in_channel, out_channel, num_layers in zip(in_channels, out_channels, block_layers):
+            self.blocks.append(
+                MaxVitBlock(
+                    in_channels=in_channel,
+                    out_channels=out_channel,
+                    squeeze_ratio=squeeze_ratio,
+                    expansion_ratio=expansion_ratio,
+                    norm_layer=norm_layer,
+                    activation_layer=activation_layer,
+                    head_dim=head_dim,
+                    mlp_ratio=mlp_ratio,
+                    mlp_dropout=mlp_dropout,
+                    attention_dropout=attention_dropout,
+                    partition_size=partition_size,
+                    input_grid_size=input_size,
+                    n_layers=num_layers,
+                    p_stochastic=p_stochastic[p_idx : p_idx + num_layers],
+                ),
+            )
+            input_size = self.blocks[-1].grid_size
+            p_idx += num_layers
+
+        # see https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1137-L1158
+        # for why there is Linear -> Tanh -> Linear
+        self.classifier = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            nn.Flatten(),
+            nn.LayerNorm(block_channels[-1]),
+            nn.Linear(block_channels[-1], block_channels[-1]),
+            nn.Tanh(),
+            nn.Linear(block_channels[-1], num_classes, bias=False),
+        )
+
+        self._init_weights()
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.stem(x)
+        for block in self.blocks:
+            x = block(x)
+        x = self.classifier(x)
+        return x
+
+    def _init_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.normal_(m.weight, std=0.02)
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, std=0.02)
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+
+
+def _maxvit(
+    # stem parameters
+    stem_channels: int,
+    # block parameters
+    block_channels: List[int],
+    block_layers: List[int],
+    stochastic_depth_prob: float,
+    # partitioning parameters
+    partition_size: int,
+    # transformer parameters
+    head_dim: int,
+    # Weights API
+    weights: Optional[WeightsEnum] = None,
+    progress: bool = False,
+    # kwargs,
+    **kwargs: Any,
+) -> MaxVit:
+
+    if weights is not None:
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+        assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
+        _ovewrite_named_param(kwargs, "input_size", weights.meta["min_size"])
+
+    input_size = kwargs.pop("input_size", (224, 224))
+
+    model = MaxVit(
+        stem_channels=stem_channels,
+        block_channels=block_channels,
+        block_layers=block_layers,
+        stochastic_depth_prob=stochastic_depth_prob,
+        head_dim=head_dim,
+        partition_size=partition_size,
+        input_size=input_size,
+        **kwargs,
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+class MaxVit_T_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        # URL empty until official release
+        url="https://download.pytorch.org/models/maxvit_t-bc5ab103.pth",
+        transforms=partial(
+            ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
+        ),
+        meta={
+            "categories": _IMAGENET_CATEGORIES,
+            "num_params": 30919624,
+            "min_size": (224, 224),
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#maxvit",
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 83.700,
+                    "acc@5": 96.722,
+                }
+            },
+            "_ops": 5.558,
+            "_file_size": 118.769,
+            "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", MaxVit_T_Weights.IMAGENET1K_V1))
+def maxvit_t(*, weights: Optional[MaxVit_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MaxVit:
+    """
+    Constructs a maxvit_t architecture from
+    `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.MaxVit_T_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.MaxVit_T_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.maxvit.MaxVit``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/maxvit.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.MaxVit_T_Weights
+        :members:
+    """
+    weights = MaxVit_T_Weights.verify(weights)
+
+    return _maxvit(
+        stem_channels=64,
+        block_channels=[64, 128, 256, 512],
+        block_layers=[2, 2, 5, 2],
+        head_dim=32,
+        stochastic_depth_prob=0.2,
+        partition_size=7,
+        weights=weights,
+        progress=progress,
+        **kwargs,
+    )

+ 434 - 0
libs/vision_libs/models/mnasnet.py

@@ -0,0 +1,434 @@
+import warnings
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from ..transforms._presets import ImageClassification
+from ..utils import _log_api_usage_once
+from ._api import register_model, Weights, WeightsEnum
+from ._meta import _IMAGENET_CATEGORIES
+from ._utils import _ovewrite_named_param, handle_legacy_interface
+
+
+__all__ = [
+    "MNASNet",
+    "MNASNet0_5_Weights",
+    "MNASNet0_75_Weights",
+    "MNASNet1_0_Weights",
+    "MNASNet1_3_Weights",
+    "mnasnet0_5",
+    "mnasnet0_75",
+    "mnasnet1_0",
+    "mnasnet1_3",
+]
+
+
+# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
+# 1.0 - tensorflow.
+_BN_MOMENTUM = 1 - 0.9997
+
+
+class _InvertedResidual(nn.Module):
+    def __init__(
+        self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1
+    ) -> None:
+        super().__init__()
+        if stride not in [1, 2]:
+            raise ValueError(f"stride should be 1 or 2 instead of {stride}")
+        if kernel_size not in [3, 5]:
+            raise ValueError(f"kernel_size should be 3 or 5 instead of {kernel_size}")
+        mid_ch = in_ch * expansion_factor
+        self.apply_residual = in_ch == out_ch and stride == 1
+        self.layers = nn.Sequential(
+            # Pointwise
+            nn.Conv2d(in_ch, mid_ch, 1, bias=False),
+            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
+            nn.ReLU(inplace=True),
+            # Depthwise
+            nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False),
+            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
+            nn.ReLU(inplace=True),
+            # Linear pointwise. Note that there's no activation.
+            nn.Conv2d(mid_ch, out_ch, 1, bias=False),
+            nn.BatchNorm2d(out_ch, momentum=bn_momentum),
+        )
+
+    def forward(self, input: Tensor) -> Tensor:
+        if self.apply_residual:
+            return self.layers(input) + input
+        else:
+            return self.layers(input)
+
+
+def _stack(
+    in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float
+) -> nn.Sequential:
+    """Creates a stack of inverted residuals."""
+    if repeats < 1:
+        raise ValueError(f"repeats should be >= 1, instead got {repeats}")
+    # First one has no skip, because feature map size changes.
+    first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum)
+    remaining = []
+    for _ in range(1, repeats):
+        remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum))
+    return nn.Sequential(first, *remaining)
+
+
+def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
+    """Asymmetric rounding to make `val` divisible by `divisor`. With default
+    bias, will round up, unless the number is no more than 10% greater than the
+    smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""
+    if not 0.0 < round_up_bias < 1.0:
+        raise ValueError(f"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}")
+    new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
+    return new_val if new_val >= round_up_bias * val else new_val + divisor
+
+
+def _get_depths(alpha: float) -> List[int]:
+    """Scales tensor depths as in reference MobileNet code, prefers rounding up
+    rather than down."""
+    depths = [32, 16, 24, 40, 80, 96, 192, 320]
+    return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
+
+
+class MNASNet(torch.nn.Module):
+    """MNASNet, as described in https://arxiv.org/abs/1807.11626. This
+    implements the B1 variant of the model.
+    >>> model = MNASNet(1.0, num_classes=1000)
+    >>> x = torch.rand(1, 3, 224, 224)
+    >>> y = model(x)
+    >>> y.dim()
+    2
+    >>> y.nelement()
+    1000
+    """
+
+    # Version 2 adds depth scaling in the initial stages of the network.
+    _version = 2
+
+    def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None:
+        super().__init__()
+        _log_api_usage_once(self)
+        if alpha <= 0.0:
+            raise ValueError(f"alpha should be greater than 0.0 instead of {alpha}")
+        self.alpha = alpha
+        self.num_classes = num_classes
+        depths = _get_depths(alpha)
+        layers = [
+            # First layer: regular conv.
+            nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
+            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
+            nn.ReLU(inplace=True),
+            # Depthwise separable, no skip.
+            nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False),
+            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
+            nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
+            # MNASNet blocks: stacks of inverted residuals.
+            _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
+            _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
+            _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
+            _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
+            _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
+            _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
+            # Final mapping to classifier input.
+            nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
+            nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
+            nn.ReLU(inplace=True),
+        ]
+        self.layers = nn.Sequential(*layers)
+        self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.ones_(m.weight)
+                nn.init.zeros_(m.bias)
+            elif isinstance(m, nn.Linear):
+                nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
+                nn.init.zeros_(m.bias)
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.layers(x)
+        # Equivalent to global avgpool and removing H and W dimensions.
+        x = x.mean([2, 3])
+        return self.classifier(x)
+
+    def _load_from_state_dict(
+        self,
+        state_dict: Dict,
+        prefix: str,
+        local_metadata: Dict,
+        strict: bool,
+        missing_keys: List[str],
+        unexpected_keys: List[str],
+        error_msgs: List[str],
+    ) -> None:
+        version = local_metadata.get("version", None)
+        if version not in [1, 2]:
+            raise ValueError(f"version shluld be set to 1 or 2 instead of {version}")
+
+        if version == 1 and not self.alpha == 1.0:
+            # In the initial version of the model (v1), stem was fixed-size.
+            # All other layer configurations were the same. This will patch
+            # the model so that it's identical to v1. Model with alpha 1.0 is
+            # unaffected.
+            depths = _get_depths(self.alpha)
+            v1_stem = [
+                nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
+                nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
+                nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
+                nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
+                _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
+            ]
+            for idx, layer in enumerate(v1_stem):
+                self.layers[idx] = layer
+
+            # The model is now identical to v1, and must be saved as such.
+            self._version = 1
+            warnings.warn(
+                "A new version of MNASNet model has been implemented. "
+                "Your checkpoint was saved using the previous version. "
+                "This checkpoint will load and work as before, but "
+                "you may want to upgrade by training a newer model or "
+                "transfer learning from an updated ImageNet checkpoint.",
+                UserWarning,
+            )
+
+        super()._load_from_state_dict(
+            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+        )
+
+
+_COMMON_META = {
+    "min_size": (1, 1),
+    "categories": _IMAGENET_CATEGORIES,
+    "recipe": "https://github.com/1e100/mnasnet_trainer",
+}
+
+
+class MNASNet0_5_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "num_params": 2218512,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 67.734,
+                    "acc@5": 87.490,
+                }
+            },
+            "_ops": 0.104,
+            "_file_size": 8.591,
+            "_docs": """These weights reproduce closely the results of the paper.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class MNASNet0_75_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth",
+        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+        meta={
+            **_COMMON_META,
+            "recipe": "https://github.com/pytorch/vision/pull/6019",
+            "num_params": 3170208,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 71.180,
+                    "acc@5": 90.496,
+                }
+            },
+            "_ops": 0.215,
+            "_file_size": 12.303,
+            "_docs": """
+                These weights were trained from scratch by using TorchVision's `new training recipe
+                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
+            """,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class MNASNet1_0_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "num_params": 4383312,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 73.456,
+                    "acc@5": 91.510,
+                }
+            },
+            "_ops": 0.314,
+            "_file_size": 16.915,
+            "_docs": """These weights reproduce closely the results of the paper.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class MNASNet1_3_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth",
+        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+        meta={
+            **_COMMON_META,
+            "recipe": "https://github.com/pytorch/vision/pull/6019",
+            "num_params": 6282256,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 76.506,
+                    "acc@5": 93.522,
+                }
+            },
+            "_ops": 0.526,
+            "_file_size": 24.246,
+            "_docs": """
+                These weights were trained from scratch by using TorchVision's `new training recipe
+                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
+            """,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
+    if weights is not None:
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+    model = MNASNet(alpha, **kwargs)
+
+    if weights:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1))
+def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
+    """MNASNet with depth multiplier of 0.5 from
+    `MnasNet: Platform-Aware Neural Architecture Search for Mobile
+    <https://arxiv.org/abs/1807.11626>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.MNASNet0_5_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.MNASNet0_5_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.MNASNet0_5_Weights
+        :members:
+    """
+    weights = MNASNet0_5_Weights.verify(weights)
+
+    return _mnasnet(0.5, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1))
+def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
+    """MNASNet with depth multiplier of 0.75 from
+    `MnasNet: Platform-Aware Neural Architecture Search for Mobile
+    <https://arxiv.org/abs/1807.11626>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.MNASNet0_75_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.MNASNet0_75_Weights
+        :members:
+    """
+    weights = MNASNet0_75_Weights.verify(weights)
+
+    return _mnasnet(0.75, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1))
+def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
+    """MNASNet with depth multiplier of 1.0 from
+    `MnasNet: Platform-Aware Neural Architecture Search for Mobile
+    <https://arxiv.org/abs/1807.11626>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.MNASNet1_0_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.MNASNet1_0_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.MNASNet1_0_Weights
+        :members:
+    """
+    weights = MNASNet1_0_Weights.verify(weights)
+
+    return _mnasnet(1.0, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1))
+def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
+    """MNASNet with depth multiplier of 1.3 from
+    `MnasNet: Platform-Aware Neural Architecture Search for Mobile
+    <https://arxiv.org/abs/1807.11626>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.MNASNet1_3_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.MNASNet1_3_Weights
+        :members:
+    """
+    weights = MNASNet1_3_Weights.verify(weights)
+
+    return _mnasnet(1.3, weights, progress, **kwargs)

+ 6 - 0
libs/vision_libs/models/mobilenet.py

@@ -0,0 +1,6 @@
+from .mobilenetv2 import *  # noqa: F401, F403
+from .mobilenetv3 import *  # noqa: F401, F403
+from .mobilenetv2 import __all__ as mv2_all
+from .mobilenetv3 import __all__ as mv3_all
+
+__all__ = mv2_all + mv3_all

+ 260 - 0
libs/vision_libs/models/mobilenetv2.py

@@ -0,0 +1,260 @@
+from functools import partial
+from typing import Any, Callable, List, Optional
+
+import torch
+from torch import nn, Tensor
+
+from ..ops.misc import Conv2dNormActivation
+from ..transforms._presets import ImageClassification
+from ..utils import _log_api_usage_once
+from ._api import register_model, Weights, WeightsEnum
+from ._meta import _IMAGENET_CATEGORIES
+from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
+
+
+__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
+
+
+# necessary for backwards compatibility
+class InvertedResidual(nn.Module):
+    def __init__(
+        self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super().__init__()
+        self.stride = stride
+        if stride not in [1, 2]:
+            raise ValueError(f"stride should be 1 or 2 instead of {stride}")
+
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+
+        hidden_dim = int(round(inp * expand_ratio))
+        self.use_res_connect = self.stride == 1 and inp == oup
+
+        layers: List[nn.Module] = []
+        if expand_ratio != 1:
+            # pw
+            layers.append(
+                Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
+            )
+        layers.extend(
+            [
+                # dw
+                Conv2dNormActivation(
+                    hidden_dim,
+                    hidden_dim,
+                    stride=stride,
+                    groups=hidden_dim,
+                    norm_layer=norm_layer,
+                    activation_layer=nn.ReLU6,
+                ),
+                # pw-linear
+                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+                norm_layer(oup),
+            ]
+        )
+        self.conv = nn.Sequential(*layers)
+        self.out_channels = oup
+        self._is_cn = stride > 1
+
+    def forward(self, x: Tensor) -> Tensor:
+        if self.use_res_connect:
+            return x + self.conv(x)
+        else:
+            return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+    def __init__(
+        self,
+        num_classes: int = 1000,
+        width_mult: float = 1.0,
+        inverted_residual_setting: Optional[List[List[int]]] = None,
+        round_nearest: int = 8,
+        block: Optional[Callable[..., nn.Module]] = None,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+        dropout: float = 0.2,
+    ) -> None:
+        """
+        MobileNet V2 main class
+
+        Args:
+            num_classes (int): Number of classes
+            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+            inverted_residual_setting: Network structure
+            round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+            Set to 1 to turn off rounding
+            block: Module specifying inverted residual building block for mobilenet
+            norm_layer: Module specifying the normalization layer to use
+            dropout (float): The droupout probability
+
+        """
+        super().__init__()
+        _log_api_usage_once(self)
+
+        if block is None:
+            block = InvertedResidual
+
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+
+        input_channel = 32
+        last_channel = 1280
+
+        if inverted_residual_setting is None:
+            inverted_residual_setting = [
+                # t, c, n, s
+                [1, 16, 1, 1],
+                [6, 24, 2, 2],
+                [6, 32, 3, 2],
+                [6, 64, 4, 2],
+                [6, 96, 3, 1],
+                [6, 160, 3, 2],
+                [6, 320, 1, 1],
+            ]
+
+        # only check the first element, assuming user knows t,c,n,s are required
+        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+            raise ValueError(
+                f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
+            )
+
+        # building first layer
+        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+        features: List[nn.Module] = [
+            Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
+        ]
+        # building inverted residual blocks
+        for t, c, n, s in inverted_residual_setting:
+            output_channel = _make_divisible(c * width_mult, round_nearest)
+            for i in range(n):
+                stride = s if i == 0 else 1
+                features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
+                input_channel = output_channel
+        # building last several layers
+        features.append(
+            Conv2dNormActivation(
+                input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
+            )
+        )
+        # make it nn.Sequential
+        self.features = nn.Sequential(*features)
+
+        # building classifier
+        self.classifier = nn.Sequential(
+            nn.Dropout(p=dropout),
+            nn.Linear(self.last_channel, num_classes),
+        )
+
+        # weight initialization
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out")
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.ones_(m.weight)
+                nn.init.zeros_(m.bias)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.zeros_(m.bias)
+
+    def _forward_impl(self, x: Tensor) -> Tensor:
+        # This exists since TorchScript doesn't support inheritance, so the superclass method
+        # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+        x = self.features(x)
+        # Cannot use "squeeze" as batch-size can be 1
+        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
+        x = torch.flatten(x, 1)
+        x = self.classifier(x)
+        return x
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self._forward_impl(x)
+
+
+_COMMON_META = {
+    "num_params": 3504872,
+    "min_size": (1, 1),
+    "categories": _IMAGENET_CATEGORIES,
+}
+
+
+class MobileNet_V2_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 71.878,
+                    "acc@5": 90.286,
+                }
+            },
+            "_ops": 0.301,
+            "_file_size": 13.555,
+            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
+        },
+    )
+    IMAGENET1K_V2 = Weights(
+        url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
+        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+        meta={
+            **_COMMON_META,
+            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 72.154,
+                    "acc@5": 90.822,
+                }
+            },
+            "_ops": 0.301,
+            "_file_size": 13.598,
+            "_docs": """
+                These weights improve upon the results of the original paper by using a modified version of TorchVision's
+                `new training recipe
+                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
+            """,
+        },
+    )
+    DEFAULT = IMAGENET1K_V2
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
+def mobilenet_v2(
+    *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any
+) -> MobileNetV2:
+    """MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear
+    Bottlenecks <https://arxiv.org/abs/1801.04381>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.MobileNet_V2_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.MobileNet_V2_Weights
+        :members:
+    """
+    weights = MobileNet_V2_Weights.verify(weights)
+
+    if weights is not None:
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+    model = MobileNetV2(**kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model

+ 423 - 0
libs/vision_libs/models/mobilenetv3.py

@@ -0,0 +1,423 @@
+from functools import partial
+from typing import Any, Callable, List, Optional, Sequence
+
+import torch
+from torch import nn, Tensor
+
+from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer
+from ..transforms._presets import ImageClassification
+from ..utils import _log_api_usage_once
+from ._api import register_model, Weights, WeightsEnum
+from ._meta import _IMAGENET_CATEGORIES
+from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
+
+
+__all__ = [
+    "MobileNetV3",
+    "MobileNet_V3_Large_Weights",
+    "MobileNet_V3_Small_Weights",
+    "mobilenet_v3_large",
+    "mobilenet_v3_small",
+]
+
+
+class InvertedResidualConfig:
+    # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
+    def __init__(
+        self,
+        input_channels: int,
+        kernel: int,
+        expanded_channels: int,
+        out_channels: int,
+        use_se: bool,
+        activation: str,
+        stride: int,
+        dilation: int,
+        width_mult: float,
+    ):
+        self.input_channels = self.adjust_channels(input_channels, width_mult)
+        self.kernel = kernel
+        self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
+        self.out_channels = self.adjust_channels(out_channels, width_mult)
+        self.use_se = use_se
+        self.use_hs = activation == "HS"
+        self.stride = stride
+        self.dilation = dilation
+
+    @staticmethod
+    def adjust_channels(channels: int, width_mult: float):
+        return _make_divisible(channels * width_mult, 8)
+
+
+class InvertedResidual(nn.Module):
+    # Implemented as described at section 5 of MobileNetV3 paper
+    def __init__(
+        self,
+        cnf: InvertedResidualConfig,
+        norm_layer: Callable[..., nn.Module],
+        se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid),
+    ):
+        super().__init__()
+        if not (1 <= cnf.stride <= 2):
+            raise ValueError("illegal stride value")
+
+        self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
+
+        layers: List[nn.Module] = []
+        activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
+
+        # expand
+        if cnf.expanded_channels != cnf.input_channels:
+            layers.append(
+                Conv2dNormActivation(
+                    cnf.input_channels,
+                    cnf.expanded_channels,
+                    kernel_size=1,
+                    norm_layer=norm_layer,
+                    activation_layer=activation_layer,
+                )
+            )
+
+        # depthwise
+        stride = 1 if cnf.dilation > 1 else cnf.stride
+        layers.append(
+            Conv2dNormActivation(
+                cnf.expanded_channels,
+                cnf.expanded_channels,
+                kernel_size=cnf.kernel,
+                stride=stride,
+                dilation=cnf.dilation,
+                groups=cnf.expanded_channels,
+                norm_layer=norm_layer,
+                activation_layer=activation_layer,
+            )
+        )
+        if cnf.use_se:
+            squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
+            layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
+
+        # project
+        layers.append(
+            Conv2dNormActivation(
+                cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
+            )
+        )
+
+        self.block = nn.Sequential(*layers)
+        self.out_channels = cnf.out_channels
+        self._is_cn = cnf.stride > 1
+
+    def forward(self, input: Tensor) -> Tensor:
+        result = self.block(input)
+        if self.use_res_connect:
+            result += input
+        return result
+
+
+class MobileNetV3(nn.Module):
+    def __init__(
+        self,
+        inverted_residual_setting: List[InvertedResidualConfig],
+        last_channel: int,
+        num_classes: int = 1000,
+        block: Optional[Callable[..., nn.Module]] = None,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+        dropout: float = 0.2,
+        **kwargs: Any,
+    ) -> None:
+        """
+        MobileNet V3 main class
+
+        Args:
+            inverted_residual_setting (List[InvertedResidualConfig]): Network structure
+            last_channel (int): The number of channels on the penultimate layer
+            num_classes (int): Number of classes
+            block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
+            norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
+            dropout (float): The droupout probability
+        """
+        super().__init__()
+        _log_api_usage_once(self)
+
+        if not inverted_residual_setting:
+            raise ValueError("The inverted_residual_setting should not be empty")
+        elif not (
+            isinstance(inverted_residual_setting, Sequence)
+            and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])
+        ):
+            raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
+
+        if block is None:
+            block = InvertedResidual
+
+        if norm_layer is None:
+            norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
+
+        layers: List[nn.Module] = []
+
+        # building first layer
+        firstconv_output_channels = inverted_residual_setting[0].input_channels
+        layers.append(
+            Conv2dNormActivation(
+                3,
+                firstconv_output_channels,
+                kernel_size=3,
+                stride=2,
+                norm_layer=norm_layer,
+                activation_layer=nn.Hardswish,
+            )
+        )
+
+        # building inverted residual blocks
+        for cnf in inverted_residual_setting:
+            layers.append(block(cnf, norm_layer))
+
+        # building last several layers
+        lastconv_input_channels = inverted_residual_setting[-1].out_channels
+        lastconv_output_channels = 6 * lastconv_input_channels
+        layers.append(
+            Conv2dNormActivation(
+                lastconv_input_channels,
+                lastconv_output_channels,
+                kernel_size=1,
+                norm_layer=norm_layer,
+                activation_layer=nn.Hardswish,
+            )
+        )
+
+        self.features = nn.Sequential(*layers)
+        self.avgpool = nn.AdaptiveAvgPool2d(1)
+        self.classifier = nn.Sequential(
+            nn.Linear(lastconv_output_channels, last_channel),
+            nn.Hardswish(inplace=True),
+            nn.Dropout(p=dropout, inplace=True),
+            nn.Linear(last_channel, num_classes),
+        )
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out")
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.ones_(m.weight)
+                nn.init.zeros_(m.bias)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.zeros_(m.bias)
+
+    def _forward_impl(self, x: Tensor) -> Tensor:
+        x = self.features(x)
+
+        x = self.avgpool(x)
+        x = torch.flatten(x, 1)
+
+        x = self.classifier(x)
+
+        return x
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self._forward_impl(x)
+
+
+def _mobilenet_v3_conf(
+    arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, **kwargs: Any
+):
+    reduce_divider = 2 if reduced_tail else 1
+    dilation = 2 if dilated else 1
+
+    bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
+    adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
+
+    if arch == "mobilenet_v3_large":
+        inverted_residual_setting = [
+            bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
+            bneck_conf(16, 3, 64, 24, False, "RE", 2, 1),  # C1
+            bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
+            bneck_conf(24, 5, 72, 40, True, "RE", 2, 1),  # C2
+            bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
+            bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
+            bneck_conf(40, 3, 240, 80, False, "HS", 2, 1),  # C3
+            bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
+            bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
+            bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
+            bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
+            bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
+            bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation),  # C4
+            bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
+            bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
+        ]
+        last_channel = adjust_channels(1280 // reduce_divider)  # C5
+    elif arch == "mobilenet_v3_small":
+        inverted_residual_setting = [
+            bneck_conf(16, 3, 16, 16, True, "RE", 2, 1),  # C1
+            bneck_conf(16, 3, 72, 24, False, "RE", 2, 1),  # C2
+            bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
+            bneck_conf(24, 5, 96, 40, True, "HS", 2, 1),  # C3
+            bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
+            bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
+            bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
+            bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
+            bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation),  # C4
+            bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
+            bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
+        ]
+        last_channel = adjust_channels(1024 // reduce_divider)  # C5
+    else:
+        raise ValueError(f"Unsupported model type {arch}")
+
+    return inverted_residual_setting, last_channel
+
+
+def _mobilenet_v3(
+    inverted_residual_setting: List[InvertedResidualConfig],
+    last_channel: int,
+    weights: Optional[WeightsEnum],
+    progress: bool,
+    **kwargs: Any,
+) -> MobileNetV3:
+    if weights is not None:
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+    model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+_COMMON_META = {
+    "min_size": (1, 1),
+    "categories": _IMAGENET_CATEGORIES,
+}
+
+
+class MobileNet_V3_Large_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "num_params": 5483032,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 74.042,
+                    "acc@5": 91.340,
+                }
+            },
+            "_ops": 0.217,
+            "_file_size": 21.114,
+            "_docs": """These weights were trained from scratch by using a simple training recipe.""",
+        },
+    )
+    IMAGENET1K_V2 = Weights(
+        url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
+        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+        meta={
+            **_COMMON_META,
+            "num_params": 5483032,
+            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 75.274,
+                    "acc@5": 92.566,
+                }
+            },
+            "_ops": 0.217,
+            "_file_size": 21.107,
+            "_docs": """
+                These weights improve marginally upon the results of the original paper by using a modified version of
+                TorchVision's `new training recipe
+                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
+            """,
+        },
+    )
+    DEFAULT = IMAGENET1K_V2
+
+
+class MobileNet_V3_Small_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "num_params": 2542856,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 67.668,
+                    "acc@5": 87.402,
+                }
+            },
+            "_ops": 0.057,
+            "_file_size": 9.829,
+            "_docs": """
+                These weights improve upon the results of the original paper by using a simple training recipe.
+            """,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1))
+def mobilenet_v3_large(
+    *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
+) -> MobileNetV3:
+    """
+    Constructs a large MobileNetV3 architecture from
+    `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`__.
+
+    Args:
+        weights (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.MobileNet_V3_Large_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.mobilenet.MobileNetV3``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.MobileNet_V3_Large_Weights
+        :members:
+    """
+    weights = MobileNet_V3_Large_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
+    return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1))
+def mobilenet_v3_small(
+    *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
+) -> MobileNetV3:
+    """
+    Constructs a small MobileNetV3 architecture from
+    `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`__.
+
+    Args:
+        weights (:class:`~torchvision.models.MobileNet_V3_Small_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.MobileNet_V3_Small_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.mobilenet.MobileNetV3``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.MobileNet_V3_Small_Weights
+        :members:
+    """
+    weights = MobileNet_V3_Small_Weights.verify(weights)
+
+    inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs)
+    return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)

+ 1 - 0
libs/vision_libs/models/optical_flow/__init__.py

@@ -0,0 +1 @@
+from .raft import *

+ 48 - 0
libs/vision_libs/models/optical_flow/_utils.py

@@ -0,0 +1,48 @@
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+
+def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", align_corners: Optional[bool] = None):
+    """Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates."""
+    h, w = img.shape[-2:]
+
+    xgrid, ygrid = absolute_grid.split([1, 1], dim=-1)
+    xgrid = 2 * xgrid / (w - 1) - 1
+    # Adding condition if h > 1 to enable this function be reused in raft-stereo
+    if h > 1:
+        ygrid = 2 * ygrid / (h - 1) - 1
+    normalized_grid = torch.cat([xgrid, ygrid], dim=-1)
+
+    return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners)
+
+
+def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"):
+    device = torch.device(device)
+    coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
+    coords = torch.stack(coords[::-1], dim=0).float()
+    return coords[None].repeat(batch_size, 1, 1, 1)
+
+
+def upsample_flow(flow, up_mask: Optional[Tensor] = None, factor: int = 8):
+    """Upsample flow by the input factor (default 8).
+
+    If up_mask is None we just interpolate.
+    If up_mask is specified, we upsample using a convex combination of its weights. See paper page 8 and appendix B.
+    Note that in appendix B the picture assumes a downsample factor of 4 instead of 8.
+    """
+    batch_size, num_channels, h, w = flow.shape
+    new_h, new_w = h * factor, w * factor
+
+    if up_mask is None:
+        return factor * F.interpolate(flow, size=(new_h, new_w), mode="bilinear", align_corners=True)
+
+    up_mask = up_mask.view(batch_size, 1, 9, factor, factor, h, w)
+    up_mask = torch.softmax(up_mask, dim=2)  # "convex" == weights sum to 1
+
+    upsampled_flow = F.unfold(factor * flow, kernel_size=3, padding=1).view(batch_size, num_channels, 9, 1, 1, h, w)
+    upsampled_flow = torch.sum(up_mask * upsampled_flow, dim=2)
+
+    return upsampled_flow.permute(0, 1, 4, 2, 5, 3).reshape(batch_size, num_channels, new_h, new_w)

+ 947 - 0
libs/vision_libs/models/optical_flow/raft.py

@@ -0,0 +1,947 @@
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn.modules.batchnorm import BatchNorm2d
+from torch.nn.modules.instancenorm import InstanceNorm2d
+from torchvision.ops import Conv2dNormActivation
+
+from ...transforms._presets import OpticalFlow
+from ...utils import _log_api_usage_once
+from .._api import register_model, Weights, WeightsEnum
+from .._utils import handle_legacy_interface
+from ._utils import grid_sample, make_coords_grid, upsample_flow
+
+
+__all__ = (
+    "RAFT",
+    "raft_large",
+    "raft_small",
+    "Raft_Large_Weights",
+    "Raft_Small_Weights",
+)
+
+
+class ResidualBlock(nn.Module):
+    """Slightly modified Residual block with extra relu and biases."""
+
+    def __init__(self, in_channels, out_channels, *, norm_layer, stride=1, always_project: bool = False):
+        super().__init__()
+
+        # Note regarding bias=True:
+        # Usually we can pass bias=False in conv layers followed by a norm layer.
+        # But in the RAFT training reference, the BatchNorm2d layers are only activated for the first dataset,
+        # and frozen for the rest of the training process (i.e. set as eval()). The bias term is thus still useful
+        # for the rest of the datasets. Technically, we could remove the bias for other norm layers like Instance norm
+        # because these aren't frozen, but we don't bother (also, we wouldn't be able to load the original weights).
+        self.convnormrelu1 = Conv2dNormActivation(
+            in_channels, out_channels, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True
+        )
+        self.convnormrelu2 = Conv2dNormActivation(
+            out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True
+        )
+
+        # make mypy happy
+        self.downsample: nn.Module
+
+        if stride == 1 and not always_project:
+            self.downsample = nn.Identity()
+        else:
+            self.downsample = Conv2dNormActivation(
+                in_channels,
+                out_channels,
+                norm_layer=norm_layer,
+                kernel_size=1,
+                stride=stride,
+                bias=True,
+                activation_layer=None,
+            )
+
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        y = x
+        y = self.convnormrelu1(y)
+        y = self.convnormrelu2(y)
+
+        x = self.downsample(x)
+
+        return self.relu(x + y)
+
+
+class BottleneckBlock(nn.Module):
+    """Slightly modified BottleNeck block (extra relu and biases)"""
+
+    def __init__(self, in_channels, out_channels, *, norm_layer, stride=1):
+        super().__init__()
+
+        # See note in ResidualBlock for the reason behind bias=True
+        self.convnormrelu1 = Conv2dNormActivation(
+            in_channels, out_channels // 4, norm_layer=norm_layer, kernel_size=1, bias=True
+        )
+        self.convnormrelu2 = Conv2dNormActivation(
+            out_channels // 4, out_channels // 4, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True
+        )
+        self.convnormrelu3 = Conv2dNormActivation(
+            out_channels // 4, out_channels, norm_layer=norm_layer, kernel_size=1, bias=True
+        )
+        self.relu = nn.ReLU(inplace=True)
+
+        if stride == 1:
+            self.downsample = nn.Identity()
+        else:
+            self.downsample = Conv2dNormActivation(
+                in_channels,
+                out_channels,
+                norm_layer=norm_layer,
+                kernel_size=1,
+                stride=stride,
+                bias=True,
+                activation_layer=None,
+            )
+
+    def forward(self, x):
+        y = x
+        y = self.convnormrelu1(y)
+        y = self.convnormrelu2(y)
+        y = self.convnormrelu3(y)
+
+        x = self.downsample(x)
+
+        return self.relu(x + y)
+
+
+class FeatureEncoder(nn.Module):
+    """The feature encoder, used both as the actual feature encoder, and as the context encoder.
+
+    It must downsample its input by 8.
+    """
+
+    def __init__(
+        self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), strides=(2, 1, 2, 2), norm_layer=nn.BatchNorm2d
+    ):
+        super().__init__()
+
+        if len(layers) != 5:
+            raise ValueError(f"The expected number of layers is 5, instead got {len(layers)}")
+
+        # See note in ResidualBlock for the reason behind bias=True
+        self.convnormrelu = Conv2dNormActivation(
+            3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=strides[0], bias=True
+        )
+
+        self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=strides[1])
+        self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=strides[2])
+        self.layer3 = self._make_2_blocks(block, layers[2], layers[3], norm_layer=norm_layer, first_stride=strides[3])
+
+        self.conv = nn.Conv2d(layers[3], layers[4], kernel_size=1)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
+                if m.weight is not None:
+                    nn.init.constant_(m.weight, 1)
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+
+        num_downsamples = len(list(filter(lambda s: s == 2, strides)))
+        self.output_dim = layers[-1]
+        self.downsample_factor = 2**num_downsamples
+
+    def _make_2_blocks(self, block, in_channels, out_channels, norm_layer, first_stride):
+        block1 = block(in_channels, out_channels, norm_layer=norm_layer, stride=first_stride)
+        block2 = block(out_channels, out_channels, norm_layer=norm_layer, stride=1)
+        return nn.Sequential(block1, block2)
+
+    def forward(self, x):
+        x = self.convnormrelu(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+
+        x = self.conv(x)
+
+        return x
+
+
+class MotionEncoder(nn.Module):
+    """The motion encoder, part of the update block.
+
+    Takes the current predicted flow and the correlation features as input and returns an encoded version of these.
+    """
+
+    def __init__(self, *, in_channels_corr, corr_layers=(256, 192), flow_layers=(128, 64), out_channels=128):
+        super().__init__()
+
+        if len(flow_layers) != 2:
+            raise ValueError(f"The expected number of flow_layers is 2, instead got {len(flow_layers)}")
+        if len(corr_layers) not in (1, 2):
+            raise ValueError(f"The number of corr_layers should be 1 or 2, instead got {len(corr_layers)}")
+
+        self.convcorr1 = Conv2dNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1)
+        if len(corr_layers) == 2:
+            self.convcorr2 = Conv2dNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3)
+        else:
+            self.convcorr2 = nn.Identity()
+
+        self.convflow1 = Conv2dNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7)
+        self.convflow2 = Conv2dNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3)
+
+        # out_channels - 2 because we cat the flow (2 channels) at the end
+        self.conv = Conv2dNormActivation(
+            corr_layers[-1] + flow_layers[-1], out_channels - 2, norm_layer=None, kernel_size=3
+        )
+
+        self.out_channels = out_channels
+
+    def forward(self, flow, corr_features):
+        corr = self.convcorr1(corr_features)
+        corr = self.convcorr2(corr)
+
+        flow_orig = flow
+        flow = self.convflow1(flow)
+        flow = self.convflow2(flow)
+
+        corr_flow = torch.cat([corr, flow], dim=1)
+        corr_flow = self.conv(corr_flow)
+        return torch.cat([corr_flow, flow_orig], dim=1)
+
+
+class ConvGRU(nn.Module):
+    """Convolutional Gru unit."""
+
+    def __init__(self, *, input_size, hidden_size, kernel_size, padding):
+        super().__init__()
+        self.convz = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding)
+        self.convr = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding)
+        self.convq = nn.Conv2d(hidden_size + input_size, hidden_size, kernel_size=kernel_size, padding=padding)
+
+    def forward(self, h, x):
+        hx = torch.cat([h, x], dim=1)
+        z = torch.sigmoid(self.convz(hx))
+        r = torch.sigmoid(self.convr(hx))
+        q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
+        h = (1 - z) * h + z * q
+        return h
+
+
+def _pass_through_h(h, _):
+    # Declared here for torchscript
+    return h
+
+
+class RecurrentBlock(nn.Module):
+    """Recurrent block, part of the update block.
+
+    Takes the current hidden state and the concatenation of (motion encoder output, context) as input.
+    Returns an updated hidden state.
+    """
+
+    def __init__(self, *, input_size, hidden_size, kernel_size=((1, 5), (5, 1)), padding=((0, 2), (2, 0))):
+        super().__init__()
+
+        if len(kernel_size) != len(padding):
+            raise ValueError(
+                f"kernel_size should have the same length as padding, instead got len(kernel_size) = {len(kernel_size)} and len(padding) = {len(padding)}"
+            )
+        if len(kernel_size) not in (1, 2):
+            raise ValueError(f"kernel_size should either 1 or 2, instead got {len(kernel_size)}")
+
+        self.convgru1 = ConvGRU(
+            input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[0], padding=padding[0]
+        )
+        if len(kernel_size) == 2:
+            self.convgru2 = ConvGRU(
+                input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[1], padding=padding[1]
+            )
+        else:
+            self.convgru2 = _pass_through_h
+
+        self.hidden_size = hidden_size
+
+    def forward(self, h, x):
+        h = self.convgru1(h, x)
+        h = self.convgru2(h, x)
+        return h
+
+
+class FlowHead(nn.Module):
+    """Flow head, part of the update block.
+
+    Takes the hidden state of the recurrent unit as input, and outputs the predicted "delta flow".
+    """
+
+    def __init__(self, *, in_channels, hidden_size):
+        super().__init__()
+        self.conv1 = nn.Conv2d(in_channels, hidden_size, 3, padding=1)
+        self.conv2 = nn.Conv2d(hidden_size, 2, 3, padding=1)
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        return self.conv2(self.relu(self.conv1(x)))
+
+
+class UpdateBlock(nn.Module):
+    """The update block which contains the motion encoder, the recurrent block, and the flow head.
+
+    It must expose a ``hidden_state_size`` attribute which is the hidden state size of its recurrent block.
+    """
+
+    def __init__(self, *, motion_encoder, recurrent_block, flow_head):
+        super().__init__()
+        self.motion_encoder = motion_encoder
+        self.recurrent_block = recurrent_block
+        self.flow_head = flow_head
+
+        self.hidden_state_size = recurrent_block.hidden_size
+
+    def forward(self, hidden_state, context, corr_features, flow):
+        motion_features = self.motion_encoder(flow, corr_features)
+        x = torch.cat([context, motion_features], dim=1)
+
+        hidden_state = self.recurrent_block(hidden_state, x)
+        delta_flow = self.flow_head(hidden_state)
+        return hidden_state, delta_flow
+
+
+class MaskPredictor(nn.Module):
+    """Mask predictor to be used when upsampling the predicted flow.
+
+    It takes the hidden state of the recurrent unit as input and outputs the mask.
+    This is not used in the raft-small model.
+    """
+
+    def __init__(self, *, in_channels, hidden_size, multiplier=0.25):
+        super().__init__()
+        self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3)
+        # 8 * 8 * 9 because the predicted flow is downsampled by 8, from the downsampling of the initial FeatureEncoder,
+        # and we interpolate with all 9 surrounding neighbors. See paper and appendix B.
+        self.conv = nn.Conv2d(hidden_size, 8 * 8 * 9, 1, padding=0)
+
+        # In the original code, they use a factor of 0.25 to "downweight the gradients" of that branch.
+        # See e.g. https://github.com/princeton-vl/RAFT/issues/119#issuecomment-953950419
+        # or https://github.com/princeton-vl/RAFT/issues/24.
+        # It doesn't seem to affect epe significantly and can likely be set to 1.
+        self.multiplier = multiplier
+
+    def forward(self, x):
+        x = self.convrelu(x)
+        x = self.conv(x)
+        return self.multiplier * x
+
+
+class CorrBlock(nn.Module):
+    """The correlation block.
+
+    Creates a correlation pyramid with ``num_levels`` levels from the outputs of the feature encoder,
+    and then indexes from this pyramid to create correlation features.
+    The "indexing" of a given centroid pixel x' is done by concatenating its surrounding neighbors that
+    are within a ``radius``, according to the infinity norm (see paper section 3.2).
+    Note: typo in the paper, it should be infinity norm, not 1-norm.
+    """
+
+    def __init__(self, *, num_levels: int = 4, radius: int = 4):
+        super().__init__()
+        self.num_levels = num_levels
+        self.radius = radius
+
+        self.corr_pyramid: List[Tensor] = [torch.tensor(0)]  # useless, but torchscript is otherwise confused :')
+
+        # The neighborhood of a centroid pixel x' is {x' + delta, ||delta||_inf <= radius}
+        # so it's a square surrounding x', and its sides have a length of 2 * radius + 1
+        # The paper claims that it's ||.||_1 instead of ||.||_inf but it's a typo:
+        # https://github.com/princeton-vl/RAFT/issues/122
+        self.out_channels = num_levels * (2 * radius + 1) ** 2
+
+    def build_pyramid(self, fmap1, fmap2):
+        """Build the correlation pyramid from two feature maps.
+
+        The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2)
+        The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions
+        to build the correlation pyramid.
+        """
+
+        if fmap1.shape != fmap2.shape:
+            raise ValueError(
+                f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)"
+            )
+
+        # Explaining min_fmap_size below: the fmaps are down-sampled (num_levels - 1) times by a factor of 2.
+        # The last corr_volume most have at least 2 values (hence the 2* factor), otherwise grid_sample() would
+        # produce nans in its output.
+        min_fmap_size = 2 * (2 ** (self.num_levels - 1))
+        if any(fmap_size < min_fmap_size for fmap_size in fmap1.shape[-2:]):
+            raise ValueError(
+                "Feature maps are too small to be down-sampled by the correlation pyramid. "
+                f"H and W of feature maps should be at least {min_fmap_size}; got: {fmap1.shape[-2:]}. "
+                "Remember that input images to the model are downsampled by 8, so that means their "
+                f"dimensions should be at least 8 * {min_fmap_size} = {8 * min_fmap_size}."
+            )
+
+        corr_volume = self._compute_corr_volume(fmap1, fmap2)
+
+        batch_size, h, w, num_channels, _, _ = corr_volume.shape  # _, _ = h, w
+        corr_volume = corr_volume.reshape(batch_size * h * w, num_channels, h, w)
+        self.corr_pyramid = [corr_volume]
+        for _ in range(self.num_levels - 1):
+            corr_volume = F.avg_pool2d(corr_volume, kernel_size=2, stride=2)
+            self.corr_pyramid.append(corr_volume)
+
+    def index_pyramid(self, centroids_coords):
+        """Return correlation features by indexing from the pyramid."""
+        neighborhood_side_len = 2 * self.radius + 1  # see note in __init__ about out_channels
+        di = torch.linspace(-self.radius, self.radius, neighborhood_side_len)
+        dj = torch.linspace(-self.radius, self.radius, neighborhood_side_len)
+        delta = torch.stack(torch.meshgrid(di, dj, indexing="ij"), dim=-1).to(centroids_coords.device)
+        delta = delta.view(1, neighborhood_side_len, neighborhood_side_len, 2)
+
+        batch_size, _, h, w = centroids_coords.shape  # _ = 2
+        centroids_coords = centroids_coords.permute(0, 2, 3, 1).reshape(batch_size * h * w, 1, 1, 2)
+
+        indexed_pyramid = []
+        for corr_volume in self.corr_pyramid:
+            sampling_coords = centroids_coords + delta  # end shape is (batch_size * h * w, side_len, side_len, 2)
+            indexed_corr_volume = grid_sample(corr_volume, sampling_coords, align_corners=True, mode="bilinear").view(
+                batch_size, h, w, -1
+            )
+            indexed_pyramid.append(indexed_corr_volume)
+            centroids_coords = centroids_coords / 2
+
+        corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous()
+
+        expected_output_shape = (batch_size, self.out_channels, h, w)
+        if corr_features.shape != expected_output_shape:
+            raise ValueError(
+                f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}"
+            )
+
+        return corr_features
+
+    def _compute_corr_volume(self, fmap1, fmap2):
+        batch_size, num_channels, h, w = fmap1.shape
+        fmap1 = fmap1.view(batch_size, num_channels, h * w)
+        fmap2 = fmap2.view(batch_size, num_channels, h * w)
+
+        corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
+        corr = corr.view(batch_size, h, w, 1, h, w)
+        return corr / torch.sqrt(torch.tensor(num_channels))
+
+
+class RAFT(nn.Module):
+    def __init__(self, *, feature_encoder, context_encoder, corr_block, update_block, mask_predictor=None):
+        """RAFT model from
+        `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
+
+        args:
+            feature_encoder (nn.Module): The feature encoder. It must downsample the input by 8.
+                Its input is the concatenation of ``image1`` and ``image2``.
+            context_encoder (nn.Module): The context encoder. It must downsample the input by 8.
+                Its input is ``image1``. As in the original implementation, its output will be split into 2 parts:
+
+                - one part will be used as the actual "context", passed to the recurrent unit of the ``update_block``
+                - one part will be used to initialize the hidden state of the recurrent unit of
+                  the ``update_block``
+
+                These 2 parts are split according to the ``hidden_state_size`` of the ``update_block``, so the output
+                of the ``context_encoder`` must be strictly greater than ``hidden_state_size``.
+
+            corr_block (nn.Module): The correlation block, which creates a correlation pyramid from the output of the
+                ``feature_encoder``, and then indexes from this pyramid to create correlation features. It must expose
+                2 methods:
+
+                - a ``build_pyramid`` method that takes ``feature_map_1`` and ``feature_map_2`` as input (these are the
+                  output of the ``feature_encoder``).
+                - a ``index_pyramid`` method that takes the coordinates of the centroid pixels as input, and returns
+                  the correlation features. See paper section 3.2.
+
+                It must expose an ``out_channels`` attribute.
+
+            update_block (nn.Module): The update block, which contains the motion encoder, the recurrent unit, and the
+                flow head. It takes as input the hidden state of its recurrent unit, the context, the correlation
+                features, and the current predicted flow. It outputs an updated hidden state, and the ``delta_flow``
+                prediction (see paper appendix A). It must expose a ``hidden_state_size`` attribute.
+            mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow.
+                The output channel must be 8 * 8 * 9 - see paper section 3.3, and Appendix B.
+                If ``None`` (default), the flow is upsampled using interpolation.
+        """
+        super().__init__()
+        _log_api_usage_once(self)
+
+        self.feature_encoder = feature_encoder
+        self.context_encoder = context_encoder
+        self.corr_block = corr_block
+        self.update_block = update_block
+
+        self.mask_predictor = mask_predictor
+
+        if not hasattr(self.update_block, "hidden_state_size"):
+            raise ValueError("The update_block parameter should expose a 'hidden_state_size' attribute.")
+
+    def forward(self, image1, image2, num_flow_updates: int = 12):
+
+        batch_size, _, h, w = image1.shape
+        if (h, w) != image2.shape[-2:]:
+            raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}")
+        if not (h % 8 == 0) and (w % 8 == 0):
+            raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)")
+
+        fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0))
+        fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
+        if fmap1.shape[-2:] != (h // 8, w // 8):
+            raise ValueError("The feature encoder should downsample H and W by 8")
+
+        self.corr_block.build_pyramid(fmap1, fmap2)
+
+        context_out = self.context_encoder(image1)
+        if context_out.shape[-2:] != (h // 8, w // 8):
+            raise ValueError("The context encoder should downsample H and W by 8")
+
+        # As in the original paper, the actual output of the context encoder is split in 2 parts:
+        # - one part is used to initialize the hidden state of the recurent units of the update block
+        # - the rest is the "actual" context.
+        hidden_state_size = self.update_block.hidden_state_size
+        out_channels_context = context_out.shape[1] - hidden_state_size
+        if out_channels_context <= 0:
+            raise ValueError(
+                f"The context encoder outputs {context_out.shape[1]} channels, but it should have at strictly more than hidden_state={hidden_state_size} channels"
+            )
+        hidden_state, context = torch.split(context_out, [hidden_state_size, out_channels_context], dim=1)
+        hidden_state = torch.tanh(hidden_state)
+        context = F.relu(context)
+
+        coords0 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
+        coords1 = make_coords_grid(batch_size, h // 8, w // 8).to(fmap1.device)
+
+        flow_predictions = []
+        for _ in range(num_flow_updates):
+            coords1 = coords1.detach()  # Don't backpropagate gradients through this branch, see paper
+            corr_features = self.corr_block.index_pyramid(centroids_coords=coords1)
+
+            flow = coords1 - coords0
+            hidden_state, delta_flow = self.update_block(hidden_state, context, corr_features, flow)
+
+            coords1 = coords1 + delta_flow
+
+            up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
+            upsampled_flow = upsample_flow(flow=(coords1 - coords0), up_mask=up_mask)
+            flow_predictions.append(upsampled_flow)
+
+        return flow_predictions
+
+
+_COMMON_META = {
+    "min_size": (128, 128),
+}
+
+
+class Raft_Large_Weights(WeightsEnum):
+    """The metrics reported here are as follows.
+
+    ``epe`` is the "end-point-error" and indicates how far (in pixels) the
+    predicted flow is from its true value. This is averaged over all pixels
+    of all images. ``per_image_epe`` is similar, but the average is different:
+    the epe is first computed on each image independently, and then averaged
+    over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
+    in the original paper, and it's only used on Kitti. ``fl-all`` is also a
+    Kitti-specific metric, defined by the author of the dataset and used for the
+    Kitti leaderboard. It corresponds to the average of pixels whose epe is
+    either <3px, or <5% of flow's 2-norm.
+    """
+
+    C_T_V1 = Weights(
+        # Weights ported from https://github.com/princeton-vl/RAFT
+        url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
+        transforms=OpticalFlow,
+        meta={
+            **_COMMON_META,
+            "num_params": 5257536,
+            "recipe": "https://github.com/princeton-vl/RAFT",
+            "_metrics": {
+                "Sintel-Train-Cleanpass": {"epe": 1.4411},
+                "Sintel-Train-Finalpass": {"epe": 2.7894},
+                "Kitti-Train": {"per_image_epe": 5.0172, "fl_all": 17.4506},
+            },
+            "_ops": 211.007,
+            "_file_size": 20.129,
+            "_docs": """These weights were ported from the original paper. They
+            are trained on :class:`~torchvision.datasets.FlyingChairs` +
+            :class:`~torchvision.datasets.FlyingThings3D`.""",
+        },
+    )
+
+    C_T_V2 = Weights(
+        url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
+        transforms=OpticalFlow,
+        meta={
+            **_COMMON_META,
+            "num_params": 5257536,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
+            "_metrics": {
+                "Sintel-Train-Cleanpass": {"epe": 1.3822},
+                "Sintel-Train-Finalpass": {"epe": 2.7161},
+                "Kitti-Train": {"per_image_epe": 4.5118, "fl_all": 16.0679},
+            },
+            "_ops": 211.007,
+            "_file_size": 20.129,
+            "_docs": """These weights were trained from scratch on
+            :class:`~torchvision.datasets.FlyingChairs` +
+            :class:`~torchvision.datasets.FlyingThings3D`.""",
+        },
+    )
+
+    C_T_SKHT_V1 = Weights(
+        # Weights ported from https://github.com/princeton-vl/RAFT
+        url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth",
+        transforms=OpticalFlow,
+        meta={
+            **_COMMON_META,
+            "num_params": 5257536,
+            "recipe": "https://github.com/princeton-vl/RAFT",
+            "_metrics": {
+                "Sintel-Test-Cleanpass": {"epe": 1.94},
+                "Sintel-Test-Finalpass": {"epe": 3.18},
+            },
+            "_ops": 211.007,
+            "_file_size": 20.129,
+            "_docs": """
+                These weights were ported from the original paper. They are
+                trained on :class:`~torchvision.datasets.FlyingChairs` +
+                :class:`~torchvision.datasets.FlyingThings3D` and fine-tuned on
+                Sintel. The Sintel fine-tuning step is a combination of
+                :class:`~torchvision.datasets.Sintel`,
+                :class:`~torchvision.datasets.KittiFlow`,
+                :class:`~torchvision.datasets.HD1K`, and
+                :class:`~torchvision.datasets.FlyingThings3D` (clean pass).
+            """,
+        },
+    )
+
+    C_T_SKHT_V2 = Weights(
+        url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
+        transforms=OpticalFlow,
+        meta={
+            **_COMMON_META,
+            "num_params": 5257536,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
+            "_metrics": {
+                "Sintel-Test-Cleanpass": {"epe": 1.819},
+                "Sintel-Test-Finalpass": {"epe": 3.067},
+            },
+            "_ops": 211.007,
+            "_file_size": 20.129,
+            "_docs": """
+                These weights were trained from scratch. They are
+                pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
+                :class:`~torchvision.datasets.FlyingThings3D` and then
+                fine-tuned on Sintel. The Sintel fine-tuning step is a
+                combination of :class:`~torchvision.datasets.Sintel`,
+                :class:`~torchvision.datasets.KittiFlow`,
+                :class:`~torchvision.datasets.HD1K`, and
+                :class:`~torchvision.datasets.FlyingThings3D` (clean pass).
+            """,
+        },
+    )
+
+    C_T_SKHT_K_V1 = Weights(
+        # Weights ported from https://github.com/princeton-vl/RAFT
+        url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth",
+        transforms=OpticalFlow,
+        meta={
+            **_COMMON_META,
+            "num_params": 5257536,
+            "recipe": "https://github.com/princeton-vl/RAFT",
+            "_metrics": {
+                "Kitti-Test": {"fl_all": 5.10},
+            },
+            "_ops": 211.007,
+            "_file_size": 20.129,
+            "_docs": """
+                These weights were ported from the original paper. They are
+                pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
+                :class:`~torchvision.datasets.FlyingThings3D`,
+                fine-tuned on Sintel, and then fine-tuned on
+                :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
+                step was described above.
+            """,
+        },
+    )
+
+    C_T_SKHT_K_V2 = Weights(
+        url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth",
+        transforms=OpticalFlow,
+        meta={
+            **_COMMON_META,
+            "num_params": 5257536,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
+            "_metrics": {
+                "Kitti-Test": {"fl_all": 5.19},
+            },
+            "_ops": 211.007,
+            "_file_size": 20.129,
+            "_docs": """
+                These weights were trained from scratch. They are
+                pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
+                :class:`~torchvision.datasets.FlyingThings3D`,
+                fine-tuned on Sintel, and then fine-tuned on
+                :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
+                step was described above.
+            """,
+        },
+    )
+
+    DEFAULT = C_T_SKHT_V2
+
+
+class Raft_Small_Weights(WeightsEnum):
+    """The metrics reported here are as follows.
+
+    ``epe`` is the "end-point-error" and indicates how far (in pixels) the
+    predicted flow is from its true value. This is averaged over all pixels
+    of all images. ``per_image_epe`` is similar, but the average is different:
+    the epe is first computed on each image independently, and then averaged
+    over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
+    in the original paper, and it's only used on Kitti. ``fl-all`` is also a
+    Kitti-specific metric, defined by the author of the dataset and used for the
+    Kitti leaderboard. It corresponds to the average of pixels whose epe is
+    either <3px, or <5% of flow's 2-norm.
+    """
+
+    C_T_V1 = Weights(
+        # Weights ported from https://github.com/princeton-vl/RAFT
+        url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
+        transforms=OpticalFlow,
+        meta={
+            **_COMMON_META,
+            "num_params": 990162,
+            "recipe": "https://github.com/princeton-vl/RAFT",
+            "_metrics": {
+                "Sintel-Train-Cleanpass": {"epe": 2.1231},
+                "Sintel-Train-Finalpass": {"epe": 3.2790},
+                "Kitti-Train": {"per_image_epe": 7.6557, "fl_all": 25.2801},
+            },
+            "_ops": 47.655,
+            "_file_size": 3.821,
+            "_docs": """These weights were ported from the original paper. They
+            are trained on :class:`~torchvision.datasets.FlyingChairs` +
+            :class:`~torchvision.datasets.FlyingThings3D`.""",
+        },
+    )
+    C_T_V2 = Weights(
+        url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
+        transforms=OpticalFlow,
+        meta={
+            **_COMMON_META,
+            "num_params": 990162,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow",
+            "_metrics": {
+                "Sintel-Train-Cleanpass": {"epe": 1.9901},
+                "Sintel-Train-Finalpass": {"epe": 3.2831},
+                "Kitti-Train": {"per_image_epe": 7.5978, "fl_all": 25.2369},
+            },
+            "_ops": 47.655,
+            "_file_size": 3.821,
+            "_docs": """These weights were trained from scratch on
+            :class:`~torchvision.datasets.FlyingChairs` +
+            :class:`~torchvision.datasets.FlyingThings3D`.""",
+        },
+    )
+
+    DEFAULT = C_T_V2
+
+
+def _raft(
+    *,
+    weights=None,
+    progress=False,
+    # Feature encoder
+    feature_encoder_layers,
+    feature_encoder_block,
+    feature_encoder_norm_layer,
+    # Context encoder
+    context_encoder_layers,
+    context_encoder_block,
+    context_encoder_norm_layer,
+    # Correlation block
+    corr_block_num_levels,
+    corr_block_radius,
+    # Motion encoder
+    motion_encoder_corr_layers,
+    motion_encoder_flow_layers,
+    motion_encoder_out_channels,
+    # Recurrent block
+    recurrent_block_hidden_state_size,
+    recurrent_block_kernel_size,
+    recurrent_block_padding,
+    # Flow Head
+    flow_head_hidden_size,
+    # Mask predictor
+    use_mask_predictor,
+    **kwargs,
+):
+    feature_encoder = kwargs.pop("feature_encoder", None) or FeatureEncoder(
+        block=feature_encoder_block, layers=feature_encoder_layers, norm_layer=feature_encoder_norm_layer
+    )
+    context_encoder = kwargs.pop("context_encoder", None) or FeatureEncoder(
+        block=context_encoder_block, layers=context_encoder_layers, norm_layer=context_encoder_norm_layer
+    )
+
+    corr_block = kwargs.pop("corr_block", None) or CorrBlock(num_levels=corr_block_num_levels, radius=corr_block_radius)
+
+    update_block = kwargs.pop("update_block", None)
+    if update_block is None:
+        motion_encoder = MotionEncoder(
+            in_channels_corr=corr_block.out_channels,
+            corr_layers=motion_encoder_corr_layers,
+            flow_layers=motion_encoder_flow_layers,
+            out_channels=motion_encoder_out_channels,
+        )
+
+        # See comments in forward pass of RAFT class about why we split the output of the context encoder
+        out_channels_context = context_encoder_layers[-1] - recurrent_block_hidden_state_size
+        recurrent_block = RecurrentBlock(
+            input_size=motion_encoder.out_channels + out_channels_context,
+            hidden_size=recurrent_block_hidden_state_size,
+            kernel_size=recurrent_block_kernel_size,
+            padding=recurrent_block_padding,
+        )
+
+        flow_head = FlowHead(in_channels=recurrent_block_hidden_state_size, hidden_size=flow_head_hidden_size)
+
+        update_block = UpdateBlock(motion_encoder=motion_encoder, recurrent_block=recurrent_block, flow_head=flow_head)
+
+    mask_predictor = kwargs.pop("mask_predictor", None)
+    if mask_predictor is None and use_mask_predictor:
+        mask_predictor = MaskPredictor(
+            in_channels=recurrent_block_hidden_state_size,
+            hidden_size=256,
+            multiplier=0.25,  # See comment in MaskPredictor about this
+        )
+
+    model = RAFT(
+        feature_encoder=feature_encoder,
+        context_encoder=context_encoder,
+        corr_block=corr_block,
+        update_block=update_block,
+        mask_predictor=mask_predictor,
+        **kwargs,  # not really needed, all params should be consumed by now
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2))
+def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs) -> RAFT:
+    """RAFT model from
+    `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
+
+    Please see the example below for a tutorial on how to use this model.
+
+    Args:
+        weights(:class:`~torchvision.models.optical_flow.Raft_Large_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.optical_flow.Raft_Large_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.optical_flow.Raft_Large_Weights
+        :members:
+    """
+
+    weights = Raft_Large_Weights.verify(weights)
+
+    return _raft(
+        weights=weights,
+        progress=progress,
+        # Feature encoder
+        feature_encoder_layers=(64, 64, 96, 128, 256),
+        feature_encoder_block=ResidualBlock,
+        feature_encoder_norm_layer=InstanceNorm2d,
+        # Context encoder
+        context_encoder_layers=(64, 64, 96, 128, 256),
+        context_encoder_block=ResidualBlock,
+        context_encoder_norm_layer=BatchNorm2d,
+        # Correlation block
+        corr_block_num_levels=4,
+        corr_block_radius=4,
+        # Motion encoder
+        motion_encoder_corr_layers=(256, 192),
+        motion_encoder_flow_layers=(128, 64),
+        motion_encoder_out_channels=128,
+        # Recurrent block
+        recurrent_block_hidden_state_size=128,
+        recurrent_block_kernel_size=((1, 5), (5, 1)),
+        recurrent_block_padding=((0, 2), (2, 0)),
+        # Flow head
+        flow_head_hidden_size=256,
+        # Mask predictor
+        use_mask_predictor=True,
+        **kwargs,
+    )
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2))
+def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT:
+    """RAFT "small" model from
+    `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`__.
+
+    Please see the example below for a tutorial on how to use this model.
+
+    Args:
+        weights(:class:`~torchvision.models.optical_flow.Raft_Small_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.optical_flow.Raft_Small_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.optical_flow.Raft_Small_Weights
+        :members:
+    """
+    weights = Raft_Small_Weights.verify(weights)
+
+    return _raft(
+        weights=weights,
+        progress=progress,
+        # Feature encoder
+        feature_encoder_layers=(32, 32, 64, 96, 128),
+        feature_encoder_block=BottleneckBlock,
+        feature_encoder_norm_layer=InstanceNorm2d,
+        # Context encoder
+        context_encoder_layers=(32, 32, 64, 96, 160),
+        context_encoder_block=BottleneckBlock,
+        context_encoder_norm_layer=None,
+        # Correlation block
+        corr_block_num_levels=4,
+        corr_block_radius=3,
+        # Motion encoder
+        motion_encoder_corr_layers=(96,),
+        motion_encoder_flow_layers=(64, 32),
+        motion_encoder_out_channels=82,
+        # Recurrent block
+        recurrent_block_hidden_state_size=96,
+        recurrent_block_kernel_size=(3,),
+        recurrent_block_padding=(1,),
+        # Flow head
+        flow_head_hidden_size=128,
+        # Mask predictor
+        use_mask_predictor=False,
+        **kwargs,
+    )

部分文件因为文件数量过多而无法显示