Explorar el Código

Merge branch 'faster_lcnn' into dev

RenLiqiang hace 3 meses
padre
commit
80280e38ec
Se han modificado 100 ficheros con 19863 adiciones y 0 borrados
  1. BIN
      assets/dog1.jpg
  2. BIN
      assets/dog1.png
  3. BIN
      assets/dog2.jpg
  4. BIN
      assets/dog2.png
  5. 73 0
      config/wireframe.yaml
  6. 4 0
      lcnn/__init__.py
  7. 1110 0
      lcnn/box.py
  8. 9 0
      lcnn/config.py
  9. 378 0
      lcnn/dataset_tool.py
  10. 294 0
      lcnn/datasets.py
  11. 209 0
      lcnn/metric.py
  12. 9 0
      lcnn/models/__init__.py
  13. 0 0
      lcnn/models/base/__init__.py
  14. 48 0
      lcnn/models/base/base_dataset.py
  15. 120 0
      lcnn/models/fasterrcnn_resnet50.py
  16. 201 0
      lcnn/models/hourglass_pose.py
  17. 276 0
      lcnn/models/line_vectorizer.py
  18. 118 0
      lcnn/models/multitask_learner.py
  19. 182 0
      lcnn/models/resnet50.py
  20. 87 0
      lcnn/models/resnet50_pose.py
  21. 126 0
      lcnn/models/unet.py
  22. 77 0
      lcnn/postprocess.py
  23. 429 0
      lcnn/trainer.py
  24. 101 0
      lcnn/utils.py
  25. BIN
      libs/vision_libs/_C.pyd
  26. 103 0
      libs/vision_libs/__init__.py
  27. 50 0
      libs/vision_libs/_internally_replaced_utils.py
  28. 225 0
      libs/vision_libs/_meta_registrations.py
  29. 32 0
      libs/vision_libs/_utils.py
  30. 146 0
      libs/vision_libs/datasets/__init__.py
  31. 491 0
      libs/vision_libs/datasets/_optical_flow.py
  32. 1224 0
      libs/vision_libs/datasets/_stereo_matching.py
  33. 241 0
      libs/vision_libs/datasets/caltech.py
  34. 193 0
      libs/vision_libs/datasets/celeba.py
  35. 167 0
      libs/vision_libs/datasets/cifar.py
  36. 221 0
      libs/vision_libs/datasets/cityscapes.py
  37. 88 0
      libs/vision_libs/datasets/clevr.py
  38. 104 0
      libs/vision_libs/datasets/coco.py
  39. 58 0
      libs/vision_libs/datasets/country211.py
  40. 100 0
      libs/vision_libs/datasets/dtd.py
  41. 58 0
      libs/vision_libs/datasets/eurosat.py
  42. 67 0
      libs/vision_libs/datasets/fakedata.py
  43. 75 0
      libs/vision_libs/datasets/fer2013.py
  44. 114 0
      libs/vision_libs/datasets/fgvc_aircraft.py
  45. 166 0
      libs/vision_libs/datasets/flickr.py
  46. 114 0
      libs/vision_libs/datasets/flowers102.py
  47. 317 0
      libs/vision_libs/datasets/folder.py
  48. 93 0
      libs/vision_libs/datasets/food101.py
  49. 103 0
      libs/vision_libs/datasets/gtsrb.py
  50. 151 0
      libs/vision_libs/datasets/hmdb51.py
  51. 218 0
      libs/vision_libs/datasets/imagenet.py
  52. 104 0
      libs/vision_libs/datasets/imagenette.py
  53. 241 0
      libs/vision_libs/datasets/inaturalist.py
  54. 247 0
      libs/vision_libs/datasets/kinetics.py
  55. 157 0
      libs/vision_libs/datasets/kitti.py
  56. 255 0
      libs/vision_libs/datasets/lfw.py
  57. 167 0
      libs/vision_libs/datasets/lsun.py
  58. 558 0
      libs/vision_libs/datasets/mnist.py
  59. 93 0
      libs/vision_libs/datasets/moving_mnist.py
  60. 102 0
      libs/vision_libs/datasets/omniglot.py
  61. 125 0
      libs/vision_libs/datasets/oxford_iiit_pet.py
  62. 134 0
      libs/vision_libs/datasets/pcam.py
  63. 228 0
      libs/vision_libs/datasets/phototour.py
  64. 170 0
      libs/vision_libs/datasets/places365.py
  65. 86 0
      libs/vision_libs/datasets/rendered_sst2.py
  66. 3 0
      libs/vision_libs/datasets/samplers/__init__.py
  67. 172 0
      libs/vision_libs/datasets/samplers/clip_sampler.py
  68. 123 0
      libs/vision_libs/datasets/sbd.py
  69. 109 0
      libs/vision_libs/datasets/sbu.py
  70. 91 0
      libs/vision_libs/datasets/semeion.py
  71. 121 0
      libs/vision_libs/datasets/stanford_cars.py
  72. 174 0
      libs/vision_libs/datasets/stl10.py
  73. 76 0
      libs/vision_libs/datasets/sun397.py
  74. 129 0
      libs/vision_libs/datasets/svhn.py
  75. 130 0
      libs/vision_libs/datasets/ucf101.py
  76. 95 0
      libs/vision_libs/datasets/usps.py
  77. 459 0
      libs/vision_libs/datasets/utils.py
  78. 419 0
      libs/vision_libs/datasets/video_utils.py
  79. 110 0
      libs/vision_libs/datasets/vision.py
  80. 224 0
      libs/vision_libs/datasets/voc.py
  81. 195 0
      libs/vision_libs/datasets/widerface.py
  82. 92 0
      libs/vision_libs/extension.py
  83. BIN
      libs/vision_libs/image.pyd
  84. 69 0
      libs/vision_libs/io/__init__.py
  85. 8 0
      libs/vision_libs/io/_load_gpu_decoder.py
  86. 512 0
      libs/vision_libs/io/_video_opt.py
  87. 264 0
      libs/vision_libs/io/image.py
  88. 415 0
      libs/vision_libs/io/video.py
  89. 286 0
      libs/vision_libs/io/video_reader.py
  90. 23 0
      libs/vision_libs/models/__init__.py
  91. 277 0
      libs/vision_libs/models/_api.py
  92. 1554 0
      libs/vision_libs/models/_meta.py
  93. 256 0
      libs/vision_libs/models/_utils.py
  94. 119 0
      libs/vision_libs/models/alexnet.py
  95. 414 0
      libs/vision_libs/models/convnext.py
  96. 448 0
      libs/vision_libs/models/densenet.py
  97. 7 0
      libs/vision_libs/models/detection/__init__.py
  98. 540 0
      libs/vision_libs/models/detection/_utils.py
  99. 268 0
      libs/vision_libs/models/detection/anchor_utils.py
  100. 244 0
      libs/vision_libs/models/detection/backbone_utils.py

BIN
assets/dog1.jpg


BIN
assets/dog1.png


BIN
assets/dog2.jpg


BIN
assets/dog2.png


+ 73 - 0
config/wireframe.yaml

@@ -0,0 +1,73 @@
+io:
+  logdir: logs/
+  datadir: D:\python\PycharmProjects\data
+#  datadir: /home/dieu/lcnn/dataset/line_data_104
+  resume_from:
+#  resume_from: /home/dieu/lcnn/logs/241112-163302-175fb79-my_data_104_resume
+  num_workers: 0
+  tensorboard_port: 0
+  validation_interval: 300    # 评估间隔
+
+model:
+  image:
+      mean: [109.730, 103.832, 98.681]
+      stddev: [22.275, 22.124, 23.229]
+
+  batch_size: 4
+  batch_size_eval: 2
+
+  # backbone multi-task parameters
+  head_size: [[2], [1], [2],[4]]
+  loss_weight:
+    jmap: 8.0
+    lmap: 0.5
+    joff: 0.25
+    lpos: 1
+    lneg: 1
+    boxes: 1.0  # 新增 box loss 权重
+
+  # backbone parameters
+  backbone: fasterrcnn_resnet50
+#  backbone: unet
+  depth: 4
+  num_stacks: 1
+  num_blocks: 1
+
+  # sampler parameters
+  ## static sampler
+  n_stc_posl: 300
+  n_stc_negl: 40
+
+  ## dynamic sampler
+  n_dyn_junc: 300
+  n_dyn_posl: 300
+  n_dyn_negl: 80
+  n_dyn_othr: 600
+
+  # LOIPool layer parameters
+  n_pts0: 32
+  n_pts1: 8
+
+  # line verification network parameters
+  dim_loi: 128
+  dim_fc: 1024
+
+  # maximum junction and line outputs
+  n_out_junc: 250
+  n_out_line: 2500
+
+  # additional ablation study parameters
+  use_cood: 0
+  use_slop: 0
+  use_conv: 0
+
+  # junction threashold for evaluation (See #5)
+  eval_junc_thres: 0.008
+
+optim:
+  name: Adam
+  lr: 4.0e-4
+  amsgrad: True
+  weight_decay: 1.0e-4
+  max_epoch: 1000
+  lr_decay_epoch: 10

+ 4 - 0
lcnn/__init__.py

@@ -0,0 +1,4 @@
+import lcnn.models
+import lcnn.trainer
+import lcnn.datasets
+import lcnn.config

+ 1110 - 0
lcnn/box.py

@@ -0,0 +1,1110 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+#
+# Copyright (c) 2017-2019 - Chris Griffith - MIT License
+"""
+Improved dictionary access through dot notation with additional tools.
+"""
+import string
+import sys
+import json
+import re
+import copy
+from keyword import kwlist
+import warnings
+
+try:
+    from collections.abc import Iterable, Mapping, Callable
+except ImportError:
+    from collections import Iterable, Mapping, Callable
+
+yaml_support = True
+
+try:
+    import yaml
+except ImportError:
+    try:
+        import ruamel.yaml as yaml
+    except ImportError:
+        yaml = None
+        yaml_support = False
+
+if sys.version_info >= (3, 0):
+    basestring = str
+else:
+    from io import open
+
+__all__ = ['Box', 'ConfigBox', 'BoxList', 'SBox',
+           'BoxError', 'BoxKeyError']
+__author__ = 'Chris Griffith'
+__version__ = '3.2.4'
+
+BOX_PARAMETERS = ('default_box', 'default_box_attr', 'conversion_box',
+                  'frozen_box', 'camel_killer_box', 'box_it_up',
+                  'box_safe_prefix', 'box_duplicates', 'ordered_box')
+
+_first_cap_re = re.compile('(.)([A-Z][a-z]+)')
+_all_cap_re = re.compile('([a-z0-9])([A-Z])')
+
+
+class BoxError(Exception):
+    """Non standard dictionary exceptions"""
+
+
+class BoxKeyError(BoxError, KeyError, AttributeError):
+    """Key does not exist"""
+
+
+# Abstract converter functions for use in any Box class
+
+
+def _to_json(obj, filename=None,
+             encoding="utf-8", errors="strict", **json_kwargs):
+    json_dump = json.dumps(obj,
+                           ensure_ascii=False, **json_kwargs)
+    if filename:
+        with open(filename, 'w', encoding=encoding, errors=errors) as f:
+            f.write(json_dump if sys.version_info >= (3, 0) else
+                    json_dump.decode("utf-8"))
+    else:
+        return json_dump
+
+
+def _from_json(json_string=None, filename=None,
+               encoding="utf-8", errors="strict", multiline=False, **kwargs):
+    if filename:
+        with open(filename, 'r', encoding=encoding, errors=errors) as f:
+            if multiline:
+                data = [json.loads(line.strip(), **kwargs) for line in f
+                        if line.strip() and not line.strip().startswith("#")]
+            else:
+                data = json.load(f, **kwargs)
+    elif json_string:
+        data = json.loads(json_string, **kwargs)
+    else:
+        raise BoxError('from_json requires a string or filename')
+    return data
+
+
+def _to_yaml(obj, filename=None, default_flow_style=False,
+             encoding="utf-8", errors="strict",
+             **yaml_kwargs):
+    if filename:
+        with open(filename, 'w',
+                  encoding=encoding, errors=errors) as f:
+            yaml.dump(obj, stream=f,
+                      default_flow_style=default_flow_style,
+                      **yaml_kwargs)
+    else:
+        return yaml.dump(obj,
+                         default_flow_style=default_flow_style,
+                         **yaml_kwargs)
+
+
+def _from_yaml(yaml_string=None, filename=None,
+               encoding="utf-8", errors="strict",
+               **kwargs):
+    if filename:
+        with open(filename, 'r',
+                  encoding=encoding, errors=errors) as f:
+            data = yaml.load(f, **kwargs)
+    elif yaml_string:
+        data = yaml.load(yaml_string, **kwargs)
+    else:
+        raise BoxError('from_yaml requires a string or filename')
+    return data
+
+
+# Helper functions
+
+
+def _safe_key(key):
+    try:
+        return str(key)
+    except UnicodeEncodeError:
+        return key.encode("utf-8", "ignore")
+
+
+def _safe_attr(attr, camel_killer=False, replacement_char='x'):
+    """Convert a key into something that is accessible as an attribute"""
+    allowed = string.ascii_letters + string.digits + '_'
+
+    attr = _safe_key(attr)
+
+    if camel_killer:
+        attr = _camel_killer(attr)
+
+    attr = attr.replace(' ', '_')
+
+    out = ''
+    for character in attr:
+        out += character if character in allowed else "_"
+    out = out.strip("_")
+
+    try:
+        int(out[0])
+    except (ValueError, IndexError):
+        pass
+    else:
+        out = '{0}{1}'.format(replacement_char, out)
+
+    if out in kwlist:
+        out = '{0}{1}'.format(replacement_char, out)
+
+    return re.sub('_+', '_', out)
+
+
+def _camel_killer(attr):
+    """
+    CamelKiller, qu'est-ce que c'est?
+
+    Taken from http://stackoverflow.com/a/1176023/3244542
+    """
+    try:
+        attr = str(attr)
+    except UnicodeEncodeError:
+        attr = attr.encode("utf-8", "ignore")
+
+    s1 = _first_cap_re.sub(r'\1_\2', attr)
+    s2 = _all_cap_re.sub(r'\1_\2', s1)
+    return re.sub('_+', '_', s2.casefold() if hasattr(s2, 'casefold') else
+                  s2.lower())
+
+
+def _recursive_tuples(iterable, box_class, recreate_tuples=False, **kwargs):
+    out_list = []
+    for i in iterable:
+        if isinstance(i, dict):
+            out_list.append(box_class(i, **kwargs))
+        elif isinstance(i, list) or (recreate_tuples and isinstance(i, tuple)):
+            out_list.append(_recursive_tuples(i, box_class,
+                                              recreate_tuples, **kwargs))
+        else:
+            out_list.append(i)
+    return tuple(out_list)
+
+
+def _conversion_checks(item, keys, box_config, check_only=False,
+                       pre_check=False):
+    """
+    Internal use for checking if a duplicate safe attribute already exists
+
+    :param item: Item to see if a dup exists
+    :param keys: Keys to check against
+    :param box_config: Easier to pass in than ask for specfic items
+    :param check_only: Don't bother doing the conversion work
+    :param pre_check: Need to add the item to the list of keys to check
+    :return: the original unmodified key, if exists and not check_only
+    """
+    if box_config['box_duplicates'] != 'ignore':
+        if pre_check:
+            keys = list(keys) + [item]
+
+        key_list = [(k,
+                     _safe_attr(k, camel_killer=box_config['camel_killer_box'],
+                                replacement_char=box_config['box_safe_prefix']
+                                )) for k in keys]
+        if len(key_list) > len(set(x[1] for x in key_list)):
+            seen = set()
+            dups = set()
+            for x in key_list:
+                if x[1] in seen:
+                    dups.add("{0}({1})".format(x[0], x[1]))
+                seen.add(x[1])
+            if box_config['box_duplicates'].startswith("warn"):
+                warnings.warn('Duplicate conversion attributes exist: '
+                              '{0}'.format(dups))
+            else:
+                raise BoxError('Duplicate conversion attributes exist: '
+                               '{0}'.format(dups))
+    if check_only:
+        return
+    # This way will be slower for warnings, as it will have double work
+    # But faster for the default 'ignore'
+    for k in keys:
+        if item == _safe_attr(k, camel_killer=box_config['camel_killer_box'],
+                              replacement_char=box_config['box_safe_prefix']):
+            return k
+
+
+def _get_box_config(cls, kwargs):
+    return {
+        # Internal use only
+        '__converted': set(),
+        '__box_heritage': kwargs.pop('__box_heritage', None),
+        '__created': False,
+        '__ordered_box_values': [],
+        # Can be changed by user after box creation
+        'default_box': kwargs.pop('default_box', False),
+        'default_box_attr': kwargs.pop('default_box_attr', cls),
+        'conversion_box': kwargs.pop('conversion_box', True),
+        'box_safe_prefix': kwargs.pop('box_safe_prefix', 'x'),
+        'frozen_box': kwargs.pop('frozen_box', False),
+        'camel_killer_box': kwargs.pop('camel_killer_box', False),
+        'modify_tuples_box': kwargs.pop('modify_tuples_box', False),
+        'box_duplicates': kwargs.pop('box_duplicates', 'ignore'),
+        'ordered_box': kwargs.pop('ordered_box', False)
+    }
+
+
+class Box(dict):
+    """
+    Improved dictionary access through dot notation with additional tools.
+
+    :param default_box: Similar to defaultdict, return a default value
+    :param default_box_attr: Specify the default replacement.
+        WARNING: If this is not the default 'Box', it will not be recursive
+    :param frozen_box: After creation, the box cannot be modified
+    :param camel_killer_box: Convert CamelCase to snake_case
+    :param conversion_box: Check for near matching keys as attributes
+    :param modify_tuples_box: Recreate incoming tuples with dicts into Boxes
+    :param box_it_up: Recursively create all Boxes from the start
+    :param box_safe_prefix: Conversion box prefix for unsafe attributes
+    :param box_duplicates: "ignore", "error" or "warn" when duplicates exists
+        in a conversion_box
+    :param ordered_box: Preserve the order of keys entered into the box
+    """
+
+    _protected_keys = dir({}) + ['to_dict', 'tree_view', 'to_json', 'to_yaml',
+                                 'from_yaml', 'from_json']
+
+    def __new__(cls, *args, **kwargs):
+        """
+        Due to the way pickling works in python 3, we need to make sure
+        the box config is created as early as possible.
+        """
+        obj = super(Box, cls).__new__(cls, *args, **kwargs)
+        obj._box_config = _get_box_config(cls, kwargs)
+        return obj
+
+    def __init__(self, *args, **kwargs):
+        self._box_config = _get_box_config(self.__class__, kwargs)
+        if self._box_config['ordered_box']:
+            self._box_config['__ordered_box_values'] = []
+        if (not self._box_config['conversion_box'] and
+                self._box_config['box_duplicates'] != "ignore"):
+            raise BoxError('box_duplicates are only for conversion_boxes')
+        if len(args) == 1:
+            if isinstance(args[0], basestring):
+                raise ValueError('Cannot extrapolate Box from string')
+            if isinstance(args[0], Mapping):
+                for k, v in args[0].items():
+                    if v is args[0]:
+                        v = self
+                    self[k] = v
+                    self.__add_ordered(k)
+            elif isinstance(args[0], Iterable):
+                for k, v in args[0]:
+                    self[k] = v
+                    self.__add_ordered(k)
+
+            else:
+                raise ValueError('First argument must be mapping or iterable')
+        elif args:
+            raise TypeError('Box expected at most 1 argument, '
+                            'got {0}'.format(len(args)))
+
+        box_it = kwargs.pop('box_it_up', False)
+        for k, v in kwargs.items():
+            if args and isinstance(args[0], Mapping) and v is args[0]:
+                v = self
+            self[k] = v
+            self.__add_ordered(k)
+
+        if (self._box_config['frozen_box'] or box_it or
+                self._box_config['box_duplicates'] != 'ignore'):
+            self.box_it_up()
+
+        self._box_config['__created'] = True
+
+    def __add_ordered(self, key):
+        if (self._box_config['ordered_box'] and
+                key not in self._box_config['__ordered_box_values']):
+            self._box_config['__ordered_box_values'].append(key)
+
+    def box_it_up(self):
+        """
+        Perform value lookup for all items in current dictionary,
+        generating all sub Box objects, while also running `box_it_up` on
+        any of those sub box objects.
+        """
+        for k in self:
+            _conversion_checks(k, self.keys(), self._box_config,
+                               check_only=True)
+            if self[k] is not self and hasattr(self[k], 'box_it_up'):
+                self[k].box_it_up()
+
+    def __hash__(self):
+        if self._box_config['frozen_box']:
+            hashing = 54321
+            for item in self.items():
+                hashing ^= hash(item)
+            return hashing
+        raise TypeError("unhashable type: 'Box'")
+
+    def __dir__(self):
+        allowed = string.ascii_letters + string.digits + '_'
+        kill_camel = self._box_config['camel_killer_box']
+        items = set(dir(dict) + ['to_dict', 'to_json',
+                                 'from_json', 'box_it_up'])
+        # Only show items accessible by dot notation
+        for key in self.keys():
+            key = _safe_key(key)
+            if (' ' not in key and key[0] not in string.digits and
+                    key not in kwlist):
+                for letter in key:
+                    if letter not in allowed:
+                        break
+                else:
+                    items.add(key)
+
+        for key in self.keys():
+            key = _safe_key(key)
+            if key not in items:
+                if self._box_config['conversion_box']:
+                    key = _safe_attr(key, camel_killer=kill_camel,
+                                     replacement_char=self._box_config[
+                                         'box_safe_prefix'])
+                    if key:
+                        items.add(key)
+            if kill_camel:
+                snake_key = _camel_killer(key)
+                if snake_key:
+                    items.remove(key)
+                    items.add(snake_key)
+
+        if yaml_support:
+            items.add('to_yaml')
+            items.add('from_yaml')
+
+        return list(items)
+
+    def get(self, key, default=None):
+        try:
+            return self[key]
+        except KeyError:
+            if isinstance(default, dict) and not isinstance(default, Box):
+                return Box(default)
+            if isinstance(default, list) and not isinstance(default, BoxList):
+                return BoxList(default)
+            return default
+
+    def copy(self):
+        return self.__class__(super(self.__class__, self).copy())
+
+    def __copy__(self):
+        return self.__class__(super(self.__class__, self).copy())
+
+    def __deepcopy__(self, memodict=None):
+        out = self.__class__()
+        memodict = memodict or {}
+        memodict[id(self)] = out
+        for k, v in self.items():
+            out[copy.deepcopy(k, memodict)] = copy.deepcopy(v, memodict)
+        return out
+
+    def __setstate__(self, state):
+        self._box_config = state['_box_config']
+        self.__dict__.update(state)
+
+    def __getitem__(self, item, _ignore_default=False):
+        try:
+            value = super(Box, self).__getitem__(item)
+        except KeyError as err:
+            if item == '_box_config':
+                raise BoxKeyError('_box_config should only exist as an '
+                                  'attribute and is never defaulted')
+            if self._box_config['default_box'] and not _ignore_default:
+                return self.__get_default(item)
+            raise BoxKeyError(str(err))
+        else:
+            return self.__convert_and_store(item, value)
+
+    def keys(self):
+        if self._box_config['ordered_box']:
+            return self._box_config['__ordered_box_values']
+        return super(Box, self).keys()
+
+    def values(self):
+        return [self[x] for x in self.keys()]
+
+    def items(self):
+        return [(x, self[x]) for x in self.keys()]
+
+    def __get_default(self, item):
+        default_value = self._box_config['default_box_attr']
+        if default_value is self.__class__:
+            return self.__class__(__box_heritage=(self, item),
+                                  **self.__box_config())
+        elif isinstance(default_value, Callable):
+            return default_value()
+        elif hasattr(default_value, 'copy'):
+            return default_value.copy()
+        return default_value
+
+    def __box_config(self):
+        out = {}
+        for k, v in self._box_config.copy().items():
+            if not k.startswith("__"):
+                out[k] = v
+        return out
+
+    def __convert_and_store(self, item, value):
+        if item in self._box_config['__converted']:
+            return value
+        if isinstance(value, dict) and not isinstance(value, Box):
+            value = self.__class__(value, __box_heritage=(self, item),
+                                   **self.__box_config())
+            self[item] = value
+        elif isinstance(value, list) and not isinstance(value, BoxList):
+            if self._box_config['frozen_box']:
+                value = _recursive_tuples(value, self.__class__,
+                                          recreate_tuples=self._box_config[
+                                              'modify_tuples_box'],
+                                          __box_heritage=(self, item),
+                                          **self.__box_config())
+            else:
+                value = BoxList(value, __box_heritage=(self, item),
+                                box_class=self.__class__,
+                                **self.__box_config())
+            self[item] = value
+        elif (self._box_config['modify_tuples_box'] and
+              isinstance(value, tuple)):
+            value = _recursive_tuples(value, self.__class__,
+                                      recreate_tuples=True,
+                                      __box_heritage=(self, item),
+                                      **self.__box_config())
+            self[item] = value
+        self._box_config['__converted'].add(item)
+        return value
+
+    def __create_lineage(self):
+        if (self._box_config['__box_heritage'] and
+                self._box_config['__created']):
+            past, item = self._box_config['__box_heritage']
+            if not past[item]:
+                past[item] = self
+            self._box_config['__box_heritage'] = None
+
+    def __getattr__(self, item):
+        try:
+            try:
+                value = self.__getitem__(item, _ignore_default=True)
+            except KeyError:
+                value = object.__getattribute__(self, item)
+        except AttributeError as err:
+            if item == "__getstate__":
+                raise AttributeError(item)
+            if item == '_box_config':
+                raise BoxError('_box_config key must exist')
+            kill_camel = self._box_config['camel_killer_box']
+            if self._box_config['conversion_box'] and item:
+                k = _conversion_checks(item, self.keys(), self._box_config)
+                if k:
+                    return self.__getitem__(k)
+            if kill_camel:
+                for k in self.keys():
+                    if item == _camel_killer(k):
+                        return self.__getitem__(k)
+            if self._box_config['default_box']:
+                return self.__get_default(item)
+            raise BoxKeyError(str(err))
+        else:
+            if item == '_box_config':
+                return value
+            return self.__convert_and_store(item, value)
+
+    def __setitem__(self, key, value):
+        if (key != '_box_config' and self._box_config['__created'] and
+                self._box_config['frozen_box']):
+            raise BoxError('Box is frozen')
+        if self._box_config['conversion_box']:
+            _conversion_checks(key, self.keys(), self._box_config,
+                               check_only=True, pre_check=True)
+        super(Box, self).__setitem__(key, value)
+        self.__add_ordered(key)
+        self.__create_lineage()
+
+    def __setattr__(self, key, value):
+        if (key != '_box_config' and self._box_config['frozen_box'] and
+                self._box_config['__created']):
+            raise BoxError('Box is frozen')
+        if key in self._protected_keys:
+            raise AttributeError("Key name '{0}' is protected".format(key))
+        if key == '_box_config':
+            return object.__setattr__(self, key, value)
+        try:
+            object.__getattribute__(self, key)
+        except (AttributeError, UnicodeEncodeError):
+            if (key not in self.keys() and
+                    (self._box_config['conversion_box'] or
+                     self._box_config['camel_killer_box'])):
+                if self._box_config['conversion_box']:
+                    k = _conversion_checks(key, self.keys(),
+                                           self._box_config)
+                    self[key if not k else k] = value
+                elif self._box_config['camel_killer_box']:
+                    for each_key in self:
+                        if key == _camel_killer(each_key):
+                            self[each_key] = value
+                            break
+            else:
+                self[key] = value
+        else:
+            object.__setattr__(self, key, value)
+        self.__add_ordered(key)
+        self.__create_lineage()
+
+    def __delitem__(self, key):
+        if self._box_config['frozen_box']:
+            raise BoxError('Box is frozen')
+        super(Box, self).__delitem__(key)
+        if (self._box_config['ordered_box'] and
+                key in self._box_config['__ordered_box_values']):
+            self._box_config['__ordered_box_values'].remove(key)
+
+    def __delattr__(self, item):
+        if self._box_config['frozen_box']:
+            raise BoxError('Box is frozen')
+        if item == '_box_config':
+            raise BoxError('"_box_config" is protected')
+        if item in self._protected_keys:
+            raise AttributeError("Key name '{0}' is protected".format(item))
+        try:
+            object.__getattribute__(self, item)
+        except AttributeError:
+            del self[item]
+        else:
+            object.__delattr__(self, item)
+        if (self._box_config['ordered_box'] and
+                item in self._box_config['__ordered_box_values']):
+            self._box_config['__ordered_box_values'].remove(item)
+
+    def pop(self, key, *args):
+        if args:
+            if len(args) != 1:
+                raise BoxError('pop() takes only one optional'
+                               ' argument "default"')
+            try:
+                item = self[key]
+            except KeyError:
+                return args[0]
+            else:
+                del self[key]
+                return item
+        try:
+            item = self[key]
+        except KeyError:
+            raise BoxKeyError('{0}'.format(key))
+        else:
+            del self[key]
+            return item
+
+    def clear(self):
+        self._box_config['__ordered_box_values'] = []
+        super(Box, self).clear()
+
+    def popitem(self):
+        try:
+            key = next(self.__iter__())
+        except StopIteration:
+            raise BoxKeyError('Empty box')
+        return key, self.pop(key)
+
+    def __repr__(self):
+        return '<Box: {0}>'.format(str(self.to_dict()))
+
+    def __str__(self):
+        return str(self.to_dict())
+
+    def __iter__(self):
+        for key in self.keys():
+            yield key
+
+    def __reversed__(self):
+        for key in reversed(list(self.keys())):
+            yield key
+
+    def to_dict(self):
+        """
+        Turn the Box and sub Boxes back into a native
+        python dictionary.
+
+        :return: python dictionary of this Box
+        """
+        out_dict = dict(self)
+        for k, v in out_dict.items():
+            if v is self:
+                out_dict[k] = out_dict
+            elif hasattr(v, 'to_dict'):
+                out_dict[k] = v.to_dict()
+            elif hasattr(v, 'to_list'):
+                out_dict[k] = v.to_list()
+        return out_dict
+
+    def update(self, item=None, **kwargs):
+        if not item:
+            item = kwargs
+        iter_over = item.items() if hasattr(item, 'items') else item
+        for k, v in iter_over:
+            if isinstance(v, dict):
+                # Box objects must be created in case they are already
+                # in the `converted` box_config set
+                v = self.__class__(v)
+                if k in self and isinstance(self[k], dict):
+                    self[k].update(v)
+                    continue
+            if isinstance(v, list):
+                v = BoxList(v)
+            try:
+                self.__setattr__(k, v)
+            except (AttributeError, TypeError):
+                self.__setitem__(k, v)
+
+    def setdefault(self, item, default=None):
+        if item in self:
+            return self[item]
+
+        if isinstance(default, dict):
+            default = self.__class__(default)
+        if isinstance(default, list):
+            default = BoxList(default)
+        self[item] = default
+        return default
+
+    def to_json(self, filename=None,
+                encoding="utf-8", errors="strict", **json_kwargs):
+        """
+        Transform the Box object into a JSON string.
+
+        :param filename: If provided will save to file
+        :param encoding: File encoding
+        :param errors: How to handle encoding errors
+        :param json_kwargs: additional arguments to pass to json.dump(s)
+        :return: string of JSON or return of `json.dump`
+        """
+        return _to_json(self.to_dict(), filename=filename,
+                        encoding=encoding, errors=errors, **json_kwargs)
+
+    @classmethod
+    def from_json(cls, json_string=None, filename=None,
+                  encoding="utf-8", errors="strict", **kwargs):
+        """
+        Transform a json object string into a Box object. If the incoming
+        json is a list, you must use BoxList.from_json.
+
+        :param json_string: string to pass to `json.loads`
+        :param filename: filename to open and pass to `json.load`
+        :param encoding: File encoding
+        :param errors: How to handle encoding errors
+        :param kwargs: parameters to pass to `Box()` or `json.loads`
+        :return: Box object from json data
+        """
+        bx_args = {}
+        for arg in kwargs.copy():
+            if arg in BOX_PARAMETERS:
+                bx_args[arg] = kwargs.pop(arg)
+
+        data = _from_json(json_string, filename=filename,
+                          encoding=encoding, errors=errors, **kwargs)
+
+        if not isinstance(data, dict):
+            raise BoxError('json data not returned as a dictionary, '
+                           'but rather a {0}'.format(type(data).__name__))
+        return cls(data, **bx_args)
+
+    if yaml_support:
+        def to_yaml(self, filename=None, default_flow_style=False,
+                    encoding="utf-8", errors="strict",
+                    **yaml_kwargs):
+            """
+            Transform the Box object into a YAML string.
+
+            :param filename:  If provided will save to file
+            :param default_flow_style: False will recursively dump dicts
+            :param encoding: File encoding
+            :param errors: How to handle encoding errors
+            :param yaml_kwargs: additional arguments to pass to yaml.dump
+            :return: string of YAML or return of `yaml.dump`
+            """
+            return _to_yaml(self.to_dict(), filename=filename,
+                            default_flow_style=default_flow_style,
+                            encoding=encoding, errors=errors, **yaml_kwargs)
+
+        @classmethod
+        def from_yaml(cls, yaml_string=None, filename=None,
+                      encoding="utf-8", errors="strict",
+                      loader=yaml.SafeLoader, **kwargs):
+            """
+            Transform a yaml object string into a Box object.
+
+            :param yaml_string: string to pass to `yaml.load`
+            :param filename: filename to open and pass to `yaml.load`
+            :param encoding: File encoding
+            :param errors: How to handle encoding errors
+            :param loader: YAML Loader, defaults to SafeLoader
+            :param kwargs: parameters to pass to `Box()` or `yaml.load`
+            :return: Box object from yaml data
+            """
+            bx_args = {}
+            for arg in kwargs.copy():
+                if arg in BOX_PARAMETERS:
+                    bx_args[arg] = kwargs.pop(arg)
+
+            data = _from_yaml(yaml_string=yaml_string, filename=filename,
+                              encoding=encoding, errors=errors,
+                              Loader=loader, **kwargs)
+            if not isinstance(data, dict):
+                raise BoxError('yaml data not returned as a dictionary'
+                               'but rather a {0}'.format(type(data).__name__))
+            return cls(data, **bx_args)
+
+
+class BoxList(list):
+    """
+    Drop in replacement of list, that converts added objects to Box or BoxList
+    objects as necessary.
+    """
+
+    def __init__(self, iterable=None, box_class=Box, **box_options):
+        self.box_class = box_class
+        self.box_options = box_options
+        self.box_org_ref = self.box_org_ref = id(iterable) if iterable else 0
+        if iterable:
+            for x in iterable:
+                self.append(x)
+        if box_options.get('frozen_box'):
+            def frozen(*args, **kwargs):
+                raise BoxError('BoxList is frozen')
+
+            for method in ['append', 'extend', 'insert', 'pop',
+                           'remove', 'reverse', 'sort']:
+                self.__setattr__(method, frozen)
+
+    def __delitem__(self, key):
+        if self.box_options.get('frozen_box'):
+            raise BoxError('BoxList is frozen')
+        super(BoxList, self).__delitem__(key)
+
+    def __setitem__(self, key, value):
+        if self.box_options.get('frozen_box'):
+            raise BoxError('BoxList is frozen')
+        super(BoxList, self).__setitem__(key, value)
+
+    def append(self, p_object):
+        if isinstance(p_object, dict):
+            try:
+                p_object = self.box_class(p_object, **self.box_options)
+            except AttributeError as err:
+                if 'box_class' in self.__dict__:
+                    raise err
+        elif isinstance(p_object, list):
+            try:
+                p_object = (self if id(p_object) == self.box_org_ref else
+                            BoxList(p_object))
+            except AttributeError as err:
+                if 'box_org_ref' in self.__dict__:
+                    raise err
+        super(BoxList, self).append(p_object)
+
+    def extend(self, iterable):
+        for item in iterable:
+            self.append(item)
+
+    def insert(self, index, p_object):
+        if isinstance(p_object, dict):
+            p_object = self.box_class(p_object, **self.box_options)
+        elif isinstance(p_object, list):
+            p_object = (self if id(p_object) == self.box_org_ref else
+                        BoxList(p_object))
+        super(BoxList, self).insert(index, p_object)
+
+    def __repr__(self):
+        return "<BoxList: {0}>".format(self.to_list())
+
+    def __str__(self):
+        return str(self.to_list())
+
+    def __copy__(self):
+        return BoxList((x for x in self),
+                       self.box_class,
+                       **self.box_options)
+
+    def __deepcopy__(self, memodict=None):
+        out = self.__class__()
+        memodict = memodict or {}
+        memodict[id(self)] = out
+        for k in self:
+            out.append(copy.deepcopy(k))
+        return out
+
+    def __hash__(self):
+        if self.box_options.get('frozen_box'):
+            hashing = 98765
+            hashing ^= hash(tuple(self))
+            return hashing
+        raise TypeError("unhashable type: 'BoxList'")
+
+    def to_list(self):
+        new_list = []
+        for x in self:
+            if x is self:
+                new_list.append(new_list)
+            elif isinstance(x, Box):
+                new_list.append(x.to_dict())
+            elif isinstance(x, BoxList):
+                new_list.append(x.to_list())
+            else:
+                new_list.append(x)
+        return new_list
+
+    def to_json(self, filename=None,
+                encoding="utf-8", errors="strict",
+                multiline=False, **json_kwargs):
+        """
+        Transform the BoxList object into a JSON string.
+
+        :param filename: If provided will save to file
+        :param encoding: File encoding
+        :param errors: How to handle encoding errors
+        :param multiline: Put each item in list onto it's own line
+        :param json_kwargs: additional arguments to pass to json.dump(s)
+        :return: string of JSON or return of `json.dump`
+        """
+        if filename and multiline:
+            lines = [_to_json(item, filename=False, encoding=encoding,
+                              errors=errors, **json_kwargs) for item in self]
+            with open(filename, 'w', encoding=encoding, errors=errors) as f:
+                f.write("\n".join(lines).decode('utf-8') if
+                        sys.version_info < (3, 0) else "\n".join(lines))
+        else:
+            return _to_json(self.to_list(), filename=filename,
+                            encoding=encoding, errors=errors, **json_kwargs)
+
+    @classmethod
+    def from_json(cls, json_string=None, filename=None, encoding="utf-8",
+                  errors="strict", multiline=False, **kwargs):
+        """
+        Transform a json object string into a BoxList object. If the incoming
+        json is a dict, you must use Box.from_json.
+
+        :param json_string: string to pass to `json.loads`
+        :param filename: filename to open and pass to `json.load`
+        :param encoding: File encoding
+        :param errors: How to handle encoding errors
+        :param multiline: One object per line
+        :param kwargs: parameters to pass to `Box()` or `json.loads`
+        :return: BoxList object from json data
+        """
+        bx_args = {}
+        for arg in kwargs.copy():
+            if arg in BOX_PARAMETERS:
+                bx_args[arg] = kwargs.pop(arg)
+
+        data = _from_json(json_string, filename=filename, encoding=encoding,
+                          errors=errors, multiline=multiline, **kwargs)
+
+        if not isinstance(data, list):
+            raise BoxError('json data not returned as a list, '
+                           'but rather a {0}'.format(type(data).__name__))
+        return cls(data, **bx_args)
+
+    if yaml_support:
+        def to_yaml(self, filename=None, default_flow_style=False,
+                    encoding="utf-8", errors="strict",
+                    **yaml_kwargs):
+            """
+            Transform the BoxList object into a YAML string.
+
+            :param filename:  If provided will save to file
+            :param default_flow_style: False will recursively dump dicts
+            :param encoding: File encoding
+            :param errors: How to handle encoding errors
+            :param yaml_kwargs: additional arguments to pass to yaml.dump
+            :return: string of YAML or return of `yaml.dump`
+            """
+            return _to_yaml(self.to_list(), filename=filename,
+                            default_flow_style=default_flow_style,
+                            encoding=encoding, errors=errors, **yaml_kwargs)
+
+        @classmethod
+        def from_yaml(cls, yaml_string=None, filename=None,
+                      encoding="utf-8", errors="strict",
+                      loader=yaml.SafeLoader,
+                      **kwargs):
+            """
+            Transform a yaml object string into a BoxList object.
+
+            :param yaml_string: string to pass to `yaml.load`
+            :param filename: filename to open and pass to `yaml.load`
+            :param encoding: File encoding
+            :param errors: How to handle encoding errors
+            :param loader: YAML Loader, defaults to SafeLoader
+            :param kwargs: parameters to pass to `BoxList()` or `yaml.load`
+            :return: BoxList object from yaml data
+            """
+            bx_args = {}
+            for arg in kwargs.copy():
+                if arg in BOX_PARAMETERS:
+                    bx_args[arg] = kwargs.pop(arg)
+
+            data = _from_yaml(yaml_string=yaml_string, filename=filename,
+                              encoding=encoding, errors=errors,
+                              Loader=loader, **kwargs)
+            if not isinstance(data, list):
+                raise BoxError('yaml data not returned as a list'
+                               'but rather a {0}'.format(type(data).__name__))
+            return cls(data, **bx_args)
+
+    def box_it_up(self):
+        for v in self:
+            if hasattr(v, 'box_it_up') and v is not self:
+                v.box_it_up()
+
+
+class ConfigBox(Box):
+    """
+    Modified box object to add object transforms.
+
+    Allows for build in transforms like:
+
+    cns = ConfigBox(my_bool='yes', my_int='5', my_list='5,4,3,3,2')
+
+    cns.bool('my_bool') # True
+    cns.int('my_int') # 5
+    cns.list('my_list', mod=lambda x: int(x)) # [5, 4, 3, 3, 2]
+    """
+
+    _protected_keys = dir({}) + ['to_dict', 'bool', 'int', 'float',
+                                 'list', 'getboolean', 'to_json', 'to_yaml',
+                                 'getfloat', 'getint',
+                                 'from_json', 'from_yaml']
+
+    def __getattr__(self, item):
+        """Config file keys are stored in lower case, be a little more
+        loosey goosey"""
+        try:
+            return super(ConfigBox, self).__getattr__(item)
+        except AttributeError:
+            return super(ConfigBox, self).__getattr__(item.lower())
+
+    def __dir__(self):
+        return super(ConfigBox, self).__dir__() + ['bool', 'int', 'float',
+                                                   'list', 'getboolean',
+                                                   'getfloat', 'getint']
+
+    def bool(self, item, default=None):
+        """ Return value of key as a boolean
+
+        :param item: key of value to transform
+        :param default: value to return if item does not exist
+        :return: approximated bool of value
+        """
+        try:
+            item = self.__getattr__(item)
+        except AttributeError as err:
+            if default is not None:
+                return default
+            raise err
+
+        if isinstance(item, (bool, int)):
+            return bool(item)
+
+        if (isinstance(item, str) and
+                item.lower() in ('n', 'no', 'false', 'f', '0')):
+            return False
+
+        return True if item else False
+
+    def int(self, item, default=None):
+        """ Return value of key as an int
+
+        :param item: key of value to transform
+        :param default: value to return if item does not exist
+        :return: int of value
+        """
+        try:
+            item = self.__getattr__(item)
+        except AttributeError as err:
+            if default is not None:
+                return default
+            raise err
+        return int(item)
+
+    def float(self, item, default=None):
+        """ Return value of key as a float
+
+        :param item: key of value to transform
+        :param default: value to return if item does not exist
+        :return: float of value
+        """
+        try:
+            item = self.__getattr__(item)
+        except AttributeError as err:
+            if default is not None:
+                return default
+            raise err
+        return float(item)
+
+    def list(self, item, default=None, spliter=",", strip=True, mod=None):
+        """ Return value of key as a list
+
+        :param item: key of value to transform
+        :param mod: function to map against list
+        :param default: value to return if item does not exist
+        :param spliter: character to split str on
+        :param strip: clean the list with the `strip`
+        :return: list of items
+        """
+        try:
+            item = self.__getattr__(item)
+        except AttributeError as err:
+            if default is not None:
+                return default
+            raise err
+        if strip:
+            item = item.lstrip('[').rstrip(']')
+        out = [x.strip() if strip else x for x in item.split(spliter)]
+        if mod:
+            return list(map(mod, out))
+        return out
+
+    # loose configparser compatibility
+
+    def getboolean(self, item, default=None):
+        return self.bool(item, default)
+
+    def getint(self, item, default=None):
+        return self.int(item, default)
+
+    def getfloat(self, item, default=None):
+        return self.float(item, default)
+
+    def __repr__(self):
+        return '<ConfigBox: {0}>'.format(str(self.to_dict()))
+
+
+class SBox(Box):
+    """
+    ShorthandBox (SBox) allows for
+    property access of `dict` `json` and `yaml`
+    """
+    _protected_keys = dir({}) + ['to_dict', 'tree_view', 'to_json', 'to_yaml',
+                                 'json', 'yaml', 'from_yaml', 'from_json',
+                                 'dict']
+
+    @property
+    def dict(self):
+        return self.to_dict()
+
+    @property
+    def json(self):
+        return self.to_json()
+
+    if yaml_support:
+        @property
+        def yaml(self):
+            return self.to_yaml()
+
+    def __repr__(self):
+        return '<ShorthandBox: {0}>'.format(str(self.to_dict()))

+ 9 - 0
lcnn/config.py

@@ -0,0 +1,9 @@
+import numpy as np
+
+from lcnn.box import Box
+
+# C is a dict storing all the configuration
+C = Box()
+
+# shortcut for C.model
+M = Box()

+ 378 - 0
lcnn/dataset_tool.py

@@ -0,0 +1,378 @@
+import cv2
+import numpy as np
+import torch
+import torchvision
+from matplotlib import pyplot as plt
+import tools.transforms as reference_transforms
+from collections import defaultdict
+
+from tools import presets
+
+import json
+
+
+def get_modules(use_v2):
+    # We need a protected import to avoid the V2 warning in case just V1 is used
+    if use_v2:
+        import torchvision.transforms.v2
+        import torchvision.tv_tensors
+
+        return torchvision.transforms.v2, torchvision.tv_tensors
+    else:
+        return reference_transforms, None
+
+
+class Augmentation:
+    # Note: this transform assumes that the input to forward() are always PIL
+    # images, regardless of the backend parameter.
+    def __init__(
+            self,
+            *,
+            data_augmentation,
+            hflip_prob=0.5,
+            mean=(123.0, 117.0, 104.0),
+            backend="pil",
+            use_v2=False,
+    ):
+
+        T, tv_tensors = get_modules(use_v2)
+
+        transforms = []
+        backend = backend.lower()
+        if backend == "tv_tensor":
+            transforms.append(T.ToImage())
+        elif backend == "tensor":
+            transforms.append(T.PILToTensor())
+        elif backend != "pil":
+            raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
+
+        if data_augmentation == "hflip":
+            transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
+        elif data_augmentation == "lsj":
+            transforms += [
+                T.ScaleJitter(target_size=(1024, 1024), antialias=True),
+                # TODO: FixedSizeCrop below doesn't work on tensors!
+                reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        elif data_augmentation == "multiscale":
+            transforms += [
+                T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        elif data_augmentation == "ssd":
+            fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
+            transforms += [
+                T.RandomPhotometricDistort(),
+                T.RandomZoomOut(fill=fill),
+                T.RandomIoUCrop(),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        elif data_augmentation == "ssdlite":
+            transforms += [
+                T.RandomIoUCrop(),
+                T.RandomHorizontalFlip(p=hflip_prob),
+            ]
+        else:
+            raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
+
+        if backend == "pil":
+            # Note: we could just convert to pure tensors even in v2.
+            transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
+
+        transforms += [T.ToDtype(torch.float, scale=True)]
+
+        if use_v2:
+            transforms += [
+                T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
+                T.SanitizeBoundingBoxes(),
+                T.ToPureTensor(),
+            ]
+
+        self.transforms = T.Compose(transforms)
+
+    def __call__(self, img, target):
+        return self.transforms(img, target)
+
+
+def read_polygon_points(lbl_path, shape):
+    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
+    polygon_points = []
+    w, h = shape[:2]
+    with open(lbl_path, 'r') as f:
+        lines = f.readlines()
+
+    for line in lines:
+        parts = line.strip().split()
+        class_id = int(parts[0])
+        points = np.array(parts[1:], dtype=np.float32).reshape(-1, 2)  # 读取点坐标
+        points[:, 0] *= h
+        points[:, 1] *= w
+
+        polygon_points.append((class_id, points))
+
+    return polygon_points
+
+
+def read_masks_from_pixels(lbl_path, shape):
+    """读取纯像素点格式的文件,不是轮廓像素点"""
+    h, w = shape
+    masks = []
+    labels = []
+
+    with open(lbl_path, 'r') as reader:
+        lines = reader.readlines()
+        mask_points = []
+        for line in lines:
+            mask = torch.zeros((h, w), dtype=torch.uint8)
+            parts = line.strip().split()
+            # print(f'parts:{parts}')
+            cls = torch.tensor(int(parts[0]), dtype=torch.int64)
+            labels.append(cls)
+            x_array = parts[1::2]
+            y_array = parts[2::2]
+
+            for x, y in zip(x_array, y_array):
+                x = float(x)
+                y = float(y)
+                mask_points.append((int(y * h), int(x * w)))
+
+            for p in mask_points:
+                mask[p] = 1
+            masks.append(mask)
+    reader.close()
+    return labels, masks
+
+
+def create_masks_from_polygons(polygons, image_shape):
+    """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
+    colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
+    masks = []
+
+    for polygon_data, col in zip(polygons, colors):
+        mask = np.zeros(image_shape[:2], dtype=np.uint8)
+        # 将多边形顶点转换为 NumPy 数组
+        _, polygon = polygon_data
+        pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
+
+        # 使用 OpenCV 的 fillPoly 函数填充多边形
+        # print(f'color:{col[:3]}')
+        cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
+        mask = torch.from_numpy(mask)
+        mask[mask != 0] = 1
+        masks.append(mask)
+
+    return masks
+
+
+def read_masks_from_txt(label_path, shape):
+    polygon_points = read_polygon_points(label_path, shape)
+    masks = create_masks_from_polygons(polygon_points, shape)
+    labels = [torch.tensor(item[0]) for item in polygon_points]
+
+    return labels, masks
+
+
+def masks_to_boxes(masks: torch.Tensor, ) -> torch.Tensor:
+    """
+    Compute the bounding boxes around the provided masks.
+
+    Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
+    ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
+
+    Args:
+        masks (Tensor[N, H, W]): masks to transform where N is the number of masks
+            and (H, W) are the spatial dimensions.
+
+    Returns:
+        Tensor[N, 4]: bounding boxes
+    """
+    # if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+    #     _log_api_usage_once(masks_to_boxes)
+    if masks.numel() == 0:
+        return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
+
+    n = masks.shape[0]
+
+    bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
+
+    for index, mask in enumerate(masks):
+        y, x = torch.where(mask != 0)
+        bounding_boxes[index, 0] = torch.min(x)
+        bounding_boxes[index, 1] = torch.min(y)
+        bounding_boxes[index, 2] = torch.max(x)
+        bounding_boxes[index, 3] = torch.max(y)
+        # debug to pixel datasets
+
+        if bounding_boxes[index, 0] == bounding_boxes[index, 2]:
+            bounding_boxes[index, 2] = bounding_boxes[index, 2] + 1
+            bounding_boxes[index, 0] = bounding_boxes[index, 0] - 1
+
+        if bounding_boxes[index, 1] == bounding_boxes[index, 3]:
+            bounding_boxes[index, 3] = bounding_boxes[index, 3] + 1
+            bounding_boxes[index, 1] = bounding_boxes[index, 1] - 1
+
+    return bounding_boxes
+
+
+def line_boxes_faster(target):
+    boxs = []
+    lpre = target["lpre"].cpu().numpy() * 4
+    vecl_target = target["lpre_label"].cpu().numpy()
+    lpre = lpre[vecl_target == 1]
+
+    lines = lpre
+    sline = np.ones(lpre.shape[0])
+
+    if len(lines) > 0 and not (lines[0] == 0).all():
+        for i, ((a, b), s) in enumerate(zip(lines, sline)):
+            if i > 0 and (lines[i] == lines[0]).all():
+                break
+            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+
+            if a[1] > b[1]:
+                ymax = a[1] + 10
+                ymin = b[1] - 10
+            else:
+                ymin = a[1] - 10
+                ymax = b[1] + 10
+            if a[0] > b[0]:
+                xmax = a[0] + 10
+                xmin = b[0] - 10
+            else:
+                xmin = a[0] - 10
+                xmax = b[0] + 10
+            boxs.append([max(0, ymin), max(0, xmin), min(512, ymax), min(512, xmax)])
+
+    return torch.tensor(boxs)
+
+
+# 将线段变为 [起点,终点]形式  传入[num,2,2]
+def a_to_b(line):
+    result_pairs = []
+    for a, b in line:
+        min_x = min(a[0], b[0])
+        min_y = min(a[1], b[1])
+        new_top_left_x = max((min_x - 10), 0)
+        new_top_left_y = max((min_y - 10), 0)
+        dist_a = (a[0] - new_top_left_x) ** 2 + (a[1] - new_top_left_y) ** 2
+        dist_b = (b[0] - new_top_left_x) ** 2 + (b[1] - new_top_left_y) ** 2
+
+        # 根据距离选择起点并设置标签
+        if dist_a <= dist_b:  # 如果a点离新左上角更近或两者距离相等
+            result_pairs.append([a, b])  # 将a设为起点,b为终点
+        else:  # 如果b点离新左上角更近
+            result_pairs.append([b, a])  # 将b设为起点,a为终点
+    result_tensor = torch.stack([torch.stack(row) for row in result_pairs])
+    return result_tensor
+
+
+def read_polygon_points_wire(lbl_path, shape):
+    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
+    polygon_points = []
+    w, h = shape[:2]
+    with open(lbl_path, 'r') as f:
+        lines = json.load(f)
+
+    for line in lines["segmentations"]:
+        parts = line["data"]
+        class_id = int(line["cls_id"])
+        points = np.array(parts, dtype=np.float32).reshape(-1, 2)  # 读取点坐标
+        points[:, 0] *= h
+        points[:, 1] *= w
+
+        polygon_points.append((class_id, points))
+
+    return polygon_points
+
+
+def read_masks_from_txt_wire(label_path, shape):
+    polygon_points = read_polygon_points_wire(label_path, shape)
+    masks = create_masks_from_polygons(polygon_points, shape)
+    labels = [torch.tensor(item[0]) for item in polygon_points]
+
+    return labels, masks
+
+
+def read_masks_from_pixels_wire(lbl_path, shape):
+    """读取纯像素点格式的文件,不是轮廓像素点"""
+    h, w = shape
+    masks = []
+    labels = []
+
+    with open(lbl_path, 'r') as reader:
+        lines = json.load(reader)
+        mask_points = []
+        for line in lines["segmentations"]:
+            # mask = torch.zeros((h, w), dtype=torch.uint8)
+            # parts = line["data"]
+            # print(f'parts:{parts}')
+            cls = torch.tensor(int(line["cls_id"]), dtype=torch.int64)
+            labels.append(cls)
+            # x_array = parts[0::2]
+            # y_array = parts[1::2]
+            # 
+            # for x, y in zip(x_array, y_array):
+            #     x = float(x)
+            #     y = float(y)
+            #     mask_points.append((int(y * h), int(x * w)))
+
+            # for p in mask_points:
+            #     mask[p] = 1
+            # masks.append(mask)
+    reader.close()
+    return labels
+
+
+def read_lable_keypoint(lbl_path):
+    """判断线段的起点终点, 起点 lable=0, 终点 lable=1"""
+    labels = []
+
+    with open(lbl_path, 'r') as reader:
+        lines = json.load(reader)
+        aa = lines["wires"][0]["line_pos_coords"]["content"]
+
+        result_pairs = []
+        for a, b in aa:
+            min_x = min(a[0], b[0])
+            min_y = min(a[1], b[1])
+
+            # 定义新的左上角位置
+            new_top_left_x = max((min_x - 10), 0)
+            new_top_left_y = max((min_y - 10), 0)
+
+            # Step 2: 计算各点到新左上角的距离平方(避免浮点运算误差)
+            dist_a = (a[0] - new_top_left_x) ** 2 + (a[1] - new_top_left_y) ** 2
+            dist_b = (b[0] - new_top_left_x) ** 2 + (b[1] - new_top_left_y) ** 2
+
+            # Step 3 & 4: 根据距离选择起点并设置标签
+            if dist_a <= dist_b:  # 如果a点离新左上角更近或两者距离相等
+                result_pairs.append([a, b])  # 将a设为起点,b为终点
+            else:  # 如果b点离新左上角更近
+                result_pairs.append([b, a])  # 将b设为起点,a为终点
+
+            # x_ = abs(a[0] - b[0])
+            # y_ = abs(a[1] - b[1])
+            # if x_ > y_:  # x小的离左上角近
+            #     if a[0] < b[0]:  # 视为起点,lable=0
+            #         label = [0, 1]
+            #     else:
+            #         label = [1, 0]
+            # else:  # x大的是起点
+            #     if a[0] > b[0]:  # 视为起点,lable=0
+            #         label = [0, 1]
+            #     else:
+            #         label = [1, 0]
+            # labels.append(label)
+        # print(result_pairs )
+    reader.close()
+    return labels
+
+
+def adjacency_matrix(n, link):  # 邻接矩阵
+    mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
+    link = torch.tensor(link)
+    if len(link) > 0:
+        mat[link[:, 0], link[:, 1]] = 1
+        mat[link[:, 1], link[:, 0]] = 1
+    return mat

+ 294 - 0
lcnn/datasets.py

@@ -0,0 +1,294 @@
+# import glob
+# import json
+# import math
+# import os
+# import random
+#
+# import numpy as np
+# import numpy.linalg as LA
+# import torch
+# from skimage import io
+# from torch.utils.data import Dataset
+# from torch.utils.data.dataloader import default_collate
+#
+# from lcnn.config import M
+#
+# from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
+#
+#
+# class WireframeDataset(Dataset):
+#     def __init__(self, rootdir, split):
+#         self.rootdir = rootdir
+#         filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
+#         filelist.sort()
+#
+#         # print(f"n{split}:", len(filelist))
+#         self.split = split
+#         self.filelist = filelist
+#
+#     def __len__(self):
+#         return len(self.filelist)
+#
+#     def __getitem__(self, idx):
+#         iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
+#         image = io.imread(iname).astype(float)[:, :, :3]
+#         if "a1" in self.filelist[idx]:
+#             image = image[:, ::-1, :]
+#         image = (image - M.image.mean) / M.image.stddev
+#         image = np.rollaxis(image, 2).copy()
+#
+#         with np.load(self.filelist[idx]) as npz:
+#             target = {
+#                 name: torch.from_numpy(npz[name]).float()
+#                 for name in ["jmap", "joff", "lmap"]
+#             }
+#             lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
+#             lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
+#             npos, nneg = len(lpos), len(lneg)
+#             lpre = np.concatenate([lpos, lneg], 0)
+#             for i in range(len(lpre)):
+#                 if random.random() > 0.5:
+#                     lpre[i] = lpre[i, ::-1]
+#             ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+#             ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+#             feat = [
+#                 lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
+#                 ldir * M.use_slop,
+#                 lpre[:, :, 2],
+#             ]
+#             feat = np.concatenate(feat, 1)
+#             meta = {
+#                 "junc": torch.from_numpy(npz["junc"][:, :2]),
+#                 "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
+#                 "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
+#                 "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
+#                 "lpre": torch.from_numpy(lpre[:, :, :2]),
+#                 "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
+#                 "lpre_feat": torch.from_numpy(feat),
+#             }
+#
+#         labels = []
+#         labels = read_masks_from_pixels_wire(iname, (512, 512))
+#         # if self.target_type == 'polygon':
+#         #     labels, masks = read_masks_from_txt_wire(iname, (512, 512))
+#         # elif self.target_type == 'pixel':
+#         #     labels = read_masks_from_pixels_wire(iname, (512, 512))
+#
+#         target["labels"] = torch.stack(labels)
+#         target["boxes"] = line_boxes_faster(meta)
+#
+#
+#         return torch.from_numpy(image).float(), meta, target
+#
+#     def adjacency_matrix(self, n, link):
+#         mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
+#         link = torch.from_numpy(link)
+#         if len(link) > 0:
+#             mat[link[:, 0], link[:, 1]] = 1
+#             mat[link[:, 1], link[:, 0]] = 1
+#         return mat
+#
+#
+# def collate(batch):
+#     return (
+#         default_collate([b[0] for b in batch]),
+#         [b[1] for b in batch],
+#         default_collate([b[2] for b in batch]),
+#     )
+
+from torch.utils.data.dataset import T_co
+
+from .models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+
+
+class  WireframeDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        meta, target, target_b = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+
+        img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, meta, target, target_b
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'shape:{shape}')
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        meta = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+
+        }
+
+        target = {
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target_b = {}
+
+        # target_b["image_id"] = torch.tensor(item)
+
+        target_b["labels"] = torch.stack(labels)
+        target_b["boxes"] = line_boxes_faster(meta)
+
+        return meta, target, target_b
+
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        def draw_vecl(lines, sline, juncs, junts, fn=None):
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            imshow(io.imread(img_path))
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            img = PIL.Image.open(img_path).convert('RGB')
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
+
+            plt.show()
+            if fn != None:
+                plt.savefig(fn)
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
+
+    def show_img(self, img_path):
+        pass
+
+
+def collate(batch):
+    return (
+        default_collate([b[0] for b in batch]),
+        [b[1] for b in batch],
+        default_collate([b[2] for b in batch]),
+        [b[3] for b in batch],
+    )
+
+
+# if __name__ == '__main__':
+#     path = r"D:\python\PycharmProjects\data"
+#     dataset = WireframeDataset(dataset_path=path, dataset_type='train')
+#     dataset.show(0)
+
+

+ 209 - 0
lcnn/metric.py

@@ -0,0 +1,209 @@
+import numpy as np
+import numpy.linalg as LA
+import matplotlib.pyplot as plt
+
+from lcnn.utils import argsort2d
+
+DX = [0, 0, 1, -1, 1, 1, -1, -1]
+DY = [1, -1, 0, 0, 1, -1, 1, -1]
+
+
+def ap(tp, fp):
+    recall = tp
+    precision = tp / np.maximum(tp + fp, 1e-9)
+
+    recall = np.concatenate(([0.0], recall, [1.0]))
+    precision = np.concatenate(([0.0], precision, [0.0]))
+
+    for i in range(precision.size - 1, 0, -1):
+        precision[i - 1] = max(precision[i - 1], precision[i])
+    i = np.where(recall[1:] != recall[:-1])[0]
+    return np.sum((recall[i + 1] - recall[i]) * precision[i + 1])
+
+
+def APJ(vert_pred, vert_gt, max_distance, im_ids):
+    if len(vert_pred) == 0:
+        return 0
+
+    vert_pred = np.array(vert_pred)
+    vert_gt = np.array(vert_gt)
+
+    confidence = vert_pred[:, -1]
+    idx = np.argsort(-confidence)
+    vert_pred = vert_pred[idx, :]
+    im_ids = im_ids[idx]
+    n_gt = sum(len(gt) for gt in vert_gt)
+
+    nd = len(im_ids)
+    tp, fp = np.zeros(nd, dtype=np.float), np.zeros(nd, dtype=np.float)
+    hit = [[False for _ in j] for j in vert_gt]
+
+    for i in range(nd):
+        gt_juns = vert_gt[im_ids[i]]
+        pred_juns = vert_pred[i][:-1]
+        if len(gt_juns) == 0:
+            continue
+        dists = np.linalg.norm((pred_juns[None, :] - gt_juns), axis=1)
+        choice = np.argmin(dists)
+        dist = np.min(dists)
+        if dist < max_distance and not hit[im_ids[i]][choice]:
+            tp[i] = 1
+            hit[im_ids[i]][choice] = True
+        else:
+            fp[i] = 1
+
+    tp = np.cumsum(tp) / n_gt
+    fp = np.cumsum(fp) / n_gt
+    return ap(tp, fp)
+
+
+def nms_j(heatmap, delta=1):
+    heatmap = heatmap.copy()
+    disable = np.zeros_like(heatmap, dtype=np.bool)
+    for x, y in argsort2d(heatmap):
+        for dx, dy in zip(DX, DY):
+            xp, yp = x + dx, y + dy
+            if not (0 <= xp < heatmap.shape[0] and 0 <= yp < heatmap.shape[1]):
+                continue
+            if heatmap[x, y] >= heatmap[xp, yp]:
+                disable[xp, yp] = True
+    heatmap[disable] *= 0.6
+    return heatmap
+
+
+def mAPJ(pred, truth, distances, im_ids):
+    return sum(APJ(pred, truth, d, im_ids) for d in distances) / len(distances) * 100
+
+
+def post_jheatmap(heatmap, offset=None, delta=1):
+    heatmap = nms_j(heatmap, delta=delta)
+    # only select the best 1000 junctions for efficiency
+    v0 = argsort2d(-heatmap)[:1000]
+    confidence = -np.sort(-heatmap.ravel())[:1000]
+    keep_id = np.where(confidence >= 1e-2)[0]
+    if len(keep_id) == 0:
+        return np.zeros((0, 3))
+
+    confidence = confidence[keep_id]
+    if offset is not None:
+        v0 = np.array([v + offset[:, v[0], v[1]] for v in v0])
+    v0 = v0[keep_id] + 0.5
+    v0 = np.hstack((v0, confidence[:, np.newaxis]))
+    return v0
+
+
+def vectorized_wireframe_2d_metric(
+    vert_pred, dpth_pred, edge_pred, vert_gt, dpth_gt, edge_gt, threshold
+):
+    # staging 1: matching
+    nd = len(vert_pred)
+    sorted_confidence = np.argsort(-vert_pred[:, -1])
+    vert_pred = vert_pred[sorted_confidence, :-1]
+    dpth_pred = dpth_pred[sorted_confidence]
+    d = np.sqrt(
+        np.sum(vert_pred ** 2, 1)[:, None]
+        + np.sum(vert_gt ** 2, 1)[None, :]
+        - 2 * vert_pred @ vert_gt.T
+    )
+    choice = np.argmin(d, 1)
+    dist = np.min(d, 1)
+
+    # staging 2: compute depth metric: SIL/L2
+    loss_L1 = loss_L2 = 0
+    hit = np.zeros_like(dpth_gt, np.bool)
+    SIL = np.zeros(dpth_pred)
+    for i in range(nd):
+        if dist[i] < threshold and not hit[choice[i]]:
+            hit[choice[i]] = True
+            loss_L1 += abs(dpth_gt[choice[i]] - dpth_pred[i])
+            loss_L2 += (dpth_gt[choice[i]] - dpth_pred[i]) ** 2
+            a = np.maximum(-dpth_pred[i], 1e-10)
+            b = -dpth_gt[choice[i]]
+            SIL[i] = np.log(a) - np.log(b)
+        else:
+            choice[i] = -1
+
+    n = max(np.sum(hit), 1)
+    loss_L1 /= n
+    loss_L2 /= n
+    loss_SIL = np.sum(SIL ** 2) / n - np.sum(SIL) ** 2 / (n * n)
+
+    # staging 3: compute mAP for edge matching
+    edgeset = set([frozenset(e) for e in edge_gt])
+    tp = np.zeros(len(edge_pred), dtype=np.float)
+    fp = np.zeros(len(edge_pred), dtype=np.float)
+    for i, (v0, v1, score) in enumerate(sorted(edge_pred, key=-edge_pred[2])):
+        length = LA.norm(vert_gt[v0] - vert_gt[v1], axis=1)
+        if frozenset([choice[v0], choice[v1]]) in edgeset:
+            tp[i] = length
+        else:
+            fp[i] = length
+    total_length = LA.norm(
+        vert_gt[edge_gt[:, 0]] - vert_gt[edge_gt[:, 1]], axis=1
+    ).sum()
+    return ap(tp / total_length, fp / total_length), (loss_SIL, loss_L1, loss_L2)
+
+
+def vectorized_wireframe_3d_metric(
+    vert_pred, dpth_pred, edge_pred, vert_gt, dpth_gt, edge_gt, threshold
+):
+    # staging 1: matching
+    nd = len(vert_pred)
+    sorted_confidence = np.argsort(-vert_pred[:, -1])
+    vert_pred = np.hstack([vert_pred[:, :-1], dpth_pred[:, None]])[sorted_confidence]
+    vert_gt = np.hstack([vert_gt[:, :-1], dpth_gt[:, None]])
+    d = np.sqrt(
+        np.sum(vert_pred ** 2, 1)[:, None]
+        + np.sum(vert_gt ** 2, 1)[None, :]
+        - 2 * vert_pred @ vert_gt.T
+    )
+    choice = np.argmin(d, 1)
+    dist = np.min(d, 1)
+    hit = np.zeros_like(dpth_gt, np.bool)
+    for i in range(nd):
+        if dist[i] < threshold and not hit[choice[i]]:
+            hit[choice[i]] = True
+        else:
+            choice[i] = -1
+
+    # staging 2: compute mAP for edge matching
+    edgeset = set([frozenset(e) for e in edge_gt])
+    tp = np.zeros(len(edge_pred), dtype=np.float)
+    fp = np.zeros(len(edge_pred), dtype=np.float)
+    for i, (v0, v1, score) in enumerate(sorted(edge_pred, key=-edge_pred[2])):
+        length = LA.norm(vert_gt[v0] - vert_gt[v1], axis=1)
+        if frozenset([choice[v0], choice[v1]]) in edgeset:
+            tp[i] = length
+        else:
+            fp[i] = length
+    total_length = LA.norm(
+        vert_gt[edge_gt[:, 0]] - vert_gt[edge_gt[:, 1]], axis=1
+    ).sum()
+
+    return ap(tp / total_length, fp / total_length)
+
+
+def msTPFP(line_pred, line_gt, threshold):
+    diff = ((line_pred[:, None, :, None] - line_gt[:, None]) ** 2).sum(-1)
+    diff = np.minimum(
+        diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0]
+    )
+    choice = np.argmin(diff, 1)
+    dist = np.min(diff, 1)
+    hit = np.zeros(len(line_gt), np.bool)
+    tp = np.zeros(len(line_pred), np.float)
+    fp = np.zeros(len(line_pred), np.float)
+    for i in range(len(line_pred)):
+        if dist[i] < threshold and not hit[choice[i]]:
+            hit[choice[i]] = True
+            tp[i] = 1
+        else:
+            fp[i] = 1
+    return tp, fp
+
+
+def msAP(line_pred, line_gt, threshold):
+    tp, fp = msTPFP(line_pred, line_gt, threshold)
+    tp = np.cumsum(tp) / len(line_gt)
+    fp = np.cumsum(fp) / len(line_gt)
+    return ap(tp, fp)

+ 9 - 0
lcnn/models/__init__.py

@@ -0,0 +1,9 @@
+# flake8: noqa
+from .hourglass_pose import hg
+from .unet import unet
+from .resnet50_pose import resnet50
+from .resnet50 import resnet501
+from .fasterrcnn_resnet50 import fasterrcnn_resnet50
+# from .dla import dla169, dla102x, dla102x2
+
+__all__ = ['hg', 'unet', 'resnet50', 'resnet501', 'fasterrcnn_resnet50']

+ 0 - 0
lcnn/models/base/__init__.py


+ 48 - 0
lcnn/models/base/base_dataset.py

@@ -0,0 +1,48 @@
+from abc import ABC, abstractmethod
+
+import torch
+from torch import nn, Tensor
+from torch.utils.data import Dataset
+from torch.utils.data.dataset import T_co
+
+from torchvision.transforms import  functional as F
+
+class BaseDataset(Dataset, ABC):
+    def __init__(self,dataset_path):
+        self.default_transform=DefaultTransform()
+        pass
+
+    def __getitem__(self, index) -> T_co:
+        pass
+
+    @abstractmethod
+    def read_target(self,item,lbl_path,extra=None):
+        pass
+
+    """显示数据集指定图片"""
+    @abstractmethod
+    def show(self,idx):
+        pass
+
+    """
+    显示数据集指定名字的图片
+    """
+
+    @abstractmethod
+    def show_img(self,img_path):
+        pass
+
+class DefaultTransform(nn.Module):
+    def forward(self, img: Tensor) -> Tensor:
+        if not isinstance(img, Tensor):
+            img = F.pil_to_tensor(img)
+        return F.convert_image_dtype(img, torch.float)
+
+    def __repr__(self) -> str:
+        return self.__class__.__name__ + "()"
+
+    def describe(self) -> str:
+        return (
+            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
+            "The images are rescaled to ``[0.0, 1.0]``."
+        )

+ 120 - 0
lcnn/models/fasterrcnn_resnet50.py

@@ -0,0 +1,120 @@
+import torch
+import torch.nn as nn
+import torchvision
+from typing import Dict, List, Optional, Tuple
+import torch.nn.functional as F
+from torchvision.ops import MultiScaleRoIAlign
+from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
+from torchvision.models.detection.transform import GeneralizedRCNNTransform
+
+
+def get_model(num_classes):
+    # 加载预训练的ResNet-50 FPN backbone
+    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
+
+    # 获取分类器的输入特征数
+    in_features = model.roi_heads.box_predictor.cls_score.in_features
+
+    # 替换分类器以适应新的类别数量
+    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
+
+    return model
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+class Fasterrcnn_resnet50(nn.Module):
+    def __init__(self, num_classes=5, num_stacks=1):
+        super(Fasterrcnn_resnet50, self).__init__()
+
+        self.model = get_model(num_classes=5)
+        self.backbone = self.model.backbone
+
+        self.box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=16, sampling_ratio=2)
+
+        out_channels = self.backbone.out_channels
+        resolution = self.box_roi_pool.output_size[0]
+        representation_size = 1024
+        self.box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
+
+        self.box_predictor = FastRCNNPredictor(representation_size, num_classes)
+
+        # 多任务输出层
+        self.score_layers = nn.ModuleList([
+            nn.Sequential(
+                nn.Conv2d(256, 128, kernel_size=3, padding=1),
+                nn.BatchNorm2d(128),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(128, num_classes, kernel_size=1)
+            )
+            for _ in range(num_stacks)
+        ])
+
+    def forward(self, x, target1, train_or_val, image_shapes=(512, 512)):
+
+        transform = GeneralizedRCNNTransform(min_size=512, max_size=1333, image_mean=[0.485, 0.456, 0.406],
+                                             image_std=[0.229, 0.224, 0.225])
+        images, targets = transform(x, target1)
+        x_ = self.backbone(images.tensors)
+
+        # x_ = self.backbone(x)  # '0'  '1'  '2'  '3'   'pool'
+        # print(f'backbone:{self.backbone}')
+        # print(f'Fasterrcnn_resnet50 x_:{x_}')
+        feature_ = x_['0']  # 图片特征
+        outputs = []
+        for score_layer in self.score_layers:
+            output = score_layer(feature_)
+            outputs.append(output)  # 多头
+
+        if train_or_val == "training":
+            loss_box = self.model(x, target1)
+            return outputs, feature_, loss_box
+        else:
+            box_all = self.model(x, target1)
+            return outputs, feature_, box_all
+
+
+def fasterrcnn_resnet50(**kwargs):
+    model = Fasterrcnn_resnet50(
+        num_classes=kwargs.get("num_classes", 5),
+        num_stacks=kwargs.get("num_stacks", 1)
+    )
+    return model

+ 201 - 0
lcnn/models/hourglass_pose.py

@@ -0,0 +1,201 @@
+"""
+Hourglass network inserted in the pre-activated Resnet
+Use lr=0.01 for current version
+(c) Yichao Zhou (LCNN)
+(c) YANG, Wei
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ["HourglassNet", "hg"]
+
+
+class Bottleneck2D(nn.Module):
+    expansion = 2  # 扩展因子
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck2D, self).__init__()
+
+        self.bn1 = nn.BatchNorm2d(inplanes)
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
+        self.bn3 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.bn1(x)
+        out = self.relu(out)
+        out = self.conv1(out)
+
+        out = self.bn2(out)
+        out = self.relu(out)
+        out = self.conv2(out)
+
+        out = self.bn3(out)
+        out = self.relu(out)
+        out = self.conv3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+
+        return out
+
+
+class Hourglass(nn.Module):
+    def __init__(self, block, num_blocks, planes, depth):
+        super(Hourglass, self).__init__()
+        self.depth = depth
+        self.block = block
+        self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
+
+    def _make_residual(self, block, num_blocks, planes):
+        layers = []
+        for i in range(0, num_blocks):
+            layers.append(block(planes * block.expansion, planes))
+        return nn.Sequential(*layers)
+
+    def _make_hour_glass(self, block, num_blocks, planes, depth):
+        hg = []
+        for i in range(depth):
+            res = []
+            for j in range(3):
+                res.append(self._make_residual(block, num_blocks, planes))
+            if i == 0:
+                res.append(self._make_residual(block, num_blocks, planes))
+            hg.append(nn.ModuleList(res))
+        return nn.ModuleList(hg)
+
+    def _hour_glass_forward(self, n, x):
+        up1 = self.hg[n - 1][0](x)
+        low1 = F.max_pool2d(x, 2, stride=2)
+        low1 = self.hg[n - 1][1](low1)
+
+        if n > 1:
+            low2 = self._hour_glass_forward(n - 1, low1)
+        else:
+            low2 = self.hg[n - 1][3](low1)
+        low3 = self.hg[n - 1][2](low2)
+        up2 = F.interpolate(low3, scale_factor=2)
+        out = up1 + up2
+        return out
+
+    def forward(self, x):
+        return self._hour_glass_forward(self.depth, x)
+
+
+class HourglassNet(nn.Module):
+    """Hourglass model from Newell et al ECCV 2016"""
+
+    def __init__(self, block, head, depth, num_stacks, num_blocks, num_classes):
+        super(HourglassNet, self).__init__()
+
+        self.inplanes = 64
+        self.num_feats = 128
+        self.num_stacks = num_stacks
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3)
+        self.bn1 = nn.BatchNorm2d(self.inplanes)
+        self.relu = nn.ReLU(inplace=True)
+        self.layer1 = self._make_residual(block, self.inplanes, 1)
+        self.layer2 = self._make_residual(block, self.inplanes, 1)
+        self.layer3 = self._make_residual(block, self.num_feats, 1)
+        self.maxpool = nn.MaxPool2d(2, stride=2)
+
+        # build hourglass modules
+        ch = self.num_feats * block.expansion
+        # vpts = []
+        hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
+        for i in range(num_stacks):
+            hg.append(Hourglass(block, num_blocks, self.num_feats, depth))
+            res.append(self._make_residual(block, self.num_feats, num_blocks))
+            fc.append(self._make_fc(ch, ch))
+            score.append(head(ch, num_classes))
+            # vpts.append(VptsHead(ch))
+            # vpts.append(nn.Linear(ch, 9))
+            # score.append(nn.Conv2d(ch, num_classes, kernel_size=1))
+            # score[i].bias.data[0] += 4.6
+            # score[i].bias.data[2] += 4.6
+            if i < num_stacks - 1:
+                fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
+                score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
+        self.hg = nn.ModuleList(hg)
+        self.res = nn.ModuleList(res)
+        self.fc = nn.ModuleList(fc)
+        self.score = nn.ModuleList(score)
+        # self.vpts = nn.ModuleList(vpts)
+        self.fc_ = nn.ModuleList(fc_)
+        self.score_ = nn.ModuleList(score_)
+
+    def _make_residual(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(
+                    self.inplanes,
+                    planes * block.expansion,
+                    kernel_size=1,
+                    stride=stride,
+                )
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def _make_fc(self, inplanes, outplanes):
+        bn = nn.BatchNorm2d(inplanes)
+        conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
+        return nn.Sequential(conv, bn, self.relu)
+
+    def forward(self, x):
+        out = []
+        # out_vps = []
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+
+        x = self.layer1(x)
+        x = self.maxpool(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+
+        for i in range(self.num_stacks):
+            y = self.hg[i](x)
+            y = self.res[i](y)
+            y = self.fc[i](y)
+            score = self.score[i](y)
+            # pre_vpts = F.adaptive_avg_pool2d(x, (1, 1))
+            # pre_vpts = pre_vpts.reshape(-1, 256)
+            # vpts = self.vpts[i](x)
+            out.append(score)
+            # out_vps.append(vpts)
+            if i < self.num_stacks - 1:
+                fc_ = self.fc_[i](y)
+                score_ = self.score_[i](score)
+                x = x + fc_ + score_
+
+        return out[::-1], y  # , out_vps[::-1]
+
+
+def hg(**kwargs):
+    model = HourglassNet(
+        Bottleneck2D,
+        head=kwargs.get("head", lambda c_in, c_out: nn.Conv2D(c_in, c_out, 1)),
+        depth=kwargs["depth"],
+        num_stacks=kwargs["num_stacks"],
+        num_blocks=kwargs["num_blocks"],
+        num_classes=kwargs["num_classes"],
+    )
+    return model

+ 276 - 0
lcnn/models/line_vectorizer.py

@@ -0,0 +1,276 @@
+import itertools
+import random
+from collections import defaultdict
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from lcnn.config import M
+
+FEATURE_DIM = 8
+
+
+class LineVectorizer(nn.Module):
+    def __init__(self, backbone):
+        super().__init__()
+        self.backbone = backbone
+
+        lambda_ = torch.linspace(0, 1, M.n_pts0)[:, None]
+        self.register_buffer("lambda_", lambda_)
+        self.do_static_sampling = M.n_stc_posl + M.n_stc_negl > 0
+
+        self.fc1 = nn.Conv2d(256, M.dim_loi, 1)
+        scale_factor = M.n_pts0 // M.n_pts1
+        if M.use_conv:
+            self.pooling = nn.Sequential(
+                nn.MaxPool1d(scale_factor, scale_factor),
+                Bottleneck1D(M.dim_loi, M.dim_loi),
+            )
+            self.fc2 = nn.Sequential(
+                nn.ReLU(inplace=True), nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, 1)
+            )
+        else:
+            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
+            self.fc2 = nn.Sequential(
+                nn.Linear(M.dim_loi * M.n_pts1 + FEATURE_DIM, M.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(M.dim_fc, M.dim_fc),
+                nn.ReLU(inplace=True),
+                nn.Linear(M.dim_fc, 1),
+            )
+        self.loss = nn.BCEWithLogitsLoss(reduction="none")
+
+    def forward(self, input_dict):
+        result = self.backbone(input_dict)
+        h = result["preds"]
+        x = self.fc1(result["feature"])
+        n_batch, n_channel, row, col = x.shape
+
+        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []
+        for i, meta in enumerate(input_dict["meta"]):
+            p, label, feat, jc = self.sample_lines(
+                meta, h["jmap"][i], h["joff"][i], input_dict["mode"]
+            )
+            # print("p.shape:", p.shape)
+            ys.append(label)
+            if input_dict["mode"] == "training" and self.do_static_sampling:
+                p = torch.cat([p, meta["lpre"]])
+                feat = torch.cat([feat, meta["lpre_feat"]])
+                ys.append(meta["lpre_label"])
+                del jc
+            else:
+                jcs.append(jc)
+                ps.append(p)
+            fs.append(feat)
+
+            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
+            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
+            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
+            px0 = px.floor().clamp(min=0, max=127)
+            py0 = py.floor().clamp(min=0, max=127)
+            px1 = (px0 + 1).clamp(min=0, max=127)
+            py1 = (py0 + 1).clamp(min=0, max=127)
+            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()
+
+            # xp: [N_LINE, N_CHANNEL, N_POINT]
+            xp = (
+                (
+                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
+                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
+                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
+                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
+                )
+                    .reshape(n_channel, -1, M.n_pts0)
+                    .permute(1, 0, 2)
+            )
+            xp = self.pooling(xp)
+            xs.append(xp)
+            idx.append(idx[-1] + xp.shape[0])
+
+        x, y = torch.cat(xs), torch.cat(ys)
+        f = torch.cat(fs)
+        x = x.reshape(-1, M.n_pts1 * M.dim_loi)
+        x = torch.cat([x, f], 1)
+        x = self.fc2(x.float()).flatten()
+
+        if input_dict["mode"] != "training":
+            p = torch.cat(ps)
+            s = torch.sigmoid(x)
+            b = s > 0.5
+            lines = []
+            score = []
+            for i in range(n_batch):
+                p0 = p[idx[i]: idx[i + 1]]
+                s0 = s[idx[i]: idx[i + 1]]
+                mask = b[idx[i]: idx[i + 1]]
+                p0 = p0[mask]
+                s0 = s0[mask]
+                if len(p0) == 0:
+                    lines.append(torch.zeros([1, M.n_out_line, 2, 2], device=p.device))
+                    score.append(torch.zeros([1, M.n_out_line], device=p.device))
+                else:
+                    arg = torch.argsort(s0, descending=True)
+                    p0, s0 = p0[arg], s0[arg]
+                    lines.append(p0[None, torch.arange(M.n_out_line) % len(p0)])
+                    score.append(s0[None, torch.arange(M.n_out_line) % len(s0)])
+                for j in range(len(jcs[i])):
+                    if len(jcs[i][j]) == 0:
+                        jcs[i][j] = torch.zeros([M.n_out_junc, 2], device=p.device)
+                    jcs[i][j] = jcs[i][j][
+                        None, torch.arange(M.n_out_junc) % len(jcs[i][j])
+                    ]
+            result["preds"]["lines"] = torch.cat(lines)
+            result["preds"]["score"] = torch.cat(score)
+            result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
+            result["box"] = result['aaa']
+            del result['aaa']
+            if len(jcs[i]) > 1:
+                result["preds"]["junts"] = torch.cat(
+                    [jcs[i][1] for i in range(n_batch)]
+                )
+
+        if input_dict["mode"] != "testing":
+            y = torch.cat(ys)
+            loss = self.loss(x, y)
+            lpos_mask, lneg_mask = y, 1 - y
+            loss_lpos, loss_lneg = loss * lpos_mask, loss * lneg_mask
+
+            def sum_batch(x):
+                xs = [x[idx[i]: idx[i + 1]].sum()[None] for i in range(n_batch)]
+                return torch.cat(xs)
+
+            lpos = sum_batch(loss_lpos) / sum_batch(lpos_mask).clamp(min=1)
+            lneg = sum_batch(loss_lneg) / sum_batch(lneg_mask).clamp(min=1)
+            result["losses"][0]["lpos"] = lpos * M.loss_weight["lpos"]
+            result["losses"][0]["lneg"] = lneg * M.loss_weight["lneg"]
+
+        if input_dict["mode"] == "training":
+            for i in result["aaa"].keys():
+                result["losses"][0][i] = result["aaa"][i]
+            del result["preds"]
+
+        return result
+
+    def sample_lines(self, meta, jmap, joff, mode):
+        with torch.no_grad():
+            junc = meta["junc_coords"]  # [N, 2]
+            jtyp = meta["jtyp"]  # [N]
+            Lpos = meta["line_pos_idx"]
+            Lneg = meta["line_neg_idx"]
+
+            n_type = jmap.shape[0]
+            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
+            joff = joff.reshape(n_type, 2, -1)
+            max_K = M.n_dyn_junc // n_type
+            N = len(junc)
+            if mode != "training":
+                K = min(int((jmap > M.eval_junc_thres).float().sum().item()), max_K)
+            else:
+                K = min(int(N * 2 + 2), max_K)
+            if K < 2:
+                K = 2
+            device = jmap.device
+
+            # index: [N_TYPE, K]
+            score, index = torch.topk(jmap, k=K)
+            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
+            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5
+
+            # xy: [N_TYPE, K, 2]
+            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
+            xy_ = xy[..., None, :]
+            del x, y, index
+
+            # dist: [N_TYPE, K, N]
+            dist = torch.sum((xy_ - junc) ** 2, -1)
+            cost, match = torch.min(dist, -1)
+
+            # xy: [N_TYPE * K, 2]
+            # match: [N_TYPE, K]
+            for t in range(n_type):
+                match[t, jtyp[match[t]] != t] = N
+            match[cost > 1.5 * 1.5] = N
+            match = match.flatten()
+
+            _ = torch.arange(n_type * K, device=device)
+            u, v = torch.meshgrid(_, _)
+            u, v = u.flatten(), v.flatten()
+            up, vp = match[u], match[v]
+            label = Lpos[up, vp]
+
+            if mode == "training":
+                c = torch.zeros_like(label, dtype=torch.bool)
+
+                # sample positive lines
+                cdx = label.nonzero().flatten()
+                if len(cdx) > M.n_dyn_posl:
+                    # print("too many positive lines")
+                    perm = torch.randperm(len(cdx), device=device)[: M.n_dyn_posl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample negative lines
+                cdx = Lneg[up, vp].nonzero().flatten()
+                if len(cdx) > M.n_dyn_negl:
+                    # print("too many negative lines")
+                    perm = torch.randperm(len(cdx), device=device)[: M.n_dyn_negl]
+                    cdx = cdx[perm]
+                c[cdx] = 1
+
+                # sample other (unmatched) lines
+                cdx = torch.randint(len(c), (M.n_dyn_othr,), device=device)
+                c[cdx] = 1
+            else:
+                c = (u < v).flatten()
+
+            # sample lines
+            u, v, label = u[c], v[c], label[c]
+            xy = xy.reshape(n_type * K, 2)
+            xyu, xyv = xy[u], xy[v]
+
+            u2v = xyu - xyv
+            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
+            feat = torch.cat(
+                [
+                    xyu / 128 * M.use_cood,
+                    xyv / 128 * M.use_cood,
+                    u2v * M.use_slop,
+                    (u[:, None] > K).float(),
+                    (v[:, None] > K).float(),
+                ],
+                1,
+            )
+            line = torch.cat([xyu[:, None], xyv[:, None]], 1)
+
+            xy = xy.reshape(n_type, K, 2)
+            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
+            return line, label.float(), feat, jcs
+
+
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+
+class Bottleneck1D(nn.Module):
+    def __init__(self, inplanes, outplanes):
+        super(Bottleneck1D, self).__init__()
+
+        planes = outplanes // 2
+        self.op = nn.Sequential(
+            nn.BatchNorm1d(inplanes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(inplanes, planes, kernel_size=1),
+            nn.BatchNorm1d(planes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
+            nn.BatchNorm1d(planes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(planes, outplanes, kernel_size=1),
+        )
+
+    def forward(self, x):
+        return x + self.op(x)

+ 118 - 0
lcnn/models/multitask_learner.py

@@ -0,0 +1,118 @@
+from collections import OrderedDict, defaultdict
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from lcnn.config import M
+
+
+class MultitaskHead(nn.Module):
+    def __init__(self, input_channels, num_class):
+        super(MultitaskHead, self).__init__()
+        # print("输入的维度是:", input_channels)
+        m = int(input_channels / 4)
+        heads = []
+        for output_channels in sum(M.head_size, []):
+            heads.append(
+                nn.Sequential(
+                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
+                    nn.ReLU(inplace=True),
+                    nn.Conv2d(m, output_channels, kernel_size=1),
+                )
+            )
+        self.heads = nn.ModuleList(heads)
+        assert num_class == sum(sum(M.head_size, []))
+
+    def forward(self, x):
+        return torch.cat([head(x) for head in self.heads], dim=1)
+
+
+class MultitaskLearner(nn.Module):
+    def __init__(self, backbone):
+        super(MultitaskLearner, self).__init__()
+        self.backbone = backbone
+        head_size = M.head_size
+        self.num_class = sum(sum(head_size, []))
+        self.head_off = np.cumsum([sum(h) for h in head_size])
+
+    def forward(self, input_dict):
+        image = input_dict["image"]
+        target_b = input_dict["target_b"]
+        outputs, feature, aaa = self.backbone(image, target_b, input_dict["mode"])  # train时aaa是损失,val时是box
+
+        result = {"feature": feature}
+        batch, channel, row, col = outputs[0].shape
+        # print(f"batch:{batch}")
+        # print(f"channel:{channel}")
+        # print(f"row:{row}")
+        # print(f"col:{col}")
+
+        T = input_dict["target"].copy()
+        n_jtyp = T["junc_map"].shape[1]
+
+        # switch to CNHW
+        for task in ["junc_map"]:
+            T[task] = T[task].permute(1, 0, 2, 3)
+        for task in ["junc_offset"]:
+            T[task] = T[task].permute(1, 2, 0, 3, 4)
+
+        offset = self.head_off
+        loss_weight = M.loss_weight
+        losses = []
+        for stack, output in enumerate(outputs):
+            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
+            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
+            lmap = output[offset[0]: offset[1]].squeeze(0)
+            # print(f"lmap:{lmap.shape}")
+            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
+            if stack == 0:
+                result["preds"] = {
+                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
+                    "lmap": lmap.sigmoid(),
+                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
+                }
+                if input_dict["mode"] == "testing":
+                    return result
+
+            L = OrderedDict()
+            L["jmap"] = sum(
+                cross_entropy_loss(jmap[i], T["junc_map"][i]) for i in range(n_jtyp)
+            )
+            L["lmap"] = (
+                F.binary_cross_entropy_with_logits(lmap, T["line_map"], reduction="none")
+                    .mean(2)
+                    .mean(1)
+            )
+            L["joff"] = sum(
+                sigmoid_l1_loss(joff[i, j], T["junc_offset"][i, j], -0.5, T["junc_map"][i])
+                for i in range(n_jtyp)
+                for j in range(2)
+            )
+            for loss_name in L:
+                L[loss_name].mul_(loss_weight[loss_name])
+            losses.append(L)
+        result["losses"] = losses
+        result["aaa"] = aaa
+        return result
+
+
+def l2loss(input, target):
+    return ((target - input) ** 2).mean(2).mean(1)
+
+
+def cross_entropy_loss(logits, positive):
+    nlogp = -F.log_softmax(logits, dim=0)
+    return (positive * nlogp[1] + (1 - positive) * nlogp[0]).mean(2).mean(1)
+
+
+def sigmoid_l1_loss(logits, target, offset=0.0, mask=None):
+    logp = torch.sigmoid(logits) + offset
+    loss = torch.abs(logp - target)
+    if mask is not None:
+        w = mask.mean(2, True).mean(1, True)
+        w[w == 0] = 1
+        loss = loss * (mask / w)
+
+    return loss.mean(2).mean(1)

+ 182 - 0
lcnn/models/resnet50.py

@@ -0,0 +1,182 @@
+# import torch
+# import torch.nn as nn
+# import torchvision.models as models
+#
+#
+# class ResNet50Backbone(nn.Module):
+#     def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
+#         super(ResNet50Backbone, self).__init__()
+#
+#         # 加载预训练的ResNet50
+#         if pretrained:
+#             resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
+#         else:
+#             resnet = models.resnet50(weights=None)
+#
+#         # 移除最后的全连接层
+#         self.backbone = nn.Sequential(
+#             resnet.conv1,#特征图分辨率降低为1/2,通道数从3升为64
+#             resnet.bn1,
+#             resnet.relu,
+#             resnet.maxpool,#特征图分辨率降低为1/4,通道数仍然为64
+#             resnet.layer1,#stride为1,不改变分辨率,依然为1/4,通道数从64升为64*4=256
+#             # resnet.layer2,#stride为2,特征图分辨率降低为1/8,通道数从256升为128*4=512
+#             # resnet.layer3,#stride为2,特征图分辨率降低为1/16,通道数从512升为256*4=1024
+#             # resnet.layer4,#stride为2,特征图分辨率降低为1/32,通道数从512升为256*4=2048
+#         )
+#
+#         # 多任务输出层
+#         self.score_layers = nn.ModuleList([
+#             nn.Sequential(
+#                 nn.Conv2d(256, 128, kernel_size=3, padding=1),
+#                 nn.BatchNorm2d(128),
+#                 nn.ReLU(inplace=True),
+#                 nn.Conv2d(128, num_classes, kernel_size=1)
+#             )
+#             for _ in range(num_stacks)
+#         ])
+#
+#         # 上采样层,确保输出大小为128x128
+#         self.upsample = nn.Upsample(
+#             scale_factor=0.25,
+#             mode='bilinear',
+#             align_corners=True
+#         )
+#
+#     def forward(self, x):
+#         # 主干网络特征提取
+#         x = self.backbone(x)
+#
+#         # # 调整通道数
+#         # x = self.channel_adjust(x)
+#         #
+#         # # 上采样到128x128
+#         # x = self.upsample(x)
+#
+#         # 多堆栈输出
+#         outputs = []
+#         for score_layer in self.score_layers:
+#             output = score_layer(x)
+#             outputs.append(output)
+#
+#         # 返回第一个输出(如果有多个堆栈)
+#         return outputs, x
+#
+#
+# def resnet50(**kwargs):
+#     model = ResNet50Backbone(
+#         num_classes=kwargs.get("num_classes", 5),
+#         num_stacks=kwargs.get("num_stacks", 1),
+#         pretrained=kwargs.get("pretrained", True)
+#     )
+#     return model
+#
+#
+# __all__ = ["ResNet50Backbone", "resnet50"]
+#
+#
+# # 测试网络输出
+# model = resnet50(num_classes=5, num_stacks=1)
+#
+#
+# # 方法1:直接传入图像张量
+# x = torch.randn(2, 3, 512, 512)
+# outputs, feature = model(x)
+# print("Outputs length:", len(outputs))
+# print("Output[0] shape:", outputs[0].shape)
+# print("Feature shape:", feature.shape)
+
+import torch
+import torch.nn as nn
+import torchvision
+import torchvision.models as models
+from torchvision.models.detection.backbone_utils import _validate_trainable_layers, _resnet_fpn_extractor
+
+
+class ResNet50Backbone1(nn.Module):
+    def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
+        super(ResNet50Backbone1, self).__init__()
+
+        # 加载预训练的ResNet50
+        if pretrained:
+            # self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
+            trainable_backbone_layers = _validate_trainable_layers(True, None, 5, 3)
+            backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
+            self.resnet = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+        else:
+            self.resnet = models.resnet50(weights=None)
+
+    def forward(self, x):
+        features = self.resnet(x)
+
+        # if self.training:
+        #     proposals, matched_idxs, labels, regression_targets = self.select_training_samples20(proposals, targets)
+        # else:
+        #     labels = None
+        #     regression_targets = None
+        #     matched_idxs = None
+        # box_features = self.box_roi_pool(features, proposals, image_shapes)  # ROI 映射到固定尺寸的特征表示上
+        # # print(f"box_features:{box_features.shape}")  # [2048, 256, 7, 7] 建议框统一大小
+        # box_features = self.box_head(box_features)  #
+        # # print(f"box_features:{box_features.shape}")  # [N, 1024] 经过头部网络处理后的特征向量
+        # class_logits, box_regression = self.box_predictor(box_features)
+        #
+        # result: List[Dict[str, torch.Tensor]] = []
+        # losses = {}
+        # if self.training:
+        #     if labels is None:
+        #         raise ValueError("labels cannot be None")
+        #     if regression_targets is None:
+        #         raise ValueError("regression_targets cannot be None")
+        #     loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+        #     losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        # else:
+        #     boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+        #     num_images = len(boxes)
+        #     for i in range(num_images):
+        #         result.append(
+        #             {
+        #                 "boxes": boxes[i],
+        #                 "labels": labels[i],
+        #                 "scores": scores[i],
+        #             }
+        #         )
+        #     print(f"boxes:{boxes[0].shape}")
+
+        return x['0']
+
+        # # 多堆栈输出
+        # outputs = []
+        # for score_layer in self.score_layers:
+        #     output = score_layer(x)
+        #     outputs.append(output)
+        #
+        # # 返回第一个输出(如果有多个堆栈)
+        # return outputs, x
+
+
+def resnet501(**kwargs):
+    model = ResNet50Backbone1(
+        num_classes=kwargs.get("num_classes", 5),
+        num_stacks=kwargs.get("num_stacks", 1),
+        pretrained=kwargs.get("pretrained", True)
+    )
+    return model
+
+
+# __all__ = ["ResNet50Backbone1", "resnet501"]
+#
+# # 测试网络输出
+# model = resnet501(num_classes=5, num_stacks=1)
+#
+# # 方法1:直接传入图像张量
+# x = torch.randn(2, 3, 512, 512)
+# # outputs, feature = model(x)
+# # print("Outputs length:", len(outputs))
+# # print("Output[0] shape:", outputs[0].shape)
+# # print("Feature shape:", feature.shape)
+# feature = model(x)
+# print("Feature shape:", feature.keys())
+
+
+

+ 87 - 0
lcnn/models/resnet50_pose.py

@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+
+
+class ResNet50Backbone(nn.Module):
+    def __init__(self, num_classes=5, num_stacks=1, pretrained=True):
+        super(ResNet50Backbone, self).__init__()
+
+        # 加载预训练的ResNet50
+        if pretrained:
+            resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
+        else:
+            resnet = models.resnet50(weights=None)
+
+        # 移除最后的全连接层
+        self.backbone = nn.Sequential(
+            resnet.conv1,#特征图分辨率降低为1/2,通道数从3升为64
+            resnet.bn1,
+            resnet.relu,
+            resnet.maxpool,#特征图分辨率降低为1/4,通道数仍然为64
+            resnet.layer1,#stride为1,不改变分辨率,依然为1/4,通道数从64升为64*4=256
+            # resnet.layer2,#stride为2,特征图分辨率降低为1/8,通道数从256升为128*4=512
+            # resnet.layer3,#stride为2,特征图分辨率降低为1/16,通道数从512升为256*4=1024
+            # resnet.layer4,#stride为2,特征图分辨率降低为1/32,通道数从512升为256*4=2048
+        )
+
+        # 多任务输出层
+        self.score_layers = nn.ModuleList([
+            nn.Sequential(
+                nn.Conv2d(256, 128, kernel_size=3, padding=1),
+                nn.BatchNorm2d(128),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(128, num_classes, kernel_size=1)
+            )
+            for _ in range(num_stacks)
+        ])
+
+        # 上采样层,确保输出大小为128x128
+        self.upsample = nn.Upsample(
+            scale_factor=0.25,
+            mode='bilinear',
+            align_corners=True
+        )
+
+    def forward(self, x):
+        # 主干网络特征提取
+        x = self.backbone(x)
+
+        # # 调整通道数
+        # x = self.channel_adjust(x)
+        #
+        # # 上采样到128x128
+        # x = self.upsample(x)
+
+        # 多堆栈输出
+        outputs = []
+        for score_layer in self.score_layers:
+            output = score_layer(x)
+            outputs.append(output)
+
+        # 返回第一个输出(如果有多个堆栈)
+        return outputs, x
+
+
+def resnet50(**kwargs):
+    model = ResNet50Backbone(
+        num_classes=kwargs.get("num_classes", 5),
+        num_stacks=kwargs.get("num_stacks", 1),
+        pretrained=kwargs.get("pretrained", True)
+    )
+    return model
+
+
+__all__ = ["ResNet50Backbone", "resnet50"]
+
+
+# 测试网络输出
+model = resnet50(num_classes=5, num_stacks=1)
+
+
+# 方法1:直接传入图像张量
+x = torch.randn(2, 3, 512, 512)
+outputs, feature = model(x)
+print("Outputs length:", len(outputs))
+print("Output[0] shape:", outputs[0].shape)
+print("Feature shape:", feature.shape)

+ 126 - 0
lcnn/models/unet.py

@@ -0,0 +1,126 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = ["UNetWithMultipleStacks", "unet"]
+
+
+class DoubleConv(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super().__init__()
+        self.double_conv = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True)
+        )
+
+    def forward(self, x):
+        return self.double_conv(x)
+
+
+class UNetWithMultipleStacks(nn.Module):
+    def __init__(self, num_classes, num_stacks=2, base_channels=64):
+        super().__init__()
+        self.num_stacks = num_stacks
+        # 编码器
+        self.enc1 = DoubleConv(3, base_channels)
+        self.pool1 = nn.MaxPool2d(2)
+        self.enc2 = DoubleConv(base_channels, base_channels * 2)
+        self.pool2 = nn.MaxPool2d(2)
+        self.enc3 = DoubleConv(base_channels * 2, base_channels * 4)
+        self.pool3 = nn.MaxPool2d(2)
+        self.enc4 = DoubleConv(base_channels * 4, base_channels * 8)
+        self.pool4 = nn.MaxPool2d(2)
+        self.enc5 = DoubleConv(base_channels * 8, base_channels * 16)
+        self.pool5 = nn.MaxPool2d(2)
+
+        # bottleneck
+        self.bottleneck = DoubleConv(base_channels * 16, base_channels * 32)
+
+        # 解码器
+        self.upconv5 = nn.ConvTranspose2d(base_channels * 32, base_channels * 16, kernel_size=2, stride=2)
+        self.dec5 = DoubleConv(base_channels * 32, base_channels * 16)
+        self.upconv4 = nn.ConvTranspose2d(base_channels * 16, base_channels * 8, kernel_size=2, stride=2)
+        self.dec4 = DoubleConv(base_channels * 16, base_channels * 8)
+        self.upconv3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2)
+        self.dec3 = DoubleConv(base_channels * 8, base_channels * 4)
+        self.upconv2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2)
+        self.dec2 = DoubleConv(base_channels * 4, base_channels * 2)
+        self.upconv1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2)
+        self.dec1 = DoubleConv(base_channels * 2, base_channels)
+
+        # 额外的上采样层,从512降到128
+        self.final_upsample = nn.Sequential(
+            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
+            nn.BatchNorm2d(base_channels),
+            nn.ReLU(inplace=True),
+            nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True)
+        )
+
+        # 修改score_layers以匹配256通道
+        self.score_layers = nn.ModuleList([
+            nn.Sequential(
+                nn.Conv2d(256, 128, kernel_size=3, padding=1),
+                nn.ReLU(inplace=True),
+                nn.Conv2d(128, num_classes, kernel_size=1)
+            )
+            for _ in range(num_stacks)
+        ])
+
+        self.channel_adjust = nn.Conv2d(base_channels, 256, kernel_size=1)
+
+    def forward(self, x):
+        # 编码过程
+        enc1 = self.enc1(x)
+        enc2 = self.enc2(self.pool1(enc1))
+        enc3 = self.enc3(self.pool2(enc2))
+        enc4 = self.enc4(self.pool3(enc3))
+        enc5 = self.enc5(self.pool4(enc4))
+
+        # 瓶颈层
+        bottleneck = self.bottleneck(self.pool5(enc5))
+
+        # 解码过程
+        dec5 = self.upconv5(bottleneck)
+        dec5 = torch.cat([dec5, enc5], dim=1)
+        dec5 = self.dec5(dec5)
+
+        dec4 = self.upconv4(dec5)
+        dec4 = torch.cat([dec4, enc4], dim=1)
+        dec4 = self.dec4(dec4)
+
+        dec3 = self.upconv3(dec4)
+        dec3 = torch.cat([dec3, enc3], dim=1)
+        dec3 = self.dec3(dec3)
+
+        dec2 = self.upconv2(dec3)
+        dec2 = torch.cat([dec2, enc2], dim=1)
+        dec2 = self.dec2(dec2)
+
+        dec1 = self.upconv1(dec2)
+        dec1 = torch.cat([dec1, enc1], dim=1)
+        dec1 = self.dec1(dec1)
+
+        # 额外的上采样,使输出大小为128
+        dec1 = self.final_upsample(dec1)
+        # 调整通道数
+        dec1 = self.channel_adjust(dec1)
+        # 多堆栈输出
+        outputs = []
+        for score_layer in self.score_layers:
+            output = score_layer(dec1)
+            outputs.append(output)
+
+        return outputs[::-1], dec1
+
+
+def unet(**kwargs):
+    model = UNetWithMultipleStacks(
+        num_classes=kwargs["num_classes"],
+        num_stacks=kwargs.get("num_stacks", 2),
+        base_channels=kwargs.get("base_channels", 64)
+    )
+    return model

+ 77 - 0
lcnn/postprocess.py

@@ -0,0 +1,77 @@
+import numpy as np
+
+
+def pline(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    u = ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
+    dx = x1 + u * px - x
+    dy = y1 + u * py - y
+    return dx * dx + dy * dy
+
+
+def psegment(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    u = max(min(((x - x1) * px + (y - y1) * py) / float(dd), 1), 0)
+    dx = x1 + u * px - x
+    dy = y1 + u * py - y
+    return dx * dx + dy * dy
+
+
+def plambda(x1, y1, x2, y2, x, y):
+    px = x2 - x1
+    py = y2 - y1
+    dd = px * px + py * py
+    return ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
+
+
+def postprocess(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
+    nlines, nscores = [], []
+    for (p, q), score in zip(lines, scores):
+        start, end = 0, 1
+        for a, b in nlines:
+            if (
+                min(
+                    max(pline(*p, *q, *a), pline(*p, *q, *b)),
+                    max(pline(*a, *b, *p), pline(*a, *b, *q)),
+                )
+                > threshold ** 2
+            ):
+                continue
+            lambda_a = plambda(*p, *q, *a)
+            lambda_b = plambda(*p, *q, *b)
+            if lambda_a > lambda_b:
+                lambda_a, lambda_b = lambda_b, lambda_a
+            lambda_a -= tol
+            lambda_b += tol
+
+            # case 1: skip (if not do_clip)
+            if start < lambda_a and lambda_b < end:
+                continue
+
+            # not intersect
+            if lambda_b < start or lambda_a > end:
+                continue
+
+            # cover
+            if lambda_a <= start and end <= lambda_b:
+                start = 10
+                break
+
+            # case 2 & 3:
+            if lambda_a <= start and start <= lambda_b:
+                start = lambda_b
+            if lambda_a <= end and end <= lambda_b:
+                end = lambda_a
+
+            if start >= end:
+                break
+
+        if start >= end:
+            continue
+        nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
+        nscores.append(score)
+    return np.array(nlines), np.array(nscores)

+ 429 - 0
lcnn/trainer.py

@@ -0,0 +1,429 @@
+import atexit
+import os
+import os.path as osp
+import shutil
+import signal
+import subprocess
+import threading
+import time
+from timeit import default_timer as timer
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn.functional as F
+from skimage import io
+from tensorboardX import SummaryWriter
+
+from lcnn.config import C, M
+from lcnn.utils import recursive_to
+import matplotlib
+
+from 冻结参数训练 import verify_freeze_params
+import os
+
+from torchvision.utils import draw_bounding_boxes
+from torchvision import transforms
+from .postprocess import postprocess
+
+os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
+
+
+# matplotlib.use('Agg')  # 使用无窗口后端
+
+
+# 绘图
+def show_line(img, pred, epoch, writer):
+    im = img.permute(1, 2, 0)
+    writer.add_image("ori", im, epoch, dataformats="HWC")
+
+    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["box"][0]["boxes"],
+                                      colors="yellow", width=1)
+    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    H = pred["preds"]
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    for i, t in enumerate([0.7]):
+        plt.gca().set_axis_off()
+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+        plt.margins(0, 0)
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+            plt.scatter(a[1], a[0], **PLTOPTS)
+            plt.scatter(b[1], b[0], **PLTOPTS)
+        plt.gca().xaxis.set_major_locator(plt.NullLocator())
+        plt.gca().yaxis.set_major_locator(plt.NullLocator())
+        plt.imshow(im)
+        plt.tight_layout()
+        fig = plt.gcf()
+        fig.canvas.draw()
+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+            fig.canvas.get_width_height()[::-1] + (3,))
+        plt.close()
+        img2 = transforms.ToTensor()(image_from_plot)
+
+        writer.add_image("output", img2, epoch)
+
+class Trainer(object):
+    def __init__(self, device, model, optimizer, train_loader, val_loader, out):
+        self.device = device
+
+        self.model = model
+        self.optim = optimizer
+
+        self.train_loader = train_loader
+        self.val_loader = val_loader
+        self.batch_size = C.model.batch_size
+
+        self.validation_interval = C.io.validation_interval
+
+        self.out = out
+        if not osp.exists(self.out):
+            os.makedirs(self.out)
+
+        # self.run_tensorboard()
+        self.writer = SummaryWriter('logs/')
+        time.sleep(1)
+
+        self.epoch = 0
+        self.iteration = 0
+        self.max_epoch = C.optim.max_epoch
+        self.lr_decay_epoch = C.optim.lr_decay_epoch
+        self.num_stacks = C.model.num_stacks
+        self.mean_loss = self.best_mean_loss = 1e1000
+
+        self.loss_labels = None
+        self.avg_metrics = None
+        self.metrics = np.zeros(0)
+
+        self.show_line = show_line
+
+    # def run_tensorboard(self):
+    #     board_out = osp.join(self.out, "tensorboard")
+    #     if not osp.exists(board_out):
+    #         os.makedirs(board_out)
+    #     self.writer = SummaryWriter(board_out)
+    #     os.environ["CUDA_VISIBLE_DEVICES"] = ""
+    #     p = subprocess.Popen(
+    #         ["tensorboard", f"--logdir={board_out}", f"--port={C.io.tensorboard_port}"]
+    #     )
+    #
+    #     def killme():
+    #         os.kill(p.pid, signal.SIGTERM)
+    #
+    #     atexit.register(killme)
+
+    def _loss(self, result):
+        losses = result["losses"]
+        # Don't move loss label to other place.
+        # If I want to change the loss, I just need to change this function.
+        if self.loss_labels is None:
+            self.loss_labels = ["sum"] + list(losses[0].keys())
+            self.metrics = np.zeros([self.num_stacks, len(self.loss_labels)])
+            print()
+            print(
+                "| ".join(
+                    ["progress "]
+                    + list(map("{:7}".format, self.loss_labels))
+                    + ["speed"]
+                )
+            )
+            with open(f"{self.out}/loss.csv", "a") as fout:
+                print(",".join(["progress"] + self.loss_labels), file=fout)
+
+        total_loss = 0
+        for i in range(self.num_stacks):
+            for j, name in enumerate(self.loss_labels):
+                if name == "sum":
+                    continue
+                if name not in losses[i]:
+                    assert i != 0
+                    continue
+                loss = losses[i][name].mean()
+                self.metrics[i, 0] += loss.item()
+                self.metrics[i, j] += loss.item()
+                total_loss += loss
+        return total_loss
+
+
+    def validate(self):
+        tprint("Running validation...", " " * 75)
+        training = self.model.training
+        self.model.eval()
+
+        # viz = osp.join(self.out, "viz", f"{self.iteration * M.batch_size_eval:09d}")
+        # npz = osp.join(self.out, "npz", f"{self.iteration * M.batch_size_eval:09d}")
+        # osp.exists(viz) or os.makedirs(viz)
+        # osp.exists(npz) or os.makedirs(npz)
+
+        total_loss = 0
+        self.metrics[...] = 0
+        with torch.no_grad():
+            for batch_idx, (image, meta, target, target_b) in enumerate(self.val_loader):
+                input_dict = {
+                    "image": recursive_to(image, self.device),
+                    "meta": recursive_to(meta, self.device),
+                    "target": recursive_to(target, self.device),
+                    "target_b": recursive_to(target_b, self.device),
+                    "mode": "validation",
+                }
+                result = self.model(input_dict)
+                # print(f'image:{image.shape}')
+                # print(result["box"])
+
+                # total_loss += self._loss(result)
+
+                print(f'self.epoch:{self.epoch}')
+                # print(result.keys())
+                self.show_line(image[0], result, self.epoch, self.writer)
+
+
+
+                # H = result["preds"]
+                # for i in range(H["jmap"].shape[0]):
+                #     index = batch_idx * M.batch_size_eval + i
+                #     np.savez(
+                #         f"{npz}/{index:06}.npz",
+                #         **{k: v[i].cpu().numpy() for k, v in H.items()},
+                #     )
+                #     if index >= 20:
+                #         continue
+                #     self._plot_samples(i, index, H, meta, target, f"{viz}/{index:06}")
+
+        # self._write_metrics(len(self.val_loader), total_loss, "validation", True)
+        # self.mean_loss = total_loss / len(self.val_loader)
+
+        torch.save(
+            {
+                "iteration": self.iteration,
+                "arch": self.model.__class__.__name__,
+                "optim_state_dict": self.optim.state_dict(),
+                "model_state_dict": self.model.state_dict(),
+                "best_mean_loss": self.best_mean_loss,
+            },
+            osp.join(self.out, "checkpoint_latest.pth"),
+        )
+        # shutil.copy(
+        #     osp.join(self.out, "checkpoint_latest.pth"),
+        #     osp.join(npz, "checkpoint.pth"),
+        # )
+        if self.mean_loss < self.best_mean_loss:
+            self.best_mean_loss = self.mean_loss
+            shutil.copy(
+                osp.join(self.out, "checkpoint_latest.pth"),
+                osp.join(self.out, "checkpoint_best.pth"),
+            )
+
+        if training:
+            self.model.train()
+
+    def verify_freeze_params(model, freeze_config):
+        """
+        验证参数冻结是否生效
+        """
+        print("\n===== Verifying Parameter Freezing =====")
+
+        for name, module in model.named_children():
+            if name in freeze_config:
+                if freeze_config[name]:
+                    print(f"\nChecking module: {name}")
+                    for param_name, param in module.named_parameters():
+                        print(f"  {param_name}: requires_grad = {param.requires_grad}")
+
+            # 特别处理fc2子模块
+            if name == 'fc2' and 'fc2_submodules' in freeze_config:
+                for subname, submodule in module.named_children():
+                    if subname in freeze_config['fc2_submodules']:
+                        if freeze_config['fc2_submodules'][subname]:
+                            print(f"\nChecking fc2 submodule: {subname}")
+                            for param_name, param in submodule.named_parameters():
+                                print(f"  {param_name}: requires_grad = {param.requires_grad}")
+
+    def train_epoch(self):
+        self.model.train()
+
+        time = timer()
+        for batch_idx, (image, meta, target, target_b) in enumerate(self.train_loader):
+            self.optim.zero_grad()
+            self.metrics[...] = 0
+
+            input_dict = {
+                "image": recursive_to(image, self.device),
+                "meta": recursive_to(meta, self.device),
+                "target": recursive_to(target, self.device),
+                "target_b": recursive_to(target_b, self.device),
+                "mode": "training",
+            }
+            result = self.model(input_dict)
+
+            loss = self._loss(result)
+            if np.isnan(loss.item()):
+                raise ValueError("loss is nan while training")
+            loss.backward()
+            self.optim.step()
+
+            if self.avg_metrics is None:
+                self.avg_metrics = self.metrics
+            else:
+                self.avg_metrics = self.avg_metrics * 0.9 + self.metrics * 0.1
+            self.iteration += 1
+            self._write_metrics(1, loss.item(), "training", do_print=False)
+
+            if self.iteration % 4 == 0:
+                tprint(
+                    f"{self.epoch:03}/{self.iteration * self.batch_size // 1000:04}k| "
+                    + "| ".join(map("{:.5f}".format, self.avg_metrics[0]))
+                    + f"| {4 * self.batch_size / (timer() - time):04.1f} "
+                )
+                time = timer()
+            num_images = self.batch_size * self.iteration
+            # if num_images % self.validation_interval == 0 or num_images == 4:
+            #     self.validate()
+            #     time = timer()
+        self.validate()
+        # verify_freeze_params()
+
+    def _write_metrics(self, size, total_loss, prefix, do_print=False):
+        for i, metrics in enumerate(self.metrics):
+            for label, metric in zip(self.loss_labels, metrics):
+                self.writer.add_scalar(
+                    f"{prefix}/{i}/{label}", metric / size, self.iteration
+                )
+            if i == 0 and do_print:
+                csv_str = (
+                        f"{self.epoch:03}/{self.iteration * self.batch_size:07},"
+                        + ",".join(map("{:.11f}".format, metrics / size))
+                )
+                prt_str = (
+                        f"{self.epoch:03}/{self.iteration * self.batch_size // 1000:04}k| "
+                        + "| ".join(map("{:.5f}".format, metrics / size))
+                )
+                with open(f"{self.out}/loss.csv", "a") as fout:
+                    print(csv_str, file=fout)
+                pprint(prt_str, " " * 7)
+        self.writer.add_scalar(
+            f"{prefix}/total_loss", total_loss / size, self.iteration
+        )
+        return total_loss
+
+    def _plot_samples(self, i, index, result, meta, target, prefix):
+        fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png"
+        img = io.imread(fn)
+        imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close()
+
+        mask_result = result["jmap"][i].cpu().numpy()
+        mask_target = target["jmap"][i].cpu().numpy()
+        for ch, (ia, ib) in enumerate(zip(mask_target, mask_result)):
+            imshow(ia), plt.savefig(f"{prefix}_mask_{ch}a.jpg"), plt.close()
+            imshow(ib), plt.savefig(f"{prefix}_mask_{ch}b.jpg"), plt.close()
+
+        line_result = result["lmap"][i].cpu().numpy()
+        line_target = target["lmap"][i].cpu().numpy()
+        imshow(line_target), plt.savefig(f"{prefix}_line_a.jpg"), plt.close()
+        imshow(line_result), plt.savefig(f"{prefix}_line_b.jpg"), plt.close()
+
+        def draw_vecl(lines, sline, juncs, junts, fn):
+            imshow(img)
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=64, zorder=100)
+            if junts is not None and len(junts) > 0 and not (junts[0] == 0).all():
+                for i, j in enumerate(junts):
+                    if i > 0 and (i == junts[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="blue", s=64, zorder=100)
+            plt.savefig(fn), plt.close()
+
+        junc = meta[i]["junc"].cpu().numpy() * 4
+        jtyp = meta[i]["jtyp"].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+        rjuncs = result["juncs"][i].cpu().numpy() * 4
+        rjunts = None
+        if "junts" in result:
+            rjunts = result["junts"][i].cpu().numpy() * 4
+
+        lpre = meta[i]["lpre"].cpu().numpy() * 4
+        vecl_target = meta[i]["lpre_label"].cpu().numpy()
+        vecl_result = result["lines"][i].cpu().numpy() * 4
+        score = result["score"][i].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg")
+        draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg")
+
+    def train(self):
+        plt.rcParams["figure.figsize"] = (24, 24)
+        # if self.iteration == 0:
+        #     self.validate()
+        epoch_size = len(self.train_loader)
+        start_epoch = self.iteration // epoch_size
+
+        for self.epoch in range(start_epoch, self.max_epoch):
+            print(f"Epoch {self.epoch}/{C.optim.max_epoch} - Iteration {self.iteration}/{epoch_size}")
+            if self.epoch == self.lr_decay_epoch:
+                self.optim.param_groups[0]["lr"] /= 10
+            self.train_epoch()
+
+
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def imshow(im):
+    plt.close()
+    plt.tight_layout()
+    plt.imshow(im)
+    plt.colorbar(sm, fraction=0.046)
+    plt.xlim([0, im.shape[0]])
+    plt.ylim([im.shape[0], 0])
+
+
+def tprint(*args):
+    """Temporarily prints things on the screen"""
+    print("\r", end="")
+    print(*args, end="")
+
+
+def pprint(*args):
+    """Permanently prints things on the screen"""
+    print("\r", end="")
+    print(*args)
+
+
+def _launch_tensorboard(board_out, port, out):
+    os.environ["CUDA_VISIBLE_DEVICES"] = ""
+    p = subprocess.Popen(["tensorboard", f"--logdir={board_out}", f"--port={port}"])
+
+    def kill():
+        os.kill(p.pid, signal.SIGTERM)
+
+    atexit.register(kill)

+ 101 - 0
lcnn/utils.py

@@ -0,0 +1,101 @@
+import math
+import os.path as osp
+import multiprocessing
+from timeit import default_timer as timer
+
+import numpy as np
+import torch
+import matplotlib.pyplot as plt
+
+
+class benchmark(object):
+    def __init__(self, msg, enable=True, fmt="%0.3g"):
+        self.msg = msg
+        self.fmt = fmt
+        self.enable = enable
+
+    def __enter__(self):
+        if self.enable:
+            self.start = timer()
+        return self
+
+    def __exit__(self, *args):
+        if self.enable:
+            t = timer() - self.start
+            print(("%s : " + self.fmt + " seconds") % (self.msg, t))
+            self.time = t
+
+
+def quiver(x, y, ax):
+    ax.set_xlim(0, x.shape[1])
+    ax.set_ylim(x.shape[0], 0)
+    ax.quiver(
+        x,
+        y,
+        units="xy",
+        angles="xy",
+        scale_units="xy",
+        scale=1,
+        minlength=0.01,
+        width=0.1,
+        color="b",
+    )
+
+
+def recursive_to(input, device):
+    if isinstance(input, torch.Tensor):
+        return input.to(device)
+    if isinstance(input, dict):
+        for name in input:
+            if isinstance(input[name], torch.Tensor):
+                input[name] = input[name].to(device)
+        return input
+    if isinstance(input, list):
+        for i, item in enumerate(input):
+            input[i] = recursive_to(item, device)
+        return input
+    assert False
+
+
+def np_softmax(x, axis=0):
+    """Compute softmax values for each sets of scores in x."""
+    e_x = np.exp(x - np.max(x))
+    return e_x / e_x.sum(axis=axis, keepdims=True)
+
+
+def argsort2d(arr):
+    return np.dstack(np.unravel_index(np.argsort(arr.ravel()), arr.shape))[0]
+
+
+def __parallel_handle(f, q_in, q_out):
+    while True:
+        i, x = q_in.get()
+        if i is None:
+            break
+        q_out.put((i, f(x)))
+
+
+def parmap(f, X, nprocs=multiprocessing.cpu_count(), progress_bar=lambda x: x):
+    if nprocs == 0:
+        nprocs = multiprocessing.cpu_count()
+    q_in = multiprocessing.Queue(1)
+    q_out = multiprocessing.Queue()
+
+    proc = [
+        multiprocessing.Process(target=__parallel_handle, args=(f, q_in, q_out))
+        for _ in range(nprocs)
+    ]
+    for p in proc:
+        p.daemon = True
+        p.start()
+
+    try:
+        sent = [q_in.put((i, x)) for i, x in enumerate(X)]
+        [q_in.put((None, None)) for _ in range(nprocs)]
+        res = [q_out.get() for _ in progress_bar(range(len(sent)))]
+        [p.join() for p in proc]
+    except KeyboardInterrupt:
+        q_in.close()
+        q_out.close()
+        raise
+    return [x for i, x in sorted(res)]

BIN
libs/vision_libs/_C.pyd


+ 103 - 0
libs/vision_libs/__init__.py

@@ -0,0 +1,103 @@
+import os
+import warnings
+from modulefinder import Module
+
+import torch
+from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
+
+from .extension import _HAS_OPS
+
+try:
+    from .version import __version__  # noqa: F401
+except ImportError:
+    pass
+
+
+# Check if torchvision is being imported within the root folder
+if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
+    os.path.realpath(os.getcwd()), "torchvision"
+):
+    message = (
+        "You are importing torchvision within its own root folder ({}). "
+        "This is not expected to work and may give errors. Please exit the "
+        "torchvision project source and relaunch your python interpreter."
+    )
+    warnings.warn(message.format(os.getcwd()))
+
+_image_backend = "PIL"
+
+_video_backend = "pyav"
+
+
+def set_image_backend(backend):
+    """
+    Specifies the package used to load images.
+
+    Args:
+        backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
+            The :mod:`accimage` package uses the Intel IPP library. It is
+            generally faster than PIL, but does not support as many operations.
+    """
+    global _image_backend
+    if backend not in ["PIL", "accimage"]:
+        raise ValueError(f"Invalid backend '{backend}'. Options are 'PIL' and 'accimage'")
+    _image_backend = backend
+
+
+def get_image_backend():
+    """
+    Gets the name of the package used to load images
+    """
+    return _image_backend
+
+
+def set_video_backend(backend):
+    """
+    Specifies the package used to decode videos.
+
+    Args:
+        backend (string): Name of the video backend. one of {'pyav', 'video_reader'}.
+            The :mod:`pyav` package uses the 3rd party PyAv library. It is a Pythonic
+            binding for the FFmpeg libraries.
+            The :mod:`video_reader` package includes a native C++ implementation on
+            top of FFMPEG libraries, and a python API of TorchScript custom operator.
+            It generally decodes faster than :mod:`pyav`, but is perhaps less robust.
+
+    .. note::
+        Building with FFMPEG is disabled by default in the latest `main`. If you want to use the 'video_reader'
+        backend, please compile torchvision from source.
+    """
+    global _video_backend
+    if backend not in ["pyav", "video_reader", "cuda"]:
+        raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
+    if backend == "video_reader" and not io._HAS_VIDEO_OPT:
+        # TODO: better messages
+        message = "video_reader video backend is not available. Please compile torchvision from source and try again"
+        raise RuntimeError(message)
+    elif backend == "cuda" and not io._HAS_GPU_VIDEO_DECODER:
+        # TODO: better messages
+        message = "cuda video backend is not available."
+        raise RuntimeError(message)
+    else:
+        _video_backend = backend
+
+
+def get_video_backend():
+    """
+    Returns the currently active video backend used to decode videos.
+
+    Returns:
+        str: Name of the video backend. one of {'pyav', 'video_reader'}.
+    """
+
+    return _video_backend
+
+
+def _is_tracing():
+    return torch._C._get_tracing_state()
+
+
+def disable_beta_transforms_warning():
+    # Noop, only exists to avoid breaking existing code.
+    # See https://github.com/pytorch/vision/issues/7896
+    pass

+ 50 - 0
libs/vision_libs/_internally_replaced_utils.py

@@ -0,0 +1,50 @@
+import importlib.machinery
+import os
+
+from torch.hub import _get_torch_home
+
+
+_HOME = os.path.join(_get_torch_home(), "datasets", "vision")
+_USE_SHARDED_DATASETS = False
+
+
+def _download_file_from_remote_location(fpath: str, url: str) -> None:
+    pass
+
+
+def _is_remote_location_available() -> bool:
+    return False
+
+
+try:
+    from torch.hub import load_state_dict_from_url  # noqa: 401
+except ImportError:
+    from torch.utils.model_zoo import load_url as load_state_dict_from_url  # noqa: 401
+
+
+def _get_extension_path(lib_name):
+
+    lib_dir = os.path.dirname(__file__)
+    if os.name == "nt":
+        # Register the main torchvision library location on the default DLL path
+        import ctypes
+
+        kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
+        with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
+        prev_error_mode = kernel32.SetErrorMode(0x0001)
+
+        if with_load_library_flags:
+            kernel32.AddDllDirectory.restype = ctypes.c_void_p
+
+        os.add_dll_directory(lib_dir)
+
+        kernel32.SetErrorMode(prev_error_mode)
+
+    loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
+
+    extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
+    ext_specs = extfinder.find_spec(lib_name)
+    if ext_specs is None:
+        raise ImportError
+
+    return ext_specs.origin

+ 225 - 0
libs/vision_libs/_meta_registrations.py

@@ -0,0 +1,225 @@
+import functools
+
+import torch
+import torch._custom_ops
+import torch.library
+
+# Ensure that torch.ops.torchvision is visible
+import torchvision.extension  # noqa: F401
+
+
+@functools.lru_cache(None)
+def get_meta_lib():
+    return torch.library.Library("torchvision", "IMPL", "Meta")
+
+
+def register_meta(op_name, overload_name="default"):
+    def wrapper(fn):
+        if torchvision.extension._has_ops():
+            get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
+        return fn
+
+    return wrapper
+
+
+@register_meta("roi_align")
+def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
+    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
+    torch._check(
+        input.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for input to have the same type as tensor for rois; "
+            f"but type {input.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    num_rois = rois.size(0)
+    channels = input.size(1)
+    return input.new_empty((num_rois, channels, pooled_height, pooled_width))
+
+
+@register_meta("_roi_align_backward")
+def meta_roi_align_backward(
+    grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
+):
+    torch._check(
+        grad.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for grad to have the same type as tensor for rois; "
+            f"but type {grad.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    return grad.new_empty((batch_size, channels, height, width))
+
+
+@register_meta("ps_roi_align")
+def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
+    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
+    torch._check(
+        input.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for input to have the same type as tensor for rois; "
+            f"but type {input.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    channels = input.size(1)
+    torch._check(
+        channels % (pooled_height * pooled_width) == 0,
+        "input channels must be a multiple of pooling height * pooling width",
+    )
+
+    num_rois = rois.size(0)
+    out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
+    return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta")
+
+
+@register_meta("_ps_roi_align_backward")
+def meta_ps_roi_align_backward(
+    grad,
+    rois,
+    channel_mapping,
+    spatial_scale,
+    pooled_height,
+    pooled_width,
+    sampling_ratio,
+    batch_size,
+    channels,
+    height,
+    width,
+):
+    torch._check(
+        grad.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for grad to have the same type as tensor for rois; "
+            f"but type {grad.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    return grad.new_empty((batch_size, channels, height, width))
+
+
+@register_meta("roi_pool")
+def meta_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
+    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
+    torch._check(
+        input.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for input to have the same type as tensor for rois; "
+            f"but type {input.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    num_rois = rois.size(0)
+    channels = input.size(1)
+    out_size = (num_rois, channels, pooled_height, pooled_width)
+    return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
+
+
+@register_meta("_roi_pool_backward")
+def meta_roi_pool_backward(
+    grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
+):
+    torch._check(
+        grad.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for grad to have the same type as tensor for rois; "
+            f"but type {grad.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    return grad.new_empty((batch_size, channels, height, width))
+
+
+@register_meta("ps_roi_pool")
+def meta_ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
+    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
+    torch._check(
+        input.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for input to have the same type as tensor for rois; "
+            f"but type {input.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    channels = input.size(1)
+    torch._check(
+        channels % (pooled_height * pooled_width) == 0,
+        "input channels must be a multiple of pooling height * pooling width",
+    )
+    num_rois = rois.size(0)
+    out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
+    return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
+
+
+@register_meta("_ps_roi_pool_backward")
+def meta_ps_roi_pool_backward(
+    grad, rois, channel_mapping, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
+):
+    torch._check(
+        grad.dtype == rois.dtype,
+        lambda: (
+            "Expected tensor for grad to have the same type as tensor for rois; "
+            f"but type {grad.dtype} does not equal {rois.dtype}"
+        ),
+    )
+    return grad.new_empty((batch_size, channels, height, width))
+
+
+@torch._custom_ops.impl_abstract("torchvision::nms")
+def meta_nms(dets, scores, iou_threshold):
+    torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
+    torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
+    torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
+    torch._check(
+        dets.size(0) == scores.size(0),
+        lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
+    )
+    ctx = torch._custom_ops.get_ctx()
+    num_to_keep = ctx.create_unbacked_symint()
+    return dets.new_empty(num_to_keep, dtype=torch.long)
+
+
+@register_meta("deform_conv2d")
+def meta_deform_conv2d(
+    input,
+    weight,
+    offset,
+    mask,
+    bias,
+    stride_h,
+    stride_w,
+    pad_h,
+    pad_w,
+    dil_h,
+    dil_w,
+    n_weight_grps,
+    n_offset_grps,
+    use_mask,
+):
+
+    out_height, out_width = offset.shape[-2:]
+    out_channels = weight.shape[0]
+    batch_size = input.shape[0]
+    return input.new_empty((batch_size, out_channels, out_height, out_width))
+
+
+@register_meta("_deform_conv2d_backward")
+def meta_deform_conv2d_backward(
+    grad,
+    input,
+    weight,
+    offset,
+    mask,
+    bias,
+    stride_h,
+    stride_w,
+    pad_h,
+    pad_w,
+    dilation_h,
+    dilation_w,
+    groups,
+    offset_groups,
+    use_mask,
+):
+
+    grad_input = input.new_empty(input.shape)
+    grad_weight = weight.new_empty(weight.shape)
+    grad_offset = offset.new_empty(offset.shape)
+    grad_mask = mask.new_empty(mask.shape)
+    grad_bias = bias.new_empty(bias.shape)
+    return grad_input, grad_weight, grad_offset, grad_mask, grad_bias

+ 32 - 0
libs/vision_libs/_utils.py

@@ -0,0 +1,32 @@
+import enum
+from typing import Sequence, Type, TypeVar
+
+T = TypeVar("T", bound=enum.Enum)
+
+
+class StrEnumMeta(enum.EnumMeta):
+    auto = enum.auto
+
+    def from_str(self: Type[T], member: str) -> T:  # type: ignore[misc]
+        try:
+            return self[member]
+        except KeyError:
+            # TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
+            #  soon as it is migrated.
+            raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
+
+
+class StrEnum(enum.Enum, metaclass=StrEnumMeta):
+    pass
+
+
+def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
+    if not seq:
+        return ""
+    if len(seq) == 1:
+        return f"'{seq[0]}'"
+
+    head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
+    tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
+
+    return head + tail

+ 146 - 0
libs/vision_libs/datasets/__init__.py

@@ -0,0 +1,146 @@
+from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
+from ._stereo_matching import (
+    CarlaStereo,
+    CREStereo,
+    ETH3DStereo,
+    FallingThingsStereo,
+    InStereo2k,
+    Kitti2012Stereo,
+    Kitti2015Stereo,
+    Middlebury2014Stereo,
+    SceneFlowStereo,
+    SintelStereo,
+)
+from .caltech import Caltech101, Caltech256
+from .celeba import CelebA
+from .cifar import CIFAR10, CIFAR100
+from .cityscapes import Cityscapes
+from .clevr import CLEVRClassification
+from .coco import CocoCaptions, CocoDetection
+from .country211 import Country211
+from .dtd import DTD
+from .eurosat import EuroSAT
+from .fakedata import FakeData
+from .fer2013 import FER2013
+from .fgvc_aircraft import FGVCAircraft
+from .flickr import Flickr30k, Flickr8k
+from .flowers102 import Flowers102
+from .folder import DatasetFolder, ImageFolder
+from .food101 import Food101
+from .gtsrb import GTSRB
+from .hmdb51 import HMDB51
+from .imagenet import ImageNet
+from .imagenette import Imagenette
+from .inaturalist import INaturalist
+from .kinetics import Kinetics
+from .kitti import Kitti
+from .lfw import LFWPairs, LFWPeople
+from .lsun import LSUN, LSUNClass
+from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
+from .moving_mnist import MovingMNIST
+from .omniglot import Omniglot
+from .oxford_iiit_pet import OxfordIIITPet
+from .pcam import PCAM
+from .phototour import PhotoTour
+from .places365 import Places365
+from .rendered_sst2 import RenderedSST2
+from .sbd import SBDataset
+from .sbu import SBU
+from .semeion import SEMEION
+from .stanford_cars import StanfordCars
+from .stl10 import STL10
+from .sun397 import SUN397
+from .svhn import SVHN
+from .ucf101 import UCF101
+from .usps import USPS
+from .vision import VisionDataset
+from .voc import VOCDetection, VOCSegmentation
+from .widerface import WIDERFace
+
+__all__ = (
+    "LSUN",
+    "LSUNClass",
+    "ImageFolder",
+    "DatasetFolder",
+    "FakeData",
+    "CocoCaptions",
+    "CocoDetection",
+    "CIFAR10",
+    "CIFAR100",
+    "EMNIST",
+    "FashionMNIST",
+    "QMNIST",
+    "MNIST",
+    "KMNIST",
+    "StanfordCars",
+    "STL10",
+    "SUN397",
+    "SVHN",
+    "PhotoTour",
+    "SEMEION",
+    "Omniglot",
+    "SBU",
+    "Flickr8k",
+    "Flickr30k",
+    "Flowers102",
+    "VOCSegmentation",
+    "VOCDetection",
+    "Cityscapes",
+    "ImageNet",
+    "Caltech101",
+    "Caltech256",
+    "CelebA",
+    "WIDERFace",
+    "SBDataset",
+    "VisionDataset",
+    "USPS",
+    "Kinetics",
+    "HMDB51",
+    "UCF101",
+    "Places365",
+    "Kitti",
+    "INaturalist",
+    "LFWPeople",
+    "LFWPairs",
+    "KittiFlow",
+    "Sintel",
+    "FlyingChairs",
+    "FlyingThings3D",
+    "HD1K",
+    "Food101",
+    "DTD",
+    "FER2013",
+    "GTSRB",
+    "CLEVRClassification",
+    "OxfordIIITPet",
+    "PCAM",
+    "Country211",
+    "FGVCAircraft",
+    "EuroSAT",
+    "RenderedSST2",
+    "Kitti2012Stereo",
+    "Kitti2015Stereo",
+    "CarlaStereo",
+    "Middlebury2014Stereo",
+    "CREStereo",
+    "FallingThingsStereo",
+    "SceneFlowStereo",
+    "SintelStereo",
+    "InStereo2k",
+    "ETH3DStereo",
+    "wrap_dataset_for_transforms_v2",
+    "Imagenette",
+)
+
+
+# We override current module's attributes to handle the import:
+# from torchvision.datasets import wrap_dataset_for_transforms_v2
+# without a cyclic error.
+# Ref: https://peps.python.org/pep-0562/
+def __getattr__(name):
+    if name in ("wrap_dataset_for_transforms_v2",):
+        from torchvision.tv_tensors._dataset_wrapper import wrap_dataset_for_transforms_v2
+
+        return wrap_dataset_for_transforms_v2
+
+    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

+ 491 - 0
libs/vision_libs/datasets/_optical_flow.py

@@ -0,0 +1,491 @@
+import itertools
+import os
+from abc import ABC, abstractmethod
+from glob import glob
+from pathlib import Path
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from ..io.image import _read_png_16
+from .utils import _read_pfm, verify_str_arg
+from .vision import VisionDataset
+
+
+T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
+T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
+
+
+__all__ = (
+    "KittiFlow",
+    "Sintel",
+    "FlyingThings3D",
+    "FlyingChairs",
+    "HD1K",
+)
+
+
+class FlowDataset(ABC, VisionDataset):
+    # Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid
+    # For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),
+    # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
+    _has_builtin_flow_mask = False
+
+    def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
+
+        super().__init__(root=root)
+        self.transforms = transforms
+
+        self._flow_list: List[str] = []
+        self._image_list: List[List[str]] = []
+
+    def _read_img(self, file_name: str) -> Image.Image:
+        img = Image.open(file_name)
+        if img.mode != "RGB":
+            img = img.convert("RGB")
+        return img
+
+    @abstractmethod
+    def _read_flow(self, file_name: str):
+        # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
+        pass
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+
+        img1 = self._read_img(self._image_list[index][0])
+        img2 = self._read_img(self._image_list[index][1])
+
+        if self._flow_list:  # it will be empty for some dataset when split="test"
+            flow = self._read_flow(self._flow_list[index])
+            if self._has_builtin_flow_mask:
+                flow, valid_flow_mask = flow
+            else:
+                valid_flow_mask = None
+        else:
+            flow = valid_flow_mask = None
+
+        if self.transforms is not None:
+            img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask)
+
+        if self._has_builtin_flow_mask or valid_flow_mask is not None:
+            # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
+            return img1, img2, flow, valid_flow_mask
+        else:
+            return img1, img2, flow
+
+    def __len__(self) -> int:
+        return len(self._image_list)
+
+    def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset:
+        return torch.utils.data.ConcatDataset([self] * v)
+
+
+class Sintel(FlowDataset):
+    """`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            Sintel
+                testing
+                    clean
+                        scene_1
+                        scene_2
+                        ...
+                    final
+                        scene_1
+                        scene_2
+                        ...
+                training
+                    clean
+                        scene_1
+                        scene_2
+                        ...
+                    final
+                        scene_1
+                        scene_2
+                        ...
+                    flow
+                        scene_1
+                        scene_2
+                        ...
+
+    Args:
+        root (string): Root directory of the Sintel Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "test"
+        pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
+            details on the different passes.
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+            ``valid_flow_mask`` is expected for consistency with other datasets which
+            return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        pass_name: str = "clean",
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root=root, transforms=transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+        verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+        passes = ["clean", "final"] if pass_name == "both" else [pass_name]
+
+        root = Path(root) / "Sintel"
+        flow_root = root / "training" / "flow"
+
+        for pass_name in passes:
+            split_dir = "training" if split == "train" else split
+            image_root = root / split_dir / pass_name
+            for scene in os.listdir(image_root):
+                image_list = sorted(glob(str(image_root / scene / "*.png")))
+                for i in range(len(image_list) - 1):
+                    self._image_list += [[image_list[i], image_list[i + 1]]]
+
+                if split == "train":
+                    self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img1, img2, flow)``.
+            The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+            ``flow`` is None if ``split="test"``.
+            If a valid flow mask is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+        """
+        return super().__getitem__(index)
+
+    def _read_flow(self, file_name: str) -> np.ndarray:
+        return _read_flo(file_name)
+
+
+class KittiFlow(FlowDataset):
+    """`KITTI <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow>`__ dataset for optical flow (2015).
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            KittiFlow
+                testing
+                    image_2
+                training
+                    image_2
+                    flow_occ
+
+    Args:
+        root (string): Root directory of the KittiFlow Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "test"
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+    """
+
+    _has_builtin_flow_mask = True
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root=root, transforms=transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "KittiFlow" / (split + "ing")
+        images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
+        images2 = sorted(glob(str(root / "image_2" / "*_11.png")))
+
+        if not images1 or not images2:
+            raise FileNotFoundError(
+                "Could not find the Kitti flow images. Please make sure the directory structure is correct."
+            )
+
+        for img1, img2 in zip(images1, images2):
+            self._image_list += [[img1, img2]]
+
+        if split == "train":
+            self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
+            where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W)
+            indicating which flow values are valid. The flow is a numpy array of
+            shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
+            ``split="test"``.
+        """
+        return super().__getitem__(index)
+
+    def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
+        return _read_16bits_png_with_flow_and_valid_mask(file_name)
+
+
+class FlyingChairs(FlowDataset):
+    """`FlyingChairs <https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs>`_ Dataset for optical flow.
+
+    You will also need to download the FlyingChairs_train_val.txt file from the dataset page.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            FlyingChairs
+                data
+                    00001_flow.flo
+                    00001_img1.ppm
+                    00001_img2.ppm
+                    ...
+                FlyingChairs_train_val.txt
+
+
+    Args:
+        root (string): Root directory of the FlyingChairs Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "val"
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+            ``valid_flow_mask`` is expected for consistency with other datasets which
+            return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+    """
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root=root, transforms=transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "val"))
+
+        root = Path(root) / "FlyingChairs"
+        images = sorted(glob(str(root / "data" / "*.ppm")))
+        flows = sorted(glob(str(root / "data" / "*.flo")))
+
+        split_file_name = "FlyingChairs_train_val.txt"
+
+        if not os.path.exists(root / split_file_name):
+            raise FileNotFoundError(
+                "The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
+            )
+
+        split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32)
+        for i in range(len(flows)):
+            split_id = split_list[i]
+            if (split == "train" and split_id == 1) or (split == "val" and split_id == 2):
+                self._flow_list += [flows[i]]
+                self._image_list += [[images[2 * i], images[2 * i + 1]]]
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img1, img2, flow)``.
+            The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+            ``flow`` is None if ``split="val"``.
+            If a valid flow mask is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+        """
+        return super().__getitem__(index)
+
+    def _read_flow(self, file_name: str) -> np.ndarray:
+        return _read_flo(file_name)
+
+
+class FlyingThings3D(FlowDataset):
+    """`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            FlyingThings3D
+                frames_cleanpass
+                    TEST
+                    TRAIN
+                frames_finalpass
+                    TEST
+                    TRAIN
+                optical_flow
+                    TEST
+                    TRAIN
+
+    Args:
+        root (string): Root directory of the intel FlyingThings3D Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "test"
+        pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
+            details on the different passes.
+        camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+            ``valid_flow_mask`` is expected for consistency with other datasets which
+            return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        pass_name: str = "clean",
+        camera: str = "left",
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root=root, transforms=transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+        split = split.upper()
+
+        verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+        passes = {
+            "clean": ["frames_cleanpass"],
+            "final": ["frames_finalpass"],
+            "both": ["frames_cleanpass", "frames_finalpass"],
+        }[pass_name]
+
+        verify_str_arg(camera, "camera", valid_values=("left", "right", "both"))
+        cameras = ["left", "right"] if camera == "both" else [camera]
+
+        root = Path(root) / "FlyingThings3D"
+
+        directions = ("into_future", "into_past")
+        for pass_name, camera, direction in itertools.product(passes, cameras, directions):
+            image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
+            image_dirs = sorted(Path(image_dir) / camera for image_dir in image_dirs)
+
+            flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
+            flow_dirs = sorted(Path(flow_dir) / direction / camera for flow_dir in flow_dirs)
+
+            if not image_dirs or not flow_dirs:
+                raise FileNotFoundError(
+                    "Could not find the FlyingThings3D flow images. "
+                    "Please make sure the directory structure is correct."
+                )
+
+            for image_dir, flow_dir in zip(image_dirs, flow_dirs):
+                images = sorted(glob(str(image_dir / "*.png")))
+                flows = sorted(glob(str(flow_dir / "*.pfm")))
+                for i in range(len(flows) - 1):
+                    if direction == "into_future":
+                        self._image_list += [[images[i], images[i + 1]]]
+                        self._flow_list += [flows[i]]
+                    elif direction == "into_past":
+                        self._image_list += [[images[i + 1], images[i]]]
+                        self._flow_list += [flows[i + 1]]
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img1, img2, flow)``.
+            The flow is a numpy array of shape (2, H, W) and the images are PIL images.
+            ``flow`` is None if ``split="test"``.
+            If a valid flow mask is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
+        """
+        return super().__getitem__(index)
+
+    def _read_flow(self, file_name: str) -> np.ndarray:
+        return _read_pfm(file_name)
+
+
+class HD1K(FlowDataset):
+    """`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            hd1k
+                hd1k_challenge
+                    image_2
+                hd1k_flow_gt
+                    flow_occ
+                hd1k_input
+                    image_2
+
+    Args:
+        root (string): Root directory of the HD1K Dataset.
+        split (string, optional): The dataset split, either "train" (default) or "test"
+        transforms (callable, optional): A function/transform that takes in
+            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
+    """
+
+    _has_builtin_flow_mask = True
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root=root, transforms=transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "hd1k"
+        if split == "train":
+            # There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
+            for seq_idx in range(36):
+                flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
+                images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
+                for i in range(len(flows) - 1):
+                    self._flow_list += [flows[i]]
+                    self._image_list += [[images[i], images[i + 1]]]
+        else:
+            images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
+            images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
+            for image1, image2 in zip(images1, images2):
+                self._image_list += [[image1, image2]]
+
+        if not self._image_list:
+            raise FileNotFoundError(
+                "Could not find the HD1K images. Please make sure the directory structure is correct."
+            )
+
+    def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
+        return _read_16bits_png_with_flow_and_valid_mask(file_name)
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
+            is a numpy boolean mask of shape (H, W)
+            indicating which flow values are valid. The flow is a numpy array of
+            shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
+            ``split="test"``.
+        """
+        return super().__getitem__(index)
+
+
+def _read_flo(file_name: str) -> np.ndarray:
+    """Read .flo file in Middlebury format"""
+    # Code adapted from:
+    # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
+    # Everything needs to be in little Endian according to
+    # https://vision.middlebury.edu/flow/code/flow-code/README.txt
+    with open(file_name, "rb") as f:
+        magic = np.fromfile(f, "c", count=4).tobytes()
+        if magic != b"PIEH":
+            raise ValueError("Magic number incorrect. Invalid .flo file")
+
+        w = int(np.fromfile(f, "<i4", count=1))
+        h = int(np.fromfile(f, "<i4", count=1))
+        data = np.fromfile(f, "<f4", count=2 * w * h)
+        return data.reshape(h, w, 2).transpose(2, 0, 1)
+
+
+def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
+
+    flow_and_valid = _read_png_16(file_name).to(torch.float32)
+    flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
+    flow = (flow - 2**15) / 64  # This conversion is explained somewhere on the kitti archive
+    valid_flow_mask = valid_flow_mask.bool()
+
+    # For consistency with other datasets, we convert to numpy
+    return flow.numpy(), valid_flow_mask.numpy()

+ 1224 - 0
libs/vision_libs/datasets/_stereo_matching.py

@@ -0,0 +1,1224 @@
+import functools
+import json
+import os
+import random
+import shutil
+from abc import ABC, abstractmethod
+from glob import glob
+from pathlib import Path
+from typing import Callable, cast, List, Optional, Tuple, Union
+
+import numpy as np
+from PIL import Image
+
+from .utils import _read_pfm, download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray]
+T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
+
+__all__ = ()
+
+_read_pfm_file = functools.partial(_read_pfm, slice_channels=1)
+
+
+class StereoMatchingDataset(ABC, VisionDataset):
+    """Base interface for Stereo matching datasets"""
+
+    _has_built_in_disparity_mask = False
+
+    def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
+        """
+        Args:
+            root(str): Root directory of the dataset.
+            transforms(callable, optional): A function/transform that takes in Tuples of
+                (images, disparities, valid_masks) and returns a transformed version of each of them.
+                images is a Tuple of (``PIL.Image``, ``PIL.Image``)
+                disparities is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (1, H, W)
+                valid_masks is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (H, W)
+                In some cases, when a dataset does not provide disparities, the ``disparities`` and
+                ``valid_masks`` can be Tuples containing None values.
+                For training splits generally the datasets provide a minimal guarantee of
+                images: (``PIL.Image``, ``PIL.Image``)
+                disparities: (``np.ndarray``, ``None``) with shape (1, H, W)
+                Optionally, based on the dataset, it can return a ``mask`` as well:
+                valid_masks: (``np.ndarray | None``, ``None``) with shape (H, W)
+                For some test splits, the datasets provides outputs that look like:
+                imgaes: (``PIL.Image``, ``PIL.Image``)
+                disparities: (``None``, ``None``)
+                Optionally, based on the dataset, it can return a ``mask`` as well:
+                valid_masks: (``None``, ``None``)
+        """
+        super().__init__(root=root)
+        self.transforms = transforms
+
+        self._images = []  # type: ignore
+        self._disparities = []  # type: ignore
+
+    def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
+        img = Image.open(file_path)
+        if img.mode != "RGB":
+            img = img.convert("RGB")
+        return img
+
+    def _scan_pairs(
+        self,
+        paths_left_pattern: str,
+        paths_right_pattern: Optional[str] = None,
+    ) -> List[Tuple[str, Optional[str]]]:
+
+        left_paths = list(sorted(glob(paths_left_pattern)))
+
+        right_paths: List[Union[None, str]]
+        if paths_right_pattern:
+            right_paths = list(sorted(glob(paths_right_pattern)))
+        else:
+            right_paths = list(None for _ in left_paths)
+
+        if not left_paths:
+            raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}")
+
+        if not right_paths:
+            raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}")
+
+        if len(left_paths) != len(right_paths):
+            raise ValueError(
+                f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n "
+                f"left pattern: {paths_left_pattern}\n"
+                f"right pattern: {paths_right_pattern}\n"
+            )
+
+        paths = list((left, right) for left, right in zip(left_paths, right_paths))
+        return paths
+
+    @abstractmethod
+    def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
+        # function that returns a disparity map and an occlusion map
+        pass
+
+    def __getitem__(self, index: int) -> Union[T1, T2]:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask``
+                can be a numpy boolean mask of shape (H, W) if the dataset provides a file
+                indicating which disparity pixels are valid. The disparity is a numpy array of
+                shape (1, H, W) and the images are PIL images. ``disparity`` is None for
+                datasets on which for ``split="test"`` the authors did not provide annotations.
+        """
+        img_left = self._read_img(self._images[index][0])
+        img_right = self._read_img(self._images[index][1])
+
+        dsp_map_left, valid_mask_left = self._read_disparity(self._disparities[index][0])
+        dsp_map_right, valid_mask_right = self._read_disparity(self._disparities[index][1])
+
+        imgs = (img_left, img_right)
+        dsp_maps = (dsp_map_left, dsp_map_right)
+        valid_masks = (valid_mask_left, valid_mask_right)
+
+        if self.transforms is not None:
+            (
+                imgs,
+                dsp_maps,
+                valid_masks,
+            ) = self.transforms(imgs, dsp_maps, valid_masks)
+
+        if self._has_built_in_disparity_mask or valid_masks[0] is not None:
+            return imgs[0], imgs[1], dsp_maps[0], cast(np.ndarray, valid_masks[0])
+        else:
+            return imgs[0], imgs[1], dsp_maps[0]
+
+    def __len__(self) -> int:
+        return len(self._images)
+
+
+class CarlaStereo(StereoMatchingDataset):
+    """
+    Carla simulator data linked in the `CREStereo github repo <https://github.com/megvii-research/CREStereo>`_.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            carla-highres
+                trainingF
+                    scene1
+                        img0.png
+                        img1.png
+                        disp0GT.pfm
+                        disp1GT.pfm
+                        calib.txt
+                    scene2
+                        img0.png
+                        img1.png
+                        disp0GT.pfm
+                        disp1GT.pfm
+                        calib.txt
+                    ...
+
+    Args:
+        root (string): Root directory where `carla-highres` is located.
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "carla-highres"
+
+        left_image_pattern = str(root / "trainingF" / "*" / "im0.png")
+        right_image_pattern = str(root / "trainingF" / "*" / "im1.png")
+        imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
+        self._images = imgs
+
+        left_disparity_pattern = str(root / "trainingF" / "*" / "disp0GT.pfm")
+        right_disparity_pattern = str(root / "trainingF" / "*" / "disp1GT.pfm")
+        disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+        self._disparities = disparities
+
+    def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
+        disparity_map = _read_pfm_file(file_path)
+        disparity_map = np.abs(disparity_map)  # ensure that the disparity is positive
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            If a ``valid_mask`` is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class Kitti2012Stereo(StereoMatchingDataset):
+    """
+    KITTI dataset from the `2012 stereo evaluation benchmark <http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php>`_.
+    Uses the RGB images for consistency with KITTI 2015.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            Kitti2012
+                testing
+                    colored_0
+                        1_10.png
+                        2_10.png
+                        ...
+                    colored_1
+                        1_10.png
+                        2_10.png
+                        ...
+                training
+                    colored_0
+                        1_10.png
+                        2_10.png
+                        ...
+                    colored_1
+                        1_10.png
+                        2_10.png
+                        ...
+                    disp_noc
+                        1.png
+                        2.png
+                        ...
+                    calib
+
+    Args:
+        root (string): Root directory where `Kitti2012` is located.
+        split (string, optional): The dataset split of scenes, either "train" (default) or "test".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "Kitti2012" / (split + "ing")
+
+        left_img_pattern = str(root / "colored_0" / "*_10.png")
+        right_img_pattern = str(root / "colored_1" / "*_10.png")
+        self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+        if split == "train":
+            disparity_pattern = str(root / "disp_noc" / "*.png")
+            self._disparities = self._scan_pairs(disparity_pattern, None)
+        else:
+            self._disparities = list((None, None) for _ in self._images)
+
+    def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
+        # test split has no disparity maps
+        if file_path is None:
+            return None, None
+
+        disparity_map = np.asarray(Image.open(file_path)) / 256.0
+        # unsqueeze the disparity map into (C, H, W) format
+        disparity_map = disparity_map[None, :, :]
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+            generate a valid mask.
+            Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class Kitti2015Stereo(StereoMatchingDataset):
+    """
+    KITTI dataset from the `2015 stereo evaluation benchmark <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php>`_.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            Kitti2015
+                testing
+                    image_2
+                        img1.png
+                        img2.png
+                        ...
+                    image_3
+                        img1.png
+                        img2.png
+                        ...
+                training
+                    image_2
+                        img1.png
+                        img2.png
+                        ...
+                    image_3
+                        img1.png
+                        img2.png
+                        ...
+                    disp_occ_0
+                        img1.png
+                        img2.png
+                        ...
+                    disp_occ_1
+                        img1.png
+                        img2.png
+                        ...
+                    calib
+
+    Args:
+        root (string): Root directory where `Kitti2015` is located.
+        split (string, optional): The dataset split of scenes, either "train" (default) or "test".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "Kitti2015" / (split + "ing")
+        left_img_pattern = str(root / "image_2" / "*.png")
+        right_img_pattern = str(root / "image_3" / "*.png")
+        self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+        if split == "train":
+            left_disparity_pattern = str(root / "disp_occ_0" / "*.png")
+            right_disparity_pattern = str(root / "disp_occ_1" / "*.png")
+            self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+        else:
+            self._disparities = list((None, None) for _ in self._images)
+
+    def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
+        # test split has no disparity maps
+        if file_path is None:
+            return None, None
+
+        disparity_map = np.asarray(Image.open(file_path)) / 256.0
+        # unsqueeze the disparity map into (C, H, W) format
+        disparity_map = disparity_map[None, :, :]
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+            generate a valid mask.
+            Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class Middlebury2014Stereo(StereoMatchingDataset):
+    """Publicly available scenes from the Middlebury dataset `2014 version <https://vision.middlebury.edu/stereo/data/scenes2014/>`.
+
+    The dataset mostly follows the original format, without containing the ambient subdirectories.  : ::
+
+        root
+            Middlebury2014
+                train
+                    scene1-{perfect,imperfect}
+                        calib.txt
+                        im{0,1}.png
+                        im1E.png
+                        im1L.png
+                        disp{0,1}.pfm
+                        disp{0,1}-n.png
+                        disp{0,1}-sd.pfm
+                        disp{0,1}y.pfm
+                    scene2-{perfect,imperfect}
+                        calib.txt
+                        im{0,1}.png
+                        im1E.png
+                        im1L.png
+                        disp{0,1}.pfm
+                        disp{0,1}-n.png
+                        disp{0,1}-sd.pfm
+                        disp{0,1}y.pfm
+                    ...
+                additional
+                    scene1-{perfect,imperfect}
+                        calib.txt
+                        im{0,1}.png
+                        im1E.png
+                        im1L.png
+                        disp{0,1}.pfm
+                        disp{0,1}-n.png
+                        disp{0,1}-sd.pfm
+                        disp{0,1}y.pfm
+                    ...
+                test
+                    scene1
+                        calib.txt
+                        im{0,1}.png
+                    scene2
+                        calib.txt
+                        im{0,1}.png
+                    ...
+
+    Args:
+        root (string): Root directory of the Middleburry 2014 Dataset.
+        split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
+        use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
+            The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
+        calibration (string, optional): Whether or not to use the calibrated (default) or uncalibrated scenes.
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+        download (boolean, optional): Whether or not to download the dataset in the ``root`` directory.
+    """
+
+    splits = {
+        "train": [
+            "Adirondack",
+            "Jadeplant",
+            "Motorcycle",
+            "Piano",
+            "Pipes",
+            "Playroom",
+            "Playtable",
+            "Recycle",
+            "Shelves",
+            "Vintage",
+        ],
+        "additional": [
+            "Backpack",
+            "Bicycle1",
+            "Cable",
+            "Classroom1",
+            "Couch",
+            "Flowers",
+            "Mask",
+            "Shopvac",
+            "Sticks",
+            "Storage",
+            "Sword1",
+            "Sword2",
+            "Umbrella",
+        ],
+        "test": [
+            "Plants",
+            "Classroom2E",
+            "Classroom2",
+            "Australia",
+            "DjembeL",
+            "CrusadeP",
+            "Crusade",
+            "Hoops",
+            "Bicycle2",
+            "Staircase",
+            "Newkuba",
+            "AustraliaP",
+            "Djembe",
+            "Livingroom",
+            "Computer",
+        ],
+    }
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        calibration: Optional[str] = "perfect",
+        use_ambient_views: bool = False,
+        transforms: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test", "additional"))
+        self.split = split
+
+        if calibration:
+            verify_str_arg(calibration, "calibration", valid_values=("perfect", "imperfect", "both", None))  # type: ignore
+            if split == "test":
+                raise ValueError("Split 'test' has only no calibration settings, please set `calibration=None`.")
+        else:
+            if split != "test":
+                raise ValueError(
+                    f"Split '{split}' has calibration settings, however None was provided as an argument."
+                    f"\nSetting calibration to 'perfect' for split '{split}'. Available calibration settings are: 'perfect', 'imperfect', 'both'.",
+                )
+
+        if download:
+            self._download_dataset(root)
+
+        root = Path(root) / "Middlebury2014"
+
+        if not os.path.exists(root / split):
+            raise FileNotFoundError(f"The {split} directory was not found in the provided root directory")
+
+        split_scenes = self.splits[split]
+        # check that the provided root folder contains the scene splits
+        if not any(
+            # using startswith to account for perfect / imperfect calibrartion
+            scene.startswith(s)
+            for scene in os.listdir(root / split)
+            for s in split_scenes
+        ):
+            raise FileNotFoundError(f"Provided root folder does not contain any scenes from the {split} split.")
+
+        calibrartion_suffixes = {
+            None: [""],
+            "perfect": ["-perfect"],
+            "imperfect": ["-imperfect"],
+            "both": ["-perfect", "-imperfect"],
+        }[calibration]
+
+        for calibration_suffix in calibrartion_suffixes:
+            scene_pattern = "*" + calibration_suffix
+            left_img_pattern = str(root / split / scene_pattern / "im0.png")
+            right_img_pattern = str(root / split / scene_pattern / "im1.png")
+            self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+            if split == "test":
+                self._disparities = list((None, None) for _ in self._images)
+            else:
+                left_dispartity_pattern = str(root / split / scene_pattern / "disp0.pfm")
+                right_dispartity_pattern = str(root / split / scene_pattern / "disp1.pfm")
+                self._disparities += self._scan_pairs(left_dispartity_pattern, right_dispartity_pattern)
+
+        self.use_ambient_views = use_ambient_views
+
+    def _read_img(self, file_path: Union[str, Path]) -> Image.Image:
+        """
+        Function that reads either the original right image or an augmented view when ``use_ambient_views`` is True.
+        When ``use_ambient_views`` is True, the dataset will return at random one of ``[im1.png, im1E.png, im1L.png]``
+        as the right image.
+        """
+        ambient_file_paths: List[Union[str, Path]]  # make mypy happy
+
+        if not isinstance(file_path, Path):
+            file_path = Path(file_path)
+
+        if file_path.name == "im1.png" and self.use_ambient_views:
+            base_path = file_path.parent
+            # initialize sampleable container
+            ambient_file_paths = list(base_path / view_name for view_name in ["im1E.png", "im1L.png"])
+            # double check that we're not going to try to read from an invalid file path
+            ambient_file_paths = list(filter(lambda p: os.path.exists(p), ambient_file_paths))
+            # keep the original image as an option as well for uniform sampling between base views
+            ambient_file_paths.append(file_path)
+            file_path = random.choice(ambient_file_paths)  # type: ignore
+        return super()._read_img(file_path)
+
+    def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
+        # test split has not disparity maps
+        if file_path is None:
+            return None, None
+
+        disparity_map = _read_pfm_file(file_path)
+        disparity_map = np.abs(disparity_map)  # ensure that the disparity is positive
+        disparity_map[disparity_map == np.inf] = 0  # remove infinite disparities
+        valid_mask = (disparity_map > 0).squeeze(0)  # mask out invalid disparities
+        return disparity_map, valid_mask
+
+    def _download_dataset(self, root: str) -> None:
+        base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
+        # train and additional splits have 2 different calibration settings
+        root = Path(root) / "Middlebury2014"
+        split_name = self.split
+
+        if split_name != "test":
+            for split_scene in self.splits[split_name]:
+                split_root = root / split_name
+                for calibration in ["perfect", "imperfect"]:
+                    scene_name = f"{split_scene}-{calibration}"
+                    scene_url = f"{base_url}/{scene_name}.zip"
+                    print(f"Downloading {scene_url}")
+                    # download the scene only if it doesn't exist
+                    if not (split_root / scene_name).exists():
+                        download_and_extract_archive(
+                            url=scene_url,
+                            filename=f"{scene_name}.zip",
+                            download_root=str(split_root),
+                            remove_finished=True,
+                        )
+        else:
+            os.makedirs(root / "test")
+            if any(s not in os.listdir(root / "test") for s in self.splits["test"]):
+                # test split is downloaded from a different location
+                test_set_url = "https://vision.middlebury.edu/stereo/submit3/zip/MiddEval3-data-F.zip"
+                # the unzip is going to produce a directory MiddEval3 with two subdirectories trainingF and testF
+                # we want to move the contents from testF into the  directory
+                download_and_extract_archive(url=test_set_url, download_root=str(root), remove_finished=True)
+                for scene_dir, scene_names, _ in os.walk(str(root / "MiddEval3/testF")):
+                    for scene in scene_names:
+                        scene_dst_dir = root / "test"
+                        scene_src_dir = Path(scene_dir) / scene
+                        os.makedirs(scene_dst_dir, exist_ok=True)
+                        shutil.move(str(scene_src_dir), str(scene_dst_dir))
+
+                # cleanup MiddEval3 directory
+                shutil.rmtree(str(root / "MiddEval3"))
+
+    def __getitem__(self, index: int) -> T2:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` for `split=test`.
+        """
+        return cast(T2, super().__getitem__(index))
+
+
+class CREStereo(StereoMatchingDataset):
+    """Synthetic dataset used in training the `CREStereo <https://arxiv.org/pdf/2203.11483.pdf>`_ architecture.
+    Dataset details on the official paper `repo <https://github.com/megvii-research/CREStereo>`_.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            CREStereo
+                tree
+                    img1_left.jpg
+                    img1_right.jpg
+                    img1_left.disp.jpg
+                    img1_right.disp.jpg
+                    img2_left.jpg
+                    img2_right.jpg
+                    img2_left.disp.jpg
+                    img2_right.disp.jpg
+                    ...
+                shapenet
+                    img1_left.jpg
+                    img1_right.jpg
+                    img1_left.disp.jpg
+                    img1_right.disp.jpg
+                    ...
+                reflective
+                    img1_left.jpg
+                    img1_right.jpg
+                    img1_left.disp.jpg
+                    img1_right.disp.jpg
+                    ...
+                hole
+                    img1_left.jpg
+                    img1_right.jpg
+                    img1_left.disp.jpg
+                    img1_right.disp.jpg
+                    ...
+
+    Args:
+        root (str): Root directory of the dataset.
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(
+        self,
+        root: str,
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "CREStereo"
+
+        dirs = ["shapenet", "reflective", "tree", "hole"]
+
+        for s in dirs:
+            left_image_pattern = str(root / s / "*_left.jpg")
+            right_image_pattern = str(root / s / "*_right.jpg")
+            imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
+            self._images += imgs
+
+            left_disparity_pattern = str(root / s / "*_left.disp.png")
+            right_disparity_pattern = str(root / s / "*_right.disp.png")
+            disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+            self._disparities += disparities
+
+    def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
+        disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
+        # unsqueeze the disparity map into (C, H, W) format
+        disparity_map = disparity_map[None, :, :] / 32.0
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+            generate a valid mask.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class FallingThingsStereo(StereoMatchingDataset):
+    """`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            FallingThings
+                single
+                    dir1
+                        scene1
+                            _object_settings.json
+                            _camera_settings.json
+                            image1.left.depth.png
+                            image1.right.depth.png
+                            image1.left.jpg
+                            image1.right.jpg
+                            image2.left.depth.png
+                            image2.right.depth.png
+                            image2.left.jpg
+                            image2.right
+                            ...
+                        scene2
+                    ...
+                mixed
+                    scene1
+                        _object_settings.json
+                        _camera_settings.json
+                        image1.left.depth.png
+                        image1.right.depth.png
+                        image1.left.jpg
+                        image1.right.jpg
+                        image2.left.depth.png
+                        image2.right.depth.png
+                        image2.left.jpg
+                        image2.right
+                        ...
+                    scene2
+                    ...
+
+    Args:
+        root (string): Root directory where FallingThings is located.
+        variant (string): Which variant to use. Either "single", "mixed", or "both".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "FallingThings"
+
+        verify_str_arg(variant, "variant", valid_values=("single", "mixed", "both"))
+
+        variants = {
+            "single": ["single"],
+            "mixed": ["mixed"],
+            "both": ["single", "mixed"],
+        }[variant]
+
+        split_prefix = {
+            "single": Path("*") / "*",
+            "mixed": Path("*"),
+        }
+
+        for s in variants:
+            left_img_pattern = str(root / s / split_prefix[s] / "*.left.jpg")
+            right_img_pattern = str(root / s / split_prefix[s] / "*.right.jpg")
+            self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+            left_disparity_pattern = str(root / s / split_prefix[s] / "*.left.depth.png")
+            right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png")
+            self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+
+    def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
+        # (H, W) image
+        depth = np.asarray(Image.open(file_path))
+        # as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
+        # in order to extract disparity from depth maps
+        camera_settings_path = Path(file_path).parent / "_camera_settings.json"
+        with open(camera_settings_path, "r") as f:
+            # inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_constant)
+            intrinsics = json.load(f)
+            focal = intrinsics["camera_settings"][0]["intrinsic_settings"]["fx"]
+            baseline, pixel_constant = 6, 100  # pixel constant is inverted
+            disparity_map = (baseline * focal * pixel_constant) / depth.astype(np.float32)
+            # unsqueeze disparity to (C, H, W)
+            disparity_map = disparity_map[None, :, :]
+            valid_mask = None
+            return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            If a ``valid_mask`` is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class SceneFlowStereo(StereoMatchingDataset):
+    """Dataset interface for `Scene Flow <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ datasets.
+    This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            SceneFlow
+                Monkaa
+                    frames_cleanpass
+                        scene1
+                            left
+                                img1.png
+                                img2.png
+                            right
+                                img1.png
+                                img2.png
+                        scene2
+                            left
+                                img1.png
+                                img2.png
+                            right
+                                img1.png
+                                img2.png
+                    frames_finalpass
+                        scene1
+                            left
+                                img1.png
+                                img2.png
+                            right
+                                img1.png
+                                img2.png
+                        ...
+                        ...
+                    disparity
+                        scene1
+                            left
+                                img1.pfm
+                                img2.pfm
+                            right
+                                img1.pfm
+                                img2.pfm
+                FlyingThings3D
+                    ...
+                    ...
+
+    Args:
+        root (string): Root directory where SceneFlow is located.
+        variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
+        pass_name (string): Which pass to use, "clean" (default), "final" or "both".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+
+    """
+
+    def __init__(
+        self,
+        root: str,
+        variant: str = "FlyingThings3D",
+        pass_name: str = "clean",
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "SceneFlow"
+
+        verify_str_arg(variant, "variant", valid_values=("FlyingThings3D", "Driving", "Monkaa"))
+        verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
+
+        passes = {
+            "clean": ["frames_cleanpass"],
+            "final": ["frames_finalpass"],
+            "both": ["frames_cleanpass", "frames_finalpass"],
+        }[pass_name]
+
+        root = root / variant
+
+        prefix_directories = {
+            "Monkaa": Path("*"),
+            "FlyingThings3D": Path("*") / "*" / "*",
+            "Driving": Path("*") / "*" / "*",
+        }
+
+        for p in passes:
+            left_image_pattern = str(root / p / prefix_directories[variant] / "left" / "*.png")
+            right_image_pattern = str(root / p / prefix_directories[variant] / "right" / "*.png")
+            self._images += self._scan_pairs(left_image_pattern, right_image_pattern)
+
+            left_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "left" / "*.pfm")
+            right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm")
+            self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+
+    def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
+        disparity_map = _read_pfm_file(file_path)
+        disparity_map = np.abs(disparity_map)  # ensure that the disparity is positive
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            If a ``valid_mask`` is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class SintelStereo(StereoMatchingDataset):
+    """Sintel `Stereo Dataset <http://sintel.is.tue.mpg.de/stereo>`_.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            Sintel
+                training
+                    final_left
+                        scene1
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+                    final_right
+                        scene2
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+                    disparities
+                        scene1
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+                    occlusions
+                        scene1
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+                    outofframe
+                        scene1
+                            img1.png
+                            img2.png
+                            ...
+                        ...
+
+    Args:
+        root (string): Root directory where Sintel Stereo is located.
+        pass_name (string): The name of the pass to use, either "final", "clean" or "both".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
+
+        root = Path(root) / "Sintel"
+        pass_names = {
+            "final": ["final"],
+            "clean": ["clean"],
+            "both": ["final", "clean"],
+        }[pass_name]
+
+        for p in pass_names:
+            left_img_pattern = str(root / "training" / f"{p}_left" / "*" / "*.png")
+            right_img_pattern = str(root / "training" / f"{p}_right" / "*" / "*.png")
+            self._images += self._scan_pairs(left_img_pattern, right_img_pattern)
+
+            disparity_pattern = str(root / "training" / "disparities" / "*" / "*.png")
+            self._disparities += self._scan_pairs(disparity_pattern, None)
+
+    def _get_occlussion_mask_paths(self, file_path: str) -> Tuple[str, str]:
+        # helper function to get the occlusion mask paths
+        # a path will look like  .../.../.../training/disparities/scene1/img1.png
+        # we want to get something like .../.../.../training/occlusions/scene1/img1.png
+        fpath = Path(file_path)
+        basename = fpath.name
+        scenedir = fpath.parent
+        # the parent of the scenedir is actually the disparity dir
+        sampledir = scenedir.parent.parent
+
+        occlusion_path = str(sampledir / "occlusions" / scenedir.name / basename)
+        outofframe_path = str(sampledir / "outofframe" / scenedir.name / basename)
+
+        if not os.path.exists(occlusion_path):
+            raise FileNotFoundError(f"Occlusion mask {occlusion_path} does not exist")
+
+        if not os.path.exists(outofframe_path):
+            raise FileNotFoundError(f"Out of frame mask {outofframe_path} does not exist")
+
+        return occlusion_path, outofframe_path
+
+    def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
+        if file_path is None:
+            return None, None
+
+        # disparity decoding as per Sintel instructions in the README provided with the dataset
+        disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
+        r, g, b = np.split(disparity_map, 3, axis=-1)
+        disparity_map = r * 4 + g / (2**6) + b / (2**14)
+        # reshape into (C, H, W) format
+        disparity_map = np.transpose(disparity_map, (2, 0, 1))
+        # find the appropriate file paths
+        occlued_mask_path, out_of_frame_mask_path = self._get_occlussion_mask_paths(file_path)
+        # occlusion masks
+        valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0
+        # out of frame masks
+        off_mask = np.asarray(Image.open(out_of_frame_mask_path)) == 0
+        # combine the masks together
+        valid_mask = np.logical_and(off_mask, valid_mask)
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T2:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst
+            the valid_mask is a numpy array of shape (H, W).
+        """
+        return cast(T2, super().__getitem__(index))
+
+
+class InStereo2k(StereoMatchingDataset):
+    """`InStereo2k <https://github.com/YuhuaXu/StereoDataset>`_ dataset.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            InStereo2k
+                train
+                    scene1
+                        left.png
+                        right.png
+                        left_disp.png
+                        right_disp.png
+                        ...
+                    scene2
+                    ...
+                test
+                    scene1
+                        left.png
+                        right.png
+                        left_disp.png
+                        right_disp.png
+                        ...
+                    scene2
+                    ...
+
+    Args:
+        root (string): Root directory where InStereo2k is located.
+        split (string): Either "train" or "test".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        root = Path(root) / "InStereo2k" / split
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        left_img_pattern = str(root / "*" / "left.png")
+        right_img_pattern = str(root / "*" / "right.png")
+        self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+        left_disparity_pattern = str(root / "*" / "left_disp.png")
+        right_disparity_pattern = str(root / "*" / "right_disp.png")
+        self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
+
+    def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
+        disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
+        # unsqueeze disparity to (C, H, W)
+        disparity_map = disparity_map[None, :, :] / 1024.0
+        valid_mask = None
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T1:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            If a ``valid_mask`` is generated within the ``transforms`` parameter,
+            a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
+        """
+        return cast(T1, super().__getitem__(index))
+
+
+class ETH3DStereo(StereoMatchingDataset):
+    """ETH3D `Low-Res Two-View <https://www.eth3d.net/datasets>`_ dataset.
+
+    The dataset is expected to have the following structure: ::
+
+        root
+            ETH3D
+                two_view_training
+                    scene1
+                        im1.png
+                        im0.png
+                        images.txt
+                        cameras.txt
+                        calib.txt
+                    scene2
+                        im1.png
+                        im0.png
+                        images.txt
+                        cameras.txt
+                        calib.txt
+                    ...
+                two_view_training_gt
+                    scene1
+                        disp0GT.pfm
+                        mask0nocc.png
+                    scene2
+                        disp0GT.pfm
+                        mask0nocc.png
+                    ...
+                two_view_testing
+                    scene1
+                        im1.png
+                        im0.png
+                        images.txt
+                        cameras.txt
+                        calib.txt
+                    scene2
+                        im1.png
+                        im0.png
+                        images.txt
+                        cameras.txt
+                        calib.txt
+                    ...
+
+    Args:
+        root (string): Root directory of the ETH3D Dataset.
+        split (string, optional): The dataset split of scenes, either "train" (default) or "test".
+        transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
+    """
+
+    _has_built_in_disparity_mask = True
+
+    def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
+        super().__init__(root, transforms)
+
+        verify_str_arg(split, "split", valid_values=("train", "test"))
+
+        root = Path(root) / "ETH3D"
+
+        img_dir = "two_view_training" if split == "train" else "two_view_test"
+        anot_dir = "two_view_training_gt"
+
+        left_img_pattern = str(root / img_dir / "*" / "im0.png")
+        right_img_pattern = str(root / img_dir / "*" / "im1.png")
+        self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
+
+        if split == "test":
+            self._disparities = list((None, None) for _ in self._images)
+        else:
+            disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
+            self._disparities = self._scan_pairs(disparity_pattern, None)
+
+    def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
+        # test split has no disparity maps
+        if file_path is None:
+            return None, None
+
+        disparity_map = _read_pfm_file(file_path)
+        disparity_map = np.abs(disparity_map)  # ensure that the disparity is positive
+        mask_path = Path(file_path).parent / "mask0nocc.png"
+        valid_mask = Image.open(mask_path)
+        valid_mask = np.asarray(valid_mask).astype(bool)
+        return disparity_map, valid_mask
+
+    def __getitem__(self, index: int) -> T2:
+        """Return example at given index.
+
+        Args:
+            index(int): The index of the example to retrieve
+
+        Returns:
+            tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
+            The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
+            ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
+            generate a valid mask.
+            Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
+        """
+        return cast(T2, super().__getitem__(index))

+ 241 - 0
libs/vision_libs/datasets/caltech.py

@@ -0,0 +1,241 @@
+import os
+import os.path
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Caltech101(VisionDataset):
+    """`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset.
+
+    .. warning::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``caltech101`` exists or will be saved to if download is set to True.
+        target_type (string or list, optional): Type of target to use, ``category`` or
+            ``annotation``. Can also be a list to output a tuple with all specified
+            target types.  ``category`` represents the target class, and
+            ``annotation`` is a list of points from a hand-generated outline.
+            Defaults to ``category``.
+        transform (callable, optional): A function/transform that takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+            .. warning::
+
+                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        target_type: Union[List[str], str] = "category",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
+        os.makedirs(self.root, exist_ok=True)
+        if isinstance(target_type, str):
+            target_type = [target_type]
+        self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
+        self.categories.remove("BACKGROUND_Google")  # this is not a real class
+
+        # For some reason, the category names in "101_ObjectCategories" and
+        # "Annotations" do not always match. This is a manual map between the
+        # two. Defaults to using same name, since most names are fine.
+        name_map = {
+            "Faces": "Faces_2",
+            "Faces_easy": "Faces_3",
+            "Motorbikes": "Motorbikes_16",
+            "airplanes": "Airplanes_Side_2",
+        }
+        self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
+
+        self.index: List[int] = []
+        self.y = []
+        for (i, c) in enumerate(self.categories):
+            n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
+            self.index.extend(range(1, n + 1))
+            self.y.extend(n * [i])
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where the type of target specified by target_type.
+        """
+        import scipy.io
+
+        img = Image.open(
+            os.path.join(
+                self.root,
+                "101_ObjectCategories",
+                self.categories[self.y[index]],
+                f"image_{self.index[index]:04d}.jpg",
+            )
+        )
+
+        target: Any = []
+        for t in self.target_type:
+            if t == "category":
+                target.append(self.y[index])
+            elif t == "annotation":
+                data = scipy.io.loadmat(
+                    os.path.join(
+                        self.root,
+                        "Annotations",
+                        self.annotation_categories[self.y[index]],
+                        f"annotation_{self.index[index]:04d}.mat",
+                    )
+                )
+                target.append(data["obj_contour"])
+        target = tuple(target) if len(target) > 1 else target[0]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def _check_integrity(self) -> bool:
+        # can be more robust and check hash of files
+        return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
+
+    def __len__(self) -> int:
+        return len(self.index)
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        download_and_extract_archive(
+            "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
+            self.root,
+            filename="101_ObjectCategories.tar.gz",
+            md5="b224c7392d521a49829488ab0f1120d9",
+        )
+        download_and_extract_archive(
+            "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
+            self.root,
+            filename="Annotations.tar",
+            md5="6f83eeb1f24d99cab4eb377263132c91",
+        )
+
+    def extra_repr(self) -> str:
+        return "Target type: {target_type}".format(**self.__dict__)
+
+
+class Caltech256(VisionDataset):
+    """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``caltech256`` exists or will be saved to if download is set to True.
+        transform (callable, optional): A function/transform that takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
+        os.makedirs(self.root, exist_ok=True)
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
+        self.index: List[int] = []
+        self.y = []
+        for (i, c) in enumerate(self.categories):
+            n = len(
+                [
+                    item
+                    for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
+                    if item.endswith(".jpg")
+                ]
+            )
+            self.index.extend(range(1, n + 1))
+            self.y.extend(n * [i])
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img = Image.open(
+            os.path.join(
+                self.root,
+                "256_ObjectCategories",
+                self.categories[self.y[index]],
+                f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
+            )
+        )
+
+        target = self.y[index]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def _check_integrity(self) -> bool:
+        # can be more robust and check hash of files
+        return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
+
+    def __len__(self) -> int:
+        return len(self.index)
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        download_and_extract_archive(
+            "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
+            self.root,
+            filename="256_ObjectCategories.tar",
+            md5="67b4f42ca05d46448c6bb8ecd2220f6d",
+        )

+ 193 - 0
libs/vision_libs/datasets/celeba.py

@@ -0,0 +1,193 @@
+import csv
+import os
+from collections import namedtuple
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import PIL
+import torch
+
+from .utils import check_integrity, download_file_from_google_drive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+CSV = namedtuple("CSV", ["header", "index", "data"])
+
+
+class CelebA(VisionDataset):
+    """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        split (string): One of {'train', 'valid', 'test', 'all'}.
+            Accordingly dataset is selected.
+        target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
+            or ``landmarks``. Can also be a list to output a tuple with all specified target types.
+            The targets represent:
+
+                - ``attr`` (Tensor shape=(40,) dtype=int): binary (0, 1) labels for attributes
+                - ``identity`` (int): label for each person (data points with the same identity are the same person)
+                - ``bbox`` (Tensor shape=(4,) dtype=int): bounding box (x, y, width, height)
+                - ``landmarks`` (Tensor shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
+                  righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
+
+            Defaults to ``attr``. If empty, ``None`` will be returned as target.
+
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+            .. warning::
+
+                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
+    """
+
+    base_folder = "celeba"
+    # There currently does not appear to be an easy way to extract 7z in python (without introducing additional
+    # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
+    # right now.
+    file_list = [
+        # File ID                                      MD5 Hash                            Filename
+        ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
+        # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
+        # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
+        ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
+        ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
+        ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
+        ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
+        # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
+        ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
+    ]
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        target_type: Union[List[str], str] = "attr",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.split = split
+        if isinstance(target_type, list):
+            self.target_type = target_type
+        else:
+            self.target_type = [target_type]
+
+        if not self.target_type and self.target_transform is not None:
+            raise RuntimeError("target_transform is specified but target_type is empty")
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        split_map = {
+            "train": 0,
+            "valid": 1,
+            "test": 2,
+            "all": None,
+        }
+        split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
+        splits = self._load_csv("list_eval_partition.txt")
+        identity = self._load_csv("identity_CelebA.txt")
+        bbox = self._load_csv("list_bbox_celeba.txt", header=1)
+        landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
+        attr = self._load_csv("list_attr_celeba.txt", header=1)
+
+        mask = slice(None) if split_ is None else (splits.data == split_).squeeze()
+
+        if mask == slice(None):  # if split == "all"
+            self.filename = splits.index
+        else:
+            self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
+        self.identity = identity.data[mask]
+        self.bbox = bbox.data[mask]
+        self.landmarks_align = landmarks_align.data[mask]
+        self.attr = attr.data[mask]
+        # map from {-1, 1} to {0, 1}
+        self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
+        self.attr_names = attr.header
+
+    def _load_csv(
+        self,
+        filename: str,
+        header: Optional[int] = None,
+    ) -> CSV:
+        with open(os.path.join(self.root, self.base_folder, filename)) as csv_file:
+            data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))
+
+        if header is not None:
+            headers = data[header]
+            data = data[header + 1 :]
+        else:
+            headers = []
+
+        indices = [row[0] for row in data]
+        data = [row[1:] for row in data]
+        data_int = [list(map(int, i)) for i in data]
+
+        return CSV(headers, indices, torch.tensor(data_int))
+
+    def _check_integrity(self) -> bool:
+        for (_, md5, filename) in self.file_list:
+            fpath = os.path.join(self.root, self.base_folder, filename)
+            _, ext = os.path.splitext(filename)
+            # Allow original archive to be deleted (zip and 7z)
+            # Only need the extracted images
+            if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
+                return False
+
+        # Should check a hash of the images
+        return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        for (file_id, md5, filename) in self.file_list:
+            download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
+
+        extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
+
+        target: Any = []
+        for t in self.target_type:
+            if t == "attr":
+                target.append(self.attr[index, :])
+            elif t == "identity":
+                target.append(self.identity[index, 0])
+            elif t == "bbox":
+                target.append(self.bbox[index, :])
+            elif t == "landmarks":
+                target.append(self.landmarks_align[index, :])
+            else:
+                # TODO: refactor with utils.verify_str_arg
+                raise ValueError(f'Target type "{t}" is not recognized.')
+
+        if self.transform is not None:
+            X = self.transform(X)
+
+        if target:
+            target = tuple(target) if len(target) > 1 else target[0]
+
+            if self.target_transform is not None:
+                target = self.target_transform(target)
+        else:
+            target = None
+
+        return X, target
+
+    def __len__(self) -> int:
+        return len(self.attr)
+
+    def extra_repr(self) -> str:
+        lines = ["Target type: {target_type}", "Split: {split}"]
+        return "\n".join(lines).format(**self.__dict__)

+ 167 - 0
libs/vision_libs/datasets/cifar.py

@@ -0,0 +1,167 @@
+import os.path
+import pickle
+from typing import Any, Callable, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive
+from .vision import VisionDataset
+
+
+class CIFAR10(VisionDataset):
+    """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
+        train (bool, optional): If True, creates dataset from training set, otherwise
+            creates from test set.
+        transform (callable, optional): A function/transform that takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    base_folder = "cifar-10-batches-py"
+    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
+    filename = "cifar-10-python.tar.gz"
+    tgz_md5 = "c58f30108f718f92721af3b95e74349a"
+    train_list = [
+        ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
+        ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
+        ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
+        ["data_batch_4", "634d18415352ddfa80567beed471001a"],
+        ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
+    ]
+
+    test_list = [
+        ["test_batch", "40351d587109b95175f43aff81a1287e"],
+    ]
+    meta = {
+        "filename": "batches.meta",
+        "key": "label_names",
+        "md5": "5ff9c542aee3614f3951f8cda6e48888",
+    }
+
+    def __init__(
+        self,
+        root: str,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self.train = train  # training set or test set
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        if self.train:
+            downloaded_list = self.train_list
+        else:
+            downloaded_list = self.test_list
+
+        self.data: Any = []
+        self.targets = []
+
+        # now load the picked numpy arrays
+        for file_name, checksum in downloaded_list:
+            file_path = os.path.join(self.root, self.base_folder, file_name)
+            with open(file_path, "rb") as f:
+                entry = pickle.load(f, encoding="latin1")
+                self.data.append(entry["data"])
+                if "labels" in entry:
+                    self.targets.extend(entry["labels"])
+                else:
+                    self.targets.extend(entry["fine_labels"])
+
+        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
+        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
+
+        self._load_meta()
+
+    def _load_meta(self) -> None:
+        path = os.path.join(self.root, self.base_folder, self.meta["filename"])
+        if not check_integrity(path, self.meta["md5"]):
+            raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
+        with open(path, "rb") as infile:
+            data = pickle.load(infile, encoding="latin1")
+            self.classes = data[self.meta["key"]]
+        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], self.targets[index]
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(img)
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    def _check_integrity(self) -> bool:
+        for filename, md5 in self.train_list + self.test_list:
+            fpath = os.path.join(self.root, self.base_folder, filename)
+            if not check_integrity(fpath, md5):
+                return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
+
+    def extra_repr(self) -> str:
+        split = "Train" if self.train is True else "Test"
+        return f"Split: {split}"
+
+
+class CIFAR100(CIFAR10):
+    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
+
+    This is a subclass of the `CIFAR10` Dataset.
+    """
+
+    base_folder = "cifar-100-python"
+    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
+    filename = "cifar-100-python.tar.gz"
+    tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
+    train_list = [
+        ["train", "16019d7e3df5f24257cddd939b257f8d"],
+    ]
+
+    test_list = [
+        ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
+    ]
+    meta = {
+        "filename": "meta",
+        "key": "fine_label_names",
+        "md5": "7973b15100ade9c7d40fb424638fde48",
+    }

+ 221 - 0
libs/vision_libs/datasets/cityscapes.py

@@ -0,0 +1,221 @@
+import json
+import os
+from collections import namedtuple
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import extract_archive, iterable_to_str, verify_str_arg
+from .vision import VisionDataset
+
+
+class Cityscapes(VisionDataset):
+    """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory ``leftImg8bit``
+            and ``gtFine`` or ``gtCoarse`` are located.
+        split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
+            otherwise ``train``, ``train_extra`` or ``val``
+        mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
+        target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
+            or ``color``. Can also be a list to output a tuple with all specified target types.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+
+    Examples:
+
+        Get semantic segmentation target
+
+        .. code-block:: python
+
+            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
+                                 target_type='semantic')
+
+            img, smnt = dataset[0]
+
+        Get multiple targets
+
+        .. code-block:: python
+
+            dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
+                                 target_type=['instance', 'color', 'polygon'])
+
+            img, (inst, col, poly) = dataset[0]
+
+        Validate on the "coarse" set
+
+        .. code-block:: python
+
+            dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
+                                 target_type='semantic')
+
+            img, smnt = dataset[0]
+    """
+
+    # Based on https://github.com/mcordts/cityscapesScripts
+    CityscapesClass = namedtuple(
+        "CityscapesClass",
+        ["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"],
+    )
+
+    classes = [
+        CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)),
+        CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)),
+        CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)),
+        CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)),
+        CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)),
+        CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)),
+        CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)),
+        CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)),
+        CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)),
+        CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)),
+        CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)),
+        CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)),
+        CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)),
+        CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)),
+        CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)),
+        CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)),
+        CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)),
+        CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)),
+        CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)),
+        CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)),
+        CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)),
+        CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)),
+        CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)),
+        CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)),
+        CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)),
+        CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)),
+        CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)),
+        CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)),
+        CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)),
+        CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)),
+        CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)),
+    ]
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        mode: str = "fine",
+        target_type: Union[List[str], str] = "instance",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transforms, transform, target_transform)
+        self.mode = "gtFine" if mode == "fine" else "gtCoarse"
+        self.images_dir = os.path.join(self.root, "leftImg8bit", split)
+        self.targets_dir = os.path.join(self.root, self.mode, split)
+        self.target_type = target_type
+        self.split = split
+        self.images = []
+        self.targets = []
+
+        verify_str_arg(mode, "mode", ("fine", "coarse"))
+        if mode == "fine":
+            valid_modes = ("train", "test", "val")
+        else:
+            valid_modes = ("train", "train_extra", "val")
+        msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
+        msg = msg.format(split, mode, iterable_to_str(valid_modes))
+        verify_str_arg(split, "split", valid_modes, msg)
+
+        if not isinstance(target_type, list):
+            self.target_type = [target_type]
+        [
+            verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color"))
+            for value in self.target_type
+        ]
+
+        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
+
+            if split == "train_extra":
+                image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip")
+            else:
+                image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip")
+
+            if self.mode == "gtFine":
+                target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip")
+            elif self.mode == "gtCoarse":
+                target_dir_zip = os.path.join(self.root, f"{self.mode}.zip")
+
+            if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
+                extract_archive(from_path=image_dir_zip, to_path=self.root)
+                extract_archive(from_path=target_dir_zip, to_path=self.root)
+            else:
+                raise RuntimeError(
+                    "Dataset not found or incomplete. Please make sure all required folders for the"
+                    ' specified "split" and "mode" are inside the "root" directory'
+                )
+
+        for city in os.listdir(self.images_dir):
+            img_dir = os.path.join(self.images_dir, city)
+            target_dir = os.path.join(self.targets_dir, city)
+            for file_name in os.listdir(img_dir):
+                target_types = []
+                for t in self.target_type:
+                    target_name = "{}_{}".format(
+                        file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t)
+                    )
+                    target_types.append(os.path.join(target_dir, target_name))
+
+                self.images.append(os.path.join(img_dir, file_name))
+                self.targets.append(target_types)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+        Returns:
+            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
+            than one item. Otherwise, target is a json object if target_type="polygon", else the image segmentation.
+        """
+
+        image = Image.open(self.images[index]).convert("RGB")
+
+        targets: Any = []
+        for i, t in enumerate(self.target_type):
+            if t == "polygon":
+                target = self._load_json(self.targets[index][i])
+            else:
+                target = Image.open(self.targets[index][i])
+
+            targets.append(target)
+
+        target = tuple(targets) if len(targets) > 1 else targets[0]
+
+        if self.transforms is not None:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def __len__(self) -> int:
+        return len(self.images)
+
+    def extra_repr(self) -> str:
+        lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
+        return "\n".join(lines).format(**self.__dict__)
+
+    def _load_json(self, path: str) -> Dict[str, Any]:
+        with open(path) as file:
+            data = json.load(file)
+        return data
+
+    def _get_target_suffix(self, mode: str, target_type: str) -> str:
+        if target_type == "instance":
+            return f"{mode}_instanceIds.png"
+        elif target_type == "semantic":
+            return f"{mode}_labelIds.png"
+        elif target_type == "color":
+            return f"{mode}_color.png"
+        else:
+            return f"{mode}_polygons.json"

+ 88 - 0
libs/vision_libs/datasets/clevr.py

@@ -0,0 +1,88 @@
+import json
+import pathlib
+from typing import Any, Callable, List, Optional, Tuple
+from urllib.parse import urlparse
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class CLEVRClassification(VisionDataset):
+    """`CLEVR <https://cs.stanford.edu/people/jcjohns/clevr/>`_  classification dataset.
+
+    The number of objects in a scene are used as label.
+
+    Args:
+        root (string): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
+            set to True.
+        split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
+        transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in them target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If
+            dataset is already downloaded, it is not downloaded again.
+    """
+
+    _URL = "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip"
+    _MD5 = "b11922020e72d0cd9154779b2d3d07d2"
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._base_folder = pathlib.Path(self.root) / "clevr"
+        self._data_folder = self._base_folder / pathlib.Path(urlparse(self._URL).path).stem
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self._image_files = sorted(self._data_folder.joinpath("images", self._split).glob("*"))
+
+        self._labels: List[Optional[int]]
+        if self._split != "test":
+            with open(self._data_folder / "scenes" / f"CLEVR_{self._split}_scenes.json") as file:
+                content = json.load(file)
+            num_objects = {scene["image_filename"]: len(scene["objects"]) for scene in content["scenes"]}
+            self._labels = [num_objects[image_file.name] for image_file in self._image_files]
+        else:
+            self._labels = [None] * len(self._image_files)
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file = self._image_files[idx]
+        label = self._labels[idx]
+
+        image = Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def _check_exists(self) -> bool:
+        return self._data_folder.exists() and self._data_folder.is_dir()
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+
+        download_and_extract_archive(self._URL, str(self._base_folder), md5=self._MD5)
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"

+ 104 - 0
libs/vision_libs/datasets/coco.py

@@ -0,0 +1,104 @@
+import os.path
+from typing import Any, Callable, List, Optional, Tuple
+
+from PIL import Image
+
+from .vision import VisionDataset
+
+
+class CocoDetection(VisionDataset):
+    """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
+
+    It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        annFile: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        transforms: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transforms, transform, target_transform)
+        from pycocotools.coco import COCO
+
+        self.coco = COCO(annFile)
+        self.ids = list(sorted(self.coco.imgs.keys()))
+
+    def _load_image(self, id: int) -> Image.Image:
+        path = self.coco.loadImgs(id)[0]["file_name"]
+        return Image.open(os.path.join(self.root, path)).convert("RGB")
+
+    def _load_target(self, id: int) -> List[Any]:
+        return self.coco.loadAnns(self.coco.getAnnIds(id))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        id = self.ids[index]
+        image = self._load_image(id)
+        target = self._load_target(id)
+
+        if self.transforms is not None:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def __len__(self) -> int:
+        return len(self.ids)
+
+
+class CocoCaptions(CocoDetection):
+    """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
+
+    It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        annFile (string): Path to json annotation file.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+
+    Example:
+
+        .. code:: python
+
+            import torchvision.datasets as dset
+            import torchvision.transforms as transforms
+            cap = dset.CocoCaptions(root = 'dir where images are',
+                                    annFile = 'json annotation file',
+                                    transform=transforms.PILToTensor())
+
+            print('Number of samples: ', len(cap))
+            img, target = cap[3] # load 4th sample
+
+            print("Image Size: ", img.size())
+            print(target)
+
+        Output: ::
+
+            Number of samples: 82783
+            Image Size: (3L, 427L, 640L)
+            [u'A plane emitting smoke stream flying over a mountain.',
+            u'A plane darts across a bright blue sky behind a mountain covered in snow',
+            u'A plane leaves a contrail above the snowy mountain top.',
+            u'A mountain that has a plane flying overheard in the distance.',
+            u'A mountain view with a plume of smoke in the background']
+
+    """
+
+    def _load_target(self, id: int) -> List[str]:
+        return [ann["caption"] for ann in super()._load_target(id)]

+ 58 - 0
libs/vision_libs/datasets/country211.py

@@ -0,0 +1,58 @@
+from pathlib import Path
+from typing import Callable, Optional
+
+from .folder import ImageFolder
+from .utils import download_and_extract_archive, verify_str_arg
+
+
+class Country211(ImageFolder):
+    """`The Country211 Data Set <https://github.com/openai/CLIP/blob/main/data/country211.md>`_ from OpenAI.
+
+    This dataset was built by filtering the images from the YFCC100m dataset
+    that have GPS coordinate corresponding to a ISO-3166 country code. The
+    dataset is balanced by sampling 150 train images, 50 validation images, and
+    100 test images for each country.
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
+        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and puts it into
+            ``root/country211/``. If dataset is already downloaded, it is not downloaded again.
+    """
+
+    _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
+    _MD5 = "84988d7644798601126c29e9877aab6a"
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
+
+        root = Path(root).expanduser()
+        self.root = str(root)
+        self._base_folder = root / "country211"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform)
+        self.root = str(root)
+
+    def _check_exists(self) -> bool:
+        return self._base_folder.exists() and self._base_folder.is_dir()
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

+ 100 - 0
libs/vision_libs/datasets/dtd.py

@@ -0,0 +1,100 @@
+import os
+import pathlib
+from typing import Any, Callable, Optional, Tuple
+
+import PIL.Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class DTD(VisionDataset):
+    """`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
+        partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
+
+            .. note::
+
+                The partition only changes which split each image belongs to. Thus, regardless of the selected
+                partition, combining all splits will result in all images.
+
+        transform (callable, optional): A function/transform that  takes in a PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again. Default is False.
+    """
+
+    _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
+    _MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1"
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        partition: int = 1,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+        if not isinstance(partition, int) and not (1 <= partition <= 10):
+            raise ValueError(
+                f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
+                f"but got {partition} instead"
+            )
+        self._partition = partition
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
+        self._data_folder = self._base_folder / "dtd"
+        self._meta_folder = self._data_folder / "labels"
+        self._images_folder = self._data_folder / "images"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self._image_files = []
+        classes = []
+        with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
+            for line in file:
+                cls, name = line.strip().split("/")
+                self._image_files.append(self._images_folder.joinpath(cls, name))
+                classes.append(cls)
+
+        self.classes = sorted(set(classes))
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+        self._labels = [self.class_to_idx[cls] for cls in classes]
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = PIL.Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}, partition={self._partition}"
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)

+ 58 - 0
libs/vision_libs/datasets/eurosat.py

@@ -0,0 +1,58 @@
+import os
+from typing import Callable, Optional
+
+from .folder import ImageFolder
+from .utils import download_and_extract_archive
+
+
+class EuroSAT(ImageFolder):
+    """RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where ``root/eurosat`` exists.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again. Default is False.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        self.root = os.path.expanduser(root)
+        self._base_folder = os.path.join(self.root, "eurosat")
+        self._data_folder = os.path.join(self._base_folder, "2750")
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
+        self.root = os.path.expanduser(root)
+
+    def __len__(self) -> int:
+        return len(self.samples)
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(self._data_folder)
+
+    def download(self) -> None:
+
+        if self._check_exists():
+            return
+
+        os.makedirs(self._base_folder, exist_ok=True)
+        download_and_extract_archive(
+            "https://madm.dfki.de/files/sentinel/EuroSAT.zip",
+            download_root=self._base_folder,
+            md5="c8fa014336c82ac7804f0398fcb19387",
+        )

+ 67 - 0
libs/vision_libs/datasets/fakedata.py

@@ -0,0 +1,67 @@
+from typing import Any, Callable, Optional, Tuple
+
+import torch
+
+from .. import transforms
+from .vision import VisionDataset
+
+
+class FakeData(VisionDataset):
+    """A fake dataset that returns randomly generated images and returns them as PIL images
+
+    Args:
+        size (int, optional): Size of the dataset. Default: 1000 images
+        image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
+        num_classes(int, optional): Number of classes in the dataset. Default: 10
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        random_offset (int): Offsets the index-based random seed used to
+            generate each image. Default: 0
+
+    """
+
+    def __init__(
+        self,
+        size: int = 1000,
+        image_size: Tuple[int, int, int] = (3, 224, 224),
+        num_classes: int = 10,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        random_offset: int = 0,
+    ) -> None:
+        super().__init__(transform=transform, target_transform=target_transform)
+        self.size = size
+        self.num_classes = num_classes
+        self.image_size = image_size
+        self.random_offset = random_offset
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is class_index of the target class.
+        """
+        # create random image that is consistent with the index id
+        if index >= len(self):
+            raise IndexError(f"{self.__class__.__name__} index out of range")
+        rng_state = torch.get_rng_state()
+        torch.manual_seed(index + self.random_offset)
+        img = torch.randn(*self.image_size)
+        target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
+        torch.set_rng_state(rng_state)
+
+        # convert to PIL Image
+        img = transforms.ToPILImage()(img)
+        if self.transform is not None:
+            img = self.transform(img)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target.item()
+
+    def __len__(self) -> int:
+        return self.size

+ 75 - 0
libs/vision_libs/datasets/fer2013.py

@@ -0,0 +1,75 @@
+import csv
+import pathlib
+from typing import Any, Callable, Optional, Tuple
+
+import torch
+from PIL import Image
+
+from .utils import check_integrity, verify_str_arg
+from .vision import VisionDataset
+
+
+class FER2013(VisionDataset):
+    """`FER2013
+    <https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``root/fer2013`` exists.
+        split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
+        transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+    """
+
+    _RESOURCES = {
+        "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
+        "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
+    }
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        self._split = verify_str_arg(split, "split", self._RESOURCES.keys())
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        base_folder = pathlib.Path(self.root) / "fer2013"
+        file_name, md5 = self._RESOURCES[self._split]
+        data_file = base_folder / file_name
+        if not check_integrity(str(data_file), md5=md5):
+            raise RuntimeError(
+                f"{file_name} not found in {base_folder} or corrupted. "
+                f"You can download it from "
+                f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
+            )
+
+        with open(data_file, "r", newline="") as file:
+            self._samples = [
+                (
+                    torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48),
+                    int(row["emotion"]) if "emotion" in row else None,
+                )
+                for row in csv.DictReader(file)
+            ]
+
+    def __len__(self) -> int:
+        return len(self._samples)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_tensor, target = self._samples[idx]
+        image = Image.fromarray(image_tensor.numpy())
+
+        if self.transform is not None:
+            image = self.transform(image)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return image, target
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"

+ 114 - 0
libs/vision_libs/datasets/fgvc_aircraft.py

@@ -0,0 +1,114 @@
+from __future__ import annotations
+
+import os
+from typing import Any, Callable, Optional, Tuple
+
+import PIL.Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class FGVCAircraft(VisionDataset):
+    """`FGVC Aircraft <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
+
+    The dataset contains 10,000 images of aircraft, with 100 images for each of 100
+    different aircraft model variants, most of which are airplanes.
+    Aircraft models are organized in a three-levels hierarchy. The three levels, from
+    finer to coarser, are:
+
+    - ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually
+        indistinguishable into one class. The dataset comprises 100 different variants.
+    - ``family``, e.g. Boeing 737. The dataset comprises 70 different families.
+    - ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers.
+
+    Args:
+        root (string): Root directory of the FGVC Aircraft dataset.
+        split (string, optional): The dataset split, supports ``train``, ``val``,
+            ``trainval`` and ``test``.
+        annotation_level (str, optional): The annotation level, supports ``variant``,
+            ``family`` and ``manufacturer``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    _URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "trainval",
+        annotation_level: str = "variant",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
+        self._annotation_level = verify_str_arg(
+            annotation_level, "annotation_level", ("variant", "family", "manufacturer")
+        )
+
+        self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        annotation_file = os.path.join(
+            self._data_path,
+            "data",
+            {
+                "variant": "variants.txt",
+                "family": "families.txt",
+                "manufacturer": "manufacturers.txt",
+            }[self._annotation_level],
+        )
+        with open(annotation_file, "r") as f:
+            self.classes = [line.strip() for line in f]
+
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+
+        image_data_folder = os.path.join(self._data_path, "data", "images")
+        labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt")
+
+        self._image_files = []
+        self._labels = []
+
+        with open(labels_file, "r") as f:
+            for line in f:
+                image_name, label_name = line.strip().split(" ", 1)
+                self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
+                self._labels.append(self.class_to_idx[label_name])
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = PIL.Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def _download(self) -> None:
+        """
+        Download the FGVC Aircraft dataset archive and extract it under root.
+        """
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, self.root)
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(self._data_path) and os.path.isdir(self._data_path)

+ 166 - 0
libs/vision_libs/datasets/flickr.py

@@ -0,0 +1,166 @@
+import glob
+import os
+from collections import defaultdict
+from html.parser import HTMLParser
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+from PIL import Image
+
+from .vision import VisionDataset
+
+
+class Flickr8kParser(HTMLParser):
+    """Parser for extracting captions from the Flickr8k dataset web page."""
+
+    def __init__(self, root: str) -> None:
+        super().__init__()
+
+        self.root = root
+
+        # Data structure to store captions
+        self.annotations: Dict[str, List[str]] = {}
+
+        # State variables
+        self.in_table = False
+        self.current_tag: Optional[str] = None
+        self.current_img: Optional[str] = None
+
+    def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
+        self.current_tag = tag
+
+        if tag == "table":
+            self.in_table = True
+
+    def handle_endtag(self, tag: str) -> None:
+        self.current_tag = None
+
+        if tag == "table":
+            self.in_table = False
+
+    def handle_data(self, data: str) -> None:
+        if self.in_table:
+            if data == "Image Not Found":
+                self.current_img = None
+            elif self.current_tag == "a":
+                img_id = data.split("/")[-2]
+                img_id = os.path.join(self.root, img_id + "_*.jpg")
+                img_id = glob.glob(img_id)[0]
+                self.current_img = img_id
+                self.annotations[img_id] = []
+            elif self.current_tag == "li" and self.current_img:
+                img_id = self.current_img
+                self.annotations[img_id].append(data.strip())
+
+
+class Flickr8k(VisionDataset):
+    """`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        ann_file (string): Path to annotation file.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        ann_file: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.ann_file = os.path.expanduser(ann_file)
+
+        # Read annotations and store in a dict
+        parser = Flickr8kParser(self.root)
+        with open(self.ann_file) as fh:
+            parser.feed(fh.read())
+        self.annotations = parser.annotations
+
+        self.ids = list(sorted(self.annotations.keys()))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is a list of captions for the image.
+        """
+        img_id = self.ids[index]
+
+        # Image
+        img = Image.open(img_id).convert("RGB")
+        if self.transform is not None:
+            img = self.transform(img)
+
+        # Captions
+        target = self.annotations[img_id]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.ids)
+
+
+class Flickr30k(VisionDataset):
+    """`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+        ann_file (string): Path to annotation file.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        ann_file: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.ann_file = os.path.expanduser(ann_file)
+
+        # Read annotations and store in a dict
+        self.annotations = defaultdict(list)
+        with open(self.ann_file) as fh:
+            for line in fh:
+                img_id, caption = line.strip().split("\t")
+                self.annotations[img_id[:-2]].append(caption)
+
+        self.ids = list(sorted(self.annotations.keys()))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target). target is a list of captions for the image.
+        """
+        img_id = self.ids[index]
+
+        # Image
+        filename = os.path.join(self.root, img_id)
+        img = Image.open(filename).convert("RGB")
+        if self.transform is not None:
+            img = self.transform(img)
+
+        # Captions
+        target = self.annotations[img_id]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.ids)

+ 114 - 0
libs/vision_libs/datasets/flowers102.py

@@ -0,0 +1,114 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple
+
+import PIL.Image
+
+from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class Flowers102(VisionDataset):
+    """`Oxford 102 Flower <https://www.robots.ox.ac.uk/~vgg/data/flowers/102/>`_ Dataset.
+
+    .. warning::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+    Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The
+    flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of
+    between 40 and 258 images.
+
+    The images have large scale, pose and light variations. In addition, there are categories that
+    have large variations within the category, and several very similar categories.
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
+        transform (callable, optional): A function/transform that takes in an PIL image and returns a
+            transformed version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
+    _file_dict = {  # filename, md5
+        "image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),
+        "label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),
+        "setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),
+    }
+    _splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+        self._base_folder = Path(self.root) / "flowers-102"
+        self._images_folder = self._base_folder / "jpg"
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        from scipy.io import loadmat
+
+        set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)
+        image_ids = set_ids[self._splits_map[self._split]].tolist()
+
+        labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True)
+        image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1))
+
+        self._labels = []
+        self._image_files = []
+        for image_id in image_ids:
+            self._labels.append(image_id_to_label[image_id])
+            self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = PIL.Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"
+
+    def _check_integrity(self):
+        if not (self._images_folder.exists() and self._images_folder.is_dir()):
+            return False
+
+        for id in ["label", "setid"]:
+            filename, md5 = self._file_dict[id]
+            if not check_integrity(str(self._base_folder / filename), md5):
+                return False
+        return True
+
+    def download(self):
+        if self._check_integrity():
+            return
+        download_and_extract_archive(
+            f"{self._download_url_prefix}{self._file_dict['image'][0]}",
+            str(self._base_folder),
+            md5=self._file_dict["image"][1],
+        )
+        for id in ["label", "setid"]:
+            filename, md5 = self._file_dict[id]
+            download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)

+ 317 - 0
libs/vision_libs/datasets/folder.py

@@ -0,0 +1,317 @@
+import os
+import os.path
+from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .vision import VisionDataset
+
+
+def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
+    """Checks if a file is an allowed extension.
+
+    Args:
+        filename (string): path to a file
+        extensions (tuple of strings): extensions to consider (lowercase)
+
+    Returns:
+        bool: True if the filename ends with one of given extensions
+    """
+    return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
+
+
+def is_image_file(filename: str) -> bool:
+    """Checks if a file is an allowed image extension.
+
+    Args:
+        filename (string): path to a file
+
+    Returns:
+        bool: True if the filename ends with a known image extension
+    """
+    return has_file_allowed_extension(filename, IMG_EXTENSIONS)
+
+
+def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
+    """Finds the class folders in a dataset.
+
+    See :class:`DatasetFolder` for details.
+    """
+    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
+    if not classes:
+        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
+
+    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
+    return classes, class_to_idx
+
+
+def make_dataset(
+    directory: str,
+    class_to_idx: Optional[Dict[str, int]] = None,
+    extensions: Optional[Union[str, Tuple[str, ...]]] = None,
+    is_valid_file: Optional[Callable[[str], bool]] = None,
+) -> List[Tuple[str, int]]:
+    """Generates a list of samples of a form (path_to_sample, class).
+
+    See :class:`DatasetFolder` for details.
+
+    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
+    by default.
+    """
+    directory = os.path.expanduser(directory)
+
+    if class_to_idx is None:
+        _, class_to_idx = find_classes(directory)
+    elif not class_to_idx:
+        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
+
+    both_none = extensions is None and is_valid_file is None
+    both_something = extensions is not None and is_valid_file is not None
+    if both_none or both_something:
+        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
+
+    if extensions is not None:
+
+        def is_valid_file(x: str) -> bool:
+            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]
+
+    is_valid_file = cast(Callable[[str], bool], is_valid_file)
+
+    instances = []
+    available_classes = set()
+    for target_class in sorted(class_to_idx.keys()):
+        class_index = class_to_idx[target_class]
+        target_dir = os.path.join(directory, target_class)
+        if not os.path.isdir(target_dir):
+            continue
+        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
+            for fname in sorted(fnames):
+                path = os.path.join(root, fname)
+                if is_valid_file(path):
+                    item = path, class_index
+                    instances.append(item)
+
+                    if target_class not in available_classes:
+                        available_classes.add(target_class)
+
+    empty_classes = set(class_to_idx.keys()) - available_classes
+    if empty_classes:
+        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
+        if extensions is not None:
+            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
+        raise FileNotFoundError(msg)
+
+    return instances
+
+
+class DatasetFolder(VisionDataset):
+    """A generic data loader.
+
+    This default directory structure can be customized by overriding the
+    :meth:`find_classes` method.
+
+    Args:
+        root (string): Root directory path.
+        loader (callable): A function to load a sample given its path.
+        extensions (tuple[string]): A list of allowed extensions.
+            both extensions and is_valid_file should not be passed.
+        transform (callable, optional): A function/transform that takes in
+            a sample and returns a transformed version.
+            E.g, ``transforms.RandomCrop`` for images.
+        target_transform (callable, optional): A function/transform that takes
+            in the target and transforms it.
+        is_valid_file (callable, optional): A function that takes path of a file
+            and check if the file is a valid file (used to check of corrupt files)
+            both extensions and is_valid_file should not be passed.
+
+     Attributes:
+        classes (list): List of the class names sorted alphabetically.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        samples (list): List of (sample path, class_index) tuples
+        targets (list): The class_index value for each image in the dataset
+    """
+
+    def __init__(
+        self,
+        root: str,
+        loader: Callable[[str], Any],
+        extensions: Optional[Tuple[str, ...]] = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        is_valid_file: Optional[Callable[[str], bool]] = None,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        classes, class_to_idx = self.find_classes(self.root)
+        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
+
+        self.loader = loader
+        self.extensions = extensions
+
+        self.classes = classes
+        self.class_to_idx = class_to_idx
+        self.samples = samples
+        self.targets = [s[1] for s in samples]
+
+    @staticmethod
+    def make_dataset(
+        directory: str,
+        class_to_idx: Dict[str, int],
+        extensions: Optional[Tuple[str, ...]] = None,
+        is_valid_file: Optional[Callable[[str], bool]] = None,
+    ) -> List[Tuple[str, int]]:
+        """Generates a list of samples of a form (path_to_sample, class).
+
+        This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
+
+        Args:
+            directory (str): root dataset directory, corresponding to ``self.root``.
+            class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
+            extensions (optional): A list of allowed extensions.
+                Either extensions or is_valid_file should be passed. Defaults to None.
+            is_valid_file (optional): A function that takes path of a file
+                and checks if the file is a valid file
+                (used to check of corrupt files) both extensions and
+                is_valid_file should not be passed. Defaults to None.
+
+        Raises:
+            ValueError: In case ``class_to_idx`` is empty.
+            ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
+            FileNotFoundError: In case no valid file was found for any class.
+
+        Returns:
+            List[Tuple[str, int]]: samples of a form (path_to_sample, class)
+        """
+        if class_to_idx is None:
+            # prevent potential bug since make_dataset() would use the class_to_idx logic of the
+            # find_classes() function, instead of using that of the find_classes() method, which
+            # is potentially overridden and thus could have a different logic.
+            raise ValueError("The class_to_idx parameter cannot be None.")
+        return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
+
+    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
+        """Find the class folders in a dataset structured as follows::
+
+            directory/
+            ├── class_x
+            │   ├── xxx.ext
+            │   ├── xxy.ext
+            │   └── ...
+            │       └── xxz.ext
+            └── class_y
+                ├── 123.ext
+                ├── nsdf3.ext
+                └── ...
+                └── asd932_.ext
+
+        This method can be overridden to only consider
+        a subset of classes, or to adapt to a different dataset directory structure.
+
+        Args:
+            directory(str): Root directory path, corresponding to ``self.root``
+
+        Raises:
+            FileNotFoundError: If ``dir`` has no class folders.
+
+        Returns:
+            (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
+        """
+        return find_classes(directory)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (sample, target) where target is class_index of the target class.
+        """
+        path, target = self.samples[index]
+        sample = self.loader(path)
+        if self.transform is not None:
+            sample = self.transform(sample)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return sample, target
+
+    def __len__(self) -> int:
+        return len(self.samples)
+
+
+IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
+
+
+def pil_loader(path: str) -> Image.Image:
+    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+    with open(path, "rb") as f:
+        img = Image.open(f)
+        return img.convert("RGB")
+
+
+# TODO: specify the return type
+def accimage_loader(path: str) -> Any:
+    import accimage
+
+    try:
+        return accimage.Image(path)
+    except OSError:
+        # Potentially a decoding problem, fall back to PIL.Image
+        return pil_loader(path)
+
+
+def default_loader(path: str) -> Any:
+    from torchvision import get_image_backend
+
+    if get_image_backend() == "accimage":
+        return accimage_loader(path)
+    else:
+        return pil_loader(path)
+
+
+class ImageFolder(DatasetFolder):
+    """A generic data loader where the images are arranged in this way by default: ::
+
+        root/dog/xxx.png
+        root/dog/xxy.png
+        root/dog/[...]/xxz.png
+
+        root/cat/123.png
+        root/cat/nsdf3.png
+        root/cat/[...]/asd932_.png
+
+    This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
+    the same methods can be overridden to customize the dataset.
+
+    Args:
+        root (string): Root directory path.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+        is_valid_file (callable, optional): A function that takes path of an Image file
+            and check if the file is a valid file (used to check of corrupt files)
+
+     Attributes:
+        classes (list): List of the class names sorted alphabetically.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        imgs (list): List of (image path, class_index) tuples
+    """
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+        is_valid_file: Optional[Callable[[str], bool]] = None,
+    ):
+        super().__init__(
+            root,
+            loader,
+            IMG_EXTENSIONS if is_valid_file is None else None,
+            transform=transform,
+            target_transform=target_transform,
+            is_valid_file=is_valid_file,
+        )
+        self.imgs = self.samples

+ 93 - 0
libs/vision_libs/datasets/food101.py

@@ -0,0 +1,93 @@
+import json
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple
+
+import PIL.Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Food101(VisionDataset):
+    """`The Food-101 Data Set <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_.
+
+    The Food-101 is a challenging data set of 101 food categories with 101,000 images.
+    For each class, 250 manually reviewed test images are provided as well as 750 training images.
+    On purpose, the training images were not cleaned, and thus still contain some amount of noise.
+    This comes mostly in the form of intense colors and sometimes wrong labels. All images were
+    rescaled to have a maximum side length of 512 pixels.
+
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
+        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again. Default is False.
+    """
+
+    _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
+    _MD5 = "85eeb15f3717b99a5da872d97d918f87"
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._split = verify_str_arg(split, "split", ("train", "test"))
+        self._base_folder = Path(self.root) / "food-101"
+        self._meta_folder = self._base_folder / "meta"
+        self._images_folder = self._base_folder / "images"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self._labels = []
+        self._image_files = []
+        with open(self._meta_folder / f"{split}.json") as f:
+            metadata = json.loads(f.read())
+
+        self.classes = sorted(metadata.keys())
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+
+        for class_label, im_rel_paths in metadata.items():
+            self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
+            self._image_files += [
+                self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
+            ]
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = PIL.Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"
+
+    def _check_exists(self) -> bool:
+        return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

+ 103 - 0
libs/vision_libs/datasets/gtsrb.py

@@ -0,0 +1,103 @@
+import csv
+import pathlib
+from typing import Any, Callable, Optional, Tuple
+
+import PIL
+
+from .folder import make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class GTSRB(VisionDataset):
+    """`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
+        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self._split = verify_str_arg(split, "split", ("train", "test"))
+        self._base_folder = pathlib.Path(root) / "gtsrb"
+        self._target_folder = (
+            self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
+        )
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        if self._split == "train":
+            samples = make_dataset(str(self._target_folder), extensions=(".ppm",))
+        else:
+            with open(self._base_folder / "GT-final_test.csv") as csv_file:
+                samples = [
+                    (str(self._target_folder / row["Filename"]), int(row["ClassId"]))
+                    for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
+                ]
+
+        self._samples = samples
+        self.transform = transform
+        self.target_transform = target_transform
+
+    def __len__(self) -> int:
+        return len(self._samples)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+
+        path, target = self._samples[index]
+        sample = PIL.Image.open(path).convert("RGB")
+
+        if self.transform is not None:
+            sample = self.transform(sample)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return sample, target
+
+    def _check_exists(self) -> bool:
+        return self._target_folder.is_dir()
+
+    def download(self) -> None:
+        if self._check_exists():
+            return
+
+        base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
+
+        if self._split == "train":
+            download_and_extract_archive(
+                f"{base_url}GTSRB-Training_fixed.zip",
+                download_root=str(self._base_folder),
+                md5="513f3c79a4c5141765e10e952eaa2478",
+            )
+        else:
+            download_and_extract_archive(
+                f"{base_url}GTSRB_Final_Test_Images.zip",
+                download_root=str(self._base_folder),
+                md5="c7e4e6327067d32654124b0fe9e82185",
+            )
+            download_and_extract_archive(
+                f"{base_url}GTSRB_Final_Test_GT.zip",
+                download_root=str(self._base_folder),
+                md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
+            )

+ 151 - 0
libs/vision_libs/datasets/hmdb51.py

@@ -0,0 +1,151 @@
+import glob
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+from torch import Tensor
+
+from .folder import find_classes, make_dataset
+from .video_utils import VideoClips
+from .vision import VisionDataset
+
+
+class HMDB51(VisionDataset):
+    """
+    `HMDB51 <https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/>`_
+    dataset.
+
+    HMDB51 is an action recognition video dataset.
+    This dataset consider every video as a collection of video clips of fixed size, specified
+    by ``frames_per_clip``, where the step in frames between each clip is given by
+    ``step_between_clips``.
+
+    To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+    and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+    elements will come from video 1, and the next three elements from video 2.
+    Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+    frames in a video might be present.
+
+    Internally, it uses a VideoClips object to handle clip creation.
+
+    Args:
+        root (string): Root directory of the HMDB51 Dataset.
+        annotation_path (str): Path to the folder containing the split files.
+        frames_per_clip (int): Number of frames in a clip.
+        step_between_clips (int): Number of frames between each clip.
+        fold (int, optional): Which fold to use. Should be between 1 and 3.
+        train (bool, optional): If ``True``, creates a dataset from the train split,
+            otherwise from the ``test`` split.
+        transform (callable, optional): A function/transform that takes in a TxHxWxC video
+            and returns a transformed version.
+        output_format (str, optional): The format of the output video tensors (before transforms).
+            Can be either "THWC" (default) or "TCHW".
+
+    Returns:
+        tuple: A 3-tuple with the following entries:
+
+            - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
+            - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+              and `L` is the number of points
+            - label (int): class of the video clip
+    """
+
+    data_url = "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
+    splits = {
+        "url": "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
+        "md5": "15e67781e70dcfbdce2d7dbb9b3344b5",
+    }
+    TRAIN_TAG = 1
+    TEST_TAG = 2
+
+    def __init__(
+        self,
+        root: str,
+        annotation_path: str,
+        frames_per_clip: int,
+        step_between_clips: int = 1,
+        frame_rate: Optional[int] = None,
+        fold: int = 1,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        _precomputed_metadata: Optional[Dict[str, Any]] = None,
+        num_workers: int = 1,
+        _video_width: int = 0,
+        _video_height: int = 0,
+        _video_min_dimension: int = 0,
+        _audio_samples: int = 0,
+        output_format: str = "THWC",
+    ) -> None:
+        super().__init__(root)
+        if fold not in (1, 2, 3):
+            raise ValueError(f"fold should be between 1 and 3, got {fold}")
+
+        extensions = ("avi",)
+        self.classes, class_to_idx = find_classes(self.root)
+        self.samples = make_dataset(
+            self.root,
+            class_to_idx,
+            extensions,
+        )
+
+        video_paths = [path for (path, _) in self.samples]
+        video_clips = VideoClips(
+            video_paths,
+            frames_per_clip,
+            step_between_clips,
+            frame_rate,
+            _precomputed_metadata,
+            num_workers=num_workers,
+            _video_width=_video_width,
+            _video_height=_video_height,
+            _video_min_dimension=_video_min_dimension,
+            _audio_samples=_audio_samples,
+            output_format=output_format,
+        )
+        # we bookkeep the full version of video clips because we want to be able
+        # to return the metadata of full version rather than the subset version of
+        # video clips
+        self.full_video_clips = video_clips
+        self.fold = fold
+        self.train = train
+        self.indices = self._select_fold(video_paths, annotation_path, fold, train)
+        self.video_clips = video_clips.subset(self.indices)
+        self.transform = transform
+
+    @property
+    def metadata(self) -> Dict[str, Any]:
+        return self.full_video_clips.metadata
+
+    def _select_fold(self, video_list: List[str], annotations_dir: str, fold: int, train: bool) -> List[int]:
+        target_tag = self.TRAIN_TAG if train else self.TEST_TAG
+        split_pattern_name = f"*test_split{fold}.txt"
+        split_pattern_path = os.path.join(annotations_dir, split_pattern_name)
+        annotation_paths = glob.glob(split_pattern_path)
+        selected_files = set()
+        for filepath in annotation_paths:
+            with open(filepath) as fid:
+                lines = fid.readlines()
+            for line in lines:
+                video_filename, tag_string = line.split()
+                tag = int(tag_string)
+                if tag == target_tag:
+                    selected_files.add(video_filename)
+
+        indices = []
+        for video_index, video_path in enumerate(video_list):
+            if os.path.basename(video_path) in selected_files:
+                indices.append(video_index)
+
+        return indices
+
+    def __len__(self) -> int:
+        return self.video_clips.num_clips()
+
+    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
+        video, audio, _, video_idx = self.video_clips.get_clip(idx)
+        sample_index = self.indices[video_idx]
+        _, class_index = self.samples[sample_index]
+
+        if self.transform is not None:
+            video = self.transform(video)
+
+        return video, audio, class_index

+ 218 - 0
libs/vision_libs/datasets/imagenet.py

@@ -0,0 +1,218 @@
+import os
+import shutil
+import tempfile
+from contextlib import contextmanager
+from typing import Any, Dict, Iterator, List, Optional, Tuple
+
+import torch
+
+from .folder import ImageFolder
+from .utils import check_integrity, extract_archive, verify_str_arg
+
+ARCHIVE_META = {
+    "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
+    "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
+    "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
+}
+
+META_FILE = "meta.bin"
+
+
+class ImageNet(ImageFolder):
+    """`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
+
+    .. note::
+        Before using this class, it is required to download ImageNet 2012 dataset from
+        `here <https://image-net.org/challenges/LSVRC/2012/2012-downloads.php>`_ and
+        place the files ``ILSVRC2012_devkit_t12.tar.gz`` and ``ILSVRC2012_img_train.tar``
+        or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory.
+
+    Args:
+        root (string): Root directory of the ImageNet Dataset.
+        split (string, optional): The dataset split, supports ``train``, or ``val``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+
+     Attributes:
+        classes (list): List of the class name tuples.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        wnids (list): List of the WordNet IDs.
+        wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
+        imgs (list): List of (image path, class_index) tuples
+        targets (list): The class_index value for each image in the dataset
+    """
+
+    def __init__(self, root: str, split: str = "train", **kwargs: Any) -> None:
+        root = self.root = os.path.expanduser(root)
+        self.split = verify_str_arg(split, "split", ("train", "val"))
+
+        self.parse_archives()
+        wnid_to_classes = load_meta_file(self.root)[0]
+
+        super().__init__(self.split_folder, **kwargs)
+        self.root = root
+
+        self.wnids = self.classes
+        self.wnid_to_idx = self.class_to_idx
+        self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
+        self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss}
+
+    def parse_archives(self) -> None:
+        if not check_integrity(os.path.join(self.root, META_FILE)):
+            parse_devkit_archive(self.root)
+
+        if not os.path.isdir(self.split_folder):
+            if self.split == "train":
+                parse_train_archive(self.root)
+            elif self.split == "val":
+                parse_val_archive(self.root)
+
+    @property
+    def split_folder(self) -> str:
+        return os.path.join(self.root, self.split)
+
+    def extra_repr(self) -> str:
+        return "Split: {split}".format(**self.__dict__)
+
+
+def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]:
+    if file is None:
+        file = META_FILE
+    file = os.path.join(root, file)
+
+    if check_integrity(file):
+        return torch.load(file)
+    else:
+        msg = (
+            "The meta file {} is not present in the root directory or is corrupted. "
+            "This file is automatically created by the ImageNet dataset."
+        )
+        raise RuntimeError(msg.format(file, root))
+
+
+def _verify_archive(root: str, file: str, md5: str) -> None:
+    if not check_integrity(os.path.join(root, file), md5):
+        msg = (
+            "The archive {} is not present in the root directory or is corrupted. "
+            "You need to download it externally and place it in {}."
+        )
+        raise RuntimeError(msg.format(file, root))
+
+
+def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
+    """Parse the devkit archive of the ImageNet2012 classification dataset and save
+    the meta information in a binary file.
+
+    Args:
+        root (str): Root directory containing the devkit archive
+        file (str, optional): Name of devkit archive. Defaults to
+            'ILSVRC2012_devkit_t12.tar.gz'
+    """
+    import scipy.io as sio
+
+    def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, Tuple[str, ...]]]:
+        metafile = os.path.join(devkit_root, "data", "meta.mat")
+        meta = sio.loadmat(metafile, squeeze_me=True)["synsets"]
+        nums_children = list(zip(*meta))[4]
+        meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
+        idcs, wnids, classes = list(zip(*meta))[:3]
+        classes = [tuple(clss.split(", ")) for clss in classes]
+        idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
+        wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
+        return idx_to_wnid, wnid_to_classes
+
+    def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
+        file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
+        with open(file) as txtfh:
+            val_idcs = txtfh.readlines()
+        return [int(val_idx) for val_idx in val_idcs]
+
+    @contextmanager
+    def get_tmp_dir() -> Iterator[str]:
+        tmp_dir = tempfile.mkdtemp()
+        try:
+            yield tmp_dir
+        finally:
+            shutil.rmtree(tmp_dir)
+
+    archive_meta = ARCHIVE_META["devkit"]
+    if file is None:
+        file = archive_meta[0]
+    md5 = archive_meta[1]
+
+    _verify_archive(root, file, md5)
+
+    with get_tmp_dir() as tmp_dir:
+        extract_archive(os.path.join(root, file), tmp_dir)
+
+        devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
+        idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
+        val_idcs = parse_val_groundtruth_txt(devkit_root)
+        val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
+
+        torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
+
+
+def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "train") -> None:
+    """Parse the train images archive of the ImageNet2012 classification dataset and
+    prepare it for usage with the ImageNet dataset.
+
+    Args:
+        root (str): Root directory containing the train images archive
+        file (str, optional): Name of train images archive. Defaults to
+            'ILSVRC2012_img_train.tar'
+        folder (str, optional): Optional name for train images folder. Defaults to
+            'train'
+    """
+    archive_meta = ARCHIVE_META["train"]
+    if file is None:
+        file = archive_meta[0]
+    md5 = archive_meta[1]
+
+    _verify_archive(root, file, md5)
+
+    train_root = os.path.join(root, folder)
+    extract_archive(os.path.join(root, file), train_root)
+
+    archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
+    for archive in archives:
+        extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
+
+
+def parse_val_archive(
+    root: str, file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val"
+) -> None:
+    """Parse the validation images archive of the ImageNet2012 classification dataset
+    and prepare it for usage with the ImageNet dataset.
+
+    Args:
+        root (str): Root directory containing the validation images archive
+        file (str, optional): Name of validation images archive. Defaults to
+            'ILSVRC2012_img_val.tar'
+        wnids (list, optional): List of WordNet IDs of the validation images. If None
+            is given, the IDs are loaded from the meta file in the root directory
+        folder (str, optional): Optional name for validation images folder. Defaults to
+            'val'
+    """
+    archive_meta = ARCHIVE_META["val"]
+    if file is None:
+        file = archive_meta[0]
+    md5 = archive_meta[1]
+    if wnids is None:
+        wnids = load_meta_file(root)[1]
+
+    _verify_archive(root, file, md5)
+
+    val_root = os.path.join(root, folder)
+    extract_archive(os.path.join(root, file), val_root)
+
+    images = sorted(os.path.join(val_root, image) for image in os.listdir(val_root))
+
+    for wnid in set(wnids):
+        os.mkdir(os.path.join(val_root, wnid))
+
+    for wnid, img_file in zip(wnids, images):
+        shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))

+ 104 - 0
libs/vision_libs/datasets/imagenette.py

@@ -0,0 +1,104 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple
+
+from PIL import Image
+
+from .folder import find_classes, make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Imagenette(VisionDataset):
+    """`Imagenette <https://github.com/fastai/imagenette#imagenette-1>`_ image classification dataset.
+
+    Args:
+        root (string): Root directory of the Imagenette dataset.
+        split (string, optional): The dataset split. Supports ``"train"`` (default), and ``"val"``.
+        size (string, optional): The image size. Supports ``"full"`` (default), ``"320px"``, and ``"160px"``.
+        download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
+            downloaded archives are not downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
+            version, e.g. ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+
+     Attributes:
+        classes (list): List of the class name tuples.
+        class_to_idx (dict): Dict with items (class name, class index).
+        wnids (list): List of the WordNet IDs.
+        wnid_to_idx (dict): Dict with items (WordNet ID, class index).
+    """
+
+    _ARCHIVES = {
+        "full": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz", "fe2fc210e6bb7c5664d602c3cd71e612"),
+        "320px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz", "3df6f0d01a2c9592104656642f5e78a3"),
+        "160px": ("https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz", "e793b78cc4c9e9a4ccc0c1155377a412"),
+    }
+    _WNID_TO_CLASS = {
+        "n01440764": ("tench", "Tinca tinca"),
+        "n02102040": ("English springer", "English springer spaniel"),
+        "n02979186": ("cassette player",),
+        "n03000684": ("chain saw", "chainsaw"),
+        "n03028079": ("church", "church building"),
+        "n03394916": ("French horn", "horn"),
+        "n03417042": ("garbage truck", "dustcart"),
+        "n03425413": ("gas pump", "gasoline pump", "petrol pump", "island dispenser"),
+        "n03445777": ("golf ball",),
+        "n03888257": ("parachute", "chute"),
+    }
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        size: str = "full",
+        download=False,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self._split = verify_str_arg(split, "split", ["train", "val"])
+        self._size = verify_str_arg(size, "size", ["full", "320px", "160px"])
+
+        self._url, self._md5 = self._ARCHIVES[self._size]
+        self._size_root = Path(self.root) / Path(self._url).stem
+        self._image_root = str(self._size_root / self._split)
+
+        if download:
+            self._download()
+        elif not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it.")
+
+        self.wnids, self.wnid_to_idx = find_classes(self._image_root)
+        self.classes = [self._WNID_TO_CLASS[wnid] for wnid in self.wnids]
+        self.class_to_idx = {
+            class_name: idx for wnid, idx in self.wnid_to_idx.items() for class_name in self._WNID_TO_CLASS[wnid]
+        }
+        self._samples = make_dataset(self._image_root, self.wnid_to_idx, extensions=".jpeg")
+
+    def _check_exists(self) -> bool:
+        return self._size_root.exists()
+
+    def _download(self):
+        if self._check_exists():
+            raise RuntimeError(
+                f"The directory {self._size_root} already exists. "
+                f"If you want to re-download or re-extract the images, delete the directory."
+            )
+
+        download_and_extract_archive(self._url, self.root, md5=self._md5)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        path, label = self._samples[idx]
+        image = Image.open(path).convert("RGB")
+
+        if self.transform is not None:
+            image = self.transform(image)
+
+        if self.target_transform is not None:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def __len__(self) -> int:
+        return len(self._samples)

+ 241 - 0
libs/vision_libs/datasets/inaturalist.py

@@ -0,0 +1,241 @@
+import os
+import os.path
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
+
+DATASET_URLS = {
+    "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
+    "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
+    "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
+    "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
+    "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
+    "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
+}
+
+DATASET_MD5 = {
+    "2017": "7c784ea5e424efaec655bd392f87301f",
+    "2018": "b1c6952ce38f31868cc50ea72d066cc3",
+    "2019": "c60a6e2962c9b8ccbd458d12c8582644",
+    "2021_train": "e0526d53c7f7b2e3167b2b43bb2690ed",
+    "2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
+    "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
+}
+
+
+class INaturalist(VisionDataset):
+    """`iNaturalist <https://github.com/visipedia/inat_comp>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where the image files are stored.
+            This class does not require/use annotation files.
+        version (string, optional): Which version of the dataset to download/use. One of
+            '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
+            Default: `2021_train`.
+        target_type (string or list, optional): Type of target to use, for 2021 versions, one of:
+
+            - ``full``: the full category (species)
+            - ``kingdom``: e.g. "Animalia"
+            - ``phylum``: e.g. "Arthropoda"
+            - ``class``: e.g. "Insecta"
+            - ``order``: e.g. "Coleoptera"
+            - ``family``: e.g. "Cleridae"
+            - ``genus``: e.g. "Trichodes"
+
+            for 2017-2019 versions, one of:
+
+            - ``full``: the full (numeric) category
+            - ``super``: the super category, e.g. "Amphibians"
+
+            Can also be a list to output a tuple with all specified target types.
+            Defaults to ``full``.
+        transform (callable, optional): A function/transform that takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        version: str = "2021_train",
+        target_type: Union[List[str], str] = "full",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
+
+        super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)
+
+        os.makedirs(root, exist_ok=True)
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.all_categories: List[str] = []
+
+        # map: category type -> name of category -> index
+        self.categories_index: Dict[str, Dict[str, int]] = {}
+
+        # list indexed by category id, containing mapping from category type -> index
+        self.categories_map: List[Dict[str, int]] = []
+
+        if not isinstance(target_type, list):
+            target_type = [target_type]
+        if self.version[:4] == "2021":
+            self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
+            self._init_2021()
+        else:
+            self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
+            self._init_pre2021()
+
+        # index of all files: (full category id, filename)
+        self.index: List[Tuple[int, str]] = []
+
+        for dir_index, dir_name in enumerate(self.all_categories):
+            files = os.listdir(os.path.join(self.root, dir_name))
+            for fname in files:
+                self.index.append((dir_index, fname))
+
+    def _init_2021(self) -> None:
+        """Initialize based on 2021 layout"""
+
+        self.all_categories = sorted(os.listdir(self.root))
+
+        # map: category type -> name of category -> index
+        self.categories_index = {k: {} for k in CATEGORIES_2021}
+
+        for dir_index, dir_name in enumerate(self.all_categories):
+            pieces = dir_name.split("_")
+            if len(pieces) != 8:
+                raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
+            if pieces[0] != f"{dir_index:05d}":
+                raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
+            cat_map = {}
+            for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
+                if name in self.categories_index[cat]:
+                    cat_id = self.categories_index[cat][name]
+                else:
+                    cat_id = len(self.categories_index[cat])
+                    self.categories_index[cat][name] = cat_id
+                cat_map[cat] = cat_id
+            self.categories_map.append(cat_map)
+
+    def _init_pre2021(self) -> None:
+        """Initialize based on 2017-2019 layout"""
+
+        # map: category type -> name of category -> index
+        self.categories_index = {"super": {}}
+
+        cat_index = 0
+        super_categories = sorted(os.listdir(self.root))
+        for sindex, scat in enumerate(super_categories):
+            self.categories_index["super"][scat] = sindex
+            subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
+            for subcat in subcategories:
+                if self.version == "2017":
+                    # this version does not use ids as directory names
+                    subcat_i = cat_index
+                    cat_index += 1
+                else:
+                    try:
+                        subcat_i = int(subcat)
+                    except ValueError:
+                        raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
+                if subcat_i >= len(self.categories_map):
+                    old_len = len(self.categories_map)
+                    self.categories_map.extend([{}] * (subcat_i - old_len + 1))
+                    self.all_categories.extend([""] * (subcat_i - old_len + 1))
+                if self.categories_map[subcat_i]:
+                    raise RuntimeError(f"Duplicate category {subcat}")
+                self.categories_map[subcat_i] = {"super": sindex}
+                self.all_categories[subcat_i] = os.path.join(scat, subcat)
+
+        # validate the dictionary
+        for cindex, c in enumerate(self.categories_map):
+            if not c:
+                raise RuntimeError(f"Missing category {cindex}")
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where the type of target specified by target_type.
+        """
+
+        cat_id, fname = self.index[index]
+        img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
+
+        target: Any = []
+        for t in self.target_type:
+            if t == "full":
+                target.append(cat_id)
+            else:
+                target.append(self.categories_map[cat_id][t])
+        target = tuple(target) if len(target) > 1 else target[0]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.index)
+
+    def category_name(self, category_type: str, category_id: int) -> str:
+        """
+        Args:
+            category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
+            category_id(int): an index (class id) from this category
+
+        Returns:
+            the name of the category
+        """
+        if category_type == "full":
+            return self.all_categories[category_id]
+        else:
+            if category_type not in self.categories_index:
+                raise ValueError(f"Invalid category type '{category_type}'")
+            else:
+                for name, id in self.categories_index[category_type].items():
+                    if id == category_id:
+                        return name
+                raise ValueError(f"Invalid category id {category_id} for {category_type}")
+
+    def _check_integrity(self) -> bool:
+        return os.path.exists(self.root) and len(os.listdir(self.root)) > 0
+
+    def download(self) -> None:
+        if self._check_integrity():
+            raise RuntimeError(
+                f"The directory {self.root} already exists. "
+                f"If you want to re-download or re-extract the images, delete the directory."
+            )
+
+        base_root = os.path.dirname(self.root)
+
+        download_and_extract_archive(
+            DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
+        )
+
+        orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
+        if not os.path.exists(orig_dir_name):
+            raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
+        os.rename(orig_dir_name, self.root)
+        print(f"Dataset version '{self.version}' has been downloaded and prepared for use")

+ 247 - 0
libs/vision_libs/datasets/kinetics.py

@@ -0,0 +1,247 @@
+import csv
+import os
+import time
+import urllib
+from functools import partial
+from multiprocessing import Pool
+from os import path
+from typing import Any, Callable, Dict, Optional, Tuple
+
+from torch import Tensor
+
+from .folder import find_classes, make_dataset
+from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
+from .video_utils import VideoClips
+from .vision import VisionDataset
+
+
+def _dl_wrap(tarpath: str, videopath: str, line: str) -> None:
+    download_and_extract_archive(line, tarpath, videopath)
+
+
+class Kinetics(VisionDataset):
+    """`Generic Kinetics <https://www.deepmind.com/open-source/kinetics>`_
+    dataset.
+
+    Kinetics-400/600/700 are action recognition video datasets.
+    This dataset consider every video as a collection of video clips of fixed size, specified
+    by ``frames_per_clip``, where the step in frames between each clip is given by
+    ``step_between_clips``.
+
+    To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+    and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+    elements will come from video 1, and the next three elements from video 2.
+    Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+    frames in a video might be present.
+
+    Args:
+        root (string): Root directory of the Kinetics Dataset.
+            Directory should be structured as follows:
+            .. code::
+
+                root/
+                ├── split
+                │   ├──  class1
+                │   │   ├──  vid1.mp4
+                │   │   ├──  vid2.mp4
+                │   │   ├──  vid3.mp4
+                │   │   ├──  ...
+                │   ├──  class2
+                │   │   ├──   vidx.mp4
+                │   │    └── ...
+
+            Note: split is appended automatically using the split argument.
+        frames_per_clip (int): number of frames in a clip
+        num_classes (int): select between Kinetics-400 (default), Kinetics-600, and Kinetics-700
+        split (str): split of the dataset to consider; supports ``"train"`` (default) ``"val"`` ``"test"``
+        frame_rate (float): If omitted, interpolate different frame rate for each clip.
+        step_between_clips (int): number of frames between each clip
+        transform (callable, optional): A function/transform that  takes in a TxHxWxC video
+            and returns a transformed version.
+        download (bool): Download the official version of the dataset to root folder.
+        num_workers (int): Use multiple workers for VideoClips creation
+        num_download_workers (int): Use multiprocessing in order to speed up download.
+        output_format (str, optional): The format of the output video tensors (before transforms).
+            Can be either "THWC" or "TCHW" (default).
+            Note that in most other utils and datasets, the default is actually "THWC".
+
+    Returns:
+        tuple: A 3-tuple with the following entries:
+
+            - video (Tensor[T, C, H, W] or Tensor[T, H, W, C]): the `T` video frames in torch.uint8 tensor
+            - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+              and `L` is the number of points in torch.float tensor
+            - label (int): class of the video clip
+
+    Raises:
+        RuntimeError: If ``download is True`` and the video archives are already extracted.
+    """
+
+    _TAR_URLS = {
+        "400": "https://s3.amazonaws.com/kinetics/400/{split}/k400_{split}_path.txt",
+        "600": "https://s3.amazonaws.com/kinetics/600/{split}/k600_{split}_path.txt",
+        "700": "https://s3.amazonaws.com/kinetics/700_2020/{split}/k700_2020_{split}_path.txt",
+    }
+    _ANNOTATION_URLS = {
+        "400": "https://s3.amazonaws.com/kinetics/400/annotations/{split}.csv",
+        "600": "https://s3.amazonaws.com/kinetics/600/annotations/{split}.csv",
+        "700": "https://s3.amazonaws.com/kinetics/700_2020/annotations/{split}.csv",
+    }
+
+    def __init__(
+        self,
+        root: str,
+        frames_per_clip: int,
+        num_classes: str = "400",
+        split: str = "train",
+        frame_rate: Optional[int] = None,
+        step_between_clips: int = 1,
+        transform: Optional[Callable] = None,
+        extensions: Tuple[str, ...] = ("avi", "mp4"),
+        download: bool = False,
+        num_download_workers: int = 1,
+        num_workers: int = 1,
+        _precomputed_metadata: Optional[Dict[str, Any]] = None,
+        _video_width: int = 0,
+        _video_height: int = 0,
+        _video_min_dimension: int = 0,
+        _audio_samples: int = 0,
+        _audio_channels: int = 0,
+        _legacy: bool = False,
+        output_format: str = "TCHW",
+    ) -> None:
+
+        # TODO: support test
+        self.num_classes = verify_str_arg(num_classes, arg="num_classes", valid_values=["400", "600", "700"])
+        self.extensions = extensions
+        self.num_download_workers = num_download_workers
+
+        self.root = root
+        self._legacy = _legacy
+
+        if _legacy:
+            print("Using legacy structure")
+            self.split_folder = root
+            self.split = "unknown"
+            output_format = "THWC"
+            if download:
+                raise ValueError("Cannot download the videos using legacy_structure.")
+        else:
+            self.split_folder = path.join(root, split)
+            self.split = verify_str_arg(split, arg="split", valid_values=["train", "val", "test"])
+
+        if download:
+            self.download_and_process_videos()
+
+        super().__init__(self.root)
+
+        self.classes, class_to_idx = find_classes(self.split_folder)
+        self.samples = make_dataset(self.split_folder, class_to_idx, extensions, is_valid_file=None)
+        video_list = [x[0] for x in self.samples]
+        self.video_clips = VideoClips(
+            video_list,
+            frames_per_clip,
+            step_between_clips,
+            frame_rate,
+            _precomputed_metadata,
+            num_workers=num_workers,
+            _video_width=_video_width,
+            _video_height=_video_height,
+            _video_min_dimension=_video_min_dimension,
+            _audio_samples=_audio_samples,
+            _audio_channels=_audio_channels,
+            output_format=output_format,
+        )
+        self.transform = transform
+
+    def download_and_process_videos(self) -> None:
+        """Downloads all the videos to the _root_ folder in the expected format."""
+        tic = time.time()
+        self._download_videos()
+        toc = time.time()
+        print("Elapsed time for downloading in mins ", (toc - tic) / 60)
+        self._make_ds_structure()
+        toc2 = time.time()
+        print("Elapsed time for processing in mins ", (toc2 - toc) / 60)
+        print("Elapsed time overall in mins ", (toc2 - tic) / 60)
+
+    def _download_videos(self) -> None:
+        """download tarballs containing the video to "tars" folder and extract them into the _split_ folder where
+        split is one of the official dataset splits.
+
+        Raises:
+            RuntimeError: if download folder exists, break to prevent downloading entire dataset again.
+        """
+        if path.exists(self.split_folder):
+            raise RuntimeError(
+                f"The directory {self.split_folder} already exists. "
+                f"If you want to re-download or re-extract the images, delete the directory."
+            )
+        tar_path = path.join(self.root, "tars")
+        file_list_path = path.join(self.root, "files")
+
+        split_url = self._TAR_URLS[self.num_classes].format(split=self.split)
+        split_url_filepath = path.join(file_list_path, path.basename(split_url))
+        if not check_integrity(split_url_filepath):
+            download_url(split_url, file_list_path)
+        with open(split_url_filepath) as file:
+            list_video_urls = [urllib.parse.quote(line, safe="/,:") for line in file.read().splitlines()]
+
+        if self.num_download_workers == 1:
+            for line in list_video_urls:
+                download_and_extract_archive(line, tar_path, self.split_folder)
+        else:
+            part = partial(_dl_wrap, tar_path, self.split_folder)
+            poolproc = Pool(self.num_download_workers)
+            poolproc.map(part, list_video_urls)
+
+    def _make_ds_structure(self) -> None:
+        """move videos from
+        split_folder/
+            ├── clip1.avi
+            ├── clip2.avi
+
+        to the correct format as described below:
+        split_folder/
+            ├── class1
+            │   ├── clip1.avi
+
+        """
+        annotation_path = path.join(self.root, "annotations")
+        if not check_integrity(path.join(annotation_path, f"{self.split}.csv")):
+            download_url(self._ANNOTATION_URLS[self.num_classes].format(split=self.split), annotation_path)
+        annotations = path.join(annotation_path, f"{self.split}.csv")
+
+        file_fmtstr = "{ytid}_{start:06}_{end:06}.mp4"
+        with open(annotations) as csvfile:
+            reader = csv.DictReader(csvfile)
+            for row in reader:
+                f = file_fmtstr.format(
+                    ytid=row["youtube_id"],
+                    start=int(row["time_start"]),
+                    end=int(row["time_end"]),
+                )
+                label = row["label"].replace(" ", "_").replace("'", "").replace("(", "").replace(")", "")
+                os.makedirs(path.join(self.split_folder, label), exist_ok=True)
+                downloaded_file = path.join(self.split_folder, f)
+                if path.isfile(downloaded_file):
+                    os.replace(
+                        downloaded_file,
+                        path.join(self.split_folder, label, f),
+                    )
+
+    @property
+    def metadata(self) -> Dict[str, Any]:
+        return self.video_clips.metadata
+
+    def __len__(self) -> int:
+        return self.video_clips.num_clips()
+
+    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
+        video, audio, info, video_idx = self.video_clips.get_clip(idx)
+        label = self.samples[video_idx][1]
+
+        if self.transform is not None:
+            video = self.transform(video)
+
+        return video, audio, label

+ 157 - 0
libs/vision_libs/datasets/kitti.py

@@ -0,0 +1,157 @@
+import csv
+import os
+from typing import Any, Callable, List, Optional, Tuple
+
+from PIL import Image
+
+from .utils import download_and_extract_archive
+from .vision import VisionDataset
+
+
+class Kitti(VisionDataset):
+    """`KITTI <http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark>`_ Dataset.
+
+    It corresponds to the "left color images of object" dataset, for object detection.
+
+    Args:
+        root (string): Root directory where images are downloaded to.
+            Expects the following folder structure if download=False:
+
+            .. code::
+
+                <root>
+                    └── Kitti
+                        └─ raw
+                            ├── training
+                            |   ├── image_2
+                            |   └── label_2
+                            └── testing
+                                └── image_2
+        train (bool, optional): Use ``train`` split if true, else ``test`` split.
+            Defaults to ``train``.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.PILToTensor``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample
+            and its target as entry and returns a transformed version.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/"
+    resources = [
+        "data_object_image_2.zip",
+        "data_object_label_2.zip",
+    ]
+    image_dir_name = "image_2"
+    labels_dir_name = "label_2"
+
+    def __init__(
+        self,
+        root: str,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        transforms: Optional[Callable] = None,
+        download: bool = False,
+    ):
+        super().__init__(
+            root,
+            transform=transform,
+            target_transform=target_transform,
+            transforms=transforms,
+        )
+        self.images = []
+        self.targets = []
+        self.train = train
+        self._location = "training" if self.train else "testing"
+
+        if download:
+            self.download()
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You may use download=True to download it.")
+
+        image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
+        if self.train:
+            labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name)
+        for img_file in os.listdir(image_dir):
+            self.images.append(os.path.join(image_dir, img_file))
+            if self.train:
+                self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt"))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """Get item at a given index.
+
+        Args:
+            index (int): Index
+        Returns:
+            tuple: (image, target), where
+            target is a list of dictionaries with the following keys:
+
+            - type: str
+            - truncated: float
+            - occluded: int
+            - alpha: float
+            - bbox: float[4]
+            - dimensions: float[3]
+            - locations: float[3]
+            - rotation_y: float
+
+        """
+        image = Image.open(self.images[index])
+        target = self._parse_target(index) if self.train else None
+        if self.transforms:
+            image, target = self.transforms(image, target)
+        return image, target
+
+    def _parse_target(self, index: int) -> List:
+        target = []
+        with open(self.targets[index]) as inp:
+            content = csv.reader(inp, delimiter=" ")
+            for line in content:
+                target.append(
+                    {
+                        "type": line[0],
+                        "truncated": float(line[1]),
+                        "occluded": int(line[2]),
+                        "alpha": float(line[3]),
+                        "bbox": [float(x) for x in line[4:8]],
+                        "dimensions": [float(x) for x in line[8:11]],
+                        "location": [float(x) for x in line[11:14]],
+                        "rotation_y": float(line[14]),
+                    }
+                )
+        return target
+
+    def __len__(self) -> int:
+        return len(self.images)
+
+    @property
+    def _raw_folder(self) -> str:
+        return os.path.join(self.root, self.__class__.__name__, "raw")
+
+    def _check_exists(self) -> bool:
+        """Check if the data directory exists."""
+        folders = [self.image_dir_name]
+        if self.train:
+            folders.append(self.labels_dir_name)
+        return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders)
+
+    def download(self) -> None:
+        """Download the KITTI data if it doesn't exist already."""
+
+        if self._check_exists():
+            return
+
+        os.makedirs(self._raw_folder, exist_ok=True)
+
+        # download files
+        for fname in self.resources:
+            download_and_extract_archive(
+                url=f"{self.data_url}{fname}",
+                download_root=self._raw_folder,
+                filename=fname,
+            )

+ 255 - 0
libs/vision_libs/datasets/lfw.py

@@ -0,0 +1,255 @@
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class _LFW(VisionDataset):
+
+    base_folder = "lfw-py"
+    download_url_prefix = "http://vis-www.cs.umass.edu/lfw/"
+
+    file_dict = {
+        "original": ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"),
+        "funneled": ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"),
+        "deepfunneled": ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201"),
+    }
+    checksums = {
+        "pairs.txt": "9f1ba174e4e1c508ff7cdf10ac338a7d",
+        "pairsDevTest.txt": "5132f7440eb68cf58910c8a45a2ac10b",
+        "pairsDevTrain.txt": "4f27cbf15b2da4a85c1907eb4181ad21",
+        "people.txt": "450f0863dd89e85e73936a6d71a3474b",
+        "peopleDevTest.txt": "e4bf5be0a43b5dcd9dc5ccfcb8fb19c5",
+        "peopleDevTrain.txt": "54eaac34beb6d042ed3a7d883e247a21",
+        "lfw-names.txt": "a6d0a479bd074669f656265a6e693f6d",
+    }
+    annot_file = {"10fold": "", "train": "DevTrain", "test": "DevTest"}
+    names = "lfw-names.txt"
+
+    def __init__(
+        self,
+        root: str,
+        split: str,
+        image_set: str,
+        view: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform)
+
+        self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
+        images_dir, self.filename, self.md5 = self.file_dict[self.image_set]
+
+        self.view = verify_str_arg(view.lower(), "view", ["people", "pairs"])
+        self.split = verify_str_arg(split.lower(), "split", ["10fold", "train", "test"])
+        self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt"
+        self.data: List[Any] = []
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.images_dir = os.path.join(self.root, images_dir)
+
+    def _loader(self, path: str) -> Image.Image:
+        with open(path, "rb") as f:
+            img = Image.open(f)
+            return img.convert("RGB")
+
+    def _check_integrity(self) -> bool:
+        st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
+        st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file])
+        if not st1 or not st2:
+            return False
+        if self.view == "people":
+            return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names])
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+        url = f"{self.download_url_prefix}{self.filename}"
+        download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5)
+        download_url(f"{self.download_url_prefix}{self.labels_file}", self.root)
+        if self.view == "people":
+            download_url(f"{self.download_url_prefix}{self.names}", self.root)
+
+    def _get_path(self, identity: str, no: Union[int, str]) -> str:
+        return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg")
+
+    def extra_repr(self) -> str:
+        return f"Alignment: {self.image_set}\nSplit: {self.split}"
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+
+class LFWPeople(_LFW):
+    """`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``lfw-py`` exists or will be saved to if download is set to True.
+        split (string, optional): The image split to use. Can be one of ``train``, ``test``,
+            ``10fold`` (default).
+        image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
+            ``deepfunneled``. Defaults to ``funneled``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomRotation``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "10fold",
+        image_set: str = "funneled",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, split, image_set, "people", transform, target_transform, download)
+
+        self.class_to_idx = self._get_classes()
+        self.data, self.targets = self._get_people()
+
+    def _get_people(self) -> Tuple[List[str], List[int]]:
+        data, targets = [], []
+        with open(os.path.join(self.root, self.labels_file)) as f:
+            lines = f.readlines()
+            n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0)
+
+            for fold in range(n_folds):
+                n_lines = int(lines[s])
+                people = [line.strip().split("\t") for line in lines[s + 1 : s + n_lines + 1]]
+                s += n_lines + 1
+                for i, (identity, num_imgs) in enumerate(people):
+                    for num in range(1, int(num_imgs) + 1):
+                        img = self._get_path(identity, num)
+                        data.append(img)
+                        targets.append(self.class_to_idx[identity])
+
+        return data, targets
+
+    def _get_classes(self) -> Dict[str, int]:
+        with open(os.path.join(self.root, self.names)) as f:
+            lines = f.readlines()
+            names = [line.strip().split()[0] for line in lines]
+        class_to_idx = {name: i for i, name in enumerate(names)}
+        return class_to_idx
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target) where target is the identity of the person.
+        """
+        img = self._loader(self.data[index])
+        target = self.targets[index]
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def extra_repr(self) -> str:
+        return super().extra_repr() + f"\nClasses (identities): {len(self.class_to_idx)}"
+
+
+class LFWPairs(_LFW):
+    """`LFW <http://vis-www.cs.umass.edu/lfw/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``lfw-py`` exists or will be saved to if download is set to True.
+        split (string, optional): The image split to use. Can be one of ``train``, ``test``,
+            ``10fold``. Defaults to ``10fold``.
+        image_set (str, optional): Type of image funneling to use, ``original``, ``funneled`` or
+            ``deepfunneled``. Defaults to ``funneled``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomRotation``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "10fold",
+        image_set: str = "funneled",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, split, image_set, "pairs", transform, target_transform, download)
+
+        self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
+
+    def _get_pairs(self, images_dir: str) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[int]]:
+        pair_names, data, targets = [], [], []
+        with open(os.path.join(self.root, self.labels_file)) as f:
+            lines = f.readlines()
+            if self.split == "10fold":
+                n_folds, n_pairs = lines[0].split("\t")
+                n_folds, n_pairs = int(n_folds), int(n_pairs)
+            else:
+                n_folds, n_pairs = 1, int(lines[0])
+            s = 1
+
+            for fold in range(n_folds):
+                matched_pairs = [line.strip().split("\t") for line in lines[s : s + n_pairs]]
+                unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs : s + (2 * n_pairs)]]
+                s += 2 * n_pairs
+                for pair in matched_pairs:
+                    img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1
+                    pair_names.append((pair[0], pair[0]))
+                    data.append((img1, img2))
+                    targets.append(same)
+                for pair in unmatched_pairs:
+                    img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[2], pair[3]), 0
+                    pair_names.append((pair[0], pair[2]))
+                    data.append((img1, img2))
+                    targets.append(same)
+
+        return pair_names, data, targets
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any, int]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image1, image2, target) where target is `0` for different indentities and `1` for same identities.
+        """
+        img1, img2 = self.data[index]
+        img1, img2 = self._loader(img1), self._loader(img2)
+        target = self.targets[index]
+
+        if self.transform is not None:
+            img1, img2 = self.transform(img1), self.transform(img2)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img1, img2, target

+ 167 - 0
libs/vision_libs/datasets/lsun.py

@@ -0,0 +1,167 @@
+import io
+import os.path
+import pickle
+import string
+from collections.abc import Iterable
+from typing import Any, Callable, cast, List, Optional, Tuple, Union
+
+from PIL import Image
+
+from .utils import iterable_to_str, verify_str_arg
+from .vision import VisionDataset
+
+
+class LSUNClass(VisionDataset):
+    def __init__(
+        self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
+    ) -> None:
+        import lmdb
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
+        with self.env.begin(write=False) as txn:
+            self.length = txn.stat()["entries"]
+        cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters)
+        if os.path.isfile(cache_file):
+            self.keys = pickle.load(open(cache_file, "rb"))
+        else:
+            with self.env.begin(write=False) as txn:
+                self.keys = [key for key in txn.cursor().iternext(keys=True, values=False)]
+            pickle.dump(self.keys, open(cache_file, "wb"))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        img, target = None, None
+        env = self.env
+        with env.begin(write=False) as txn:
+            imgbuf = txn.get(self.keys[index])
+
+        buf = io.BytesIO()
+        buf.write(imgbuf)
+        buf.seek(0)
+        img = Image.open(buf).convert("RGB")
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return self.length
+
+
+class LSUN(VisionDataset):
+    """`LSUN <https://www.yf.io/p/lsun>`_ dataset.
+
+    You will need to install the ``lmdb`` package to use this dataset: run
+    ``pip install lmdb``
+
+    Args:
+        root (string): Root directory for the database files.
+        classes (string or list): One of {'train', 'val', 'test'} or a list of
+            categories to load. e,g. ['bedroom_train', 'church_outdoor_train'].
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    def __init__(
+        self,
+        root: str,
+        classes: Union[str, List[str]] = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.classes = self._verify_classes(classes)
+
+        # for each class, create an LSUNClassDataset
+        self.dbs = []
+        for c in self.classes:
+            self.dbs.append(LSUNClass(root=os.path.join(root, f"{c}_lmdb"), transform=transform))
+
+        self.indices = []
+        count = 0
+        for db in self.dbs:
+            count += len(db)
+            self.indices.append(count)
+
+        self.length = count
+
+    def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]:
+        categories = [
+            "bedroom",
+            "bridge",
+            "church_outdoor",
+            "classroom",
+            "conference_room",
+            "dining_room",
+            "kitchen",
+            "living_room",
+            "restaurant",
+            "tower",
+        ]
+        dset_opts = ["train", "val", "test"]
+
+        try:
+            classes = cast(str, classes)
+            verify_str_arg(classes, "classes", dset_opts)
+            if classes == "test":
+                classes = [classes]
+            else:
+                classes = [c + "_" + classes for c in categories]
+        except ValueError:
+            if not isinstance(classes, Iterable):
+                msg = "Expected type str or Iterable for argument classes, but got type {}."
+                raise ValueError(msg.format(type(classes)))
+
+            classes = list(classes)
+            msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}."
+            for c in classes:
+                verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
+                c_short = c.split("_")
+                category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
+
+                msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
+                msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
+                verify_str_arg(category, valid_values=categories, custom_msg=msg)
+
+                msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
+                verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
+
+        return classes
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: Tuple (image, target) where target is the index of the target category.
+        """
+        target = 0
+        sub = 0
+        for ind in self.indices:
+            if index < ind:
+                break
+            target += 1
+            sub = ind
+
+        db = self.dbs[target]
+        index = index - sub
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        img, _ = db[index]
+        return img, target
+
+    def __len__(self) -> int:
+        return self.length
+
+    def extra_repr(self) -> str:
+        return "Classes: {classes}".format(**self.__dict__)

+ 558 - 0
libs/vision_libs/datasets/mnist.py

@@ -0,0 +1,558 @@
+import codecs
+import os
+import os.path
+import shutil
+import string
+import sys
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple
+from urllib.error import URLError
+
+import numpy as np
+import torch
+from PIL import Image
+
+from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class MNIST(VisionDataset):
+    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
+            and  ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
+        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
+            otherwise from ``t10k-images-idx3-ubyte``.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    mirrors = [
+        "http://yann.lecun.com/exdb/mnist/",
+        "https://ossci-datasets.s3.amazonaws.com/mnist/",
+    ]
+
+    resources = [
+        ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
+        ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
+        ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
+        ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
+    ]
+
+    training_file = "training.pt"
+    test_file = "test.pt"
+    classes = [
+        "0 - zero",
+        "1 - one",
+        "2 - two",
+        "3 - three",
+        "4 - four",
+        "5 - five",
+        "6 - six",
+        "7 - seven",
+        "8 - eight",
+        "9 - nine",
+    ]
+
+    @property
+    def train_labels(self):
+        warnings.warn("train_labels has been renamed targets")
+        return self.targets
+
+    @property
+    def test_labels(self):
+        warnings.warn("test_labels has been renamed targets")
+        return self.targets
+
+    @property
+    def train_data(self):
+        warnings.warn("train_data has been renamed data")
+        return self.data
+
+    @property
+    def test_data(self):
+        warnings.warn("test_data has been renamed data")
+        return self.data
+
+    def __init__(
+        self,
+        root: str,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.train = train  # training set or test set
+
+        if self._check_legacy_exist():
+            self.data, self.targets = self._load_legacy_data()
+            return
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self.data, self.targets = self._load_data()
+
+    def _check_legacy_exist(self):
+        processed_folder_exists = os.path.exists(self.processed_folder)
+        if not processed_folder_exists:
+            return False
+
+        return all(
+            check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
+        )
+
+    def _load_legacy_data(self):
+        # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
+        # directly.
+        data_file = self.training_file if self.train else self.test_file
+        return torch.load(os.path.join(self.processed_folder, data_file))
+
+    def _load_data(self):
+        image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
+        data = read_image_file(os.path.join(self.raw_folder, image_file))
+
+        label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
+        targets = read_label_file(os.path.join(self.raw_folder, label_file))
+
+        return data, targets
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], int(self.targets[index])
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(img.numpy(), mode="L")
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    @property
+    def raw_folder(self) -> str:
+        return os.path.join(self.root, self.__class__.__name__, "raw")
+
+    @property
+    def processed_folder(self) -> str:
+        return os.path.join(self.root, self.__class__.__name__, "processed")
+
+    @property
+    def class_to_idx(self) -> Dict[str, int]:
+        return {_class: i for i, _class in enumerate(self.classes)}
+
+    def _check_exists(self) -> bool:
+        return all(
+            check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
+            for url, _ in self.resources
+        )
+
+    def download(self) -> None:
+        """Download the MNIST data if it doesn't exist already."""
+
+        if self._check_exists():
+            return
+
+        os.makedirs(self.raw_folder, exist_ok=True)
+
+        # download files
+        for filename, md5 in self.resources:
+            for mirror in self.mirrors:
+                url = f"{mirror}{filename}"
+                try:
+                    print(f"Downloading {url}")
+                    download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
+                except URLError as error:
+                    print(f"Failed to download (trying next):\n{error}")
+                    continue
+                finally:
+                    print()
+                break
+            else:
+                raise RuntimeError(f"Error downloading {filename}")
+
+    def extra_repr(self) -> str:
+        split = "Train" if self.train is True else "Test"
+        return f"Split: {split}"
+
+
+class FashionMNIST(MNIST):
+    """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
+            and  ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
+        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
+            otherwise from ``t10k-images-idx3-ubyte``.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
+
+    resources = [
+        ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
+        ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
+        ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
+        ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
+    ]
+    classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
+
+
+class KMNIST(MNIST):
+    """`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
+            and  ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
+        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
+            otherwise from ``t10k-images-idx3-ubyte``.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
+
+    resources = [
+        ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
+        ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
+        ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
+        ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
+    ]
+    classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
+
+
+class EMNIST(MNIST):
+    """`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
+            and  ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
+        split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
+            ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
+            which one to use.
+        train (bool, optional): If True, creates dataset from ``training.pt``,
+            otherwise from ``test.pt``.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+    """
+
+    url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"
+    md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
+    splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
+    # Merged Classes assumes Same structure for both uppercase and lowercase version
+    _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
+    _all_classes = set(string.digits + string.ascii_letters)
+    classes_split_dict = {
+        "byclass": sorted(list(_all_classes)),
+        "bymerge": sorted(list(_all_classes - _merged_classes)),
+        "balanced": sorted(list(_all_classes - _merged_classes)),
+        "letters": ["N/A"] + list(string.ascii_lowercase),
+        "digits": list(string.digits),
+        "mnist": list(string.digits),
+    }
+
+    def __init__(self, root: str, split: str, **kwargs: Any) -> None:
+        self.split = verify_str_arg(split, "split", self.splits)
+        self.training_file = self._training_file(split)
+        self.test_file = self._test_file(split)
+        super().__init__(root, **kwargs)
+        self.classes = self.classes_split_dict[self.split]
+
+    @staticmethod
+    def _training_file(split) -> str:
+        return f"training_{split}.pt"
+
+    @staticmethod
+    def _test_file(split) -> str:
+        return f"test_{split}.pt"
+
+    @property
+    def _file_prefix(self) -> str:
+        return f"emnist-{self.split}-{'train' if self.train else 'test'}"
+
+    @property
+    def images_file(self) -> str:
+        return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")
+
+    @property
+    def labels_file(self) -> str:
+        return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")
+
+    def _load_data(self):
+        return read_image_file(self.images_file), read_label_file(self.labels_file)
+
+    def _check_exists(self) -> bool:
+        return all(check_integrity(file) for file in (self.images_file, self.labels_file))
+
+    def download(self) -> None:
+        """Download the EMNIST data if it doesn't exist already."""
+
+        if self._check_exists():
+            return
+
+        os.makedirs(self.raw_folder, exist_ok=True)
+
+        download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
+        gzip_folder = os.path.join(self.raw_folder, "gzip")
+        for gzip_file in os.listdir(gzip_folder):
+            if gzip_file.endswith(".gz"):
+                extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
+        shutil.rmtree(gzip_folder)
+
+
+class QMNIST(MNIST):
+    """`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset whose ``raw``
+            subdir contains binary files of the datasets.
+        what (string,optional): Can be 'train', 'test', 'test10k',
+            'test50k', or 'nist' for respectively the mnist compatible
+            training set, the 60k qmnist testing set, the 10k qmnist
+            examples that match the mnist testing set, the 50k
+            remaining qmnist testing examples, or all the nist
+            digits. The default is to select 'train' or 'test'
+            according to the compatibility argument 'train'.
+        compat (bool,optional): A boolean that says whether the target
+            for each example is class number (for compatibility with
+            the MNIST dataloader) or a torch vector containing the
+            full qmnist information. Default=True.
+        download (bool, optional): If True, downloads the dataset from
+            the internet and puts it in root directory. If dataset is
+            already downloaded, it is not downloaded again.
+        transform (callable, optional): A function/transform that
+            takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform
+            that takes in the target and transforms it.
+        train (bool,optional,compatibility): When argument 'what' is
+            not specified, this boolean decides whether to load the
+            training set or the testing set.  Default: True.
+    """
+
+    subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
+    resources: Dict[str, List[Tuple[str, str]]] = {  # type: ignore[assignment]
+        "train": [
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
+                "ed72d4157d28c017586c42bc6afe6370",
+            ),
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
+                "0058f8dd561b90ffdd0f734c6a30e5e4",
+            ),
+        ],
+        "test": [
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
+                "1394631089c404de565df7b7aeaf9412",
+            ),
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
+                "5b5b05890a5e13444e108efe57b788aa",
+            ),
+        ],
+        "nist": [
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
+                "7f124b3b8ab81486c9d8c2749c17f834",
+            ),
+            (
+                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
+                "5ed0e788978e45d4a8bd4b7caec3d79d",
+            ),
+        ],
+    }
+    classes = [
+        "0 - zero",
+        "1 - one",
+        "2 - two",
+        "3 - three",
+        "4 - four",
+        "5 - five",
+        "6 - six",
+        "7 - seven",
+        "8 - eight",
+        "9 - nine",
+    ]
+
+    def __init__(
+        self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
+    ) -> None:
+        if what is None:
+            what = "train" if train else "test"
+        self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
+        self.compat = compat
+        self.data_file = what + ".pt"
+        self.training_file = self.data_file
+        self.test_file = self.data_file
+        super().__init__(root, train, **kwargs)
+
+    @property
+    def images_file(self) -> str:
+        (url, _), _ = self.resources[self.subsets[self.what]]
+        return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
+
+    @property
+    def labels_file(self) -> str:
+        _, (url, _) = self.resources[self.subsets[self.what]]
+        return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
+
+    def _check_exists(self) -> bool:
+        return all(check_integrity(file) for file in (self.images_file, self.labels_file))
+
+    def _load_data(self):
+        data = read_sn3_pascalvincent_tensor(self.images_file)
+        if data.dtype != torch.uint8:
+            raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
+        if data.ndimension() != 3:
+            raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")
+
+        targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
+        if targets.ndimension() != 2:
+            raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")
+
+        if self.what == "test10k":
+            data = data[0:10000, :, :].clone()
+            targets = targets[0:10000, :].clone()
+        elif self.what == "test50k":
+            data = data[10000:, :, :].clone()
+            targets = targets[10000:, :].clone()
+
+        return data, targets
+
+    def download(self) -> None:
+        """Download the QMNIST data if it doesn't exist already.
+        Note that we only download what has been asked for (argument 'what').
+        """
+        if self._check_exists():
+            return
+
+        os.makedirs(self.raw_folder, exist_ok=True)
+        split = self.resources[self.subsets[self.what]]
+
+        for url, md5 in split:
+            download_and_extract_archive(url, self.raw_folder, md5=md5)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        # redefined to handle the compat flag
+        img, target = self.data[index], self.targets[index]
+        img = Image.fromarray(img.numpy(), mode="L")
+        if self.transform is not None:
+            img = self.transform(img)
+        if self.compat:
+            target = int(target[0])
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+        return img, target
+
+    def extra_repr(self) -> str:
+        return f"Split: {self.what}"
+
+
+def get_int(b: bytes) -> int:
+    return int(codecs.encode(b, "hex"), 16)
+
+
+SN3_PASCALVINCENT_TYPEMAP = {
+    8: torch.uint8,
+    9: torch.int8,
+    11: torch.int16,
+    12: torch.int32,
+    13: torch.float32,
+    14: torch.float64,
+}
+
+
+def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
+    """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
+    Argument may be a filename, compressed filename, or file object.
+    """
+    # read
+    with open(path, "rb") as f:
+        data = f.read()
+
+    # parse
+    if sys.byteorder == "little":
+        magic = get_int(data[0:4])
+        nd = magic % 256
+        ty = magic // 256
+    else:
+        nd = get_int(data[0:1])
+        ty = get_int(data[1:2]) + get_int(data[2:3]) * 256 + get_int(data[3:4]) * 256 * 256
+
+    assert 1 <= nd <= 3
+    assert 8 <= ty <= 14
+    torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
+    s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
+
+    if sys.byteorder == "big":
+        for i in range(len(s)):
+            s[i] = int.from_bytes(s[i].to_bytes(4, byteorder="little"), byteorder="big", signed=False)
+
+    parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
+
+    # The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case
+    # that is little endian and the dtype has more than one byte, we need to flip them.
+    if sys.byteorder == "little" and parsed.element_size() > 1:
+        parsed = _flip_byte_order(parsed)
+
+    assert parsed.shape[0] == np.prod(s) or not strict
+    return parsed.view(*s)
+
+
+def read_label_file(path: str) -> torch.Tensor:
+    x = read_sn3_pascalvincent_tensor(path, strict=False)
+    if x.dtype != torch.uint8:
+        raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
+    if x.ndimension() != 1:
+        raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
+    return x.long()
+
+
+def read_image_file(path: str) -> torch.Tensor:
+    x = read_sn3_pascalvincent_tensor(path, strict=False)
+    if x.dtype != torch.uint8:
+        raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
+    if x.ndimension() != 3:
+        raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
+    return x

+ 93 - 0
libs/vision_libs/datasets/moving_mnist.py

@@ -0,0 +1,93 @@
+import os.path
+from typing import Callable, Optional
+
+import numpy as np
+import torch
+from torchvision.datasets.utils import download_url, verify_str_arg
+from torchvision.datasets.vision import VisionDataset
+
+
+class MovingMNIST(VisionDataset):
+    """`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists.
+        split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``.
+            If ``split=None``, the full data is returned.
+        split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
+            frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
+            is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
+        transform (callable, optional): A function/transform that takes in an torch Tensor
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    _URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"
+
+    def __init__(
+        self,
+        root: str,
+        split: Optional[str] = None,
+        split_ratio: int = 10,
+        download: bool = False,
+        transform: Optional[Callable] = None,
+    ) -> None:
+        super().__init__(root, transform=transform)
+
+        self._base_folder = os.path.join(self.root, self.__class__.__name__)
+        self._filename = self._URL.split("/")[-1]
+
+        if split is not None:
+            verify_str_arg(split, "split", ("train", "test"))
+        self.split = split
+
+        if not isinstance(split_ratio, int):
+            raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}")
+        elif not (1 <= split_ratio <= 19):
+            raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.")
+        self.split_ratio = split_ratio
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it.")
+
+        data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
+        if self.split == "train":
+            data = data[: self.split_ratio]
+        elif self.split == "test":
+            data = data[self.split_ratio :]
+        self.data = data.transpose(0, 1).unsqueeze(2).contiguous()
+
+    def __getitem__(self, idx: int) -> torch.Tensor:
+        """
+        Args:
+            index (int): Index
+        Returns:
+            torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
+        """
+        data = self.data[idx]
+        if self.transform is not None:
+            data = self.transform(data)
+
+        return data
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    def _check_exists(self) -> bool:
+        return os.path.exists(os.path.join(self._base_folder, self._filename))
+
+    def download(self) -> None:
+        if self._check_exists():
+            return
+
+        download_url(
+            url=self._URL,
+            root=self._base_folder,
+            filename=self._filename,
+            md5="be083ec986bfe91a449d63653c411eb2",
+        )

+ 102 - 0
libs/vision_libs/datasets/omniglot.py

@@ -0,0 +1,102 @@
+from os.path import join
+from typing import Any, Callable, List, Optional, Tuple
+
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, list_dir, list_files
+from .vision import VisionDataset
+
+
+class Omniglot(VisionDataset):
+    """`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``omniglot-py`` exists.
+        background (bool, optional): If True, creates dataset from the "background" set, otherwise
+            creates from the "evaluation" set. This terminology is defined by the authors.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset zip files from the internet and
+            puts it in root directory. If the zip files are already downloaded, they are not
+            downloaded again.
+    """
+
+    folder = "omniglot-py"
+    download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
+    zips_md5 = {
+        "images_background": "68d2efa1b9178cc56df9314c21c6e718",
+        "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
+    }
+
+    def __init__(
+        self,
+        root: str,
+        background: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
+        self.background = background
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        self.target_folder = join(self.root, self._get_target_folder())
+        self._alphabets = list_dir(self.target_folder)
+        self._characters: List[str] = sum(
+            ([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
+        )
+        self._character_images = [
+            [(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
+            for idx, character in enumerate(self._characters)
+        ]
+        self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
+
+    def __len__(self) -> int:
+        return len(self._flat_character_images)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target character class.
+        """
+        image_name, character_class = self._flat_character_images[index]
+        image_path = join(self.target_folder, self._characters[character_class], image_name)
+        image = Image.open(image_path, mode="r").convert("L")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            character_class = self.target_transform(character_class)
+
+        return image, character_class
+
+    def _check_integrity(self) -> bool:
+        zip_filename = self._get_target_folder()
+        if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
+            return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        filename = self._get_target_folder()
+        zip_filename = filename + ".zip"
+        url = self.download_url_prefix + "/" + zip_filename
+        download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
+
+    def _get_target_folder(self) -> str:
+        return "images_background" if self.background else "images_evaluation"

+ 125 - 0
libs/vision_libs/datasets/oxford_iiit_pet.py

@@ -0,0 +1,125 @@
+import os
+import os.path
+import pathlib
+from typing import Any, Callable, Optional, Sequence, Tuple, Union
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class OxfordIIITPet(VisionDataset):
+    """`Oxford-IIIT Pet Dataset   <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_.
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"trainval"`` (default) or ``"test"``.
+        target_types (string, sequence of strings, optional): Types of target to use. Can be ``category`` (default) or
+            ``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent:
+
+                - ``category`` (int): Label for one of the 37 pet categories.
+                - ``segmentation`` (PIL image): Segmentation trimap of the image.
+
+            If empty, ``None`` will be returned as target.
+
+        transform (callable, optional): A function/transform that  takes in a PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and puts it into
+            ``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again.
+    """
+
+    _RESOURCES = (
+        ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"),
+        ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"),
+    )
+    _VALID_TARGET_TYPES = ("category", "segmentation")
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "trainval",
+        target_types: Union[Sequence[str], str] = "category",
+        transforms: Optional[Callable] = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ):
+        self._split = verify_str_arg(split, "split", ("trainval", "test"))
+        if isinstance(target_types, str):
+            target_types = [target_types]
+        self._target_types = [
+            verify_str_arg(target_type, "target_types", self._VALID_TARGET_TYPES) for target_type in target_types
+        ]
+
+        super().__init__(root, transforms=transforms, transform=transform, target_transform=target_transform)
+        self._base_folder = pathlib.Path(self.root) / "oxford-iiit-pet"
+        self._images_folder = self._base_folder / "images"
+        self._anns_folder = self._base_folder / "annotations"
+        self._segs_folder = self._anns_folder / "trimaps"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        image_ids = []
+        self._labels = []
+        with open(self._anns_folder / f"{self._split}.txt") as file:
+            for line in file:
+                image_id, label, *_ = line.strip().split()
+                image_ids.append(image_id)
+                self._labels.append(int(label) - 1)
+
+        self.classes = [
+            " ".join(part.title() for part in raw_cls.split("_"))
+            for raw_cls, _ in sorted(
+                {(image_id.rsplit("_", 1)[0], label) for image_id, label in zip(image_ids, self._labels)},
+                key=lambda image_id_and_label: image_id_and_label[1],
+            )
+        ]
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+
+        self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids]
+        self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids]
+
+    def __len__(self) -> int:
+        return len(self._images)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image = Image.open(self._images[idx]).convert("RGB")
+
+        target: Any = []
+        for target_type in self._target_types:
+            if target_type == "category":
+                target.append(self._labels[idx])
+            else:  # target_type == "segmentation"
+                target.append(Image.open(self._segs[idx]))
+
+        if not target:
+            target = None
+        elif len(target) == 1:
+            target = target[0]
+        else:
+            target = tuple(target)
+
+        if self.transforms:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def _check_exists(self) -> bool:
+        for folder in (self._images_folder, self._anns_folder):
+            if not (os.path.exists(folder) and os.path.isdir(folder)):
+                return False
+        else:
+            return True
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+
+        for url, md5 in self._RESOURCES:
+            download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5)

+ 134 - 0
libs/vision_libs/datasets/pcam.py

@@ -0,0 +1,134 @@
+import pathlib
+from typing import Any, Callable, Optional, Tuple
+
+from PIL import Image
+
+from .utils import _decompress, download_file_from_google_drive, verify_str_arg
+from .vision import VisionDataset
+
+
+class PCAM(VisionDataset):
+    """`PCAM Dataset   <https://github.com/basveeling/pcam>`_.
+
+    The PatchCamelyon dataset is a binary classification dataset with 327,680
+    color images (96px x 96px), extracted from histopathologic scans of lymph node
+    sections. Each image is annotated with a binary label indicating presence of
+    metastatic tissue.
+
+    This dataset requires the ``h5py`` package which you can install with ``pip install h5py``.
+
+    Args:
+         root (string): Root directory of the dataset.
+         split (string, optional): The dataset split, supports ``"train"`` (default), ``"test"`` or ``"val"``.
+         transform (callable, optional): A function/transform that  takes in a PIL image and returns a transformed
+             version. E.g, ``transforms.RandomCrop``.
+         target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+         download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
+             dataset is already downloaded, it is not downloaded again.
+
+             .. warning::
+
+                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
+    """
+
+    _FILES = {
+        "train": {
+            "images": (
+                "camelyonpatch_level_2_split_train_x.h5",  # Data file name
+                "1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2",  # Google Drive ID
+                "1571f514728f59376b705fc836ff4b63",  # md5 hash
+            ),
+            "targets": (
+                "camelyonpatch_level_2_split_train_y.h5",
+                "1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
+                "35c2d7259d906cfc8143347bb8e05be7",
+            ),
+        },
+        "test": {
+            "images": (
+                "camelyonpatch_level_2_split_test_x.h5",
+                "1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
+                "d8c2d60d490dbd479f8199bdfa0cf6ec",
+            ),
+            "targets": (
+                "camelyonpatch_level_2_split_test_y.h5",
+                "17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
+                "60a7035772fbdb7f34eb86d4420cf66a",
+            ),
+        },
+        "val": {
+            "images": (
+                "camelyonpatch_level_2_split_valid_x.h5",
+                "1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
+                "d5b63470df7cfa627aeec8b9dc0c066e",
+            ),
+            "targets": (
+                "camelyonpatch_level_2_split_valid_y.h5",
+                "1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
+                "2b85f58b927af9964a4c15b8f7e8f179",
+            ),
+        },
+    }
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ):
+        try:
+            import h5py
+
+            self.h5py = h5py
+        except ImportError:
+            raise RuntimeError(
+                "h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
+            )
+
+        self._split = verify_str_arg(split, "split", ("train", "test", "val"))
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._base_folder = pathlib.Path(self.root) / "pcam"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+    def __len__(self) -> int:
+        images_file = self._FILES[self._split]["images"][0]
+        with self.h5py.File(self._base_folder / images_file) as images_data:
+            return images_data["x"].shape[0]
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        images_file = self._FILES[self._split]["images"][0]
+        with self.h5py.File(self._base_folder / images_file) as images_data:
+            image = Image.fromarray(images_data["x"][idx]).convert("RGB")
+
+        targets_file = self._FILES[self._split]["targets"][0]
+        with self.h5py.File(self._base_folder / targets_file) as targets_data:
+            target = int(targets_data["y"][idx, 0, 0, 0])  # shape is [num_images, 1, 1, 1]
+
+        if self.transform:
+            image = self.transform(image)
+        if self.target_transform:
+            target = self.target_transform(target)
+
+        return image, target
+
+    def _check_exists(self) -> bool:
+        images_file = self._FILES[self._split]["images"][0]
+        targets_file = self._FILES[self._split]["targets"][0]
+        return all(self._base_folder.joinpath(h5_file).exists() for h5_file in (images_file, targets_file))
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+
+        for file_name, file_id, md5 in self._FILES[self._split].values():
+            archive_name = file_name + ".gz"
+            download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5)
+            _decompress(str(self._base_folder / archive_name))

+ 228 - 0
libs/vision_libs/datasets/phototour.py

@@ -0,0 +1,228 @@
+import os
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from .utils import download_url
+from .vision import VisionDataset
+
+
+class PhotoTour(VisionDataset):
+    """`Multi-view Stereo Correspondence <http://matthewalunbrown.com/patchdata/patchdata.html>`_ Dataset.
+
+    .. note::
+
+        We only provide the newer version of the dataset, since the authors state that it
+
+            is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the
+            patches are centred on real interest point detections, rather than being projections of 3D points as is the
+            case in the old dataset.
+
+        The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm.
+
+
+    Args:
+        root (string): Root directory where images are.
+        name (string): Name of the dataset to load.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    urls = {
+        "notredame_harris": [
+            "http://matthewalunbrown.com/patchdata/notredame_harris.zip",
+            "notredame_harris.zip",
+            "69f8c90f78e171349abdf0307afefe4d",
+        ],
+        "yosemite_harris": [
+            "http://matthewalunbrown.com/patchdata/yosemite_harris.zip",
+            "yosemite_harris.zip",
+            "a73253d1c6fbd3ba2613c45065c00d46",
+        ],
+        "liberty_harris": [
+            "http://matthewalunbrown.com/patchdata/liberty_harris.zip",
+            "liberty_harris.zip",
+            "c731fcfb3abb4091110d0ae8c7ba182c",
+        ],
+        "notredame": [
+            "http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip",
+            "notredame.zip",
+            "509eda8535847b8c0a90bbb210c83484",
+        ],
+        "yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"],
+        "liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"],
+    }
+    means = {
+        "notredame": 0.4854,
+        "yosemite": 0.4844,
+        "liberty": 0.4437,
+        "notredame_harris": 0.4854,
+        "yosemite_harris": 0.4844,
+        "liberty_harris": 0.4437,
+    }
+    stds = {
+        "notredame": 0.1864,
+        "yosemite": 0.1818,
+        "liberty": 0.2019,
+        "notredame_harris": 0.1864,
+        "yosemite_harris": 0.1818,
+        "liberty_harris": 0.2019,
+    }
+    lens = {
+        "notredame": 468159,
+        "yosemite": 633587,
+        "liberty": 450092,
+        "liberty_harris": 379587,
+        "yosemite_harris": 450912,
+        "notredame_harris": 325295,
+    }
+    image_ext = "bmp"
+    info_file = "info.txt"
+    matches_files = "m50_100000_100000_0.txt"
+
+    def __init__(
+        self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False
+    ) -> None:
+        super().__init__(root, transform=transform)
+        self.name = name
+        self.data_dir = os.path.join(self.root, name)
+        self.data_down = os.path.join(self.root, f"{name}.zip")
+        self.data_file = os.path.join(self.root, f"{name}.pt")
+
+        self.train = train
+        self.mean = self.means[name]
+        self.std = self.stds[name]
+
+        if download:
+            self.download()
+
+        if not self._check_datafile_exists():
+            self.cache()
+
+        # load the serialized data
+        self.data, self.labels, self.matches = torch.load(self.data_file)
+
+    def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (data1, data2, matches)
+        """
+        if self.train:
+            data = self.data[index]
+            if self.transform is not None:
+                data = self.transform(data)
+            return data
+        m = self.matches[index]
+        data1, data2 = self.data[m[0]], self.data[m[1]]
+        if self.transform is not None:
+            data1 = self.transform(data1)
+            data2 = self.transform(data2)
+        return data1, data2, m[2]
+
+    def __len__(self) -> int:
+        return len(self.data if self.train else self.matches)
+
+    def _check_datafile_exists(self) -> bool:
+        return os.path.exists(self.data_file)
+
+    def _check_downloaded(self) -> bool:
+        return os.path.exists(self.data_dir)
+
+    def download(self) -> None:
+        if self._check_datafile_exists():
+            print(f"# Found cached data {self.data_file}")
+            return
+
+        if not self._check_downloaded():
+            # download files
+            url = self.urls[self.name][0]
+            filename = self.urls[self.name][1]
+            md5 = self.urls[self.name][2]
+            fpath = os.path.join(self.root, filename)
+
+            download_url(url, self.root, filename, md5)
+
+            print(f"# Extracting data {self.data_down}\n")
+
+            import zipfile
+
+            with zipfile.ZipFile(fpath, "r") as z:
+                z.extractall(self.data_dir)
+
+            os.unlink(fpath)
+
+    def cache(self) -> None:
+        # process and save as torch files
+        print(f"# Caching data {self.data_file}")
+
+        dataset = (
+            read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
+            read_info_file(self.data_dir, self.info_file),
+            read_matches_files(self.data_dir, self.matches_files),
+        )
+
+        with open(self.data_file, "wb") as f:
+            torch.save(dataset, f)
+
+    def extra_repr(self) -> str:
+        split = "Train" if self.train is True else "Test"
+        return f"Split: {split}"
+
+
+def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
+    """Return a Tensor containing the patches"""
+
+    def PIL2array(_img: Image.Image) -> np.ndarray:
+        """Convert PIL image type to numpy 2D array"""
+        return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)
+
+    def find_files(_data_dir: str, _image_ext: str) -> List[str]:
+        """Return a list with the file names of the images containing the patches"""
+        files = []
+        # find those files with the specified extension
+        for file_dir in os.listdir(_data_dir):
+            if file_dir.endswith(_image_ext):
+                files.append(os.path.join(_data_dir, file_dir))
+        return sorted(files)  # sort files in ascend order to keep relations
+
+    patches = []
+    list_files = find_files(data_dir, image_ext)
+
+    for fpath in list_files:
+        img = Image.open(fpath)
+        for y in range(0, img.height, 64):
+            for x in range(0, img.width, 64):
+                patch = img.crop((x, y, x + 64, y + 64))
+                patches.append(PIL2array(patch))
+    return torch.ByteTensor(np.array(patches[:n]))
+
+
+def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
+    """Return a Tensor containing the list of labels
+    Read the file and keep only the ID of the 3D point.
+    """
+    with open(os.path.join(data_dir, info_file)) as f:
+        labels = [int(line.split()[0]) for line in f]
+    return torch.LongTensor(labels)
+
+
+def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
+    """Return a Tensor containing the ground truth matches
+    Read the file and keep only 3D point ID.
+    Matches are represented with a 1, non matches with a 0.
+    """
+    matches = []
+    with open(os.path.join(data_dir, matches_file)) as f:
+        for line in f:
+            line_split = line.split()
+            matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
+    return torch.LongTensor(matches)

+ 170 - 0
libs/vision_libs/datasets/places365.py

@@ -0,0 +1,170 @@
+import os
+from os import path
+from typing import Any, Callable, Dict, List, Optional, Tuple
+from urllib.parse import urljoin
+
+from .folder import default_loader
+from .utils import check_integrity, download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class Places365(VisionDataset):
+    r"""`Places365 <http://places2.csail.mit.edu/index.html>`_ classification dataset.
+
+    Args:
+        root (string): Root directory of the Places365 dataset.
+        split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
+            ``val``.
+        small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the
+            high resolution ones.
+        download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
+            downloaded archives are not downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+
+     Attributes:
+        classes (list): List of the class names.
+        class_to_idx (dict): Dict with items (class_name, class_index).
+        imgs (list): List of (image path, class_index) tuples
+        targets (list): The class_index value for each image in the dataset
+
+    Raises:
+        RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted.
+        RuntimeError: If ``download is True`` and the image archive is already extracted.
+    """
+    _SPLITS = ("train-standard", "train-challenge", "val")
+    _BASE_URL = "http://data.csail.mit.edu/places/places365/"
+    # {variant: (archive, md5)}
+    _DEVKIT_META = {
+        "standard": ("filelist_places365-standard.tar", "35a0585fee1fa656440f3ab298f8479c"),
+        "challenge": ("filelist_places365-challenge.tar", "70a8307e459c3de41690a7c76c931734"),
+    }
+    # (file, md5)
+    _CATEGORIES_META = ("categories_places365.txt", "06c963b85866bd0649f97cb43dd16673")
+    # {split: (file, md5)}
+    _FILE_LIST_META = {
+        "train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
+        "train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
+        "val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
+    }
+    # {(split, small): (file, md5)}
+    _IMAGES_META = {
+        ("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
+        ("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
+        ("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
+        ("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
+        ("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
+        ("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
+    }
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train-standard",
+        small: bool = False,
+        download: bool = False,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        loader: Callable[[str], Any] = default_loader,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self.split = self._verify_split(split)
+        self.small = small
+        self.loader = loader
+
+        self.classes, self.class_to_idx = self.load_categories(download)
+        self.imgs, self.targets = self.load_file_list(download)
+
+        if download:
+            self.download_images()
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        file, target = self.imgs[index]
+        image = self.loader(file)
+
+        if self.transforms is not None:
+            image, target = self.transforms(image, target)
+
+        return image, target
+
+    def __len__(self) -> int:
+        return len(self.imgs)
+
+    @property
+    def variant(self) -> str:
+        return "challenge" if "challenge" in self.split else "standard"
+
+    @property
+    def images_dir(self) -> str:
+        size = "256" if self.small else "large"
+        if self.split.startswith("train"):
+            dir = f"data_{size}_{self.variant}"
+        else:
+            dir = f"{self.split}_{size}"
+        return path.join(self.root, dir)
+
+    def load_categories(self, download: bool = True) -> Tuple[List[str], Dict[str, int]]:
+        def process(line: str) -> Tuple[str, int]:
+            cls, idx = line.split()
+            return cls, int(idx)
+
+        file, md5 = self._CATEGORIES_META
+        file = path.join(self.root, file)
+        if not self._check_integrity(file, md5, download):
+            self.download_devkit()
+
+        with open(file) as fh:
+            class_to_idx = dict(process(line) for line in fh)
+
+        return sorted(class_to_idx.keys()), class_to_idx
+
+    def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]:
+        def process(line: str, sep="/") -> Tuple[str, int]:
+            image, idx = line.split()
+            return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx)
+
+        file, md5 = self._FILE_LIST_META[self.split]
+        file = path.join(self.root, file)
+        if not self._check_integrity(file, md5, download):
+            self.download_devkit()
+
+        with open(file) as fh:
+            images = [process(line) for line in fh]
+
+        _, targets = zip(*images)
+        return images, list(targets)
+
+    def download_devkit(self) -> None:
+        file, md5 = self._DEVKIT_META[self.variant]
+        download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
+
+    def download_images(self) -> None:
+        if path.exists(self.images_dir):
+            raise RuntimeError(
+                f"The directory {self.images_dir} already exists. If you want to re-download or re-extract the images, "
+                f"delete the directory."
+            )
+
+        file, md5 = self._IMAGES_META[(self.split, self.small)]
+        download_and_extract_archive(urljoin(self._BASE_URL, file), self.root, md5=md5)
+
+        if self.split.startswith("train"):
+            os.rename(self.images_dir.rsplit("_", 1)[0], self.images_dir)
+
+    def extra_repr(self) -> str:
+        return "\n".join(("Split: {split}", "Small: {small}")).format(**self.__dict__)
+
+    def _verify_split(self, split: str) -> str:
+        return verify_str_arg(split, "split", self._SPLITS)
+
+    def _check_integrity(self, file: str, md5: str, download: bool) -> bool:
+        integrity = check_integrity(file, md5=md5)
+        if not integrity and not download:
+            raise RuntimeError(
+                f"The file {file} does not exist or is corrupted. You can set download=True to download it."
+            )
+        return integrity

+ 86 - 0
libs/vision_libs/datasets/rendered_sst2.py

@@ -0,0 +1,86 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple
+
+import PIL.Image
+
+from .folder import make_dataset
+from .utils import download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class RenderedSST2(VisionDataset):
+    """`The Rendered SST2 Dataset <https://github.com/openai/CLIP/blob/main/data/rendered-sst2.md>`_.
+
+    Rendered SST2 is an image classification dataset used to evaluate the models capability on optical
+    character recognition. This dataset was generated by rendering sentences in the Standford Sentiment
+    Treebank v2 dataset.
+
+    This dataset contains two classes (positive and negative) and is divided in three splits: a  train
+    split containing 6920 images (3610 positive and 3310 negative), a validation split containing 872 images
+    (444 positive and 428 negative), and a test split containing 1821 images (909 positive and 912 negative).
+
+    Args:
+        root (string): Root directory of the dataset.
+        split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``.
+        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again. Default is False.
+    """
+
+    _URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz"
+    _MD5 = "2384d08e9dcfa4bd55b324e610496ee5"
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._split = verify_str_arg(split, "split", ("train", "val", "test"))
+        self._split_to_folder = {"train": "train", "val": "valid", "test": "test"}
+        self._base_folder = Path(self.root) / "rendered-sst2"
+        self.classes = ["negative", "positive"]
+        self.class_to_idx = {"negative": 0, "positive": 1}
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self._samples = make_dataset(str(self._base_folder / self._split_to_folder[self._split]), extensions=("png",))
+
+    def __len__(self) -> int:
+        return len(self._samples)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file, label = self._samples[idx]
+        image = PIL.Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def extra_repr(self) -> str:
+        return f"split={self._split}"
+
+    def _check_exists(self) -> bool:
+        for class_label in set(self.classes):
+            if not (self._base_folder / self._split_to_folder[self._split] / class_label).is_dir():
+                return False
+        return True
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)

+ 3 - 0
libs/vision_libs/datasets/samplers/__init__.py

@@ -0,0 +1,3 @@
+from .clip_sampler import DistributedSampler, RandomClipSampler, UniformClipSampler
+
+__all__ = ("DistributedSampler", "UniformClipSampler", "RandomClipSampler")

+ 172 - 0
libs/vision_libs/datasets/samplers/clip_sampler.py

@@ -0,0 +1,172 @@
+import math
+from typing import cast, Iterator, List, Optional, Sized, Union
+
+import torch
+import torch.distributed as dist
+from torch.utils.data import Sampler
+from torchvision.datasets.video_utils import VideoClips
+
+
+class DistributedSampler(Sampler):
+    """
+    Extension of DistributedSampler, as discussed in
+    https://github.com/pytorch/pytorch/issues/23430
+
+    Example:
+        dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
+        num_replicas: 4
+        shuffle: False
+
+    when group_size = 1
+            RANK    |  shard_dataset
+            =========================
+            rank_0  |  [0, 4, 8, 12]
+            rank_1  |  [1, 5, 9, 13]
+            rank_2  |  [2, 6, 10, 0]
+            rank_3  |  [3, 7, 11, 1]
+
+    when group_size = 2
+
+            RANK    |  shard_dataset
+            =========================
+            rank_0  |  [0, 1, 8, 9]
+            rank_1  |  [2, 3, 10, 11]
+            rank_2  |  [4, 5, 12, 13]
+            rank_3  |  [6, 7, 0, 1]
+
+    """
+
+    def __init__(
+        self,
+        dataset: Sized,
+        num_replicas: Optional[int] = None,
+        rank: Optional[int] = None,
+        shuffle: bool = False,
+        group_size: int = 1,
+    ) -> None:
+        if num_replicas is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            num_replicas = dist.get_world_size()
+        if rank is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            rank = dist.get_rank()
+        if len(dataset) % group_size != 0:
+            raise ValueError(
+                f"dataset length must be a multiplier of group size dataset length: {len(dataset)}, group size: {group_size}"
+            )
+        self.dataset = dataset
+        self.group_size = group_size
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.epoch = 0
+        dataset_group_length = len(dataset) // group_size
+        self.num_group_samples = int(math.ceil(dataset_group_length * 1.0 / self.num_replicas))
+        self.num_samples = self.num_group_samples * group_size
+        self.total_size = self.num_samples * self.num_replicas
+        self.shuffle = shuffle
+
+    def __iter__(self) -> Iterator[int]:
+        # deterministically shuffle based on epoch
+        g = torch.Generator()
+        g.manual_seed(self.epoch)
+        indices: Union[torch.Tensor, List[int]]
+        if self.shuffle:
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = list(range(len(self.dataset)))
+
+        # add extra samples to make it evenly divisible
+        indices += indices[: (self.total_size - len(indices))]
+        assert len(indices) == self.total_size
+
+        total_group_size = self.total_size // self.group_size
+        indices = torch.reshape(torch.LongTensor(indices), (total_group_size, self.group_size))
+
+        # subsample
+        indices = indices[self.rank : total_group_size : self.num_replicas, :]
+        indices = torch.reshape(indices, (-1,)).tolist()
+        assert len(indices) == self.num_samples
+
+        if isinstance(self.dataset, Sampler):
+            orig_indices = list(iter(self.dataset))
+            indices = [orig_indices[i] for i in indices]
+
+        return iter(indices)
+
+    def __len__(self) -> int:
+        return self.num_samples
+
+    def set_epoch(self, epoch: int) -> None:
+        self.epoch = epoch
+
+
+class UniformClipSampler(Sampler):
+    """
+    Sample `num_video_clips_per_video` clips for each video, equally spaced.
+    When number of unique clips in the video is fewer than num_video_clips_per_video,
+    repeat the clips until `num_video_clips_per_video` clips are collected
+
+    Args:
+        video_clips (VideoClips): video clips to sample from
+        num_clips_per_video (int): number of clips to be sampled per video
+    """
+
+    def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
+        if not isinstance(video_clips, VideoClips):
+            raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
+        self.video_clips = video_clips
+        self.num_clips_per_video = num_clips_per_video
+
+    def __iter__(self) -> Iterator[int]:
+        idxs = []
+        s = 0
+        # select num_clips_per_video for each video, uniformly spaced
+        for c in self.video_clips.clips:
+            length = len(c)
+            if length == 0:
+                # corner case where video decoding fails
+                continue
+
+            sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64)
+            s += length
+            idxs.append(sampled)
+        return iter(cast(List[int], torch.cat(idxs).tolist()))
+
+    def __len__(self) -> int:
+        return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0)
+
+
+class RandomClipSampler(Sampler):
+    """
+    Samples at most `max_video_clips_per_video` clips for each video randomly
+
+    Args:
+        video_clips (VideoClips): video clips to sample from
+        max_clips_per_video (int): maximum number of clips to be sampled per video
+    """
+
+    def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
+        if not isinstance(video_clips, VideoClips):
+            raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
+        self.video_clips = video_clips
+        self.max_clips_per_video = max_clips_per_video
+
+    def __iter__(self) -> Iterator[int]:
+        idxs = []
+        s = 0
+        # select at most max_clips_per_video for each video, randomly
+        for c in self.video_clips.clips:
+            length = len(c)
+            size = min(length, self.max_clips_per_video)
+            sampled = torch.randperm(length)[:size] + s
+            s += length
+            idxs.append(sampled)
+        idxs_ = torch.cat(idxs)
+        # shuffle all clips randomly
+        perm = torch.randperm(len(idxs_))
+        return iter(idxs_[perm].tolist())
+
+    def __len__(self) -> int:
+        return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)

+ 123 - 0
libs/vision_libs/datasets/sbd.py

@@ -0,0 +1,123 @@
+import os
+import shutil
+from typing import Any, Callable, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+
+from .utils import download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class SBDataset(VisionDataset):
+    """`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_
+
+    The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
+
+    .. note ::
+
+        Please note that the train and val splits included with this dataset are different from
+        the splits in the PASCAL VOC dataset. In particular some "train" images might be part of
+        VOC2012 val.
+        If you are interested in testing on VOC 2012 val, then use `image_set='train_noval'`,
+        which excludes all val images.
+
+    .. warning::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+    Args:
+        root (string): Root directory of the Semantic Boundaries Dataset
+        image_set (string, optional): Select the image_set to use, ``train``, ``val`` or ``train_noval``.
+            Image set ``train_noval`` excludes VOC 2012 val images.
+        mode (string, optional): Select target type. Possible values 'boundaries' or 'segmentation'.
+            In case of 'boundaries', the target is an array of shape `[num_classes, H, W]`,
+            where `num_classes=20`.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version. Input sample is PIL image and target is a numpy array
+            if `mode='boundaries'` or PIL image if `mode='segmentation'`.
+    """
+
+    url = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"
+    md5 = "82b4d87ceb2ed10f6038a1cba92111cb"
+    filename = "benchmark.tgz"
+
+    voc_train_url = "http://home.bharathh.info/pubs/codes/SBD/train_noval.txt"
+    voc_split_filename = "train_noval.txt"
+    voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"
+
+    def __init__(
+        self,
+        root: str,
+        image_set: str = "train",
+        mode: str = "boundaries",
+        download: bool = False,
+        transforms: Optional[Callable] = None,
+    ) -> None:
+
+        try:
+            from scipy.io import loadmat
+
+            self._loadmat = loadmat
+        except ImportError:
+            raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
+
+        super().__init__(root, transforms)
+        self.image_set = verify_str_arg(image_set, "image_set", ("train", "val", "train_noval"))
+        self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
+        self.num_classes = 20
+
+        sbd_root = self.root
+        image_dir = os.path.join(sbd_root, "img")
+        mask_dir = os.path.join(sbd_root, "cls")
+
+        if download:
+            download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
+            extracted_ds_root = os.path.join(self.root, "benchmark_RELEASE", "dataset")
+            for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
+                old_path = os.path.join(extracted_ds_root, f)
+                shutil.move(old_path, sbd_root)
+            download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5)
+
+        if not os.path.isdir(sbd_root):
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        split_f = os.path.join(sbd_root, image_set.rstrip("\n") + ".txt")
+
+        with open(os.path.join(split_f)) as fh:
+            file_names = [x.strip() for x in fh.readlines()]
+
+        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+        self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
+
+        self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target
+
+    def _get_segmentation_target(self, filepath: str) -> Image.Image:
+        mat = self._loadmat(filepath)
+        return Image.fromarray(mat["GTcls"][0]["Segmentation"][0])
+
+    def _get_boundaries_target(self, filepath: str) -> np.ndarray:
+        mat = self._loadmat(filepath)
+        return np.concatenate(
+            [np.expand_dims(mat["GTcls"][0]["Boundaries"][0][i][0].toarray(), axis=0) for i in range(self.num_classes)],
+            axis=0,
+        )
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        img = Image.open(self.images[index]).convert("RGB")
+        target = self._get_target(self.masks[index])
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.images)
+
+    def extra_repr(self) -> str:
+        lines = ["Image set: {image_set}", "Mode: {mode}"]
+        return "\n".join(lines).format(**self.__dict__)

+ 109 - 0
libs/vision_libs/datasets/sbu.py

@@ -0,0 +1,109 @@
+import os
+from typing import Any, Callable, Optional, Tuple
+
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, download_url
+from .vision import VisionDataset
+
+
+class SBU(VisionDataset):
+    """`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where tarball
+            ``SBUCaptionedPhotoDataset.tar.gz`` exists.
+        transform (callable, optional): A function/transform that takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    url = "https://www.cs.rice.edu/~vo9/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
+    filename = "SBUCaptionedPhotoDataset.tar.gz"
+    md5_checksum = "9aec147b3488753cf758b4d493422285"
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = True,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        # Read the caption for each photo
+        self.photos = []
+        self.captions = []
+
+        file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")
+        file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt")
+
+        for line1, line2 in zip(open(file1), open(file2)):
+            url = line1.rstrip()
+            photo = os.path.basename(url)
+            filename = os.path.join(self.root, "dataset", photo)
+            if os.path.exists(filename):
+                caption = line2.rstrip()
+                self.photos.append(photo)
+                self.captions.append(caption)
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is a caption for the photo.
+        """
+        filename = os.path.join(self.root, "dataset", self.photos[index])
+        img = Image.open(filename).convert("RGB")
+        if self.transform is not None:
+            img = self.transform(img)
+
+        target = self.captions[index]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        """The number of photos in the dataset."""
+        return len(self.photos)
+
+    def _check_integrity(self) -> bool:
+        """Check the md5 checksum of the downloaded tarball."""
+        root = self.root
+        fpath = os.path.join(root, self.filename)
+        if not check_integrity(fpath, self.md5_checksum):
+            return False
+        return True
+
+    def download(self) -> None:
+        """Download and extract the tarball, and download each individual photo."""
+
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        download_and_extract_archive(self.url, self.root, self.root, self.filename, self.md5_checksum)
+
+        # Download individual photos
+        with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
+            for line in fh:
+                url = line.rstrip()
+                try:
+                    download_url(url, os.path.join(self.root, "dataset"))
+                except OSError:
+                    # The images point to public images on Flickr.
+                    # Note: Images might be removed by users at anytime.
+                    pass

+ 91 - 0
libs/vision_libs/datasets/semeion.py

@@ -0,0 +1,91 @@
+import os.path
+from typing import Any, Callable, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_url
+from .vision import VisionDataset
+
+
+class SEMEION(VisionDataset):
+    r"""`SEMEION <http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``semeion.py`` exists.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+    url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data"
+    filename = "semeion.data"
+    md5_checksum = "cb545d371d2ce14ec121470795a77432"
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = True,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        fp = os.path.join(self.root, self.filename)
+        data = np.loadtxt(fp)
+        # convert value to 8 bit unsigned integer
+        # color (white #255) the pixels
+        self.data = (data[:, :256] * 255).astype("uint8")
+        self.data = np.reshape(self.data, (-1, 16, 16))
+        self.labels = np.nonzero(data[:, 256:])[1]
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], int(self.labels[index])
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(img, mode="L")
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    def _check_integrity(self) -> bool:
+        root = self.root
+        fpath = os.path.join(root, self.filename)
+        if not check_integrity(fpath, self.md5_checksum):
+            return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        root = self.root
+        download_url(self.url, root, self.filename, self.md5_checksum)

+ 121 - 0
libs/vision_libs/datasets/stanford_cars.py

@@ -0,0 +1,121 @@
+import pathlib
+from typing import Any, Callable, Optional, Tuple
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class StanfordCars(VisionDataset):
+    """`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
+
+    The Cars dataset contains 16,185 images of 196 classes of cars. The data is
+    split into 8,144 training images and 8,041 testing images, where each class
+    has been split roughly in a 50-50 split
+
+    .. note::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
+
+    Args:
+        root (string): Root directory of dataset
+        split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If True, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again."""
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+
+        try:
+            import scipy.io as sio
+        except ImportError:
+            raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
+
+        super().__init__(root, transform=transform, target_transform=target_transform)
+
+        self._split = verify_str_arg(split, "split", ("train", "test"))
+        self._base_folder = pathlib.Path(root) / "stanford_cars"
+        devkit = self._base_folder / "devkit"
+
+        if self._split == "train":
+            self._annotations_mat_path = devkit / "cars_train_annos.mat"
+            self._images_base_path = self._base_folder / "cars_train"
+        else:
+            self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
+            self._images_base_path = self._base_folder / "cars_test"
+
+        if download:
+            self.download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        self._samples = [
+            (
+                str(self._images_base_path / annotation["fname"]),
+                annotation["class"] - 1,  # Original target mapping  starts from 1, hence -1
+            )
+            for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
+        ]
+
+        self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
+        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
+
+    def __len__(self) -> int:
+        return len(self._samples)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        """Returns pil_image and class_id for given index"""
+        image_path, target = self._samples[idx]
+        pil_image = Image.open(image_path).convert("RGB")
+
+        if self.transform is not None:
+            pil_image = self.transform(pil_image)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+        return pil_image, target
+
+    def download(self) -> None:
+        if self._check_exists():
+            return
+
+        download_and_extract_archive(
+            url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
+            download_root=str(self._base_folder),
+            md5="c3b158d763b6e2245038c8ad08e45376",
+        )
+        if self._split == "train":
+            download_and_extract_archive(
+                url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
+                download_root=str(self._base_folder),
+                md5="065e5b463ae28d29e77c1b4b166cfe61",
+            )
+        else:
+            download_and_extract_archive(
+                url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
+                download_root=str(self._base_folder),
+                md5="4ce7ebf6a94d07f1952d94dd34c4d501",
+            )
+            download_url(
+                url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
+                root=str(self._base_folder),
+                md5="b0a2b23655a3edd16d84508592a98d10",
+            )
+
+    def _check_exists(self) -> bool:
+        if not (self._base_folder / "devkit").is_dir():
+            return False
+
+        return self._annotations_mat_path.exists() and self._images_base_path.is_dir()

+ 174 - 0
libs/vision_libs/datasets/stl10.py

@@ -0,0 +1,174 @@
+import os.path
+from typing import Any, Callable, cast, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_and_extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class STL10(VisionDataset):
+    """`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.
+
+    Args:
+        root (string): Root directory of dataset where directory
+            ``stl10_binary`` exists.
+        split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
+            Accordingly, dataset is selected.
+        folds (int, optional): One of {0-9} or None.
+            For training, loads one of the 10 pre-defined folds of 1k samples for the
+            standard evaluation procedure. If no value is passed, loads the 5k samples.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    base_folder = "stl10_binary"
+    url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
+    filename = "stl10_binary.tar.gz"
+    tgz_md5 = "91f7769df0f17e558f3565bffb0c7dfb"
+    class_names_file = "class_names.txt"
+    folds_list_file = "fold_indices.txt"
+    train_list = [
+        ["train_X.bin", "918c2871b30a85fa023e0c44e0bee87f"],
+        ["train_y.bin", "5a34089d4802c674881badbb80307741"],
+        ["unlabeled_X.bin", "5242ba1fed5e4be9e1e742405eb56ca4"],
+    ]
+
+    test_list = [["test_X.bin", "7f263ba9f9e0b06b93213547f721ac82"], ["test_y.bin", "36f9794fa4beb8a2c72628de14fa638e"]]
+    splits = ("train", "train+unlabeled", "unlabeled", "test")
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        folds: Optional[int] = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.split = verify_str_arg(split, "split", self.splits)
+        self.folds = self._verify_folds(folds)
+
+        if download:
+            self.download()
+        elif not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        # now load the picked numpy arrays
+        self.labels: Optional[np.ndarray]
+        if self.split == "train":
+            self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
+            self.labels = cast(np.ndarray, self.labels)
+            self.__load_folds(folds)
+
+        elif self.split == "train+unlabeled":
+            self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
+            self.labels = cast(np.ndarray, self.labels)
+            self.__load_folds(folds)
+            unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
+            self.data = np.concatenate((self.data, unlabeled_data))
+            self.labels = np.concatenate((self.labels, np.asarray([-1] * unlabeled_data.shape[0])))
+
+        elif self.split == "unlabeled":
+            self.data, _ = self.__loadfile(self.train_list[2][0])
+            self.labels = np.asarray([-1] * self.data.shape[0])
+        else:  # self.split == 'test':
+            self.data, self.labels = self.__loadfile(self.test_list[0][0], self.test_list[1][0])
+
+        class_file = os.path.join(self.root, self.base_folder, self.class_names_file)
+        if os.path.isfile(class_file):
+            with open(class_file) as f:
+                self.classes = f.read().splitlines()
+
+    def _verify_folds(self, folds: Optional[int]) -> Optional[int]:
+        if folds is None:
+            return folds
+        elif isinstance(folds, int):
+            if folds in range(10):
+                return folds
+            msg = "Value for argument folds should be in the range [0, 10), but got {}."
+            raise ValueError(msg.format(folds))
+        else:
+            msg = "Expected type None or int for argument folds, but got type {}."
+            raise ValueError(msg.format(type(folds)))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        target: Optional[int]
+        if self.labels is not None:
+            img, target = self.data[index], int(self.labels[index])
+        else:
+            img, target = self.data[index], None
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(np.transpose(img, (1, 2, 0)))
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return self.data.shape[0]
+
+    def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
+        labels = None
+        if labels_file:
+            path_to_labels = os.path.join(self.root, self.base_folder, labels_file)
+            with open(path_to_labels, "rb") as f:
+                labels = np.fromfile(f, dtype=np.uint8) - 1  # 0-based
+
+        path_to_data = os.path.join(self.root, self.base_folder, data_file)
+        with open(path_to_data, "rb") as f:
+            # read whole file in uint8 chunks
+            everything = np.fromfile(f, dtype=np.uint8)
+            images = np.reshape(everything, (-1, 3, 96, 96))
+            images = np.transpose(images, (0, 1, 3, 2))
+
+        return images, labels
+
+    def _check_integrity(self) -> bool:
+        for filename, md5 in self.train_list + self.test_list:
+            fpath = os.path.join(self.root, self.base_folder, filename)
+            if not check_integrity(fpath, md5):
+                return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
+        self._check_integrity()
+
+    def extra_repr(self) -> str:
+        return "Split: {split}".format(**self.__dict__)
+
+    def __load_folds(self, folds: Optional[int]) -> None:
+        # loads one of the folds if specified
+        if folds is None:
+            return
+        path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file)
+        with open(path_to_folds) as f:
+            str_idx = f.read().splitlines()[folds]
+            list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ")
+            self.data = self.data[list_idx, :, :, :]
+            if self.labels is not None:
+                self.labels = self.labels[list_idx]

+ 76 - 0
libs/vision_libs/datasets/sun397.py

@@ -0,0 +1,76 @@
+from pathlib import Path
+from typing import Any, Callable, Optional, Tuple
+
+import PIL.Image
+
+from .utils import download_and_extract_archive
+from .vision import VisionDataset
+
+
+class SUN397(VisionDataset):
+    """`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.
+
+    The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
+    397 categories with 108'754 images.
+
+    Args:
+        root (string): Root directory of the dataset.
+        transform (callable, optional): A function/transform that  takes in an PIL image and returns a transformed
+            version. E.g, ``transforms.RandomCrop``.
+        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+    """
+
+    _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
+    _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
+
+    def __init__(
+        self,
+        root: str,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self._data_dir = Path(self.root) / "SUN397"
+
+        if download:
+            self._download()
+
+        if not self._check_exists():
+            raise RuntimeError("Dataset not found. You can use download=True to download it")
+
+        with open(self._data_dir / "ClassName.txt") as f:
+            self.classes = [c[3:].strip() for c in f]
+
+        self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
+        self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
+
+        self._labels = [
+            self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
+        ]
+
+    def __len__(self) -> int:
+        return len(self._image_files)
+
+    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
+        image_file, label = self._image_files[idx], self._labels[idx]
+        image = PIL.Image.open(image_file).convert("RGB")
+
+        if self.transform:
+            image = self.transform(image)
+
+        if self.target_transform:
+            label = self.target_transform(label)
+
+        return image, label
+
+    def _check_exists(self) -> bool:
+        return self._data_dir.is_dir()
+
+    def _download(self) -> None:
+        if self._check_exists():
+            return
+        download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)

+ 129 - 0
libs/vision_libs/datasets/svhn.py

@@ -0,0 +1,129 @@
+import os.path
+from typing import Any, Callable, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+
+from .utils import check_integrity, download_url, verify_str_arg
+from .vision import VisionDataset
+
+
+class SVHN(VisionDataset):
+    """`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
+    Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
+    we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
+    expect the class labels to be in the range `[0, C-1]`
+
+    .. warning::
+
+        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load data from `.mat` format.
+
+    Args:
+        root (string): Root directory of the dataset where the data is stored.
+        split (string): One of {'train', 'test', 'extra'}.
+            Accordingly dataset is selected. 'extra' is Extra training set.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    split_list = {
+        "train": [
+            "http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
+            "train_32x32.mat",
+            "e26dedcc434d2e4c54c9b2d4a06d8373",
+        ],
+        "test": [
+            "http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
+            "test_32x32.mat",
+            "eb5a983be6a315427106f1b164d9cef3",
+        ],
+        "extra": [
+            "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
+            "extra_32x32.mat",
+            "a93ce644f1a588dc4d68dda5feec44a7",
+        ],
+    }
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
+        self.url = self.split_list[split][0]
+        self.filename = self.split_list[split][1]
+        self.file_md5 = self.split_list[split][2]
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        # import here rather than at top of file because this is
+        # an optional dependency for torchvision
+        import scipy.io as sio
+
+        # reading(loading) mat file as array
+        loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
+
+        self.data = loaded_mat["X"]
+        # loading from the .mat file gives an np.ndarray of type np.uint8
+        # converting to np.int64, so that we have a LongTensor after
+        # the conversion from the numpy array
+        # the squeeze is needed to obtain a 1D tensor
+        self.labels = loaded_mat["y"].astype(np.int64).squeeze()
+
+        # the svhn dataset assigns the class label "10" to the digit 0
+        # this makes it inconsistent with several loss functions
+        # which expect the class labels to be in the range [0, C-1]
+        np.place(self.labels, self.labels == 10, 0)
+        self.data = np.transpose(self.data, (3, 2, 0, 1))
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], int(self.labels[index])
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(np.transpose(img, (1, 2, 0)))
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)
+
+    def _check_integrity(self) -> bool:
+        root = self.root
+        md5 = self.split_list[self.split][2]
+        fpath = os.path.join(root, self.filename)
+        return check_integrity(fpath, md5)
+
+    def download(self) -> None:
+        md5 = self.split_list[self.split][2]
+        download_url(self.url, self.root, self.filename, md5)
+
+    def extra_repr(self) -> str:
+        return "Split: {split}".format(**self.__dict__)

+ 130 - 0
libs/vision_libs/datasets/ucf101.py

@@ -0,0 +1,130 @@
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+from torch import Tensor
+
+from .folder import find_classes, make_dataset
+from .video_utils import VideoClips
+from .vision import VisionDataset
+
+
+class UCF101(VisionDataset):
+    """
+    `UCF101 <https://www.crcv.ucf.edu/data/UCF101.php>`_ dataset.
+
+    UCF101 is an action recognition video dataset.
+    This dataset consider every video as a collection of video clips of fixed size, specified
+    by ``frames_per_clip``, where the step in frames between each clip is given by
+    ``step_between_clips``. The dataset itself can be downloaded from the dataset website;
+    annotations that ``annotation_path`` should be pointing to can be downloaded from `here
+    <https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip>`_.
+
+    To give an example, for 2 videos with 10 and 15 frames respectively, if ``frames_per_clip=5``
+    and ``step_between_clips=5``, the dataset size will be (2 + 3) = 5, where the first two
+    elements will come from video 1, and the next three elements from video 2.
+    Note that we drop clips which do not have exactly ``frames_per_clip`` elements, so not all
+    frames in a video might be present.
+
+    Internally, it uses a VideoClips object to handle clip creation.
+
+    Args:
+        root (string): Root directory of the UCF101 Dataset.
+        annotation_path (str): path to the folder containing the split files;
+            see docstring above for download instructions of these files
+        frames_per_clip (int): number of frames in a clip.
+        step_between_clips (int, optional): number of frames between each clip.
+        fold (int, optional): which fold to use. Should be between 1 and 3.
+        train (bool, optional): if ``True``, creates a dataset from the train split,
+            otherwise from the ``test`` split.
+        transform (callable, optional): A function/transform that  takes in a TxHxWxC video
+            and returns a transformed version.
+        output_format (str, optional): The format of the output video tensors (before transforms).
+            Can be either "THWC" (default) or "TCHW".
+
+    Returns:
+        tuple: A 3-tuple with the following entries:
+
+            - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
+            -  audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
+               and `L` is the number of points
+            - label (int): class of the video clip
+    """
+
+    def __init__(
+        self,
+        root: str,
+        annotation_path: str,
+        frames_per_clip: int,
+        step_between_clips: int = 1,
+        frame_rate: Optional[int] = None,
+        fold: int = 1,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        _precomputed_metadata: Optional[Dict[str, Any]] = None,
+        num_workers: int = 1,
+        _video_width: int = 0,
+        _video_height: int = 0,
+        _video_min_dimension: int = 0,
+        _audio_samples: int = 0,
+        output_format: str = "THWC",
+    ) -> None:
+        super().__init__(root)
+        if not 1 <= fold <= 3:
+            raise ValueError(f"fold should be between 1 and 3, got {fold}")
+
+        extensions = ("avi",)
+        self.fold = fold
+        self.train = train
+
+        self.classes, class_to_idx = find_classes(self.root)
+        self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
+        video_list = [x[0] for x in self.samples]
+        video_clips = VideoClips(
+            video_list,
+            frames_per_clip,
+            step_between_clips,
+            frame_rate,
+            _precomputed_metadata,
+            num_workers=num_workers,
+            _video_width=_video_width,
+            _video_height=_video_height,
+            _video_min_dimension=_video_min_dimension,
+            _audio_samples=_audio_samples,
+            output_format=output_format,
+        )
+        # we bookkeep the full version of video clips because we want to be able
+        # to return the metadata of full version rather than the subset version of
+        # video clips
+        self.full_video_clips = video_clips
+        self.indices = self._select_fold(video_list, annotation_path, fold, train)
+        self.video_clips = video_clips.subset(self.indices)
+        self.transform = transform
+
+    @property
+    def metadata(self) -> Dict[str, Any]:
+        return self.full_video_clips.metadata
+
+    def _select_fold(self, video_list: List[str], annotation_path: str, fold: int, train: bool) -> List[int]:
+        name = "train" if train else "test"
+        name = f"{name}list{fold:02d}.txt"
+        f = os.path.join(annotation_path, name)
+        selected_files = set()
+        with open(f) as fid:
+            data = fid.readlines()
+            data = [x.strip().split(" ")[0] for x in data]
+            data = [os.path.join(self.root, *x.split("/")) for x in data]
+            selected_files.update(data)
+        indices = [i for i in range(len(video_list)) if video_list[i] in selected_files]
+        return indices
+
+    def __len__(self) -> int:
+        return self.video_clips.num_clips()
+
+    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
+        video, audio, info, video_idx = self.video_clips.get_clip(idx)
+        label = self.samples[self.indices[video_idx]][1]
+
+        if self.transform is not None:
+            video = self.transform(video)
+
+        return video, audio, label

+ 95 - 0
libs/vision_libs/datasets/usps.py

@@ -0,0 +1,95 @@
+import os
+from typing import Any, Callable, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+
+from .utils import download_url
+from .vision import VisionDataset
+
+
+class USPS(VisionDataset):
+    """`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset.
+    The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``.
+    The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
+    and make pixel values in ``[0, 255]``.
+
+    Args:
+        root (string): Root directory of dataset to store``USPS`` data files.
+        train (bool, optional): If True, creates dataset from ``usps.bz2``,
+            otherwise from ``usps.t.bz2``.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+    """
+
+    split_list = {
+        "train": [
+            "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
+            "usps.bz2",
+            "ec16c51db3855ca6c91edd34d0e9b197",
+        ],
+        "test": [
+            "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
+            "usps.t.bz2",
+            "8ea070ee2aca1ac39742fdd1ef5ed118",
+        ],
+    }
+
+    def __init__(
+        self,
+        root: str,
+        train: bool = True,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(root, transform=transform, target_transform=target_transform)
+        split = "train" if train else "test"
+        url, filename, checksum = self.split_list[split]
+        full_path = os.path.join(self.root, filename)
+
+        if download and not os.path.exists(full_path):
+            download_url(url, self.root, filename, md5=checksum)
+
+        import bz2
+
+        with bz2.open(full_path) as fp:
+            raw_data = [line.decode().split() for line in fp.readlines()]
+            tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
+            imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
+            imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
+            targets = [int(d[0]) - 1 for d in raw_data]
+
+        self.data = imgs
+        self.targets = targets
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is index of the target class.
+        """
+        img, target = self.data[index], int(self.targets[index])
+
+        # doing this so that it is consistent with all other datasets
+        # to return a PIL Image
+        img = Image.fromarray(img, mode="L")
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.data)

+ 459 - 0
libs/vision_libs/datasets/utils.py

@@ -0,0 +1,459 @@
+import bz2
+import gzip
+import hashlib
+import lzma
+import os
+import os.path
+import pathlib
+import re
+import sys
+import tarfile
+import urllib
+import urllib.error
+import urllib.request
+import zipfile
+from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
+from urllib.parse import urlparse
+
+import numpy as np
+import torch
+from torch.utils.model_zoo import tqdm
+
+from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available
+
+USER_AGENT = "pytorch/vision"
+
+
+def _save_response_content(
+    content: Iterator[bytes],
+    destination: str,
+    length: Optional[int] = None,
+) -> None:
+    with open(destination, "wb") as fh, tqdm(total=length) as pbar:
+        for chunk in content:
+            # filter out keep-alive new chunks
+            if not chunk:
+                continue
+
+            fh.write(chunk)
+            pbar.update(len(chunk))
+
+
+def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
+    with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
+        _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
+
+
+def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
+    # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
+    # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
+    # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere.
+    if sys.version_info >= (3, 9):
+        md5 = hashlib.md5(usedforsecurity=False)
+    else:
+        md5 = hashlib.md5()
+    with open(fpath, "rb") as f:
+        while chunk := f.read(chunk_size):
+            md5.update(chunk)
+    return md5.hexdigest()
+
+
+def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
+    return md5 == calculate_md5(fpath, **kwargs)
+
+
+def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
+    if not os.path.isfile(fpath):
+        return False
+    if md5 is None:
+        return True
+    return check_md5(fpath, md5)
+
+
+def _get_redirect_url(url: str, max_hops: int = 3) -> str:
+    initial_url = url
+    headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
+
+    for _ in range(max_hops + 1):
+        with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
+            if response.url == url or response.url is None:
+                return url
+
+            url = response.url
+    else:
+        raise RecursionError(
+            f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}."
+        )
+
+
+def _get_google_drive_file_id(url: str) -> Optional[str]:
+    parts = urlparse(url)
+
+    if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
+        return None
+
+    match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
+    if match is None:
+        return None
+
+    return match.group("id")
+
+
+def download_url(
+    url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3
+) -> None:
+    """Download a file from a url and place it in root.
+
+    Args:
+        url (str): URL to download file from
+        root (str): Directory to place downloaded file in
+        filename (str, optional): Name to save the file under. If None, use the basename of the URL
+        md5 (str, optional): MD5 checksum of the download. If None, do not check
+        max_redirect_hops (int, optional): Maximum number of redirect hops allowed
+    """
+    root = os.path.expanduser(root)
+    if not filename:
+        filename = os.path.basename(url)
+    fpath = os.path.join(root, filename)
+
+    os.makedirs(root, exist_ok=True)
+
+    # check if file is already present locally
+    if check_integrity(fpath, md5):
+        print("Using downloaded and verified file: " + fpath)
+        return
+
+    if _is_remote_location_available():
+        _download_file_from_remote_location(fpath, url)
+    else:
+        # expand redirect chain if needed
+        url = _get_redirect_url(url, max_hops=max_redirect_hops)
+
+        # check if file is located on Google Drive
+        file_id = _get_google_drive_file_id(url)
+        if file_id is not None:
+            return download_file_from_google_drive(file_id, root, filename, md5)
+
+        # download the file
+        try:
+            print("Downloading " + url + " to " + fpath)
+            _urlretrieve(url, fpath)
+        except (urllib.error.URLError, OSError) as e:  # type: ignore[attr-defined]
+            if url[:5] == "https":
+                url = url.replace("https:", "http:")
+                print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
+                _urlretrieve(url, fpath)
+            else:
+                raise e
+
+    # check integrity of downloaded file
+    if not check_integrity(fpath, md5):
+        raise RuntimeError("File not found or corrupted.")
+
+
+def list_dir(root: str, prefix: bool = False) -> List[str]:
+    """List all directories at a given root
+
+    Args:
+        root (str): Path to directory whose folders need to be listed
+        prefix (bool, optional): If true, prepends the path to each result, otherwise
+            only returns the name of the directories found
+    """
+    root = os.path.expanduser(root)
+    directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))]
+    if prefix is True:
+        directories = [os.path.join(root, d) for d in directories]
+    return directories
+
+
+def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
+    """List all files ending with a suffix at a given root
+
+    Args:
+        root (str): Path to directory whose folders need to be listed
+        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
+            It uses the Python "str.endswith" method and is passed directly
+        prefix (bool, optional): If true, prepends the path to each result, otherwise
+            only returns the name of the files found
+    """
+    root = os.path.expanduser(root)
+    files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)]
+    if prefix is True:
+        files = [os.path.join(root, d) for d in files]
+    return files
+
+
+def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
+    """Download a Google Drive file from  and place it in root.
+
+    Args:
+        file_id (str): id of file to be downloaded
+        root (str): Directory to place downloaded file in
+        filename (str, optional): Name to save the file under. If None, use the id of the file.
+        md5 (str, optional): MD5 checksum of the download. If None, do not check
+    """
+    try:
+        import gdown
+    except ModuleNotFoundError:
+        raise RuntimeError(
+            "To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
+        )
+
+    root = os.path.expanduser(root)
+    if not filename:
+        filename = file_id
+    fpath = os.path.join(root, filename)
+
+    os.makedirs(root, exist_ok=True)
+
+    if check_integrity(fpath, md5):
+        print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
+        return
+
+    gdown.download(id=file_id, output=fpath, quiet=False, user_agent=USER_AGENT)
+
+    if not check_integrity(fpath, md5):
+        raise RuntimeError("File not found or corrupted.")
+
+
+def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
+    with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
+        tar.extractall(to_path)
+
+
+_ZIP_COMPRESSION_MAP: Dict[str, int] = {
+    ".bz2": zipfile.ZIP_BZIP2,
+    ".xz": zipfile.ZIP_LZMA,
+}
+
+
+def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None:
+    with zipfile.ZipFile(
+        from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED
+    ) as zip:
+        zip.extractall(to_path)
+
+
+_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = {
+    ".tar": _extract_tar,
+    ".zip": _extract_zip,
+}
+_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = {
+    ".bz2": bz2.open,
+    ".gz": gzip.open,
+    ".xz": lzma.open,
+}
+_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {
+    ".tbz": (".tar", ".bz2"),
+    ".tbz2": (".tar", ".bz2"),
+    ".tgz": (".tar", ".gz"),
+}
+
+
+def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
+    """Detect the archive type and/or compression of a file.
+
+    Args:
+        file (str): the filename
+
+    Returns:
+        (tuple): tuple of suffix, archive type, and compression
+
+    Raises:
+        RuntimeError: if file has no suffix or suffix is not supported
+    """
+    suffixes = pathlib.Path(file).suffixes
+    if not suffixes:
+        raise RuntimeError(
+            f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
+        )
+    suffix = suffixes[-1]
+
+    # check if the suffix is a known alias
+    if suffix in _FILE_TYPE_ALIASES:
+        return (suffix, *_FILE_TYPE_ALIASES[suffix])
+
+    # check if the suffix is an archive type
+    if suffix in _ARCHIVE_EXTRACTORS:
+        return suffix, suffix, None
+
+    # check if the suffix is a compression
+    if suffix in _COMPRESSED_FILE_OPENERS:
+        # check for suffix hierarchy
+        if len(suffixes) > 1:
+            suffix2 = suffixes[-2]
+
+            # check if the suffix2 is an archive type
+            if suffix2 in _ARCHIVE_EXTRACTORS:
+                return suffix2 + suffix, suffix2, suffix
+
+        return suffix, None, suffix
+
+    valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
+    raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
+
+
+def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
+    r"""Decompress a file.
+
+    The compression is automatically detected from the file name.
+
+    Args:
+        from_path (str): Path to the file to be decompressed.
+        to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used.
+        remove_finished (bool): If ``True``, remove the file after the extraction.
+
+    Returns:
+        (str): Path to the decompressed file.
+    """
+    suffix, archive_type, compression = _detect_file_type(from_path)
+    if not compression:
+        raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.")
+
+    if to_path is None:
+        to_path = from_path.replace(suffix, archive_type if archive_type is not None else "")
+
+    # We don't need to check for a missing key here, since this was already done in _detect_file_type()
+    compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression]
+
+    with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh:
+        wfh.write(rfh.read())
+
+    if remove_finished:
+        os.remove(from_path)
+
+    return to_path
+
+
+def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
+    """Extract an archive.
+
+    The archive type and a possible compression is automatically detected from the file name. If the file is compressed
+    but not an archive the call is dispatched to :func:`decompress`.
+
+    Args:
+        from_path (str): Path to the file to be extracted.
+        to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is
+            used.
+        remove_finished (bool): If ``True``, remove the file after the extraction.
+
+    Returns:
+        (str): Path to the directory the file was extracted to.
+    """
+    if to_path is None:
+        to_path = os.path.dirname(from_path)
+
+    suffix, archive_type, compression = _detect_file_type(from_path)
+    if not archive_type:
+        return _decompress(
+            from_path,
+            os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")),
+            remove_finished=remove_finished,
+        )
+
+    # We don't need to check for a missing key here, since this was already done in _detect_file_type()
+    extractor = _ARCHIVE_EXTRACTORS[archive_type]
+
+    extractor(from_path, to_path, compression)
+    if remove_finished:
+        os.remove(from_path)
+
+    return to_path
+
+
+def download_and_extract_archive(
+    url: str,
+    download_root: str,
+    extract_root: Optional[str] = None,
+    filename: Optional[str] = None,
+    md5: Optional[str] = None,
+    remove_finished: bool = False,
+) -> None:
+    download_root = os.path.expanduser(download_root)
+    if extract_root is None:
+        extract_root = download_root
+    if not filename:
+        filename = os.path.basename(url)
+
+    download_url(url, download_root, filename, md5)
+
+    archive = os.path.join(download_root, filename)
+    print(f"Extracting {archive} to {extract_root}")
+    extract_archive(archive, extract_root, remove_finished)
+
+
+def iterable_to_str(iterable: Iterable) -> str:
+    return "'" + "', '".join([str(item) for item in iterable]) + "'"
+
+
+T = TypeVar("T", str, bytes)
+
+
+def verify_str_arg(
+    value: T,
+    arg: Optional[str] = None,
+    valid_values: Optional[Iterable[T]] = None,
+    custom_msg: Optional[str] = None,
+) -> T:
+    if not isinstance(value, str):
+        if arg is None:
+            msg = "Expected type str, but got type {type}."
+        else:
+            msg = "Expected type str for argument {arg}, but got type {type}."
+        msg = msg.format(type=type(value), arg=arg)
+        raise ValueError(msg)
+
+    if valid_values is None:
+        return value
+
+    if value not in valid_values:
+        if custom_msg is not None:
+            msg = custom_msg
+        else:
+            msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
+            msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
+        raise ValueError(msg)
+
+    return value
+
+
+def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
+    """Read file in .pfm format. Might contain either 1 or 3 channels of data.
+
+    Args:
+        file_name (str): Path to the file.
+        slice_channels (int): Number of channels to slice out of the file.
+            Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
+    """
+
+    with open(file_name, "rb") as f:
+        header = f.readline().rstrip()
+        if header not in [b"PF", b"Pf"]:
+            raise ValueError("Invalid PFM file")
+
+        dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
+        if not dim_match:
+            raise Exception("Malformed PFM header.")
+        w, h = (int(dim) for dim in dim_match.groups())
+
+        scale = float(f.readline().rstrip())
+        if scale < 0:  # little-endian
+            endian = "<"
+            scale = -scale
+        else:
+            endian = ">"  # big-endian
+
+        data = np.fromfile(f, dtype=endian + "f")
+
+    pfm_channels = 3 if header == b"PF" else 1
+
+    data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
+    data = np.flip(data, axis=1)  # flip on h dimension
+    data = data[:slice_channels, :, :]
+    return data.astype(np.float32)
+
+
+def _flip_byte_order(t: torch.Tensor) -> torch.Tensor:
+    return (
+        t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype)
+    )

+ 419 - 0
libs/vision_libs/datasets/video_utils.py

@@ -0,0 +1,419 @@
+import bisect
+import math
+import warnings
+from fractions import Fraction
+from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar, Union
+
+import torch
+from torchvision.io import _probe_video_from_file, _read_video_from_file, read_video, read_video_timestamps
+
+from .utils import tqdm
+
+T = TypeVar("T")
+
+
+def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int:
+    """convert pts between different time bases
+    Args:
+        pts: presentation timestamp, float
+        timebase_from: original timebase. Fraction
+        timebase_to: new timebase. Fraction
+        round_func: rounding function.
+    """
+    new_pts = Fraction(pts, 1) * timebase_from / timebase_to
+    return round_func(new_pts)
+
+
+def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor:
+    """
+    similar to tensor.unfold, but with the dilation
+    and specialized for 1d tensors
+
+    Returns all consecutive windows of `size` elements, with
+    `step` between windows. The distance between each element
+    in a window is given by `dilation`.
+    """
+    if tensor.dim() != 1:
+        raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}")
+    o_stride = tensor.stride(0)
+    numel = tensor.numel()
+    new_stride = (step * o_stride, dilation * o_stride)
+    new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
+    if new_size[0] < 1:
+        new_size = (0, size)
+    return torch.as_strided(tensor, new_size, new_stride)
+
+
+class _VideoTimestampsDataset:
+    """
+    Dataset used to parallelize the reading of the timestamps
+    of a list of videos, given their paths in the filesystem.
+
+    Used in VideoClips and defined at top level, so it can be
+    pickled when forking.
+    """
+
+    def __init__(self, video_paths: List[str]) -> None:
+        self.video_paths = video_paths
+
+    def __len__(self) -> int:
+        return len(self.video_paths)
+
+    def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]:
+        return read_video_timestamps(self.video_paths[idx])
+
+
+def _collate_fn(x: T) -> T:
+    """
+    Dummy collate function to be used with _VideoTimestampsDataset
+    """
+    return x
+
+
+class VideoClips:
+    """
+    Given a list of video files, computes all consecutive subvideos of size
+    `clip_length_in_frames`, where the distance between each subvideo in the
+    same video is defined by `frames_between_clips`.
+    If `frame_rate` is specified, it will also resample all the videos to have
+    the same frame rate, and the clips will refer to this frame rate.
+
+    Creating this instance the first time is time-consuming, as it needs to
+    decode all the videos in `video_paths`. It is recommended that you
+    cache the results after instantiation of the class.
+
+    Recreating the clips for different clip lengths is fast, and can be done
+    with the `compute_clips` method.
+
+    Args:
+        video_paths (List[str]): paths to the video files
+        clip_length_in_frames (int): size of a clip in number of frames
+        frames_between_clips (int): step (in frames) between each clip
+        frame_rate (int, optional): if specified, it will resample the video
+            so that it has `frame_rate`, and then the clips will be defined
+            on the resampled video
+        num_workers (int): how many subprocesses to use for data loading.
+            0 means that the data will be loaded in the main process. (default: 0)
+        output_format (str): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
+    """
+
+    def __init__(
+        self,
+        video_paths: List[str],
+        clip_length_in_frames: int = 16,
+        frames_between_clips: int = 1,
+        frame_rate: Optional[int] = None,
+        _precomputed_metadata: Optional[Dict[str, Any]] = None,
+        num_workers: int = 0,
+        _video_width: int = 0,
+        _video_height: int = 0,
+        _video_min_dimension: int = 0,
+        _video_max_dimension: int = 0,
+        _audio_samples: int = 0,
+        _audio_channels: int = 0,
+        output_format: str = "THWC",
+    ) -> None:
+
+        self.video_paths = video_paths
+        self.num_workers = num_workers
+
+        # these options are not valid for pyav backend
+        self._video_width = _video_width
+        self._video_height = _video_height
+        self._video_min_dimension = _video_min_dimension
+        self._video_max_dimension = _video_max_dimension
+        self._audio_samples = _audio_samples
+        self._audio_channels = _audio_channels
+        self.output_format = output_format.upper()
+        if self.output_format not in ("THWC", "TCHW"):
+            raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
+
+        if _precomputed_metadata is None:
+            self._compute_frame_pts()
+        else:
+            self._init_from_metadata(_precomputed_metadata)
+        self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
+
+    def _compute_frame_pts(self) -> None:
+        self.video_pts = []  # len = num_videos. Each entry is a tensor of shape (num_frames_in_video,)
+        self.video_fps: List[int] = []  # len = num_videos
+
+        # strategy: use a DataLoader to parallelize read_video_timestamps
+        # so need to create a dummy dataset first
+        import torch.utils.data
+
+        dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
+            _VideoTimestampsDataset(self.video_paths),  # type: ignore[arg-type]
+            batch_size=16,
+            num_workers=self.num_workers,
+            collate_fn=_collate_fn,
+        )
+
+        with tqdm(total=len(dl)) as pbar:
+            for batch in dl:
+                pbar.update(1)
+                batch_pts, batch_fps = list(zip(*batch))
+                # we need to specify dtype=torch.long because for empty list,
+                # torch.as_tensor will use torch.float as default dtype. This
+                # happens when decoding fails and no pts is returned in the list.
+                batch_pts = [torch.as_tensor(pts, dtype=torch.long) for pts in batch_pts]
+                self.video_pts.extend(batch_pts)
+                self.video_fps.extend(batch_fps)
+
+    def _init_from_metadata(self, metadata: Dict[str, Any]) -> None:
+        self.video_paths = metadata["video_paths"]
+        assert len(self.video_paths) == len(metadata["video_pts"])
+        self.video_pts = metadata["video_pts"]
+        assert len(self.video_paths) == len(metadata["video_fps"])
+        self.video_fps = metadata["video_fps"]
+
+    @property
+    def metadata(self) -> Dict[str, Any]:
+        _metadata = {
+            "video_paths": self.video_paths,
+            "video_pts": self.video_pts,
+            "video_fps": self.video_fps,
+        }
+        return _metadata
+
+    def subset(self, indices: List[int]) -> "VideoClips":
+        video_paths = [self.video_paths[i] for i in indices]
+        video_pts = [self.video_pts[i] for i in indices]
+        video_fps = [self.video_fps[i] for i in indices]
+        metadata = {
+            "video_paths": video_paths,
+            "video_pts": video_pts,
+            "video_fps": video_fps,
+        }
+        return type(self)(
+            video_paths,
+            clip_length_in_frames=self.num_frames,
+            frames_between_clips=self.step,
+            frame_rate=self.frame_rate,
+            _precomputed_metadata=metadata,
+            num_workers=self.num_workers,
+            _video_width=self._video_width,
+            _video_height=self._video_height,
+            _video_min_dimension=self._video_min_dimension,
+            _video_max_dimension=self._video_max_dimension,
+            _audio_samples=self._audio_samples,
+            _audio_channels=self._audio_channels,
+            output_format=self.output_format,
+        )
+
+    @staticmethod
+    def compute_clips_for_video(
+        video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None
+    ) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]:
+        if fps is None:
+            # if for some reason the video doesn't have fps (because doesn't have a video stream)
+            # set the fps to 1. The value doesn't matter, because video_pts is empty anyway
+            fps = 1
+        if frame_rate is None:
+            frame_rate = fps
+        total_frames = len(video_pts) * (float(frame_rate) / fps)
+        _idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
+        video_pts = video_pts[_idxs]
+        clips = unfold(video_pts, num_frames, step)
+        if not clips.numel():
+            warnings.warn(
+                "There aren't enough frames in the current video to get a clip for the given clip length and "
+                "frames between clips. The video (and potentially others) will be skipped."
+            )
+        idxs: Union[List[slice], torch.Tensor]
+        if isinstance(_idxs, slice):
+            idxs = [_idxs] * len(clips)
+        else:
+            idxs = unfold(_idxs, num_frames, step)
+        return clips, idxs
+
+    def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None:
+        """
+        Compute all consecutive sequences of clips from video_pts.
+        Always returns clips of size `num_frames`, meaning that the
+        last few frames in a video can potentially be dropped.
+
+        Args:
+            num_frames (int): number of frames for the clip
+            step (int): distance between two clips
+            frame_rate (int, optional): The frame rate
+        """
+        self.num_frames = num_frames
+        self.step = step
+        self.frame_rate = frame_rate
+        self.clips = []
+        self.resampling_idxs = []
+        for video_pts, fps in zip(self.video_pts, self.video_fps):
+            clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
+            self.clips.append(clips)
+            self.resampling_idxs.append(idxs)
+        clip_lengths = torch.as_tensor([len(v) for v in self.clips])
+        self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
+
+    def __len__(self) -> int:
+        return self.num_clips()
+
+    def num_videos(self) -> int:
+        return len(self.video_paths)
+
+    def num_clips(self) -> int:
+        """
+        Number of subclips that are available in the video list.
+        """
+        return self.cumulative_sizes[-1]
+
+    def get_clip_location(self, idx: int) -> Tuple[int, int]:
+        """
+        Converts a flattened representation of the indices into a video_idx, clip_idx
+        representation.
+        """
+        video_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+        if video_idx == 0:
+            clip_idx = idx
+        else:
+            clip_idx = idx - self.cumulative_sizes[video_idx - 1]
+        return video_idx, clip_idx
+
+    @staticmethod
+    def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]:
+        step = float(original_fps) / new_fps
+        if step.is_integer():
+            # optimization: if step is integer, don't need to perform
+            # advanced indexing
+            step = int(step)
+            return slice(None, None, step)
+        idxs = torch.arange(num_frames, dtype=torch.float32) * step
+        idxs = idxs.floor().to(torch.int64)
+        return idxs
+
+    def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]:
+        """
+        Gets a subclip from a list of videos.
+
+        Args:
+            idx (int): index of the subclip. Must be between 0 and num_clips().
+
+        Returns:
+            video (Tensor)
+            audio (Tensor)
+            info (Dict)
+            video_idx (int): index of the video in `video_paths`
+        """
+        if idx >= self.num_clips():
+            raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)")
+        video_idx, clip_idx = self.get_clip_location(idx)
+        video_path = self.video_paths[video_idx]
+        clip_pts = self.clips[video_idx][clip_idx]
+
+        from torchvision import get_video_backend
+
+        backend = get_video_backend()
+
+        if backend == "pyav":
+            # check for invalid options
+            if self._video_width != 0:
+                raise ValueError("pyav backend doesn't support _video_width != 0")
+            if self._video_height != 0:
+                raise ValueError("pyav backend doesn't support _video_height != 0")
+            if self._video_min_dimension != 0:
+                raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
+            if self._video_max_dimension != 0:
+                raise ValueError("pyav backend doesn't support _video_max_dimension != 0")
+            if self._audio_samples != 0:
+                raise ValueError("pyav backend doesn't support _audio_samples != 0")
+
+        if backend == "pyav":
+            start_pts = clip_pts[0].item()
+            end_pts = clip_pts[-1].item()
+            video, audio, info = read_video(video_path, start_pts, end_pts)
+        else:
+            _info = _probe_video_from_file(video_path)
+            video_fps = _info.video_fps
+            audio_fps = None
+
+            video_start_pts = cast(int, clip_pts[0].item())
+            video_end_pts = cast(int, clip_pts[-1].item())
+
+            audio_start_pts, audio_end_pts = 0, -1
+            audio_timebase = Fraction(0, 1)
+            video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator)
+            if _info.has_audio:
+                audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator)
+                audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
+                audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
+                audio_fps = _info.audio_sample_rate
+            video, audio, _ = _read_video_from_file(
+                video_path,
+                video_width=self._video_width,
+                video_height=self._video_height,
+                video_min_dimension=self._video_min_dimension,
+                video_max_dimension=self._video_max_dimension,
+                video_pts_range=(video_start_pts, video_end_pts),
+                video_timebase=video_timebase,
+                audio_samples=self._audio_samples,
+                audio_channels=self._audio_channels,
+                audio_pts_range=(audio_start_pts, audio_end_pts),
+                audio_timebase=audio_timebase,
+            )
+
+            info = {"video_fps": video_fps}
+            if audio_fps is not None:
+                info["audio_fps"] = audio_fps
+
+        if self.frame_rate is not None:
+            resampling_idx = self.resampling_idxs[video_idx][clip_idx]
+            if isinstance(resampling_idx, torch.Tensor):
+                resampling_idx = resampling_idx - resampling_idx[0]
+            video = video[resampling_idx]
+            info["video_fps"] = self.frame_rate
+        assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
+
+        if self.output_format == "TCHW":
+            # [T,H,W,C] --> [T,C,H,W]
+            video = video.permute(0, 3, 1, 2)
+
+        return video, audio, info, video_idx
+
+    def __getstate__(self) -> Dict[str, Any]:
+        video_pts_sizes = [len(v) for v in self.video_pts]
+        # To be back-compatible, we convert data to dtype torch.long as needed
+        # because for empty list, in legacy implementation, torch.as_tensor will
+        # use torch.float as default dtype. This happens when decoding fails and
+        # no pts is returned in the list.
+        video_pts = [x.to(torch.int64) for x in self.video_pts]
+        # video_pts can be an empty list if no frames have been decoded
+        if video_pts:
+            video_pts = torch.cat(video_pts)  # type: ignore[assignment]
+            # avoid bug in https://github.com/pytorch/pytorch/issues/32351
+            # TODO: Revert it once the bug is fixed.
+            video_pts = video_pts.numpy()  # type: ignore[attr-defined]
+
+        # make a copy of the fields of self
+        d = self.__dict__.copy()
+        d["video_pts_sizes"] = video_pts_sizes
+        d["video_pts"] = video_pts
+        # delete the following attributes to reduce the size of dictionary. They
+        # will be re-computed in "__setstate__()"
+        del d["clips"]
+        del d["resampling_idxs"]
+        del d["cumulative_sizes"]
+
+        # for backwards-compatibility
+        d["_version"] = 2
+        return d
+
+    def __setstate__(self, d: Dict[str, Any]) -> None:
+        # for backwards-compatibility
+        if "_version" not in d:
+            self.__dict__ = d
+            return
+
+        video_pts = torch.as_tensor(d["video_pts"], dtype=torch.int64)
+        video_pts = torch.split(video_pts, d["video_pts_sizes"], dim=0)
+        # don't need this info anymore
+        del d["video_pts_sizes"]
+
+        d["video_pts"] = video_pts
+        self.__dict__ = d
+        # recompute attributes "clips", "resampling_idxs" and other derivative ones
+        self.compute_clips(self.num_frames, self.step, self.frame_rate)

+ 110 - 0
libs/vision_libs/datasets/vision.py

@@ -0,0 +1,110 @@
+import os
+from typing import Any, Callable, List, Optional, Tuple
+
+import torch.utils.data as data
+
+from ..utils import _log_api_usage_once
+
+
+class VisionDataset(data.Dataset):
+    """
+    Base Class For making datasets which are compatible with torchvision.
+    It is necessary to override the ``__getitem__`` and ``__len__`` method.
+
+    Args:
+        root (string, optional): Root directory of dataset. Only used for `__repr__`.
+        transforms (callable, optional): A function/transforms that takes in
+            an image and a label and returns the transformed versions of both.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+
+    .. note::
+
+        :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
+    """
+
+    _repr_indent = 4
+
+    def __init__(
+        self,
+        root: str = None,  # type: ignore[assignment]
+        transforms: Optional[Callable] = None,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+    ) -> None:
+        _log_api_usage_once(self)
+        if isinstance(root, str):
+            root = os.path.expanduser(root)
+        self.root = root
+
+        has_transforms = transforms is not None
+        has_separate_transform = transform is not None or target_transform is not None
+        if has_transforms and has_separate_transform:
+            raise ValueError("Only transforms or transform/target_transform can be passed as argument")
+
+        # for backwards-compatibility
+        self.transform = transform
+        self.target_transform = target_transform
+
+        if has_separate_transform:
+            transforms = StandardTransform(transform, target_transform)
+        self.transforms = transforms
+
+    def __getitem__(self, index: int) -> Any:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            (Any): Sample and meta data, optionally transformed by the respective transforms.
+        """
+        raise NotImplementedError
+
+    def __len__(self) -> int:
+        raise NotImplementedError
+
+    def __repr__(self) -> str:
+        head = "Dataset " + self.__class__.__name__
+        body = [f"Number of datapoints: {self.__len__()}"]
+        if self.root is not None:
+            body.append(f"Root location: {self.root}")
+        body += self.extra_repr().splitlines()
+        if hasattr(self, "transforms") and self.transforms is not None:
+            body += [repr(self.transforms)]
+        lines = [head] + [" " * self._repr_indent + line for line in body]
+        return "\n".join(lines)
+
+    def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
+        lines = transform.__repr__().splitlines()
+        return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
+
+    def extra_repr(self) -> str:
+        return ""
+
+
+class StandardTransform:
+    def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
+        self.transform = transform
+        self.target_transform = target_transform
+
+    def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
+        if self.transform is not None:
+            input = self.transform(input)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+        return input, target
+
+    def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
+        lines = transform.__repr__().splitlines()
+        return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
+
+    def __repr__(self) -> str:
+        body = [self.__class__.__name__]
+        if self.transform is not None:
+            body += self._format_transform_repr(self.transform, "Transform: ")
+        if self.target_transform is not None:
+            body += self._format_transform_repr(self.target_transform, "Target transform: ")
+
+        return "\n".join(body)

+ 224 - 0
libs/vision_libs/datasets/voc.py

@@ -0,0 +1,224 @@
+import collections
+import os
+from xml.etree.ElementTree import Element as ET_Element
+
+from .vision import VisionDataset
+
+try:
+    from defusedxml.ElementTree import parse as ET_parse
+except ImportError:
+    from xml.etree.ElementTree import parse as ET_parse
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+from PIL import Image
+
+from .utils import download_and_extract_archive, verify_str_arg
+
+DATASET_YEAR_DICT = {
+    "2012": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
+        "filename": "VOCtrainval_11-May-2012.tar",
+        "md5": "6cd6e144f989b92b3379bac3b3de84fd",
+        "base_dir": os.path.join("VOCdevkit", "VOC2012"),
+    },
+    "2011": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
+        "filename": "VOCtrainval_25-May-2011.tar",
+        "md5": "6c3384ef61512963050cb5d687e5bf1e",
+        "base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"),
+    },
+    "2010": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
+        "filename": "VOCtrainval_03-May-2010.tar",
+        "md5": "da459979d0c395079b5c75ee67908abb",
+        "base_dir": os.path.join("VOCdevkit", "VOC2010"),
+    },
+    "2009": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
+        "filename": "VOCtrainval_11-May-2009.tar",
+        "md5": "a3e00b113cfcfebf17e343f59da3caa1",
+        "base_dir": os.path.join("VOCdevkit", "VOC2009"),
+    },
+    "2008": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
+        "filename": "VOCtrainval_11-May-2012.tar",
+        "md5": "2629fa636546599198acfcfbfcf1904a",
+        "base_dir": os.path.join("VOCdevkit", "VOC2008"),
+    },
+    "2007": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
+        "filename": "VOCtrainval_06-Nov-2007.tar",
+        "md5": "c52e279531787c972589f7e41ab4ae64",
+        "base_dir": os.path.join("VOCdevkit", "VOC2007"),
+    },
+    "2007-test": {
+        "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
+        "filename": "VOCtest_06-Nov-2007.tar",
+        "md5": "b6e924de25625d8de591ea690078ad9f",
+        "base_dir": os.path.join("VOCdevkit", "VOC2007"),
+    },
+}
+
+
+class _VOCBase(VisionDataset):
+    _SPLITS_DIR: str
+    _TARGET_DIR: str
+    _TARGET_FILE_EXT: str
+
+    def __init__(
+        self,
+        root: str,
+        year: str = "2012",
+        image_set: str = "train",
+        download: bool = False,
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        transforms: Optional[Callable] = None,
+    ):
+        super().__init__(root, transforms, transform, target_transform)
+
+        self.year = verify_str_arg(year, "year", valid_values=[str(yr) for yr in range(2007, 2013)])
+
+        valid_image_sets = ["train", "trainval", "val"]
+        if year == "2007":
+            valid_image_sets.append("test")
+        self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
+
+        key = "2007-test" if year == "2007" and image_set == "test" else year
+        dataset_year_dict = DATASET_YEAR_DICT[key]
+
+        self.url = dataset_year_dict["url"]
+        self.filename = dataset_year_dict["filename"]
+        self.md5 = dataset_year_dict["md5"]
+
+        base_dir = dataset_year_dict["base_dir"]
+        voc_root = os.path.join(self.root, base_dir)
+
+        if download:
+            download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
+
+        if not os.path.isdir(voc_root):
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
+
+        splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
+        split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
+        with open(os.path.join(split_f)) as f:
+            file_names = [x.strip() for x in f.readlines()]
+
+        image_dir = os.path.join(voc_root, "JPEGImages")
+        self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+
+        target_dir = os.path.join(voc_root, self._TARGET_DIR)
+        self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
+
+        assert len(self.images) == len(self.targets)
+
+    def __len__(self) -> int:
+        return len(self.images)
+
+
+class VOCSegmentation(_VOCBase):
+    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
+
+    Args:
+        root (string): Root directory of the VOC Dataset.
+        year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
+        image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
+            ``year=="2007"``, can also be ``"test"``.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+    """
+
+    _SPLITS_DIR = "Segmentation"
+    _TARGET_DIR = "SegmentationClass"
+    _TARGET_FILE_EXT = ".png"
+
+    @property
+    def masks(self) -> List[str]:
+        return self.targets
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is the image segmentation.
+        """
+        img = Image.open(self.images[index]).convert("RGB")
+        target = Image.open(self.masks[index])
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+
+class VOCDetection(_VOCBase):
+    """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
+
+    Args:
+        root (string): Root directory of the VOC Dataset.
+        year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
+        image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
+            ``year=="2007"``, can also be ``"test"``.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+            (default: alphabetic indexing of VOC's 20 classes).
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, required): A function/transform that takes in the
+            target and transforms it.
+        transforms (callable, optional): A function/transform that takes input sample and its target as entry
+            and returns a transformed version.
+    """
+
+    _SPLITS_DIR = "Main"
+    _TARGET_DIR = "Annotations"
+    _TARGET_FILE_EXT = ".xml"
+
+    @property
+    def annotations(self) -> List[str]:
+        return self.targets
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is a dictionary of the XML tree.
+        """
+        img = Image.open(self.images[index]).convert("RGB")
+        target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
+
+        if self.transforms is not None:
+            img, target = self.transforms(img, target)
+
+        return img, target
+
+    @staticmethod
+    def parse_voc_xml(node: ET_Element) -> Dict[str, Any]:
+        voc_dict: Dict[str, Any] = {}
+        children = list(node)
+        if children:
+            def_dic: Dict[str, Any] = collections.defaultdict(list)
+            for dc in map(VOCDetection.parse_voc_xml, children):
+                for ind, v in dc.items():
+                    def_dic[ind].append(v)
+            if node.tag == "annotation":
+                def_dic["object"] = [def_dic["object"]]
+            voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}}
+        if node.text:
+            text = node.text.strip()
+            if not children:
+                voc_dict[node.tag] = text
+        return voc_dict

+ 195 - 0
libs/vision_libs/datasets/widerface.py

@@ -0,0 +1,195 @@
+import os
+from os.path import abspath, expanduser
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+
+from .utils import download_and_extract_archive, download_file_from_google_drive, extract_archive, verify_str_arg
+from .vision import VisionDataset
+
+
+class WIDERFace(VisionDataset):
+    """`WIDERFace <http://shuoyang1213.me/WIDERFACE/>`_ Dataset.
+
+    Args:
+        root (string): Root directory where images and annotations are downloaded to.
+            Expects the following folder structure if download=False:
+
+            .. code::
+
+                <root>
+                    └── widerface
+                        ├── wider_face_split ('wider_face_split.zip' if compressed)
+                        ├── WIDER_train ('WIDER_train.zip' if compressed)
+                        ├── WIDER_val ('WIDER_val.zip' if compressed)
+                        └── WIDER_test ('WIDER_test.zip' if compressed)
+        split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
+            Defaults to ``train``.
+        transform (callable, optional): A function/transform that  takes in a PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        download (bool, optional): If true, downloads the dataset from the internet and
+            puts it in root directory. If dataset is already downloaded, it is not
+            downloaded again.
+
+            .. warning::
+
+                To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
+
+    """
+
+    BASE_FOLDER = "widerface"
+    FILE_LIST = [
+        # File ID                             MD5 Hash                            Filename
+        ("15hGDLhsx8bLgLcIRD5DhYt5iBxnjNF1M", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
+        ("1GUCogbp16PMGa39thoMMeWxp7Rp5oM8Q", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
+        ("1HIfDbVEWKmsYKJZm4lchTBDLW5N7dY5T", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"),
+    ]
+    ANNOTATIONS_FILE = (
+        "http://shuoyang1213.me/WIDERFACE/support/bbx_annotation/wider_face_split.zip",
+        "0e3767bcf0e326556d407bf5bff5d27c",
+        "wider_face_split.zip",
+    )
+
+    def __init__(
+        self,
+        root: str,
+        split: str = "train",
+        transform: Optional[Callable] = None,
+        target_transform: Optional[Callable] = None,
+        download: bool = False,
+    ) -> None:
+        super().__init__(
+            root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
+        )
+        # check arguments
+        self.split = verify_str_arg(split, "split", ("train", "val", "test"))
+
+        if download:
+            self.download()
+
+        if not self._check_integrity():
+            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download and prepare it")
+
+        self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
+        if self.split in ("train", "val"):
+            self.parse_train_val_annotations_file()
+        else:
+            self.parse_test_annotations_file()
+
+    def __getitem__(self, index: int) -> Tuple[Any, Any]:
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is a dict of annotations for all faces in the image.
+            target=None for the test split.
+        """
+
+        # stay consistent with other datasets and return a PIL Image
+        img = Image.open(self.img_info[index]["img_path"])
+
+        if self.transform is not None:
+            img = self.transform(img)
+
+        target = None if self.split == "test" else self.img_info[index]["annotations"]
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target
+
+    def __len__(self) -> int:
+        return len(self.img_info)
+
+    def extra_repr(self) -> str:
+        lines = ["Split: {split}"]
+        return "\n".join(lines).format(**self.__dict__)
+
+    def parse_train_val_annotations_file(self) -> None:
+        filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
+        filepath = os.path.join(self.root, "wider_face_split", filename)
+
+        with open(filepath) as f:
+            lines = f.readlines()
+            file_name_line, num_boxes_line, box_annotation_line = True, False, False
+            num_boxes, box_counter = 0, 0
+            labels = []
+            for line in lines:
+                line = line.rstrip()
+                if file_name_line:
+                    img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
+                    img_path = abspath(expanduser(img_path))
+                    file_name_line = False
+                    num_boxes_line = True
+                elif num_boxes_line:
+                    num_boxes = int(line)
+                    num_boxes_line = False
+                    box_annotation_line = True
+                elif box_annotation_line:
+                    box_counter += 1
+                    line_split = line.split(" ")
+                    line_values = [int(x) for x in line_split]
+                    labels.append(line_values)
+                    if box_counter >= num_boxes:
+                        box_annotation_line = False
+                        file_name_line = True
+                        labels_tensor = torch.tensor(labels)
+                        self.img_info.append(
+                            {
+                                "img_path": img_path,
+                                "annotations": {
+                                    "bbox": labels_tensor[:, 0:4].clone(),  # x, y, width, height
+                                    "blur": labels_tensor[:, 4].clone(),
+                                    "expression": labels_tensor[:, 5].clone(),
+                                    "illumination": labels_tensor[:, 6].clone(),
+                                    "occlusion": labels_tensor[:, 7].clone(),
+                                    "pose": labels_tensor[:, 8].clone(),
+                                    "invalid": labels_tensor[:, 9].clone(),
+                                },
+                            }
+                        )
+                        box_counter = 0
+                        labels.clear()
+                else:
+                    raise RuntimeError(f"Error parsing annotation file {filepath}")
+
+    def parse_test_annotations_file(self) -> None:
+        filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
+        filepath = abspath(expanduser(filepath))
+        with open(filepath) as f:
+            lines = f.readlines()
+            for line in lines:
+                line = line.rstrip()
+                img_path = os.path.join(self.root, "WIDER_test", "images", line)
+                img_path = abspath(expanduser(img_path))
+                self.img_info.append({"img_path": img_path})
+
+    def _check_integrity(self) -> bool:
+        # Allow original archive to be deleted (zip). Only need the extracted images
+        all_files = self.FILE_LIST.copy()
+        all_files.append(self.ANNOTATIONS_FILE)
+        for (_, md5, filename) in all_files:
+            file, ext = os.path.splitext(filename)
+            extracted_dir = os.path.join(self.root, file)
+            if not os.path.exists(extracted_dir):
+                return False
+        return True
+
+    def download(self) -> None:
+        if self._check_integrity():
+            print("Files already downloaded and verified")
+            return
+
+        # download and extract image data
+        for (file_id, md5, filename) in self.FILE_LIST:
+            download_file_from_google_drive(file_id, self.root, filename, md5)
+            filepath = os.path.join(self.root, filename)
+            extract_archive(filepath)
+
+        # download and extract annotation files
+        download_and_extract_archive(
+            url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1]
+        )

+ 92 - 0
libs/vision_libs/extension.py

@@ -0,0 +1,92 @@
+import os
+import sys
+
+import torch
+
+from ._internally_replaced_utils import _get_extension_path
+
+
+_HAS_OPS = False
+
+
+def _has_ops():
+    return False
+
+
+try:
+    # On Windows Python-3.8.x has `os.add_dll_directory` call,
+    # which is called to configure dll search path.
+    # To find cuda related dlls we need to make sure the
+    # conda environment/bin path is configured Please take a look:
+    # https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python
+    # Please note: if some path can't be added using add_dll_directory we simply ignore this path
+    if os.name == "nt" and sys.version_info < (3, 9):
+        env_path = os.environ["PATH"]
+        path_arr = env_path.split(";")
+        for path in path_arr:
+            if os.path.exists(path):
+                try:
+                    os.add_dll_directory(path)  # type: ignore[attr-defined]
+                except Exception:
+                    pass
+
+    lib_path = _get_extension_path("_C")
+    torch.ops.load_library(lib_path)
+    _HAS_OPS = True
+
+    def _has_ops():  # noqa: F811
+        return True
+
+except (ImportError, OSError):
+    pass
+
+
+def _assert_has_ops():
+    if not _has_ops():
+        raise RuntimeError(
+            "Couldn't load custom C++ ops. This can happen if your PyTorch and "
+            "torchvision versions are incompatible, or if you had errors while compiling "
+            "torchvision from source. For further information on the compatible versions, check "
+            "https://github.com/pytorch/vision#installation for the compatibility matrix. "
+            "Please check your PyTorch version with torch.__version__ and your torchvision "
+            "version with torchvision.__version__ and verify if they are compatible, and if not "
+            "please reinstall torchvision so that it matches your PyTorch install."
+        )
+
+
+def _check_cuda_version():
+    """
+    Make sure that CUDA versions match between the pytorch install and torchvision install
+    """
+    if not _HAS_OPS:
+        return -1
+    from torch.version import cuda as torch_version_cuda
+
+    _version = torch.ops.torchvision._cuda_version()
+    if _version != -1 and torch_version_cuda is not None:
+        tv_version = str(_version)
+        if int(tv_version) < 10000:
+            tv_major = int(tv_version[0])
+            tv_minor = int(tv_version[2])
+        else:
+            tv_major = int(tv_version[0:2])
+            tv_minor = int(tv_version[3])
+        t_version = torch_version_cuda.split(".")
+        t_major = int(t_version[0])
+        t_minor = int(t_version[1])
+        if t_major != tv_major:
+            raise RuntimeError(
+                "Detected that PyTorch and torchvision were compiled with different CUDA major versions. "
+                f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
+                f"CUDA Version={tv_major}.{tv_minor}. "
+                "Please reinstall the torchvision that matches your PyTorch install."
+            )
+    return _version
+
+
+def _load_library(lib_name):
+    lib_path = _get_extension_path(lib_name)
+    torch.ops.load_library(lib_path)
+
+
+_check_cuda_version()

BIN
libs/vision_libs/image.pyd


+ 69 - 0
libs/vision_libs/io/__init__.py

@@ -0,0 +1,69 @@
+from typing import Any, Dict, Iterator
+
+import torch
+
+from ..utils import _log_api_usage_once
+
+try:
+    from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
+except ModuleNotFoundError:
+    _HAS_GPU_VIDEO_DECODER = False
+
+from ._video_opt import (
+    _HAS_VIDEO_OPT,
+    _probe_video_from_file,
+    _probe_video_from_memory,
+    _read_video_from_file,
+    _read_video_from_memory,
+    _read_video_timestamps_from_file,
+    _read_video_timestamps_from_memory,
+    Timebase,
+    VideoMetaData,
+)
+from .image import (
+    decode_image,
+    decode_jpeg,
+    decode_png,
+    encode_jpeg,
+    encode_png,
+    ImageReadMode,
+    read_file,
+    read_image,
+    write_file,
+    write_jpeg,
+    write_png,
+)
+from .video import read_video, read_video_timestamps, write_video
+from .video_reader import VideoReader
+
+
+__all__ = [
+    "write_video",
+    "read_video",
+    "read_video_timestamps",
+    "_read_video_from_file",
+    "_read_video_timestamps_from_file",
+    "_probe_video_from_file",
+    "_read_video_from_memory",
+    "_read_video_timestamps_from_memory",
+    "_probe_video_from_memory",
+    "_HAS_VIDEO_OPT",
+    "_HAS_GPU_VIDEO_DECODER",
+    "_read_video_clip_from_memory",
+    "_read_video_meta_data",
+    "VideoMetaData",
+    "Timebase",
+    "ImageReadMode",
+    "decode_image",
+    "decode_jpeg",
+    "decode_png",
+    "encode_jpeg",
+    "encode_png",
+    "read_file",
+    "read_image",
+    "write_file",
+    "write_jpeg",
+    "write_png",
+    "Video",
+    "VideoReader",
+]

+ 8 - 0
libs/vision_libs/io/_load_gpu_decoder.py

@@ -0,0 +1,8 @@
+from ..extension import _load_library
+
+
+try:
+    _load_library("Decoder")
+    _HAS_GPU_VIDEO_DECODER = True
+except (ImportError, OSError):
+    _HAS_GPU_VIDEO_DECODER = False

+ 512 - 0
libs/vision_libs/io/_video_opt.py

@@ -0,0 +1,512 @@
+import math
+import warnings
+from fractions import Fraction
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..extension import _load_library
+
+
+try:
+    _load_library("video_reader")
+    _HAS_VIDEO_OPT = True
+except (ImportError, OSError):
+    _HAS_VIDEO_OPT = False
+
+default_timebase = Fraction(0, 1)
+
+
+# simple class for torch scripting
+# the complex Fraction class from fractions module is not scriptable
+class Timebase:
+    __annotations__ = {"numerator": int, "denominator": int}
+    __slots__ = ["numerator", "denominator"]
+
+    def __init__(
+        self,
+        numerator: int,
+        denominator: int,
+    ) -> None:
+        self.numerator = numerator
+        self.denominator = denominator
+
+
+class VideoMetaData:
+    __annotations__ = {
+        "has_video": bool,
+        "video_timebase": Timebase,
+        "video_duration": float,
+        "video_fps": float,
+        "has_audio": bool,
+        "audio_timebase": Timebase,
+        "audio_duration": float,
+        "audio_sample_rate": float,
+    }
+    __slots__ = [
+        "has_video",
+        "video_timebase",
+        "video_duration",
+        "video_fps",
+        "has_audio",
+        "audio_timebase",
+        "audio_duration",
+        "audio_sample_rate",
+    ]
+
+    def __init__(self) -> None:
+        self.has_video = False
+        self.video_timebase = Timebase(0, 1)
+        self.video_duration = 0.0
+        self.video_fps = 0.0
+        self.has_audio = False
+        self.audio_timebase = Timebase(0, 1)
+        self.audio_duration = 0.0
+        self.audio_sample_rate = 0.0
+
+
+def _validate_pts(pts_range: Tuple[int, int]) -> None:
+
+    if pts_range[0] > pts_range[1] > 0:
+        raise ValueError(
+            f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}"
+        )
+
+
+def _fill_info(
+    vtimebase: torch.Tensor,
+    vfps: torch.Tensor,
+    vduration: torch.Tensor,
+    atimebase: torch.Tensor,
+    asample_rate: torch.Tensor,
+    aduration: torch.Tensor,
+) -> VideoMetaData:
+    """
+    Build update VideoMetaData struct with info about the video
+    """
+    meta = VideoMetaData()
+    if vtimebase.numel() > 0:
+        meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item()))
+        timebase = vtimebase[0].item() / float(vtimebase[1].item())
+        if vduration.numel() > 0:
+            meta.has_video = True
+            meta.video_duration = float(vduration.item()) * timebase
+    if vfps.numel() > 0:
+        meta.video_fps = float(vfps.item())
+    if atimebase.numel() > 0:
+        meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item()))
+        timebase = atimebase[0].item() / float(atimebase[1].item())
+        if aduration.numel() > 0:
+            meta.has_audio = True
+            meta.audio_duration = float(aduration.item()) * timebase
+    if asample_rate.numel() > 0:
+        meta.audio_sample_rate = float(asample_rate.item())
+
+    return meta
+
+
+def _align_audio_frames(
+    aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int]
+) -> torch.Tensor:
+    start, end = aframe_pts[0], aframe_pts[-1]
+    num_samples = aframes.size(0)
+    step_per_aframe = float(end - start + 1) / float(num_samples)
+    s_idx = 0
+    e_idx = num_samples
+    if start < audio_pts_range[0]:
+        s_idx = int((audio_pts_range[0] - start) / step_per_aframe)
+    if audio_pts_range[1] != -1 and end > audio_pts_range[1]:
+        e_idx = int((audio_pts_range[1] - end) / step_per_aframe)
+    return aframes[s_idx:e_idx, :]
+
+
+def _read_video_from_file(
+    filename: str,
+    seek_frame_margin: float = 0.25,
+    read_video_stream: bool = True,
+    video_width: int = 0,
+    video_height: int = 0,
+    video_min_dimension: int = 0,
+    video_max_dimension: int = 0,
+    video_pts_range: Tuple[int, int] = (0, -1),
+    video_timebase: Fraction = default_timebase,
+    read_audio_stream: bool = True,
+    audio_samples: int = 0,
+    audio_channels: int = 0,
+    audio_pts_range: Tuple[int, int] = (0, -1),
+    audio_timebase: Fraction = default_timebase,
+) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
+    """
+    Reads a video from a file, returning both the video frames and the audio frames
+
+    Args:
+    filename (str): path to the video file
+    seek_frame_margin (double, optional): seeking frame in the stream is imprecise. Thus,
+        when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
+    read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
+    video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
+        the size of decoded frames:
+
+            - When video_width = 0, video_height = 0, video_min_dimension = 0,
+                and video_max_dimension = 0, keep the original frame resolution
+            - When video_width = 0, video_height = 0, video_min_dimension != 0,
+                and video_max_dimension = 0, keep the aspect ratio and resize the
+                frame so that shorter edge size is video_min_dimension
+            - When video_width = 0, video_height = 0, video_min_dimension = 0,
+                and video_max_dimension != 0, keep the aspect ratio and resize
+                the frame so that longer edge size is video_max_dimension
+            - When video_width = 0, video_height = 0, video_min_dimension != 0,
+                and video_max_dimension != 0, resize the frame so that shorter
+                edge size is video_min_dimension, and longer edge size is
+                video_max_dimension. The aspect ratio may not be preserved
+            - When video_width = 0, video_height != 0, video_min_dimension = 0,
+                and video_max_dimension = 0, keep the aspect ratio and resize
+                the frame so that frame video_height is $video_height
+            - When video_width != 0, video_height == 0, video_min_dimension = 0,
+                and video_max_dimension = 0, keep the aspect ratio and resize
+                the frame so that frame video_width is $video_width
+            - When video_width != 0, video_height != 0, video_min_dimension = 0,
+                and video_max_dimension = 0, resize the frame so that frame
+                video_width and  video_height are set to $video_width and
+                $video_height, respectively
+    video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
+    video_timebase (Fraction, optional): a Fraction rational number which denotes timebase in video stream
+    read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
+    audio_samples (int, optional): audio sampling rate
+    audio_channels (int optional): audio channels
+    audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
+    audio_timebase (Fraction, optional): a Fraction rational number which denotes time base in audio stream
+
+    Returns
+        vframes (Tensor[T, H, W, C]): the `T` video frames
+        aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
+            `K` is the number of audio_channels
+        info (Dict): metadata for the video and audio. Can contain the fields video_fps (float)
+            and audio_fps (int)
+    """
+    _validate_pts(video_pts_range)
+    _validate_pts(audio_pts_range)
+
+    result = torch.ops.video_reader.read_video_from_file(
+        filename,
+        seek_frame_margin,
+        0,  # getPtsOnly
+        read_video_stream,
+        video_width,
+        video_height,
+        video_min_dimension,
+        video_max_dimension,
+        video_pts_range[0],
+        video_pts_range[1],
+        video_timebase.numerator,
+        video_timebase.denominator,
+        read_audio_stream,
+        audio_samples,
+        audio_channels,
+        audio_pts_range[0],
+        audio_pts_range[1],
+        audio_timebase.numerator,
+        audio_timebase.denominator,
+    )
+    vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
+    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+    if aframes.numel() > 0:
+        # when audio stream is found
+        aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
+    return vframes, aframes, info
+
+
+def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]:
+    """
+    Decode all video- and audio frames in the video. Only pts
+    (presentation timestamp) is returned. The actual frame pixel data is not
+    copied. Thus, it is much faster than read_video(...)
+    """
+    result = torch.ops.video_reader.read_video_from_file(
+        filename,
+        0,  # seek_frame_margin
+        1,  # getPtsOnly
+        1,  # read_video_stream
+        0,  # video_width
+        0,  # video_height
+        0,  # video_min_dimension
+        0,  # video_max_dimension
+        0,  # video_start_pts
+        -1,  # video_end_pts
+        0,  # video_timebase_num
+        1,  # video_timebase_den
+        1,  # read_audio_stream
+        0,  # audio_samples
+        0,  # audio_channels
+        0,  # audio_start_pts
+        -1,  # audio_end_pts
+        0,  # audio_timebase_num
+        1,  # audio_timebase_den
+    )
+    _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
+    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+
+    vframe_pts = vframe_pts.numpy().tolist()
+    aframe_pts = aframe_pts.numpy().tolist()
+    return vframe_pts, aframe_pts, info
+
+
+def _probe_video_from_file(filename: str) -> VideoMetaData:
+    """
+    Probe a video file and return VideoMetaData with info about the video
+    """
+    result = torch.ops.video_reader.probe_video_from_file(filename)
+    vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
+    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+    return info
+
+
+def _read_video_from_memory(
+    video_data: torch.Tensor,
+    seek_frame_margin: float = 0.25,
+    read_video_stream: int = 1,
+    video_width: int = 0,
+    video_height: int = 0,
+    video_min_dimension: int = 0,
+    video_max_dimension: int = 0,
+    video_pts_range: Tuple[int, int] = (0, -1),
+    video_timebase_numerator: int = 0,
+    video_timebase_denominator: int = 1,
+    read_audio_stream: int = 1,
+    audio_samples: int = 0,
+    audio_channels: int = 0,
+    audio_pts_range: Tuple[int, int] = (0, -1),
+    audio_timebase_numerator: int = 0,
+    audio_timebase_denominator: int = 1,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Reads a video from memory, returning both the video frames as the audio frames
+    This function is torchscriptable.
+
+    Args:
+    video_data (data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes):
+        compressed video content stored in either 1) torch.Tensor 2) python bytes
+    seek_frame_margin (double, optional): seeking frame in the stream is imprecise.
+        Thus, when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
+    read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
+    video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
+        the size of decoded frames:
+
+            - When video_width = 0, video_height = 0, video_min_dimension = 0,
+                and video_max_dimension = 0, keep the original frame resolution
+            - When video_width = 0, video_height = 0, video_min_dimension != 0,
+                and video_max_dimension = 0, keep the aspect ratio and resize the
+                frame so that shorter edge size is video_min_dimension
+            - When video_width = 0, video_height = 0, video_min_dimension = 0,
+                and video_max_dimension != 0, keep the aspect ratio and resize
+                the frame so that longer edge size is video_max_dimension
+            - When video_width = 0, video_height = 0, video_min_dimension != 0,
+                and video_max_dimension != 0, resize the frame so that shorter
+                edge size is video_min_dimension, and longer edge size is
+                video_max_dimension. The aspect ratio may not be preserved
+            - When video_width = 0, video_height != 0, video_min_dimension = 0,
+                and video_max_dimension = 0, keep the aspect ratio and resize
+                the frame so that frame video_height is $video_height
+            - When video_width != 0, video_height == 0, video_min_dimension = 0,
+                and video_max_dimension = 0, keep the aspect ratio and resize
+                the frame so that frame video_width is $video_width
+            - When video_width != 0, video_height != 0, video_min_dimension = 0,
+                and video_max_dimension = 0, resize the frame so that frame
+                video_width and  video_height are set to $video_width and
+                $video_height, respectively
+    video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
+    video_timebase_numerator / video_timebase_denominator (float, optional): a rational
+        number which denotes timebase in video stream
+    read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
+    audio_samples (int, optional): audio sampling rate
+    audio_channels (int optional): audio audio_channels
+    audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
+    audio_timebase_numerator / audio_timebase_denominator (float, optional):
+        a rational number which denotes time base in audio stream
+
+    Returns:
+        vframes (Tensor[T, H, W, C]): the `T` video frames
+        aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
+            `K` is the number of channels
+    """
+
+    _validate_pts(video_pts_range)
+    _validate_pts(audio_pts_range)
+
+    if not isinstance(video_data, torch.Tensor):
+        with warnings.catch_warnings():
+            # Ignore the warning because we actually don't modify the buffer in this function
+            warnings.filterwarnings("ignore", message="The given buffer is not writable")
+            video_data = torch.frombuffer(video_data, dtype=torch.uint8)
+
+    result = torch.ops.video_reader.read_video_from_memory(
+        video_data,
+        seek_frame_margin,
+        0,  # getPtsOnly
+        read_video_stream,
+        video_width,
+        video_height,
+        video_min_dimension,
+        video_max_dimension,
+        video_pts_range[0],
+        video_pts_range[1],
+        video_timebase_numerator,
+        video_timebase_denominator,
+        read_audio_stream,
+        audio_samples,
+        audio_channels,
+        audio_pts_range[0],
+        audio_pts_range[1],
+        audio_timebase_numerator,
+        audio_timebase_denominator,
+    )
+
+    vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
+
+    if aframes.numel() > 0:
+        # when audio stream is found
+        aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
+
+    return vframes, aframes
+
+
+def _read_video_timestamps_from_memory(
+    video_data: torch.Tensor,
+) -> Tuple[List[int], List[int], VideoMetaData]:
+    """
+    Decode all frames in the video. Only pts (presentation timestamp) is returned.
+    The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
+    is much faster than read_video(...)
+    """
+    if not isinstance(video_data, torch.Tensor):
+        with warnings.catch_warnings():
+            # Ignore the warning because we actually don't modify the buffer in this function
+            warnings.filterwarnings("ignore", message="The given buffer is not writable")
+            video_data = torch.frombuffer(video_data, dtype=torch.uint8)
+    result = torch.ops.video_reader.read_video_from_memory(
+        video_data,
+        0,  # seek_frame_margin
+        1,  # getPtsOnly
+        1,  # read_video_stream
+        0,  # video_width
+        0,  # video_height
+        0,  # video_min_dimension
+        0,  # video_max_dimension
+        0,  # video_start_pts
+        -1,  # video_end_pts
+        0,  # video_timebase_num
+        1,  # video_timebase_den
+        1,  # read_audio_stream
+        0,  # audio_samples
+        0,  # audio_channels
+        0,  # audio_start_pts
+        -1,  # audio_end_pts
+        0,  # audio_timebase_num
+        1,  # audio_timebase_den
+    )
+    _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
+    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+
+    vframe_pts = vframe_pts.numpy().tolist()
+    aframe_pts = aframe_pts.numpy().tolist()
+    return vframe_pts, aframe_pts, info
+
+
+def _probe_video_from_memory(
+    video_data: torch.Tensor,
+) -> VideoMetaData:
+    """
+    Probe a video in memory and return VideoMetaData with info about the video
+    This function is torchscriptable
+    """
+    if not isinstance(video_data, torch.Tensor):
+        with warnings.catch_warnings():
+            # Ignore the warning because we actually don't modify the buffer in this function
+            warnings.filterwarnings("ignore", message="The given buffer is not writable")
+            video_data = torch.frombuffer(video_data, dtype=torch.uint8)
+    result = torch.ops.video_reader.probe_video_from_memory(video_data)
+    vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
+    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
+    return info
+
+
+def _read_video(
+    filename: str,
+    start_pts: Union[float, Fraction] = 0,
+    end_pts: Optional[Union[float, Fraction]] = None,
+    pts_unit: str = "pts",
+) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
+    if end_pts is None:
+        end_pts = float("inf")
+
+    if pts_unit == "pts":
+        warnings.warn(
+            "The pts_unit 'pts' gives wrong results and will be removed in a "
+            + "follow-up version. Please use pts_unit 'sec'."
+        )
+
+    info = _probe_video_from_file(filename)
+
+    has_video = info.has_video
+    has_audio = info.has_audio
+
+    def get_pts(time_base):
+        start_offset = start_pts
+        end_offset = end_pts
+        if pts_unit == "sec":
+            start_offset = int(math.floor(start_pts * (1 / time_base)))
+            if end_offset != float("inf"):
+                end_offset = int(math.ceil(end_pts * (1 / time_base)))
+        if end_offset == float("inf"):
+            end_offset = -1
+        return start_offset, end_offset
+
+    video_pts_range = (0, -1)
+    video_timebase = default_timebase
+    if has_video:
+        video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
+        video_pts_range = get_pts(video_timebase)
+
+    audio_pts_range = (0, -1)
+    audio_timebase = default_timebase
+    if has_audio:
+        audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
+        audio_pts_range = get_pts(audio_timebase)
+
+    vframes, aframes, info = _read_video_from_file(
+        filename,
+        read_video_stream=True,
+        video_pts_range=video_pts_range,
+        video_timebase=video_timebase,
+        read_audio_stream=True,
+        audio_pts_range=audio_pts_range,
+        audio_timebase=audio_timebase,
+    )
+    _info = {}
+    if has_video:
+        _info["video_fps"] = info.video_fps
+    if has_audio:
+        _info["audio_fps"] = info.audio_sample_rate
+
+    return vframes, aframes, _info
+
+
+def _read_video_timestamps(
+    filename: str, pts_unit: str = "pts"
+) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]:
+    if pts_unit == "pts":
+        warnings.warn(
+            "The pts_unit 'pts' gives wrong results and will be removed in a "
+            + "follow-up version. Please use pts_unit 'sec'."
+        )
+
+    pts: Union[List[int], List[Fraction]]
+    pts, _, info = _read_video_timestamps_from_file(filename)
+
+    if pts_unit == "sec":
+        video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
+        pts = [x * video_time_base for x in pts]
+
+    video_fps = info.video_fps if info.has_video else None
+
+    return pts, video_fps

+ 264 - 0
libs/vision_libs/io/image.py

@@ -0,0 +1,264 @@
+from enum import Enum
+from warnings import warn
+
+import torch
+
+from ..extension import _load_library
+from ..utils import _log_api_usage_once
+
+
+try:
+    _load_library("image")
+except (ImportError, OSError) as e:
+    warn(
+        f"Failed to load image Python extension: '{e}'"
+        f"If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. "
+        f"Otherwise, there might be something wrong with your environment. "
+        f"Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?"
+    )
+
+
+class ImageReadMode(Enum):
+    """
+    Support for various modes while reading images.
+
+    Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
+    ``ImageReadMode.GRAY`` for converting to grayscale,
+    ``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
+    ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
+    RGB with transparency.
+    """
+
+    UNCHANGED = 0
+    GRAY = 1
+    GRAY_ALPHA = 2
+    RGB = 3
+    RGB_ALPHA = 4
+
+
+def read_file(path: str) -> torch.Tensor:
+    """
+    Reads and outputs the bytes contents of a file as a uint8 Tensor
+    with one dimension.
+
+    Args:
+        path (str): the path to the file to be read
+
+    Returns:
+        data (Tensor)
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(read_file)
+    data = torch.ops.image.read_file(path)
+    return data
+
+
+def write_file(filename: str, data: torch.Tensor) -> None:
+    """
+    Writes the contents of an uint8 tensor with one dimension to a
+    file.
+
+    Args:
+        filename (str): the path to the file to be written
+        data (Tensor): the contents to be written to the output file
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(write_file)
+    torch.ops.image.write_file(filename, data)
+
+
+def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
+    """
+    Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
+    Optionally converts the image to the desired format.
+    The values of the output tensor are uint8 in [0, 255].
+
+    Args:
+        input (Tensor[1]): a one dimensional uint8 tensor containing
+            the raw bytes of the PNG image.
+        mode (ImageReadMode): the read mode used for optionally
+            converting the image. Default: ``ImageReadMode.UNCHANGED``.
+            See `ImageReadMode` class for more information on various
+            available modes.
+
+    Returns:
+        output (Tensor[image_channels, image_height, image_width])
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(decode_png)
+    output = torch.ops.image.decode_png(input, mode.value, False)
+    return output
+
+
+def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
+    """
+    Takes an input tensor in CHW layout and returns a buffer with the contents
+    of its corresponding PNG file.
+
+    Args:
+        input (Tensor[channels, image_height, image_width]): int8 image tensor of
+            ``c`` channels, where ``c`` must 3 or 1.
+        compression_level (int): Compression factor for the resulting file, it must be a number
+            between 0 and 9. Default: 6
+
+    Returns:
+        Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
+            PNG file.
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(encode_png)
+    output = torch.ops.image.encode_png(input, compression_level)
+    return output
+
+
+def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
+    """
+    Takes an input tensor in CHW layout (or HW in the case of grayscale images)
+    and saves it in a PNG file.
+
+    Args:
+        input (Tensor[channels, image_height, image_width]): int8 image tensor of
+            ``c`` channels, where ``c`` must be 1 or 3.
+        filename (str): Path to save the image.
+        compression_level (int): Compression factor for the resulting file, it must be a number
+            between 0 and 9. Default: 6
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(write_png)
+    output = encode_png(input, compression_level)
+    write_file(filename, output)
+
+
+def decode_jpeg(
+    input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu"
+) -> torch.Tensor:
+    """
+    Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
+    Optionally converts the image to the desired format.
+    The values of the output tensor are uint8 between 0 and 255.
+
+    Args:
+        input (Tensor[1]): a one dimensional uint8 tensor containing
+            the raw bytes of the JPEG image. This tensor must be on CPU,
+            regardless of the ``device`` parameter.
+        mode (ImageReadMode): the read mode used for optionally
+            converting the image. The supported modes are: ``ImageReadMode.UNCHANGED``,
+            ``ImageReadMode.GRAY`` and ``ImageReadMode.RGB``
+            Default: ``ImageReadMode.UNCHANGED``.
+            See ``ImageReadMode`` class for more information on various
+            available modes.
+        device (str or torch.device): The device on which the decoded image will
+            be stored. If a cuda device is specified, the image will be decoded
+            with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
+            supported for CUDA version >= 10.1
+
+            .. betastatus:: device parameter
+
+            .. warning::
+                There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
+                Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
+
+    Returns:
+        output (Tensor[image_channels, image_height, image_width])
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(decode_jpeg)
+    device = torch.device(device)
+    if device.type == "cuda":
+        output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
+    else:
+        output = torch.ops.image.decode_jpeg(input, mode.value)
+    return output
+
+
+def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
+    """
+    Takes an input tensor in CHW layout and returns a buffer with the contents
+    of its corresponding JPEG file.
+
+    Args:
+        input (Tensor[channels, image_height, image_width])): int8 image tensor of
+            ``c`` channels, where ``c`` must be 1 or 3.
+        quality (int): Quality of the resulting JPEG file, it must be a number between
+            1 and 100. Default: 75
+
+    Returns:
+        output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
+            JPEG file.
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(encode_jpeg)
+    if quality < 1 or quality > 100:
+        raise ValueError("Image quality should be a positive number between 1 and 100")
+
+    output = torch.ops.image.encode_jpeg(input, quality)
+    return output
+
+
+def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
+    """
+    Takes an input tensor in CHW layout and saves it in a JPEG file.
+
+    Args:
+        input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
+            channels, where ``c`` must be 1 or 3.
+        filename (str): Path to save the image.
+        quality (int): Quality of the resulting JPEG file, it must be a number
+            between 1 and 100. Default: 75
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(write_jpeg)
+    output = encode_jpeg(input, quality)
+    write_file(filename, output)
+
+
+def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
+    """
+    Detects whether an image is a JPEG or PNG and performs the appropriate
+    operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
+
+    Optionally converts the image to the desired format.
+    The values of the output tensor are uint8 in [0, 255].
+
+    Args:
+        input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
+            PNG or JPEG image.
+        mode (ImageReadMode): the read mode used for optionally converting the image.
+            Default: ``ImageReadMode.UNCHANGED``.
+            See ``ImageReadMode`` class for more information on various
+            available modes.
+
+    Returns:
+        output (Tensor[image_channels, image_height, image_width])
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(decode_image)
+    output = torch.ops.image.decode_image(input, mode.value)
+    return output
+
+
+def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
+    """
+    Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
+    Optionally converts the image to the desired format.
+    The values of the output tensor are uint8 in [0, 255].
+
+    Args:
+        path (str): path of the JPEG or PNG image.
+        mode (ImageReadMode): the read mode used for optionally converting the image.
+            Default: ``ImageReadMode.UNCHANGED``.
+            See ``ImageReadMode`` class for more information on various
+            available modes.
+
+    Returns:
+        output (Tensor[image_channels, image_height, image_width])
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(read_image)
+    data = read_file(path)
+    return decode_image(data, mode)
+
+
+def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
+    data = read_file(path)
+    return torch.ops.image.decode_png(data, mode.value, True)

+ 415 - 0
libs/vision_libs/io/video.py

@@ -0,0 +1,415 @@
+import gc
+import math
+import os
+import re
+import warnings
+from fractions import Fraction
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..utils import _log_api_usage_once
+from . import _video_opt
+
+try:
+    import av
+
+    av.logging.set_level(av.logging.ERROR)
+    if not hasattr(av.video.frame.VideoFrame, "pict_type"):
+        av = ImportError(
+            """\
+Your version of PyAV is too old for the necessary video operations in torchvision.
+If you are on Python 3.5, you will have to build from source (the conda-forge
+packages are not up-to-date).  See
+https://github.com/mikeboers/PyAV#installation for instructions on how to
+install PyAV on your system.
+"""
+        )
+except ImportError:
+    av = ImportError(
+        """\
+PyAV is not installed, and is necessary for the video operations in torchvision.
+See https://github.com/mikeboers/PyAV#installation for instructions on how to
+install PyAV on your system.
+"""
+    )
+
+
+def _check_av_available() -> None:
+    if isinstance(av, Exception):
+        raise av
+
+
+def _av_available() -> bool:
+    return not isinstance(av, Exception)
+
+
+# PyAV has some reference cycles
+_CALLED_TIMES = 0
+_GC_COLLECTION_INTERVAL = 10
+
+
+def write_video(
+    filename: str,
+    video_array: torch.Tensor,
+    fps: float,
+    video_codec: str = "libx264",
+    options: Optional[Dict[str, Any]] = None,
+    audio_array: Optional[torch.Tensor] = None,
+    audio_fps: Optional[float] = None,
+    audio_codec: Optional[str] = None,
+    audio_options: Optional[Dict[str, Any]] = None,
+) -> None:
+    """
+    Writes a 4d tensor in [T, H, W, C] format in a video file
+
+    Args:
+        filename (str): path where the video will be saved
+        video_array (Tensor[T, H, W, C]): tensor containing the individual frames,
+            as a uint8 tensor in [T, H, W, C] format
+        fps (Number): video frames per second
+        video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc.
+        options (Dict): dictionary containing options to be passed into the PyAV video stream
+        audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels
+            and N is the number of samples
+        audio_fps (Number): audio sample rate, typically 44100 or 48000
+        audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc.
+        audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(write_video)
+    _check_av_available()
+    video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
+
+    # PyAV does not support floating point numbers with decimal point
+    # and will throw OverflowException in case this is not the case
+    if isinstance(fps, float):
+        fps = np.round(fps)
+
+    with av.open(filename, mode="w") as container:
+        stream = container.add_stream(video_codec, rate=fps)
+        stream.width = video_array.shape[2]
+        stream.height = video_array.shape[1]
+        stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
+        stream.options = options or {}
+
+        if audio_array is not None:
+            audio_format_dtypes = {
+                "dbl": "<f8",
+                "dblp": "<f8",
+                "flt": "<f4",
+                "fltp": "<f4",
+                "s16": "<i2",
+                "s16p": "<i2",
+                "s32": "<i4",
+                "s32p": "<i4",
+                "u8": "u1",
+                "u8p": "u1",
+            }
+            a_stream = container.add_stream(audio_codec, rate=audio_fps)
+            a_stream.options = audio_options or {}
+
+            num_channels = audio_array.shape[0]
+            audio_layout = "stereo" if num_channels > 1 else "mono"
+            audio_sample_fmt = container.streams.audio[0].format.name
+
+            format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
+            audio_array = torch.as_tensor(audio_array).numpy().astype(format_dtype)
+
+            frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)
+
+            frame.sample_rate = audio_fps
+
+            for packet in a_stream.encode(frame):
+                container.mux(packet)
+
+            for packet in a_stream.encode():
+                container.mux(packet)
+
+        for img in video_array:
+            frame = av.VideoFrame.from_ndarray(img, format="rgb24")
+            frame.pict_type = "NONE"
+            for packet in stream.encode(frame):
+                container.mux(packet)
+
+        # Flush stream
+        for packet in stream.encode():
+            container.mux(packet)
+
+
+def _read_from_stream(
+    container: "av.container.Container",
+    start_offset: float,
+    end_offset: float,
+    pts_unit: str,
+    stream: "av.stream.Stream",
+    stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
+) -> List["av.frame.Frame"]:
+    global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
+    _CALLED_TIMES += 1
+    if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
+        gc.collect()
+
+    if pts_unit == "sec":
+        # TODO: we should change all of this from ground up to simply take
+        # sec and convert to MS in C++
+        start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
+        if end_offset != float("inf"):
+            end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
+    else:
+        warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
+
+    frames = {}
+    should_buffer = True
+    max_buffer_size = 5
+    if stream.type == "video":
+        # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
+        # so need to buffer some extra frames to sort everything
+        # properly
+        extradata = stream.codec_context.extradata
+        # overly complicated way of finding if `divx_packed` is set, following
+        # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
+        if extradata and b"DivX" in extradata:
+            # can't use regex directly because of some weird characters sometimes...
+            pos = extradata.find(b"DivX")
+            d = extradata[pos:]
+            o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
+            if o is None:
+                o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
+            if o is not None:
+                should_buffer = o.group(3) == b"p"
+    seek_offset = start_offset
+    # some files don't seek to the right location, so better be safe here
+    seek_offset = max(seek_offset - 1, 0)
+    if should_buffer:
+        # FIXME this is kind of a hack, but we will jump to the previous keyframe
+        # so this will be safe
+        seek_offset = max(seek_offset - max_buffer_size, 0)
+    try:
+        # TODO check if stream needs to always be the video stream here or not
+        container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
+    except av.AVError:
+        # TODO add some warnings in this case
+        # print("Corrupted file?", container.name)
+        return []
+    buffer_count = 0
+    try:
+        for _idx, frame in enumerate(container.decode(**stream_name)):
+            frames[frame.pts] = frame
+            if frame.pts >= end_offset:
+                if should_buffer and buffer_count < max_buffer_size:
+                    buffer_count += 1
+                    continue
+                break
+    except av.AVError:
+        # TODO add a warning
+        pass
+    # ensure that the results are sorted wrt the pts
+    result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
+    if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
+        # if there is no frame that exactly matches the pts of start_offset
+        # add the last frame smaller than start_offset, to guarantee that
+        # we will have all the necessary data. This is most useful for audio
+        preceding_frames = [i for i in frames if i < start_offset]
+        if len(preceding_frames) > 0:
+            first_frame_pts = max(preceding_frames)
+            result.insert(0, frames[first_frame_pts])
+    return result
+
+
+def _align_audio_frames(
+    aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float
+) -> torch.Tensor:
+    start, end = audio_frames[0].pts, audio_frames[-1].pts
+    total_aframes = aframes.shape[1]
+    step_per_aframe = (end - start + 1) / total_aframes
+    s_idx = 0
+    e_idx = total_aframes
+    if start < ref_start:
+        s_idx = int((ref_start - start) / step_per_aframe)
+    if end > ref_end:
+        e_idx = int((ref_end - end) / step_per_aframe)
+    return aframes[:, s_idx:e_idx]
+
+
+def read_video(
+    filename: str,
+    start_pts: Union[float, Fraction] = 0,
+    end_pts: Optional[Union[float, Fraction]] = None,
+    pts_unit: str = "pts",
+    output_format: str = "THWC",
+) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
+    """
+    Reads a video from a file, returning both the video frames and the audio frames
+
+    Args:
+        filename (str): path to the video file
+        start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
+            The start presentation time of the video
+        end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
+            The end presentation time
+        pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
+            either 'pts' or 'sec'. Defaults to 'pts'.
+        output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
+
+    Returns:
+        vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
+        aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
+        info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(read_video)
+
+    output_format = output_format.upper()
+    if output_format not in ("THWC", "TCHW"):
+        raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
+
+    from torchvision import get_video_backend
+
+    if not os.path.exists(filename):
+        raise RuntimeError(f"File not found: {filename}")
+
+    if get_video_backend() != "pyav":
+        vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
+    else:
+        _check_av_available()
+
+        if end_pts is None:
+            end_pts = float("inf")
+
+        if end_pts < start_pts:
+            raise ValueError(
+                f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
+            )
+
+        info = {}
+        video_frames = []
+        audio_frames = []
+        audio_timebase = _video_opt.default_timebase
+
+        try:
+            with av.open(filename, metadata_errors="ignore") as container:
+                if container.streams.audio:
+                    audio_timebase = container.streams.audio[0].time_base
+                if container.streams.video:
+                    video_frames = _read_from_stream(
+                        container,
+                        start_pts,
+                        end_pts,
+                        pts_unit,
+                        container.streams.video[0],
+                        {"video": 0},
+                    )
+                    video_fps = container.streams.video[0].average_rate
+                    # guard against potentially corrupted files
+                    if video_fps is not None:
+                        info["video_fps"] = float(video_fps)
+
+                if container.streams.audio:
+                    audio_frames = _read_from_stream(
+                        container,
+                        start_pts,
+                        end_pts,
+                        pts_unit,
+                        container.streams.audio[0],
+                        {"audio": 0},
+                    )
+                    info["audio_fps"] = container.streams.audio[0].rate
+
+        except av.AVError:
+            # TODO raise a warning?
+            pass
+
+        vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
+        aframes_list = [frame.to_ndarray() for frame in audio_frames]
+
+        if vframes_list:
+            vframes = torch.as_tensor(np.stack(vframes_list))
+        else:
+            vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
+
+        if aframes_list:
+            aframes = np.concatenate(aframes_list, 1)
+            aframes = torch.as_tensor(aframes)
+            if pts_unit == "sec":
+                start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
+                if end_pts != float("inf"):
+                    end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
+            aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
+        else:
+            aframes = torch.empty((1, 0), dtype=torch.float32)
+
+    if output_format == "TCHW":
+        # [T,H,W,C] --> [T,C,H,W]
+        vframes = vframes.permute(0, 3, 1, 2)
+
+    return vframes, aframes, info
+
+
+def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
+    extradata = container.streams[0].codec_context.extradata
+    if extradata is None:
+        return False
+    if b"Lavc" in extradata:
+        return True
+    return False
+
+
+def _decode_video_timestamps(container: "av.container.Container") -> List[int]:
+    if _can_read_timestamps_from_packets(container):
+        # fast path
+        return [x.pts for x in container.demux(video=0) if x.pts is not None]
+    else:
+        return [x.pts for x in container.decode(video=0) if x.pts is not None]
+
+
+def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
+    """
+    List the video frames timestamps.
+
+    Note that the function decodes the whole video frame-by-frame.
+
+    Args:
+        filename (str): path to the video file
+        pts_unit (str, optional): unit in which timestamp values will be returned
+            either 'pts' or 'sec'. Defaults to 'pts'.
+
+    Returns:
+        pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'):
+            presentation timestamps for each one of the frames in the video.
+        video_fps (float, optional): the frame rate for the video
+
+    """
+    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+        _log_api_usage_once(read_video_timestamps)
+    from torchvision import get_video_backend
+
+    if get_video_backend() != "pyav":
+        return _video_opt._read_video_timestamps(filename, pts_unit)
+
+    _check_av_available()
+
+    video_fps = None
+    pts = []
+
+    try:
+        with av.open(filename, metadata_errors="ignore") as container:
+            if container.streams.video:
+                video_stream = container.streams.video[0]
+                video_time_base = video_stream.time_base
+                try:
+                    pts = _decode_video_timestamps(container)
+                except av.AVError:
+                    warnings.warn(f"Failed decoding frames for file {filename}")
+                video_fps = float(video_stream.average_rate)
+    except av.AVError as e:
+        msg = f"Failed to open container for {filename}; Caught error: {e}"
+        warnings.warn(msg, RuntimeWarning)
+
+    pts.sort()
+
+    if pts_unit == "sec":
+        pts = [x * video_time_base for x in pts]
+
+    return pts, video_fps

+ 286 - 0
libs/vision_libs/io/video_reader.py

@@ -0,0 +1,286 @@
+import io
+import warnings
+
+from typing import Any, Dict, Iterator
+
+import torch
+
+from ..utils import _log_api_usage_once
+
+from ._video_opt import _HAS_VIDEO_OPT
+
+if _HAS_VIDEO_OPT:
+
+    def _has_video_opt() -> bool:
+        return True
+
+else:
+
+    def _has_video_opt() -> bool:
+        return False
+
+
+try:
+    import av
+
+    av.logging.set_level(av.logging.ERROR)
+    if not hasattr(av.video.frame.VideoFrame, "pict_type"):
+        av = ImportError(
+            """\
+Your version of PyAV is too old for the necessary video operations in torchvision.
+If you are on Python 3.5, you will have to build from source (the conda-forge
+packages are not up-to-date).  See
+https://github.com/mikeboers/PyAV#installation for instructions on how to
+install PyAV on your system.
+"""
+        )
+except ImportError:
+    av = ImportError(
+        """\
+PyAV is not installed, and is necessary for the video operations in torchvision.
+See https://github.com/mikeboers/PyAV#installation for instructions on how to
+install PyAV on your system.
+"""
+    )
+
+
+class VideoReader:
+    """
+    Fine-grained video-reading API.
+    Supports frame-by-frame reading of various streams from a single video
+    container. Much like previous video_reader API it supports the following
+    backends: video_reader, pyav, and cuda.
+    Backends can be set via `torchvision.set_video_backend` function.
+
+    .. betastatus:: VideoReader class
+
+    Example:
+        The following examples creates a :mod:`VideoReader` object, seeks into 2s
+        point, and returns a single frame::
+
+            import torchvision
+            video_path = "path_to_a_test_video"
+            reader = torchvision.io.VideoReader(video_path, "video")
+            reader.seek(2.0)
+            frame = next(reader)
+
+        :mod:`VideoReader` implements the iterable API, which makes it suitable to
+        using it in conjunction with :mod:`itertools` for more advanced reading.
+        As such, we can use a :mod:`VideoReader` instance inside for loops::
+
+            reader.seek(2)
+            for frame in reader:
+                frames.append(frame['data'])
+            # additionally, `seek` implements a fluent API, so we can do
+            for frame in reader.seek(2):
+                frames.append(frame['data'])
+
+        With :mod:`itertools`, we can read all frames between 2 and 5 seconds with the
+        following code::
+
+            for frame in itertools.takewhile(lambda x: x['pts'] <= 5, reader.seek(2)):
+                frames.append(frame['data'])
+
+        and similarly, reading 10 frames after the 2s timestamp can be achieved
+        as follows::
+
+            for frame in itertools.islice(reader.seek(2), 10):
+                frames.append(frame['data'])
+
+    .. note::
+
+        Each stream descriptor consists of two parts: stream type (e.g. 'video') and
+        a unique stream id (which are determined by the video encoding).
+        In this way, if the video container contains multiple
+        streams of the same type, users can access the one they want.
+        If only stream type is passed, the decoder auto-detects first stream of that type.
+
+    Args:
+        src (string, bytes object, or tensor): The media source.
+            If string-type, it must be a file path supported by FFMPEG.
+            If bytes, should be an in-memory representation of a file supported by FFMPEG.
+            If Tensor, it is interpreted internally as byte buffer.
+            It must be one-dimensional, of type ``torch.uint8``.
+
+        stream (string, optional): descriptor of the required stream, followed by the stream id,
+            in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
+            Currently available options include ``['video', 'audio']``
+
+        num_threads (int, optional): number of threads used by the codec to decode video.
+            Default value (0) enables multithreading with codec-dependent heuristic. The performance
+            will depend on the version of FFMPEG codecs supported.
+    """
+
+    def __init__(
+        self,
+        src: str,
+        stream: str = "video",
+        num_threads: int = 0,
+    ) -> None:
+        _log_api_usage_once(self)
+        from .. import get_video_backend
+
+        self.backend = get_video_backend()
+        if isinstance(src, str):
+            if not src:
+                raise ValueError("src cannot be empty")
+        elif isinstance(src, bytes):
+            if self.backend in ["cuda"]:
+                raise RuntimeError(
+                    "VideoReader cannot be initialized from bytes object when using cuda or pyav backend."
+                )
+            elif self.backend == "pyav":
+                src = io.BytesIO(src)
+            else:
+                with warnings.catch_warnings():
+                    # Ignore the warning because we actually don't modify the buffer in this function
+                    warnings.filterwarnings("ignore", message="The given buffer is not writable")
+                    src = torch.frombuffer(src, dtype=torch.uint8)
+        elif isinstance(src, torch.Tensor):
+            if self.backend in ["cuda", "pyav"]:
+                raise RuntimeError(
+                    "VideoReader cannot be initialized from Tensor object when using cuda or pyav backend."
+                )
+        else:
+            raise ValueError(f"src must be either string, Tensor or bytes object. Got {type(src)}")
+
+        if self.backend == "cuda":
+            device = torch.device("cuda")
+            self._c = torch.classes.torchvision.GPUDecoder(src, device)
+
+        elif self.backend == "video_reader":
+            if isinstance(src, str):
+                self._c = torch.classes.torchvision.Video(src, stream, num_threads)
+            elif isinstance(src, torch.Tensor):
+                self._c = torch.classes.torchvision.Video("", "", 0)
+                self._c.init_from_memory(src, stream, num_threads)
+
+        elif self.backend == "pyav":
+            self.container = av.open(src, metadata_errors="ignore")
+            # TODO: load metadata
+            stream_type = stream.split(":")[0]
+            stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
+            self.pyav_stream = {stream_type: stream_id}
+            self._c = self.container.decode(**self.pyav_stream)
+
+            # TODO: add extradata exception
+
+        else:
+            raise RuntimeError("Unknown video backend: {}".format(self.backend))
+
+    def __next__(self) -> Dict[str, Any]:
+        """Decodes and returns the next frame of the current stream.
+        Frames are encoded as a dict with mandatory
+        data and pts fields, where data is a tensor, and pts is a
+        presentation timestamp of the frame expressed in seconds
+        as a float.
+
+        Returns:
+            (dict): a dictionary and containing decoded frame (``data``)
+            and corresponding timestamp (``pts``) in seconds
+
+        """
+        if self.backend == "cuda":
+            frame = self._c.next()
+            if frame.numel() == 0:
+                raise StopIteration
+            return {"data": frame, "pts": None}
+        elif self.backend == "video_reader":
+            frame, pts = self._c.next()
+        else:
+            try:
+                frame = next(self._c)
+                pts = float(frame.pts * frame.time_base)
+                if "video" in self.pyav_stream:
+                    frame = torch.tensor(frame.to_rgb().to_ndarray()).permute(2, 0, 1)
+                elif "audio" in self.pyav_stream:
+                    frame = torch.tensor(frame.to_ndarray()).permute(1, 0)
+                else:
+                    frame = None
+            except av.error.EOFError:
+                raise StopIteration
+
+        if frame.numel() == 0:
+            raise StopIteration
+
+        return {"data": frame, "pts": pts}
+
+    def __iter__(self) -> Iterator[Dict[str, Any]]:
+        return self
+
+    def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":
+        """Seek within current stream.
+
+        Args:
+            time_s (float): seek time in seconds
+            keyframes_only (bool): allow to seek only to keyframes
+
+        .. note::
+            Current implementation is the so-called precise seek. This
+            means following seek, call to :mod:`next()` will return the
+            frame with the exact timestamp if it exists or
+            the first frame with timestamp larger than ``time_s``.
+        """
+        if self.backend in ["cuda", "video_reader"]:
+            self._c.seek(time_s, keyframes_only)
+        else:
+            # handle special case as pyav doesn't catch it
+            if time_s < 0:
+                time_s = 0
+            temp_str = self.container.streams.get(**self.pyav_stream)[0]
+            offset = int(round(time_s / temp_str.time_base))
+            if not keyframes_only:
+                warnings.warn("Accurate seek is not implemented for pyav backend")
+            self.container.seek(offset, backward=True, any_frame=False, stream=temp_str)
+            self._c = self.container.decode(**self.pyav_stream)
+        return self
+
+    def get_metadata(self) -> Dict[str, Any]:
+        """Returns video metadata
+
+        Returns:
+            (dict): dictionary containing duration and frame rate for every stream
+        """
+        if self.backend == "pyav":
+            metadata = {}  # type:  Dict[str, Any]
+            for stream in self.container.streams:
+                if stream.type not in metadata:
+                    if stream.type == "video":
+                        rate_n = "fps"
+                    else:
+                        rate_n = "framerate"
+                    metadata[stream.type] = {rate_n: [], "duration": []}
+
+                rate = stream.average_rate if stream.average_rate is not None else stream.sample_rate
+
+                metadata[stream.type]["duration"].append(float(stream.duration * stream.time_base))
+                metadata[stream.type][rate_n].append(float(rate))
+            return metadata
+        return self._c.get_metadata()
+
+    def set_current_stream(self, stream: str) -> bool:
+        """Set current stream.
+        Explicitly define the stream we are operating on.
+
+        Args:
+            stream (string): descriptor of the required stream. Defaults to ``"video:0"``
+                Currently available stream types include ``['video', 'audio']``.
+                Each descriptor consists of two parts: stream type (e.g. 'video') and
+                a unique stream id (which are determined by video encoding).
+                In this way, if the video container contains multiple
+                streams of the same type, users can access the one they want.
+                If only stream type is passed, the decoder auto-detects first stream
+                of that type and returns it.
+
+        Returns:
+            (bool): True on success, False otherwise
+        """
+        if self.backend == "cuda":
+            warnings.warn("GPU decoding only works with video stream.")
+        if self.backend == "pyav":
+            stream_type = stream.split(":")[0]
+            stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
+            self.pyav_stream = {stream_type: stream_id}
+            self._c = self.container.decode(**self.pyav_stream)
+            return True
+        return self._c.set_current_stream(stream)

+ 23 - 0
libs/vision_libs/models/__init__.py

@@ -0,0 +1,23 @@
+from .alexnet import *
+from .convnext import *
+from .densenet import *
+from .efficientnet import *
+from .googlenet import *
+from .inception import *
+from .mnasnet import *
+from .mobilenet import *
+from .regnet import *
+from .resnet import *
+from .shufflenetv2 import *
+from .squeezenet import *
+from .vgg import *
+from .vision_transformer import *
+from .swin_transformer import *
+from .maxvit import *
+from . import detection, optical_flow, quantization, segmentation, video
+
+# The Weights and WeightsEnum are developer-facing utils that we make public for
+# downstream libs like torchgeo https://github.com/pytorch/vision/issues/7094
+# TODO: we could / should document them publicly, but it's not clear where, as
+# they're not intended for end users.
+from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models, Weights, WeightsEnum

+ 277 - 0
libs/vision_libs/models/_api.py

@@ -0,0 +1,277 @@
+import fnmatch
+import importlib
+import inspect
+import sys
+from dataclasses import dataclass
+from enum import Enum
+from functools import partial
+from inspect import signature
+from types import ModuleType
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
+
+from torch import nn
+
+from .._internally_replaced_utils import load_state_dict_from_url
+
+
+__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"]
+
+
+@dataclass
+class Weights:
+    """
+    This class is used to group important attributes associated with the pre-trained weights.
+
+    Args:
+        url (str): The location where we find the weights.
+        transforms (Callable): A callable that constructs the preprocessing method (or validation preset transforms)
+            needed to use the model. The reason we attach a constructor method rather than an already constructed
+            object is because the specific object might have memory and thus we want to delay initialization until
+            needed.
+        meta (Dict[str, Any]): Stores meta-data related to the weights of the model and its configuration. These can be
+            informative attributes (for example the number of parameters/flops, recipe link/methods used in training
+            etc), configuration parameters (for example the `num_classes`) needed to construct the model or important
+            meta-data (for example the `classes` of a classification model) needed to use the model.
+    """
+
+    url: str
+    transforms: Callable
+    meta: Dict[str, Any]
+
+    def __eq__(self, other: Any) -> bool:
+        # We need this custom implementation for correct deep-copy and deserialization behavior.
+        # TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
+        # involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
+        # defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
+        # for it, the check against the defined members would fail and effectively prevent the weights from being
+        # deep-copied or deserialized.
+        # See https://github.com/pytorch/vision/pull/7107 for details.
+        if not isinstance(other, Weights):
+            return NotImplemented
+
+        if self.url != other.url:
+            return False
+
+        if self.meta != other.meta:
+            return False
+
+        if isinstance(self.transforms, partial) and isinstance(other.transforms, partial):
+            return (
+                self.transforms.func == other.transforms.func
+                and self.transforms.args == other.transforms.args
+                and self.transforms.keywords == other.transforms.keywords
+            )
+        else:
+            return self.transforms == other.transforms
+
+
+class WeightsEnum(Enum):
+    """
+    This class is the parent class of all model weights. Each model building method receives an optional `weights`
+    parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
+    `Weights`.
+
+    Args:
+        value (Weights): The data class entry with the weight information.
+    """
+
+    @classmethod
+    def verify(cls, obj: Any) -> Any:
+        if obj is not None:
+            if type(obj) is str:
+                obj = cls[obj.replace(cls.__name__ + ".", "")]
+            elif not isinstance(obj, cls):
+                raise TypeError(
+                    f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
+                )
+        return obj
+
+    def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
+        return load_state_dict_from_url(self.url, *args, **kwargs)
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}.{self._name_}"
+
+    @property
+    def url(self):
+        return self.value.url
+
+    @property
+    def transforms(self):
+        return self.value.transforms
+
+    @property
+    def meta(self):
+        return self.value.meta
+
+
+def get_weight(name: str) -> WeightsEnum:
+    """
+    Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
+
+    Args:
+        name (str): The name of the weight enum entry.
+
+    Returns:
+        WeightsEnum: The requested weight enum.
+    """
+    try:
+        enum_name, value_name = name.split(".")
+    except ValueError:
+        raise ValueError(f"Invalid weight name provided: '{name}'.")
+
+    base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1])
+    base_module = importlib.import_module(base_module_name)
+    model_modules = [base_module] + [
+        x[1]
+        for x in inspect.getmembers(base_module, inspect.ismodule)
+        if x[1].__file__.endswith("__init__.py")  # type: ignore[union-attr]
+    ]
+
+    weights_enum = None
+    for m in model_modules:
+        potential_class = m.__dict__.get(enum_name, None)
+        if potential_class is not None and issubclass(potential_class, WeightsEnum):
+            weights_enum = potential_class
+            break
+
+    if weights_enum is None:
+        raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
+
+    return weights_enum[value_name]
+
+
+def get_model_weights(name: Union[Callable, str]) -> Type[WeightsEnum]:
+    """
+    Returns the weights enum class associated to the given model.
+
+    Args:
+        name (callable or str): The model builder function or the name under which it is registered.
+
+    Returns:
+        weights_enum (WeightsEnum): The weights enum class associated with the model.
+    """
+    model = get_model_builder(name) if isinstance(name, str) else name
+    return _get_enum_from_fn(model)
+
+
+def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
+    """
+    Internal method that gets the weight enum of a specific model builder method.
+
+    Args:
+        fn (Callable): The builder method used to create the model.
+    Returns:
+        WeightsEnum: The requested weight enum.
+    """
+    sig = signature(fn)
+    if "weights" not in sig.parameters:
+        raise ValueError("The method is missing the 'weights' argument.")
+
+    ann = signature(fn).parameters["weights"].annotation
+    weights_enum = None
+    if isinstance(ann, type) and issubclass(ann, WeightsEnum):
+        weights_enum = ann
+    else:
+        # handle cases like Union[Optional, T]
+        # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
+        for t in ann.__args__:  # type: ignore[union-attr]
+            if isinstance(t, type) and issubclass(t, WeightsEnum):
+                weights_enum = t
+                break
+
+    if weights_enum is None:
+        raise ValueError(
+            "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
+        )
+
+    return weights_enum
+
+
+M = TypeVar("M", bound=nn.Module)
+
+BUILTIN_MODELS = {}
+
+
+def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]:
+    def wrapper(fn: Callable[..., M]) -> Callable[..., M]:
+        key = name if name is not None else fn.__name__
+        if key in BUILTIN_MODELS:
+            raise ValueError(f"An entry is already registered under the name '{key}'.")
+        BUILTIN_MODELS[key] = fn
+        return fn
+
+    return wrapper
+
+
+def list_models(
+    module: Optional[ModuleType] = None,
+    include: Union[Iterable[str], str, None] = None,
+    exclude: Union[Iterable[str], str, None] = None,
+) -> List[str]:
+    """
+    Returns a list with the names of registered models.
+
+    Args:
+        module (ModuleType, optional): The module from which we want to extract the available models.
+        include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models.
+            Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
+            wildcards. In case of many filters, the results is the union of individual filters.
+        exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models.
+            Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
+            wildcards. In case of many filters, the results is removal of all the models that match any individual filter.
+
+    Returns:
+        models (list): A list with the names of available models.
+    """
+    all_models = {
+        k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
+    }
+    if include:
+        models: Set[str] = set()
+        if isinstance(include, str):
+            include = [include]
+        for include_filter in include:
+            models = models | set(fnmatch.filter(all_models, include_filter))
+    else:
+        models = all_models
+
+    if exclude:
+        if isinstance(exclude, str):
+            exclude = [exclude]
+        for exclude_filter in exclude:
+            models = models - set(fnmatch.filter(all_models, exclude_filter))
+    return sorted(models)
+
+
+def get_model_builder(name: str) -> Callable[..., nn.Module]:
+    """
+    Gets the model name and returns the model builder method.
+
+    Args:
+        name (str): The name under which the model is registered.
+
+    Returns:
+        fn (Callable): The model builder method.
+    """
+    name = name.lower()
+    try:
+        fn = BUILTIN_MODELS[name]
+    except KeyError:
+        raise ValueError(f"Unknown model {name}")
+    return fn
+
+
+def get_model(name: str, **config: Any) -> nn.Module:
+    """
+    Gets the model name and configuration and returns an instantiated model.
+
+    Args:
+        name (str): The name under which the model is registered.
+        **config (Any): parameters passed to the model builder method.
+
+    Returns:
+        model (nn.Module): The initialized model.
+    """
+    fn = get_model_builder(name)
+    return fn(**config)

+ 1554 - 0
libs/vision_libs/models/_meta.py

@@ -0,0 +1,1554 @@
+"""
+This file is part of the private API. Please do not refer to any variables defined here directly as they will be
+removed on future versions without warning.
+"""
+
+# This will eventually be replaced with a call at torchvision.datasets.info("imagenet").categories
+_IMAGENET_CATEGORIES = [
+    "tench",
+    "goldfish",
+    "great white shark",
+    "tiger shark",
+    "hammerhead",
+    "electric ray",
+    "stingray",
+    "cock",
+    "hen",
+    "ostrich",
+    "brambling",
+    "goldfinch",
+    "house finch",
+    "junco",
+    "indigo bunting",
+    "robin",
+    "bulbul",
+    "jay",
+    "magpie",
+    "chickadee",
+    "water ouzel",
+    "kite",
+    "bald eagle",
+    "vulture",
+    "great grey owl",
+    "European fire salamander",
+    "common newt",
+    "eft",
+    "spotted salamander",
+    "axolotl",
+    "bullfrog",
+    "tree frog",
+    "tailed frog",
+    "loggerhead",
+    "leatherback turtle",
+    "mud turtle",
+    "terrapin",
+    "box turtle",
+    "banded gecko",
+    "common iguana",
+    "American chameleon",
+    "whiptail",
+    "agama",
+    "frilled lizard",
+    "alligator lizard",
+    "Gila monster",
+    "green lizard",
+    "African chameleon",
+    "Komodo dragon",
+    "African crocodile",
+    "American alligator",
+    "triceratops",
+    "thunder snake",
+    "ringneck snake",
+    "hognose snake",
+    "green snake",
+    "king snake",
+    "garter snake",
+    "water snake",
+    "vine snake",
+    "night snake",
+    "boa constrictor",
+    "rock python",
+    "Indian cobra",
+    "green mamba",
+    "sea snake",
+    "horned viper",
+    "diamondback",
+    "sidewinder",
+    "trilobite",
+    "harvestman",
+    "scorpion",
+    "black and gold garden spider",
+    "barn spider",
+    "garden spider",
+    "black widow",
+    "tarantula",
+    "wolf spider",
+    "tick",
+    "centipede",
+    "black grouse",
+    "ptarmigan",
+    "ruffed grouse",
+    "prairie chicken",
+    "peacock",
+    "quail",
+    "partridge",
+    "African grey",
+    "macaw",
+    "sulphur-crested cockatoo",
+    "lorikeet",
+    "coucal",
+    "bee eater",
+    "hornbill",
+    "hummingbird",
+    "jacamar",
+    "toucan",
+    "drake",
+    "red-breasted merganser",
+    "goose",
+    "black swan",
+    "tusker",
+    "echidna",
+    "platypus",
+    "wallaby",
+    "koala",
+    "wombat",
+    "jellyfish",
+    "sea anemone",
+    "brain coral",
+    "flatworm",
+    "nematode",
+    "conch",
+    "snail",
+    "slug",
+    "sea slug",
+    "chiton",
+    "chambered nautilus",
+    "Dungeness crab",
+    "rock crab",
+    "fiddler crab",
+    "king crab",
+    "American lobster",
+    "spiny lobster",
+    "crayfish",
+    "hermit crab",
+    "isopod",
+    "white stork",
+    "black stork",
+    "spoonbill",
+    "flamingo",
+    "little blue heron",
+    "American egret",
+    "bittern",
+    "crane bird",
+    "limpkin",
+    "European gallinule",
+    "American coot",
+    "bustard",
+    "ruddy turnstone",
+    "red-backed sandpiper",
+    "redshank",
+    "dowitcher",
+    "oystercatcher",
+    "pelican",
+    "king penguin",
+    "albatross",
+    "grey whale",
+    "killer whale",
+    "dugong",
+    "sea lion",
+    "Chihuahua",
+    "Japanese spaniel",
+    "Maltese dog",
+    "Pekinese",
+    "Shih-Tzu",
+    "Blenheim spaniel",
+    "papillon",
+    "toy terrier",
+    "Rhodesian ridgeback",
+    "Afghan hound",
+    "basset",
+    "beagle",
+    "bloodhound",
+    "bluetick",
+    "black-and-tan coonhound",
+    "Walker hound",
+    "English foxhound",
+    "redbone",
+    "borzoi",
+    "Irish wolfhound",
+    "Italian greyhound",
+    "whippet",
+    "Ibizan hound",
+    "Norwegian elkhound",
+    "otterhound",
+    "Saluki",
+    "Scottish deerhound",
+    "Weimaraner",
+    "Staffordshire bullterrier",
+    "American Staffordshire terrier",
+    "Bedlington terrier",
+    "Border terrier",
+    "Kerry blue terrier",
+    "Irish terrier",
+    "Norfolk terrier",
+    "Norwich terrier",
+    "Yorkshire terrier",
+    "wire-haired fox terrier",
+    "Lakeland terrier",
+    "Sealyham terrier",
+    "Airedale",
+    "cairn",
+    "Australian terrier",
+    "Dandie Dinmont",
+    "Boston bull",
+    "miniature schnauzer",
+    "giant schnauzer",
+    "standard schnauzer",
+    "Scotch terrier",
+    "Tibetan terrier",
+    "silky terrier",
+    "soft-coated wheaten terrier",
+    "West Highland white terrier",
+    "Lhasa",
+    "flat-coated retriever",
+    "curly-coated retriever",
+    "golden retriever",
+    "Labrador retriever",
+    "Chesapeake Bay retriever",
+    "German short-haired pointer",
+    "vizsla",
+    "English setter",
+    "Irish setter",
+    "Gordon setter",
+    "Brittany spaniel",
+    "clumber",
+    "English springer",
+    "Welsh springer spaniel",
+    "cocker spaniel",
+    "Sussex spaniel",
+    "Irish water spaniel",
+    "kuvasz",
+    "schipperke",
+    "groenendael",
+    "malinois",
+    "briard",
+    "kelpie",
+    "komondor",
+    "Old English sheepdog",
+    "Shetland sheepdog",
+    "collie",
+    "Border collie",
+    "Bouvier des Flandres",
+    "Rottweiler",
+    "German shepherd",
+    "Doberman",
+    "miniature pinscher",
+    "Greater Swiss Mountain dog",
+    "Bernese mountain dog",
+    "Appenzeller",
+    "EntleBucher",
+    "boxer",
+    "bull mastiff",
+    "Tibetan mastiff",
+    "French bulldog",
+    "Great Dane",
+    "Saint Bernard",
+    "Eskimo dog",
+    "malamute",
+    "Siberian husky",
+    "dalmatian",
+    "affenpinscher",
+    "basenji",
+    "pug",
+    "Leonberg",
+    "Newfoundland",
+    "Great Pyrenees",
+    "Samoyed",
+    "Pomeranian",
+    "chow",
+    "keeshond",
+    "Brabancon griffon",
+    "Pembroke",
+    "Cardigan",
+    "toy poodle",
+    "miniature poodle",
+    "standard poodle",
+    "Mexican hairless",
+    "timber wolf",
+    "white wolf",
+    "red wolf",
+    "coyote",
+    "dingo",
+    "dhole",
+    "African hunting dog",
+    "hyena",
+    "red fox",
+    "kit fox",
+    "Arctic fox",
+    "grey fox",
+    "tabby",
+    "tiger cat",
+    "Persian cat",
+    "Siamese cat",
+    "Egyptian cat",
+    "cougar",
+    "lynx",
+    "leopard",
+    "snow leopard",
+    "jaguar",
+    "lion",
+    "tiger",
+    "cheetah",
+    "brown bear",
+    "American black bear",
+    "ice bear",
+    "sloth bear",
+    "mongoose",
+    "meerkat",
+    "tiger beetle",
+    "ladybug",
+    "ground beetle",
+    "long-horned beetle",
+    "leaf beetle",
+    "dung beetle",
+    "rhinoceros beetle",
+    "weevil",
+    "fly",
+    "bee",
+    "ant",
+    "grasshopper",
+    "cricket",
+    "walking stick",
+    "cockroach",
+    "mantis",
+    "cicada",
+    "leafhopper",
+    "lacewing",
+    "dragonfly",
+    "damselfly",
+    "admiral",
+    "ringlet",
+    "monarch",
+    "cabbage butterfly",
+    "sulphur butterfly",
+    "lycaenid",
+    "starfish",
+    "sea urchin",
+    "sea cucumber",
+    "wood rabbit",
+    "hare",
+    "Angora",
+    "hamster",
+    "porcupine",
+    "fox squirrel",
+    "marmot",
+    "beaver",
+    "guinea pig",
+    "sorrel",
+    "zebra",
+    "hog",
+    "wild boar",
+    "warthog",
+    "hippopotamus",
+    "ox",
+    "water buffalo",
+    "bison",
+    "ram",
+    "bighorn",
+    "ibex",
+    "hartebeest",
+    "impala",
+    "gazelle",
+    "Arabian camel",
+    "llama",
+    "weasel",
+    "mink",
+    "polecat",
+    "black-footed ferret",
+    "otter",
+    "skunk",
+    "badger",
+    "armadillo",
+    "three-toed sloth",
+    "orangutan",
+    "gorilla",
+    "chimpanzee",
+    "gibbon",
+    "siamang",
+    "guenon",
+    "patas",
+    "baboon",
+    "macaque",
+    "langur",
+    "colobus",
+    "proboscis monkey",
+    "marmoset",
+    "capuchin",
+    "howler monkey",
+    "titi",
+    "spider monkey",
+    "squirrel monkey",
+    "Madagascar cat",
+    "indri",
+    "Indian elephant",
+    "African elephant",
+    "lesser panda",
+    "giant panda",
+    "barracouta",
+    "eel",
+    "coho",
+    "rock beauty",
+    "anemone fish",
+    "sturgeon",
+    "gar",
+    "lionfish",
+    "puffer",
+    "abacus",
+    "abaya",
+    "academic gown",
+    "accordion",
+    "acoustic guitar",
+    "aircraft carrier",
+    "airliner",
+    "airship",
+    "altar",
+    "ambulance",
+    "amphibian",
+    "analog clock",
+    "apiary",
+    "apron",
+    "ashcan",
+    "assault rifle",
+    "backpack",
+    "bakery",
+    "balance beam",
+    "balloon",
+    "ballpoint",
+    "Band Aid",
+    "banjo",
+    "bannister",
+    "barbell",
+    "barber chair",
+    "barbershop",
+    "barn",
+    "barometer",
+    "barrel",
+    "barrow",
+    "baseball",
+    "basketball",
+    "bassinet",
+    "bassoon",
+    "bathing cap",
+    "bath towel",
+    "bathtub",
+    "beach wagon",
+    "beacon",
+    "beaker",
+    "bearskin",
+    "beer bottle",
+    "beer glass",
+    "bell cote",
+    "bib",
+    "bicycle-built-for-two",
+    "bikini",
+    "binder",
+    "binoculars",
+    "birdhouse",
+    "boathouse",
+    "bobsled",
+    "bolo tie",
+    "bonnet",
+    "bookcase",
+    "bookshop",
+    "bottlecap",
+    "bow",
+    "bow tie",
+    "brass",
+    "brassiere",
+    "breakwater",
+    "breastplate",
+    "broom",
+    "bucket",
+    "buckle",
+    "bulletproof vest",
+    "bullet train",
+    "butcher shop",
+    "cab",
+    "caldron",
+    "candle",
+    "cannon",
+    "canoe",
+    "can opener",
+    "cardigan",
+    "car mirror",
+    "carousel",
+    "carpenter's kit",
+    "carton",
+    "car wheel",
+    "cash machine",
+    "cassette",
+    "cassette player",
+    "castle",
+    "catamaran",
+    "CD player",
+    "cello",
+    "cellular telephone",
+    "chain",
+    "chainlink fence",
+    "chain mail",
+    "chain saw",
+    "chest",
+    "chiffonier",
+    "chime",
+    "china cabinet",
+    "Christmas stocking",
+    "church",
+    "cinema",
+    "cleaver",
+    "cliff dwelling",
+    "cloak",
+    "clog",
+    "cocktail shaker",
+    "coffee mug",
+    "coffeepot",
+    "coil",
+    "combination lock",
+    "computer keyboard",
+    "confectionery",
+    "container ship",
+    "convertible",
+    "corkscrew",
+    "cornet",
+    "cowboy boot",
+    "cowboy hat",
+    "cradle",
+    "crane",
+    "crash helmet",
+    "crate",
+    "crib",
+    "Crock Pot",
+    "croquet ball",
+    "crutch",
+    "cuirass",
+    "dam",
+    "desk",
+    "desktop computer",
+    "dial telephone",
+    "diaper",
+    "digital clock",
+    "digital watch",
+    "dining table",
+    "dishrag",
+    "dishwasher",
+    "disk brake",
+    "dock",
+    "dogsled",
+    "dome",
+    "doormat",
+    "drilling platform",
+    "drum",
+    "drumstick",
+    "dumbbell",
+    "Dutch oven",
+    "electric fan",
+    "electric guitar",
+    "electric locomotive",
+    "entertainment center",
+    "envelope",
+    "espresso maker",
+    "face powder",
+    "feather boa",
+    "file",
+    "fireboat",
+    "fire engine",
+    "fire screen",
+    "flagpole",
+    "flute",
+    "folding chair",
+    "football helmet",
+    "forklift",
+    "fountain",
+    "fountain pen",
+    "four-poster",
+    "freight car",
+    "French horn",
+    "frying pan",
+    "fur coat",
+    "garbage truck",
+    "gasmask",
+    "gas pump",
+    "goblet",
+    "go-kart",
+    "golf ball",
+    "golfcart",
+    "gondola",
+    "gong",
+    "gown",
+    "grand piano",
+    "greenhouse",
+    "grille",
+    "grocery store",
+    "guillotine",
+    "hair slide",
+    "hair spray",
+    "half track",
+    "hammer",
+    "hamper",
+    "hand blower",
+    "hand-held computer",
+    "handkerchief",
+    "hard disc",
+    "harmonica",
+    "harp",
+    "harvester",
+    "hatchet",
+    "holster",
+    "home theater",
+    "honeycomb",
+    "hook",
+    "hoopskirt",
+    "horizontal bar",
+    "horse cart",
+    "hourglass",
+    "iPod",
+    "iron",
+    "jack-o'-lantern",
+    "jean",
+    "jeep",
+    "jersey",
+    "jigsaw puzzle",
+    "jinrikisha",
+    "joystick",
+    "kimono",
+    "knee pad",
+    "knot",
+    "lab coat",
+    "ladle",
+    "lampshade",
+    "laptop",
+    "lawn mower",
+    "lens cap",
+    "letter opener",
+    "library",
+    "lifeboat",
+    "lighter",
+    "limousine",
+    "liner",
+    "lipstick",
+    "Loafer",
+    "lotion",
+    "loudspeaker",
+    "loupe",
+    "lumbermill",
+    "magnetic compass",
+    "mailbag",
+    "mailbox",
+    "maillot",
+    "maillot tank suit",
+    "manhole cover",
+    "maraca",
+    "marimba",
+    "mask",
+    "matchstick",
+    "maypole",
+    "maze",
+    "measuring cup",
+    "medicine chest",
+    "megalith",
+    "microphone",
+    "microwave",
+    "military uniform",
+    "milk can",
+    "minibus",
+    "miniskirt",
+    "minivan",
+    "missile",
+    "mitten",
+    "mixing bowl",
+    "mobile home",
+    "Model T",
+    "modem",
+    "monastery",
+    "monitor",
+    "moped",
+    "mortar",
+    "mortarboard",
+    "mosque",
+    "mosquito net",
+    "motor scooter",
+    "mountain bike",
+    "mountain tent",
+    "mouse",
+    "mousetrap",
+    "moving van",
+    "muzzle",
+    "nail",
+    "neck brace",
+    "necklace",
+    "nipple",
+    "notebook",
+    "obelisk",
+    "oboe",
+    "ocarina",
+    "odometer",
+    "oil filter",
+    "organ",
+    "oscilloscope",
+    "overskirt",
+    "oxcart",
+    "oxygen mask",
+    "packet",
+    "paddle",
+    "paddlewheel",
+    "padlock",
+    "paintbrush",
+    "pajama",
+    "palace",
+    "panpipe",
+    "paper towel",
+    "parachute",
+    "parallel bars",
+    "park bench",
+    "parking meter",
+    "passenger car",
+    "patio",
+    "pay-phone",
+    "pedestal",
+    "pencil box",
+    "pencil sharpener",
+    "perfume",
+    "Petri dish",
+    "photocopier",
+    "pick",
+    "pickelhaube",
+    "picket fence",
+    "pickup",
+    "pier",
+    "piggy bank",
+    "pill bottle",
+    "pillow",
+    "ping-pong ball",
+    "pinwheel",
+    "pirate",
+    "pitcher",
+    "plane",
+    "planetarium",
+    "plastic bag",
+    "plate rack",
+    "plow",
+    "plunger",
+    "Polaroid camera",
+    "pole",
+    "police van",
+    "poncho",
+    "pool table",
+    "pop bottle",
+    "pot",
+    "potter's wheel",
+    "power drill",
+    "prayer rug",
+    "printer",
+    "prison",
+    "projectile",
+    "projector",
+    "puck",
+    "punching bag",
+    "purse",
+    "quill",
+    "quilt",
+    "racer",
+    "racket",
+    "radiator",
+    "radio",
+    "radio telescope",
+    "rain barrel",
+    "recreational vehicle",
+    "reel",
+    "reflex camera",
+    "refrigerator",
+    "remote control",
+    "restaurant",
+    "revolver",
+    "rifle",
+    "rocking chair",
+    "rotisserie",
+    "rubber eraser",
+    "rugby ball",
+    "rule",
+    "running shoe",
+    "safe",
+    "safety pin",
+    "saltshaker",
+    "sandal",
+    "sarong",
+    "sax",
+    "scabbard",
+    "scale",
+    "school bus",
+    "schooner",
+    "scoreboard",
+    "screen",
+    "screw",
+    "screwdriver",
+    "seat belt",
+    "sewing machine",
+    "shield",
+    "shoe shop",
+    "shoji",
+    "shopping basket",
+    "shopping cart",
+    "shovel",
+    "shower cap",
+    "shower curtain",
+    "ski",
+    "ski mask",
+    "sleeping bag",
+    "slide rule",
+    "sliding door",
+    "slot",
+    "snorkel",
+    "snowmobile",
+    "snowplow",
+    "soap dispenser",
+    "soccer ball",
+    "sock",
+    "solar dish",
+    "sombrero",
+    "soup bowl",
+    "space bar",
+    "space heater",
+    "space shuttle",
+    "spatula",
+    "speedboat",
+    "spider web",
+    "spindle",
+    "sports car",
+    "spotlight",
+    "stage",
+    "steam locomotive",
+    "steel arch bridge",
+    "steel drum",
+    "stethoscope",
+    "stole",
+    "stone wall",
+    "stopwatch",
+    "stove",
+    "strainer",
+    "streetcar",
+    "stretcher",
+    "studio couch",
+    "stupa",
+    "submarine",
+    "suit",
+    "sundial",
+    "sunglass",
+    "sunglasses",
+    "sunscreen",
+    "suspension bridge",
+    "swab",
+    "sweatshirt",
+    "swimming trunks",
+    "swing",
+    "switch",
+    "syringe",
+    "table lamp",
+    "tank",
+    "tape player",
+    "teapot",
+    "teddy",
+    "television",
+    "tennis ball",
+    "thatch",
+    "theater curtain",
+    "thimble",
+    "thresher",
+    "throne",
+    "tile roof",
+    "toaster",
+    "tobacco shop",
+    "toilet seat",
+    "torch",
+    "totem pole",
+    "tow truck",
+    "toyshop",
+    "tractor",
+    "trailer truck",
+    "tray",
+    "trench coat",
+    "tricycle",
+    "trimaran",
+    "tripod",
+    "triumphal arch",
+    "trolleybus",
+    "trombone",
+    "tub",
+    "turnstile",
+    "typewriter keyboard",
+    "umbrella",
+    "unicycle",
+    "upright",
+    "vacuum",
+    "vase",
+    "vault",
+    "velvet",
+    "vending machine",
+    "vestment",
+    "viaduct",
+    "violin",
+    "volleyball",
+    "waffle iron",
+    "wall clock",
+    "wallet",
+    "wardrobe",
+    "warplane",
+    "washbasin",
+    "washer",
+    "water bottle",
+    "water jug",
+    "water tower",
+    "whiskey jug",
+    "whistle",
+    "wig",
+    "window screen",
+    "window shade",
+    "Windsor tie",
+    "wine bottle",
+    "wing",
+    "wok",
+    "wooden spoon",
+    "wool",
+    "worm fence",
+    "wreck",
+    "yawl",
+    "yurt",
+    "web site",
+    "comic book",
+    "crossword puzzle",
+    "street sign",
+    "traffic light",
+    "book jacket",
+    "menu",
+    "plate",
+    "guacamole",
+    "consomme",
+    "hot pot",
+    "trifle",
+    "ice cream",
+    "ice lolly",
+    "French loaf",
+    "bagel",
+    "pretzel",
+    "cheeseburger",
+    "hotdog",
+    "mashed potato",
+    "head cabbage",
+    "broccoli",
+    "cauliflower",
+    "zucchini",
+    "spaghetti squash",
+    "acorn squash",
+    "butternut squash",
+    "cucumber",
+    "artichoke",
+    "bell pepper",
+    "cardoon",
+    "mushroom",
+    "Granny Smith",
+    "strawberry",
+    "orange",
+    "lemon",
+    "fig",
+    "pineapple",
+    "banana",
+    "jackfruit",
+    "custard apple",
+    "pomegranate",
+    "hay",
+    "carbonara",
+    "chocolate sauce",
+    "dough",
+    "meat loaf",
+    "pizza",
+    "potpie",
+    "burrito",
+    "red wine",
+    "espresso",
+    "cup",
+    "eggnog",
+    "alp",
+    "bubble",
+    "cliff",
+    "coral reef",
+    "geyser",
+    "lakeside",
+    "promontory",
+    "sandbar",
+    "seashore",
+    "valley",
+    "volcano",
+    "ballplayer",
+    "groom",
+    "scuba diver",
+    "rapeseed",
+    "daisy",
+    "yellow lady's slipper",
+    "corn",
+    "acorn",
+    "hip",
+    "buckeye",
+    "coral fungus",
+    "agaric",
+    "gyromitra",
+    "stinkhorn",
+    "earthstar",
+    "hen-of-the-woods",
+    "bolete",
+    "ear",
+    "toilet tissue",
+]
+
+# To be replaced with torchvision.datasets.info("coco").categories
+_COCO_CATEGORIES = [
+    "__background__",
+    "person",
+    "bicycle",
+    "car",
+    "motorcycle",
+    "airplane",
+    "bus",
+    "train",
+    "truck",
+    "boat",
+    "traffic light",
+    "fire hydrant",
+    "N/A",
+    "stop sign",
+    "parking meter",
+    "bench",
+    "bird",
+    "cat",
+    "dog",
+    "horse",
+    "sheep",
+    "cow",
+    "elephant",
+    "bear",
+    "zebra",
+    "giraffe",
+    "N/A",
+    "backpack",
+    "umbrella",
+    "N/A",
+    "N/A",
+    "handbag",
+    "tie",
+    "suitcase",
+    "frisbee",
+    "skis",
+    "snowboard",
+    "sports ball",
+    "kite",
+    "baseball bat",
+    "baseball glove",
+    "skateboard",
+    "surfboard",
+    "tennis racket",
+    "bottle",
+    "N/A",
+    "wine glass",
+    "cup",
+    "fork",
+    "knife",
+    "spoon",
+    "bowl",
+    "banana",
+    "apple",
+    "sandwich",
+    "orange",
+    "broccoli",
+    "carrot",
+    "hot dog",
+    "pizza",
+    "donut",
+    "cake",
+    "chair",
+    "couch",
+    "potted plant",
+    "bed",
+    "N/A",
+    "dining table",
+    "N/A",
+    "N/A",
+    "toilet",
+    "N/A",
+    "tv",
+    "laptop",
+    "mouse",
+    "remote",
+    "keyboard",
+    "cell phone",
+    "microwave",
+    "oven",
+    "toaster",
+    "sink",
+    "refrigerator",
+    "N/A",
+    "book",
+    "clock",
+    "vase",
+    "scissors",
+    "teddy bear",
+    "hair drier",
+    "toothbrush",
+]
+
+# To be replaced with torchvision.datasets.info("coco_kp")
+_COCO_PERSON_CATEGORIES = ["no person", "person"]
+_COCO_PERSON_KEYPOINT_NAMES = [
+    "nose",
+    "left_eye",
+    "right_eye",
+    "left_ear",
+    "right_ear",
+    "left_shoulder",
+    "right_shoulder",
+    "left_elbow",
+    "right_elbow",
+    "left_wrist",
+    "right_wrist",
+    "left_hip",
+    "right_hip",
+    "left_knee",
+    "right_knee",
+    "left_ankle",
+    "right_ankle",
+]
+
+# To be replaced with torchvision.datasets.info("voc").categories
+_VOC_CATEGORIES = [
+    "__background__",
+    "aeroplane",
+    "bicycle",
+    "bird",
+    "boat",
+    "bottle",
+    "bus",
+    "car",
+    "cat",
+    "chair",
+    "cow",
+    "diningtable",
+    "dog",
+    "horse",
+    "motorbike",
+    "person",
+    "pottedplant",
+    "sheep",
+    "sofa",
+    "train",
+    "tvmonitor",
+]
+
+# To be replaced with torchvision.datasets.info("kinetics400").categories
+_KINETICS400_CATEGORIES = [
+    "abseiling",
+    "air drumming",
+    "answering questions",
+    "applauding",
+    "applying cream",
+    "archery",
+    "arm wrestling",
+    "arranging flowers",
+    "assembling computer",
+    "auctioning",
+    "baby waking up",
+    "baking cookies",
+    "balloon blowing",
+    "bandaging",
+    "barbequing",
+    "bartending",
+    "beatboxing",
+    "bee keeping",
+    "belly dancing",
+    "bench pressing",
+    "bending back",
+    "bending metal",
+    "biking through snow",
+    "blasting sand",
+    "blowing glass",
+    "blowing leaves",
+    "blowing nose",
+    "blowing out candles",
+    "bobsledding",
+    "bookbinding",
+    "bouncing on trampoline",
+    "bowling",
+    "braiding hair",
+    "breading or breadcrumbing",
+    "breakdancing",
+    "brush painting",
+    "brushing hair",
+    "brushing teeth",
+    "building cabinet",
+    "building shed",
+    "bungee jumping",
+    "busking",
+    "canoeing or kayaking",
+    "capoeira",
+    "carrying baby",
+    "cartwheeling",
+    "carving pumpkin",
+    "catching fish",
+    "catching or throwing baseball",
+    "catching or throwing frisbee",
+    "catching or throwing softball",
+    "celebrating",
+    "changing oil",
+    "changing wheel",
+    "checking tires",
+    "cheerleading",
+    "chopping wood",
+    "clapping",
+    "clay pottery making",
+    "clean and jerk",
+    "cleaning floor",
+    "cleaning gutters",
+    "cleaning pool",
+    "cleaning shoes",
+    "cleaning toilet",
+    "cleaning windows",
+    "climbing a rope",
+    "climbing ladder",
+    "climbing tree",
+    "contact juggling",
+    "cooking chicken",
+    "cooking egg",
+    "cooking on campfire",
+    "cooking sausages",
+    "counting money",
+    "country line dancing",
+    "cracking neck",
+    "crawling baby",
+    "crossing river",
+    "crying",
+    "curling hair",
+    "cutting nails",
+    "cutting pineapple",
+    "cutting watermelon",
+    "dancing ballet",
+    "dancing charleston",
+    "dancing gangnam style",
+    "dancing macarena",
+    "deadlifting",
+    "decorating the christmas tree",
+    "digging",
+    "dining",
+    "disc golfing",
+    "diving cliff",
+    "dodgeball",
+    "doing aerobics",
+    "doing laundry",
+    "doing nails",
+    "drawing",
+    "dribbling basketball",
+    "drinking",
+    "drinking beer",
+    "drinking shots",
+    "driving car",
+    "driving tractor",
+    "drop kicking",
+    "drumming fingers",
+    "dunking basketball",
+    "dying hair",
+    "eating burger",
+    "eating cake",
+    "eating carrots",
+    "eating chips",
+    "eating doughnuts",
+    "eating hotdog",
+    "eating ice cream",
+    "eating spaghetti",
+    "eating watermelon",
+    "egg hunting",
+    "exercising arm",
+    "exercising with an exercise ball",
+    "extinguishing fire",
+    "faceplanting",
+    "feeding birds",
+    "feeding fish",
+    "feeding goats",
+    "filling eyebrows",
+    "finger snapping",
+    "fixing hair",
+    "flipping pancake",
+    "flying kite",
+    "folding clothes",
+    "folding napkins",
+    "folding paper",
+    "front raises",
+    "frying vegetables",
+    "garbage collecting",
+    "gargling",
+    "getting a haircut",
+    "getting a tattoo",
+    "giving or receiving award",
+    "golf chipping",
+    "golf driving",
+    "golf putting",
+    "grinding meat",
+    "grooming dog",
+    "grooming horse",
+    "gymnastics tumbling",
+    "hammer throw",
+    "headbanging",
+    "headbutting",
+    "high jump",
+    "high kick",
+    "hitting baseball",
+    "hockey stop",
+    "holding snake",
+    "hopscotch",
+    "hoverboarding",
+    "hugging",
+    "hula hooping",
+    "hurdling",
+    "hurling (sport)",
+    "ice climbing",
+    "ice fishing",
+    "ice skating",
+    "ironing",
+    "javelin throw",
+    "jetskiing",
+    "jogging",
+    "juggling balls",
+    "juggling fire",
+    "juggling soccer ball",
+    "jumping into pool",
+    "jumpstyle dancing",
+    "kicking field goal",
+    "kicking soccer ball",
+    "kissing",
+    "kitesurfing",
+    "knitting",
+    "krumping",
+    "laughing",
+    "laying bricks",
+    "long jump",
+    "lunge",
+    "making a cake",
+    "making a sandwich",
+    "making bed",
+    "making jewelry",
+    "making pizza",
+    "making snowman",
+    "making sushi",
+    "making tea",
+    "marching",
+    "massaging back",
+    "massaging feet",
+    "massaging legs",
+    "massaging person's head",
+    "milking cow",
+    "mopping floor",
+    "motorcycling",
+    "moving furniture",
+    "mowing lawn",
+    "news anchoring",
+    "opening bottle",
+    "opening present",
+    "paragliding",
+    "parasailing",
+    "parkour",
+    "passing American football (in game)",
+    "passing American football (not in game)",
+    "peeling apples",
+    "peeling potatoes",
+    "petting animal (not cat)",
+    "petting cat",
+    "picking fruit",
+    "planting trees",
+    "plastering",
+    "playing accordion",
+    "playing badminton",
+    "playing bagpipes",
+    "playing basketball",
+    "playing bass guitar",
+    "playing cards",
+    "playing cello",
+    "playing chess",
+    "playing clarinet",
+    "playing controller",
+    "playing cricket",
+    "playing cymbals",
+    "playing didgeridoo",
+    "playing drums",
+    "playing flute",
+    "playing guitar",
+    "playing harmonica",
+    "playing harp",
+    "playing ice hockey",
+    "playing keyboard",
+    "playing kickball",
+    "playing monopoly",
+    "playing organ",
+    "playing paintball",
+    "playing piano",
+    "playing poker",
+    "playing recorder",
+    "playing saxophone",
+    "playing squash or racquetball",
+    "playing tennis",
+    "playing trombone",
+    "playing trumpet",
+    "playing ukulele",
+    "playing violin",
+    "playing volleyball",
+    "playing xylophone",
+    "pole vault",
+    "presenting weather forecast",
+    "pull ups",
+    "pumping fist",
+    "pumping gas",
+    "punching bag",
+    "punching person (boxing)",
+    "push up",
+    "pushing car",
+    "pushing cart",
+    "pushing wheelchair",
+    "reading book",
+    "reading newspaper",
+    "recording music",
+    "riding a bike",
+    "riding camel",
+    "riding elephant",
+    "riding mechanical bull",
+    "riding mountain bike",
+    "riding mule",
+    "riding or walking with horse",
+    "riding scooter",
+    "riding unicycle",
+    "ripping paper",
+    "robot dancing",
+    "rock climbing",
+    "rock scissors paper",
+    "roller skating",
+    "running on treadmill",
+    "sailing",
+    "salsa dancing",
+    "sanding floor",
+    "scrambling eggs",
+    "scuba diving",
+    "setting table",
+    "shaking hands",
+    "shaking head",
+    "sharpening knives",
+    "sharpening pencil",
+    "shaving head",
+    "shaving legs",
+    "shearing sheep",
+    "shining shoes",
+    "shooting basketball",
+    "shooting goal (soccer)",
+    "shot put",
+    "shoveling snow",
+    "shredding paper",
+    "shuffling cards",
+    "side kick",
+    "sign language interpreting",
+    "singing",
+    "situp",
+    "skateboarding",
+    "ski jumping",
+    "skiing (not slalom or crosscountry)",
+    "skiing crosscountry",
+    "skiing slalom",
+    "skipping rope",
+    "skydiving",
+    "slacklining",
+    "slapping",
+    "sled dog racing",
+    "smoking",
+    "smoking hookah",
+    "snatch weight lifting",
+    "sneezing",
+    "sniffing",
+    "snorkeling",
+    "snowboarding",
+    "snowkiting",
+    "snowmobiling",
+    "somersaulting",
+    "spinning poi",
+    "spray painting",
+    "spraying",
+    "springboard diving",
+    "squat",
+    "sticking tongue out",
+    "stomping grapes",
+    "stretching arm",
+    "stretching leg",
+    "strumming guitar",
+    "surfing crowd",
+    "surfing water",
+    "sweeping floor",
+    "swimming backstroke",
+    "swimming breast stroke",
+    "swimming butterfly stroke",
+    "swing dancing",
+    "swinging legs",
+    "swinging on something",
+    "sword fighting",
+    "tai chi",
+    "taking a shower",
+    "tango dancing",
+    "tap dancing",
+    "tapping guitar",
+    "tapping pen",
+    "tasting beer",
+    "tasting food",
+    "testifying",
+    "texting",
+    "throwing axe",
+    "throwing ball",
+    "throwing discus",
+    "tickling",
+    "tobogganing",
+    "tossing coin",
+    "tossing salad",
+    "training dog",
+    "trapezing",
+    "trimming or shaving beard",
+    "trimming trees",
+    "triple jump",
+    "tying bow tie",
+    "tying knot (not on a tie)",
+    "tying tie",
+    "unboxing",
+    "unloading truck",
+    "using computer",
+    "using remote controller (not gaming)",
+    "using segway",
+    "vault",
+    "waiting in line",
+    "walking the dog",
+    "washing dishes",
+    "washing feet",
+    "washing hair",
+    "washing hands",
+    "water skiing",
+    "water sliding",
+    "watering plants",
+    "waxing back",
+    "waxing chest",
+    "waxing eyebrows",
+    "waxing legs",
+    "weaving basket",
+    "welding",
+    "whistling",
+    "windsurfing",
+    "wrapping present",
+    "wrestling",
+    "writing",
+    "yawning",
+    "yoga",
+    "zumba",
+]

+ 256 - 0
libs/vision_libs/models/_utils.py

@@ -0,0 +1,256 @@
+import functools
+import inspect
+import warnings
+from collections import OrderedDict
+from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union
+
+from torch import nn
+
+from .._utils import sequence_to_str
+from ._api import WeightsEnum
+
+
+class IntermediateLayerGetter(nn.ModuleDict):
+    """
+    Module wrapper that returns intermediate layers from a model
+
+    It has a strong assumption that the modules have been registered
+    into the model in the same order as they are used.
+    This means that one should **not** reuse the same nn.Module
+    twice in the forward if you want this to work.
+
+    Additionally, it is only able to query submodules that are directly
+    assigned to the model. So if `model` is passed, `model.feature1` can
+    be returned, but not `model.feature1.layer2`.
+
+    Args:
+        model (nn.Module): model on which we will extract the features
+        return_layers (Dict[name, new_name]): a dict containing the names
+            of the modules for which the activations will be returned as
+            the key of the dict, and the value of the dict is the name
+            of the returned activation (which the user can specify).
+
+    Examples::
+
+        >>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
+        >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
+        >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
+        >>>     {'layer1': 'feat1', 'layer3': 'feat2'})
+        >>> out = new_m(torch.rand(1, 3, 224, 224))
+        >>> print([(k, v.shape) for k, v in out.items()])
+        >>>     [('feat1', torch.Size([1, 64, 56, 56])),
+        >>>      ('feat2', torch.Size([1, 256, 14, 14]))]
+    """
+
+    _version = 2
+    __annotations__ = {
+        "return_layers": Dict[str, str],
+    }
+
+    def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
+        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
+            raise ValueError("return_layers are not present in model")
+        orig_return_layers = return_layers
+        return_layers = {str(k): str(v) for k, v in return_layers.items()}
+        layers = OrderedDict()
+        for name, module in model.named_children():
+            layers[name] = module
+            if name in return_layers:
+                del return_layers[name]
+            if not return_layers:
+                break
+
+        super().__init__(layers)
+        self.return_layers = orig_return_layers
+
+    def forward(self, x):
+        out = OrderedDict()
+        for name, module in self.items():
+            x = module(x)
+            if name in self.return_layers:
+                out_name = self.return_layers[name]
+                out[out_name] = x
+        return out
+
+
+def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
+    """
+    This function is taken from the original tf repo.
+    It ensures that all layers have a channel number that is divisible by 8
+    It can be seen here:
+    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+    """
+    if min_value is None:
+        min_value = divisor
+    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+    # Make sure that round down does not go down by more than 10%.
+    if new_v < 0.9 * v:
+        new_v += divisor
+    return new_v
+
+
+D = TypeVar("D")
+
+
+def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
+    """Decorates a function that uses keyword only parameters to also allow them being passed as positionals.
+
+    For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:
+
+    .. code::
+
+        def old_fn(foo, bar, baz=None):
+            ...
+
+        def new_fn(foo, *, bar, baz=None):
+            ...
+
+    Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
+    and at the same time warn the user of the deprecation, this decorator can be used:
+
+    .. code::
+
+        @kwonly_to_pos_or_kw
+        def new_fn(foo, *, bar, baz=None):
+            ...
+
+        new_fn("foo", "bar, "baz")
+    """
+    params = inspect.signature(fn).parameters
+
+    try:
+        keyword_only_start_idx = next(
+            idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
+        )
+    except StopIteration:
+        raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None
+
+    keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]
+
+    @functools.wraps(fn)
+    def wrapper(*args: Any, **kwargs: Any) -> D:
+        args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
+        if keyword_only_args:
+            keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
+            warnings.warn(
+                f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
+                f"parameter(s) is deprecated since 0.13 and may be removed in the future. Please use keyword parameter(s) "
+                f"instead."
+            )
+            kwargs.update(keyword_only_kwargs)
+
+        return fn(*args, **kwargs)
+
+    return wrapper
+
+
+W = TypeVar("W", bound=WeightsEnum)
+M = TypeVar("M", bound=nn.Module)
+V = TypeVar("V")
+
+
+def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]):
+    """Decorates a model builder with the new interface to make it compatible with the old.
+
+    In particular this handles two things:
+
+    1. Allows positional parameters again, but emits a deprecation warning in case they are used. See
+        :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details.
+    2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to
+        ``weights=Weights`` and emits a deprecation warning with instructions for the new interface.
+
+    Args:
+        **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter
+            name and default value for the legacy ``pretrained=True``. The default value can be a callable in which
+            case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in
+            the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters
+            should be accessed with :meth:`~dict.get`.
+    """
+
+    def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]:
+        @kwonly_to_pos_or_kw
+        @functools.wraps(builder)
+        def inner_wrapper(*args: Any, **kwargs: Any) -> M:
+            for weights_param, (pretrained_param, default) in weights.items():  # type: ignore[union-attr]
+                # If neither the weights nor the pretrained parameter as passed, or the weights argument already use
+                # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the
+                # weight argument, since it is a valid value.
+                sentinel = object()
+                weights_arg = kwargs.get(weights_param, sentinel)
+                if (
+                    (weights_param not in kwargs and pretrained_param not in kwargs)
+                    or isinstance(weights_arg, WeightsEnum)
+                    or (isinstance(weights_arg, str) and weights_arg != "legacy")
+                    or weights_arg is None
+                ):
+                    continue
+
+                # If the pretrained parameter was passed as positional argument, it is now mapped to
+                # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current
+                # signature to infer the names of positionally passed arguments and thus has no knowledge that there
+                # used to be a pretrained parameter.
+                pretrained_positional = weights_arg is not sentinel
+                if pretrained_positional:
+                    # We put the pretrained argument under its legacy name in the keyword argument dictionary to have
+                    # unified access to the value if the default value is a callable.
+                    kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
+                else:
+                    pretrained_arg = kwargs[pretrained_param]
+
+                if pretrained_arg:
+                    default_weights_arg = default(kwargs) if callable(default) else default
+                    if not isinstance(default_weights_arg, WeightsEnum):
+                        raise ValueError(f"No weights available for model {builder.__name__}")
+                else:
+                    default_weights_arg = None
+
+                if not pretrained_positional:
+                    warnings.warn(
+                        f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
+                        f"please use '{weights_param}' instead."
+                    )
+
+                msg = (
+                    f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated since 0.13 and "
+                    f"may be removed in the future. "
+                    f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`."
+                )
+                if pretrained_arg:
+                    msg = (
+                        f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
+                        f"to get the most up-to-date weights."
+                    )
+                warnings.warn(msg)
+
+                del kwargs[pretrained_param]
+                kwargs[weights_param] = default_weights_arg
+
+            return builder(*args, **kwargs)
+
+        return inner_wrapper
+
+    return outer_wrapper
+
+
+def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None:
+    if param in kwargs:
+        if kwargs[param] != new_value:
+            raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.")
+    else:
+        kwargs[param] = new_value
+
+
+def _ovewrite_value_param(param: str, actual: Optional[V], expected: V) -> V:
+    if actual is not None:
+        if actual != expected:
+            raise ValueError(f"The parameter '{param}' expected value {expected} but got {actual} instead.")
+    return expected
+
+
+class _ModelURLs(dict):
+    def __getitem__(self, item):
+        warnings.warn(
+            "Accessing the model URLs via the internal dictionary of the module is deprecated since 0.13 and may "
+            "be removed in the future. Please access them via the appropriate Weights Enum instead."
+        )
+        return super().__getitem__(item)

+ 119 - 0
libs/vision_libs/models/alexnet.py

@@ -0,0 +1,119 @@
+from functools import partial
+from typing import Any, Optional
+
+import torch
+import torch.nn as nn
+
+from ..transforms._presets import ImageClassification
+from ..utils import _log_api_usage_once
+from ._api import register_model, Weights, WeightsEnum
+from ._meta import _IMAGENET_CATEGORIES
+from ._utils import _ovewrite_named_param, handle_legacy_interface
+
+
+__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
+
+
+class AlexNet(nn.Module):
+    def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
+        super().__init__()
+        _log_api_usage_once(self)
+        self.features = nn.Sequential(
+            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+            nn.Conv2d(64, 192, kernel_size=5, padding=2),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+            nn.Conv2d(192, 384, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(384, 256, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(256, 256, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+        )
+        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
+        self.classifier = nn.Sequential(
+            nn.Dropout(p=dropout),
+            nn.Linear(256 * 6 * 6, 4096),
+            nn.ReLU(inplace=True),
+            nn.Dropout(p=dropout),
+            nn.Linear(4096, 4096),
+            nn.ReLU(inplace=True),
+            nn.Linear(4096, num_classes),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.features(x)
+        x = self.avgpool(x)
+        x = torch.flatten(x, 1)
+        x = self.classifier(x)
+        return x
+
+
+class AlexNet_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            "num_params": 61100840,
+            "min_size": (63, 63),
+            "categories": _IMAGENET_CATEGORIES,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 56.522,
+                    "acc@5": 79.066,
+                }
+            },
+            "_ops": 0.714,
+            "_file_size": 233.087,
+            "_docs": """
+                These weights reproduce closely the results of the paper using a simplified training recipe.
+            """,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
+def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
+    """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__.
+
+    .. note::
+        AlexNet was originally introduced in the `ImageNet Classification with
+        Deep Convolutional Neural Networks
+        <https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html>`__
+        paper. Our implementation is based instead on the "One weird trick"
+        paper above.
+
+    Args:
+        weights (:class:`~torchvision.models.AlexNet_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.AlexNet_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.squeezenet.AlexNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.AlexNet_Weights
+        :members:
+    """
+
+    weights = AlexNet_Weights.verify(weights)
+
+    if weights is not None:
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+    model = AlexNet(**kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model

+ 414 - 0
libs/vision_libs/models/convnext.py

@@ -0,0 +1,414 @@
+from functools import partial
+from typing import Any, Callable, List, Optional, Sequence
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+from ..ops.misc import Conv2dNormActivation, Permute
+from ..ops.stochastic_depth import StochasticDepth
+from ..transforms._presets import ImageClassification
+from ..utils import _log_api_usage_once
+from ._api import register_model, Weights, WeightsEnum
+from ._meta import _IMAGENET_CATEGORIES
+from ._utils import _ovewrite_named_param, handle_legacy_interface
+
+
+__all__ = [
+    "ConvNeXt",
+    "ConvNeXt_Tiny_Weights",
+    "ConvNeXt_Small_Weights",
+    "ConvNeXt_Base_Weights",
+    "ConvNeXt_Large_Weights",
+    "convnext_tiny",
+    "convnext_small",
+    "convnext_base",
+    "convnext_large",
+]
+
+
+class LayerNorm2d(nn.LayerNorm):
+    def forward(self, x: Tensor) -> Tensor:
+        x = x.permute(0, 2, 3, 1)
+        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+        x = x.permute(0, 3, 1, 2)
+        return x
+
+
+class CNBlock(nn.Module):
+    def __init__(
+        self,
+        dim,
+        layer_scale: float,
+        stochastic_depth_prob: float,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+        if norm_layer is None:
+            norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+        self.block = nn.Sequential(
+            nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
+            Permute([0, 2, 3, 1]),
+            norm_layer(dim),
+            nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
+            nn.GELU(),
+            nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
+            Permute([0, 3, 1, 2]),
+        )
+        self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
+        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
+
+    def forward(self, input: Tensor) -> Tensor:
+        result = self.layer_scale * self.block(input)
+        result = self.stochastic_depth(result)
+        result += input
+        return result
+
+
+class CNBlockConfig:
+    # Stores information listed at Section 3 of the ConvNeXt paper
+    def __init__(
+        self,
+        input_channels: int,
+        out_channels: Optional[int],
+        num_layers: int,
+    ) -> None:
+        self.input_channels = input_channels
+        self.out_channels = out_channels
+        self.num_layers = num_layers
+
+    def __repr__(self) -> str:
+        s = self.__class__.__name__ + "("
+        s += "input_channels={input_channels}"
+        s += ", out_channels={out_channels}"
+        s += ", num_layers={num_layers}"
+        s += ")"
+        return s.format(**self.__dict__)
+
+
+class ConvNeXt(nn.Module):
+    def __init__(
+        self,
+        block_setting: List[CNBlockConfig],
+        stochastic_depth_prob: float = 0.0,
+        layer_scale: float = 1e-6,
+        num_classes: int = 1000,
+        block: Optional[Callable[..., nn.Module]] = None,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__()
+        _log_api_usage_once(self)
+
+        if not block_setting:
+            raise ValueError("The block_setting should not be empty")
+        elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
+            raise TypeError("The block_setting should be List[CNBlockConfig]")
+
+        if block is None:
+            block = CNBlock
+
+        if norm_layer is None:
+            norm_layer = partial(LayerNorm2d, eps=1e-6)
+
+        layers: List[nn.Module] = []
+
+        # Stem
+        firstconv_output_channels = block_setting[0].input_channels
+        layers.append(
+            Conv2dNormActivation(
+                3,
+                firstconv_output_channels,
+                kernel_size=4,
+                stride=4,
+                padding=0,
+                norm_layer=norm_layer,
+                activation_layer=None,
+                bias=True,
+            )
+        )
+
+        total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
+        stage_block_id = 0
+        for cnf in block_setting:
+            # Bottlenecks
+            stage: List[nn.Module] = []
+            for _ in range(cnf.num_layers):
+                # adjust stochastic depth probability based on the depth of the stage block
+                sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
+                stage.append(block(cnf.input_channels, layer_scale, sd_prob))
+                stage_block_id += 1
+            layers.append(nn.Sequential(*stage))
+            if cnf.out_channels is not None:
+                # Downsampling
+                layers.append(
+                    nn.Sequential(
+                        norm_layer(cnf.input_channels),
+                        nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
+                    )
+                )
+
+        self.features = nn.Sequential(*layers)
+        self.avgpool = nn.AdaptiveAvgPool2d(1)
+
+        lastblock = block_setting[-1]
+        lastconv_output_channels = (
+            lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
+        )
+        self.classifier = nn.Sequential(
+            norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
+        )
+
+        for m in self.modules():
+            if isinstance(m, (nn.Conv2d, nn.Linear)):
+                nn.init.trunc_normal_(m.weight, std=0.02)
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+
+    def _forward_impl(self, x: Tensor) -> Tensor:
+        x = self.features(x)
+        x = self.avgpool(x)
+        x = self.classifier(x)
+        return x
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self._forward_impl(x)
+
+
+def _convnext(
+    block_setting: List[CNBlockConfig],
+    stochastic_depth_prob: float,
+    weights: Optional[WeightsEnum],
+    progress: bool,
+    **kwargs: Any,
+) -> ConvNeXt:
+    if weights is not None:
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+    model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+_COMMON_META = {
+    "min_size": (32, 32),
+    "categories": _IMAGENET_CATEGORIES,
+    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
+    "_docs": """
+        These weights improve upon the results of the original paper by using a modified version of TorchVision's
+        `new training recipe
+        <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
+    """,
+}
+
+
+class ConvNeXt_Tiny_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
+        transforms=partial(ImageClassification, crop_size=224, resize_size=236),
+        meta={
+            **_COMMON_META,
+            "num_params": 28589128,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 82.520,
+                    "acc@5": 96.146,
+                }
+            },
+            "_ops": 4.456,
+            "_file_size": 109.119,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class ConvNeXt_Small_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
+        transforms=partial(ImageClassification, crop_size=224, resize_size=230),
+        meta={
+            **_COMMON_META,
+            "num_params": 50223688,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 83.616,
+                    "acc@5": 96.650,
+                }
+            },
+            "_ops": 8.684,
+            "_file_size": 191.703,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class ConvNeXt_Base_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
+        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+        meta={
+            **_COMMON_META,
+            "num_params": 88591464,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 84.062,
+                    "acc@5": 96.870,
+                }
+            },
+            "_ops": 15.355,
+            "_file_size": 338.064,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class ConvNeXt_Large_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
+        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
+        meta={
+            **_COMMON_META,
+            "num_params": 197767336,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 84.414,
+                    "acc@5": 96.976,
+                }
+            },
+            "_ops": 34.361,
+            "_file_size": 754.537,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
+def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
+    """ConvNeXt Tiny model architecture from the
+    `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`, optional): The pretrained
+            weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`
+            below for more details and possible values. By default, no pre-trained weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.ConvNeXt_Tiny_Weights
+        :members:
+    """
+    weights = ConvNeXt_Tiny_Weights.verify(weights)
+
+    block_setting = [
+        CNBlockConfig(96, 192, 3),
+        CNBlockConfig(192, 384, 3),
+        CNBlockConfig(384, 768, 9),
+        CNBlockConfig(768, None, 3),
+    ]
+    stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
+    return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
+def convnext_small(
+    *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
+) -> ConvNeXt:
+    """ConvNeXt Small model architecture from the
+    `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`, optional): The pretrained
+            weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`
+            below for more details and possible values. By default, no pre-trained weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.ConvNeXt_Small_Weights
+        :members:
+    """
+    weights = ConvNeXt_Small_Weights.verify(weights)
+
+    block_setting = [
+        CNBlockConfig(96, 192, 3),
+        CNBlockConfig(192, 384, 3),
+        CNBlockConfig(384, 768, 27),
+        CNBlockConfig(768, None, 3),
+    ]
+    stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
+    return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
+def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
+    """ConvNeXt Base model architecture from the
+    `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`, optional): The pretrained
+            weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`
+            below for more details and possible values. By default, no pre-trained weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.ConvNeXt_Base_Weights
+        :members:
+    """
+    weights = ConvNeXt_Base_Weights.verify(weights)
+
+    block_setting = [
+        CNBlockConfig(128, 256, 3),
+        CNBlockConfig(256, 512, 3),
+        CNBlockConfig(512, 1024, 27),
+        CNBlockConfig(1024, None, 3),
+    ]
+    stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
+    return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
+def convnext_large(
+    *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
+) -> ConvNeXt:
+    """ConvNeXt Large model architecture from the
+    `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
+
+    Args:
+        weights (:class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`, optional): The pretrained
+            weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`
+            below for more details and possible values. By default, no pre-trained weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.ConvNeXt_Large_Weights
+        :members:
+    """
+    weights = ConvNeXt_Large_Weights.verify(weights)
+
+    block_setting = [
+        CNBlockConfig(192, 384, 3),
+        CNBlockConfig(384, 768, 3),
+        CNBlockConfig(768, 1536, 27),
+        CNBlockConfig(1536, None, 3),
+    ]
+    stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
+    return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)

+ 448 - 0
libs/vision_libs/models/densenet.py

@@ -0,0 +1,448 @@
+import re
+from collections import OrderedDict
+from functools import partial
+from typing import Any, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from torch import Tensor
+
+from ..transforms._presets import ImageClassification
+from ..utils import _log_api_usage_once
+from ._api import register_model, Weights, WeightsEnum
+from ._meta import _IMAGENET_CATEGORIES
+from ._utils import _ovewrite_named_param, handle_legacy_interface
+
+__all__ = [
+    "DenseNet",
+    "DenseNet121_Weights",
+    "DenseNet161_Weights",
+    "DenseNet169_Weights",
+    "DenseNet201_Weights",
+    "densenet121",
+    "densenet161",
+    "densenet169",
+    "densenet201",
+]
+
+
+class _DenseLayer(nn.Module):
+    def __init__(
+        self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
+    ) -> None:
+        super().__init__()
+        self.norm1 = nn.BatchNorm2d(num_input_features)
+        self.relu1 = nn.ReLU(inplace=True)
+        self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
+
+        self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
+        self.relu2 = nn.ReLU(inplace=True)
+        self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
+
+        self.drop_rate = float(drop_rate)
+        self.memory_efficient = memory_efficient
+
+    def bn_function(self, inputs: List[Tensor]) -> Tensor:
+        concated_features = torch.cat(inputs, 1)
+        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
+        return bottleneck_output
+
+    # todo: rewrite when torchscript supports any
+    def any_requires_grad(self, input: List[Tensor]) -> bool:
+        for tensor in input:
+            if tensor.requires_grad:
+                return True
+        return False
+
+    @torch.jit.unused  # noqa: T484
+    def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
+        def closure(*inputs):
+            return self.bn_function(inputs)
+
+        return cp.checkpoint(closure, *input)
+
+    @torch.jit._overload_method  # noqa: F811
+    def forward(self, input: List[Tensor]) -> Tensor:  # noqa: F811
+        pass
+
+    @torch.jit._overload_method  # noqa: F811
+    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
+        pass
+
+    # torchscript does not yet support *args, so we overload method
+    # allowing it to take either a List[Tensor] or single Tensor
+    def forward(self, input: Tensor) -> Tensor:  # noqa: F811
+        if isinstance(input, Tensor):
+            prev_features = [input]
+        else:
+            prev_features = input
+
+        if self.memory_efficient and self.any_requires_grad(prev_features):
+            if torch.jit.is_scripting():
+                raise Exception("Memory Efficient not supported in JIT")
+
+            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
+        else:
+            bottleneck_output = self.bn_function(prev_features)
+
+        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
+        if self.drop_rate > 0:
+            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
+        return new_features
+
+
+class _DenseBlock(nn.ModuleDict):
+    _version = 2
+
+    def __init__(
+        self,
+        num_layers: int,
+        num_input_features: int,
+        bn_size: int,
+        growth_rate: int,
+        drop_rate: float,
+        memory_efficient: bool = False,
+    ) -> None:
+        super().__init__()
+        for i in range(num_layers):
+            layer = _DenseLayer(
+                num_input_features + i * growth_rate,
+                growth_rate=growth_rate,
+                bn_size=bn_size,
+                drop_rate=drop_rate,
+                memory_efficient=memory_efficient,
+            )
+            self.add_module("denselayer%d" % (i + 1), layer)
+
+    def forward(self, init_features: Tensor) -> Tensor:
+        features = [init_features]
+        for name, layer in self.items():
+            new_features = layer(features)
+            features.append(new_features)
+        return torch.cat(features, 1)
+
+
+class _Transition(nn.Sequential):
+    def __init__(self, num_input_features: int, num_output_features: int) -> None:
+        super().__init__()
+        self.norm = nn.BatchNorm2d(num_input_features)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)
+        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
+
+
+class DenseNet(nn.Module):
+    r"""Densenet-BC model class, based on
+    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
+
+    Args:
+        growth_rate (int) - how many filters to add each layer (`k` in paper)
+        block_config (list of 4 ints) - how many layers in each pooling block
+        num_init_features (int) - the number of filters to learn in the first convolution layer
+        bn_size (int) - multiplicative factor for number of bottle neck layers
+          (i.e. bn_size * k features in the bottleneck layer)
+        drop_rate (float) - dropout rate after each dense layer
+        num_classes (int) - number of classification classes
+        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
+          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
+    """
+
+    def __init__(
+        self,
+        growth_rate: int = 32,
+        block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
+        num_init_features: int = 64,
+        bn_size: int = 4,
+        drop_rate: float = 0,
+        num_classes: int = 1000,
+        memory_efficient: bool = False,
+    ) -> None:
+
+        super().__init__()
+        _log_api_usage_once(self)
+
+        # First convolution
+        self.features = nn.Sequential(
+            OrderedDict(
+                [
+                    ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
+                    ("norm0", nn.BatchNorm2d(num_init_features)),
+                    ("relu0", nn.ReLU(inplace=True)),
+                    ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
+                ]
+            )
+        )
+
+        # Each denseblock
+        num_features = num_init_features
+        for i, num_layers in enumerate(block_config):
+            block = _DenseBlock(
+                num_layers=num_layers,
+                num_input_features=num_features,
+                bn_size=bn_size,
+                growth_rate=growth_rate,
+                drop_rate=drop_rate,
+                memory_efficient=memory_efficient,
+            )
+            self.features.add_module("denseblock%d" % (i + 1), block)
+            num_features = num_features + num_layers * growth_rate
+            if i != len(block_config) - 1:
+                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
+                self.features.add_module("transition%d" % (i + 1), trans)
+                num_features = num_features // 2
+
+        # Final batch norm
+        self.features.add_module("norm5", nn.BatchNorm2d(num_features))
+
+        # Linear layer
+        self.classifier = nn.Linear(num_features, num_classes)
+
+        # Official init from torch repo.
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x: Tensor) -> Tensor:
+        features = self.features(x)
+        out = F.relu(features, inplace=True)
+        out = F.adaptive_avg_pool2d(out, (1, 1))
+        out = torch.flatten(out, 1)
+        out = self.classifier(out)
+        return out
+
+
+def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
+    # '.'s are no longer allowed in module names, but previous _DenseLayer
+    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
+    # They are also in the checkpoints in model_urls. This pattern is used
+    # to find such keys.
+    pattern = re.compile(
+        r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
+    )
+
+    state_dict = weights.get_state_dict(progress=progress, check_hash=True)
+    for key in list(state_dict.keys()):
+        res = pattern.match(key)
+        if res:
+            new_key = res.group(1) + res.group(2)
+            state_dict[new_key] = state_dict[key]
+            del state_dict[key]
+    model.load_state_dict(state_dict)
+
+
+def _densenet(
+    growth_rate: int,
+    block_config: Tuple[int, int, int, int],
+    num_init_features: int,
+    weights: Optional[WeightsEnum],
+    progress: bool,
+    **kwargs: Any,
+) -> DenseNet:
+    if weights is not None:
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+
+    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
+
+    if weights is not None:
+        _load_state_dict(model=model, weights=weights, progress=progress)
+
+    return model
+
+
+_COMMON_META = {
+    "min_size": (29, 29),
+    "categories": _IMAGENET_CATEGORIES,
+    "recipe": "https://github.com/pytorch/vision/pull/116",
+    "_docs": """These weights are ported from LuaTorch.""",
+}
+
+
+class DenseNet121_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "num_params": 7978856,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 74.434,
+                    "acc@5": 91.972,
+                }
+            },
+            "_ops": 2.834,
+            "_file_size": 30.845,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class DenseNet161_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "num_params": 28681000,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 77.138,
+                    "acc@5": 93.560,
+                }
+            },
+            "_ops": 7.728,
+            "_file_size": 110.369,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class DenseNet169_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "num_params": 14149480,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 75.600,
+                    "acc@5": 92.806,
+                }
+            },
+            "_ops": 3.36,
+            "_file_size": 54.708,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+class DenseNet201_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        url="https://download.pytorch.org/models/densenet201-c1103571.pth",
+        transforms=partial(ImageClassification, crop_size=224),
+        meta={
+            **_COMMON_META,
+            "num_params": 20013928,
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 76.896,
+                    "acc@5": 93.370,
+                }
+            },
+            "_ops": 4.291,
+            "_file_size": 77.373,
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
+def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
+    r"""Densenet-121 model from
+    `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.DenseNet121_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.DenseNet121_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.DenseNet121_Weights
+        :members:
+    """
+    weights = DenseNet121_Weights.verify(weights)
+
+    return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
+def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
+    r"""Densenet-161 model from
+    `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.DenseNet161_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.DenseNet161_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.DenseNet161_Weights
+        :members:
+    """
+    weights = DenseNet161_Weights.verify(weights)
+
+    return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
+def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
+    r"""Densenet-169 model from
+    `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.DenseNet169_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.DenseNet169_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.DenseNet169_Weights
+        :members:
+    """
+    weights = DenseNet169_Weights.verify(weights)
+
+    return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
+
+
+@register_model()
+@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
+def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
+    r"""Densenet-201 model from
+    `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.DenseNet201_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.DenseNet201_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.DenseNet201_Weights
+        :members:
+    """
+    weights = DenseNet201_Weights.verify(weights)
+
+    return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)

+ 7 - 0
libs/vision_libs/models/detection/__init__.py

@@ -0,0 +1,7 @@
+from .faster_rcnn import *
+from .fcos import *
+from .keypoint_rcnn import *
+from .mask_rcnn import *
+from .retinanet import *
+from .ssd import *
+from .ssdlite import *

+ 540 - 0
libs/vision_libs/models/detection/_utils.py

@@ -0,0 +1,540 @@
+import math
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
+
+
+class BalancedPositiveNegativeSampler:
+    """
+    This class samples batches, ensuring that they contain a fixed proportion of positives
+    """
+
+    def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
+        """
+        Args:
+            batch_size_per_image (int): number of elements to be selected per image
+            positive_fraction (float): percentage of positive elements per batch
+        """
+        self.batch_size_per_image = batch_size_per_image
+        self.positive_fraction = positive_fraction
+
+    def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
+        """
+        Args:
+            matched_idxs: list of tensors containing -1, 0 or positive values.
+                Each tensor corresponds to a specific image.
+                -1 values are ignored, 0 are considered as negatives and > 0 as
+                positives.
+
+        Returns:
+            pos_idx (list[tensor])
+            neg_idx (list[tensor])
+
+        Returns two lists of binary masks for each image.
+        The first list contains the positive elements that were selected,
+        and the second list the negative example.
+        """
+        pos_idx = []
+        neg_idx = []
+        for matched_idxs_per_image in matched_idxs:
+            positive = torch.where(matched_idxs_per_image >= 1)[0]
+            negative = torch.where(matched_idxs_per_image == 0)[0]
+
+            num_pos = int(self.batch_size_per_image * self.positive_fraction)
+            # protect against not enough positive examples
+            num_pos = min(positive.numel(), num_pos)
+            num_neg = self.batch_size_per_image - num_pos
+            # protect against not enough negative examples
+            num_neg = min(negative.numel(), num_neg)
+
+            # randomly select positive and negative examples
+            perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
+            perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
+
+            pos_idx_per_image = positive[perm1]
+            neg_idx_per_image = negative[perm2]
+
+            # create binary mask from indices
+            pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+            neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+
+            pos_idx_per_image_mask[pos_idx_per_image] = 1
+            neg_idx_per_image_mask[neg_idx_per_image] = 1
+
+            pos_idx.append(pos_idx_per_image_mask)
+            neg_idx.append(neg_idx_per_image_mask)
+
+        return pos_idx, neg_idx
+
+
+@torch.jit._script_if_tracing
+def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
+    """
+    Encode a set of proposals with respect to some
+    reference boxes
+
+    Args:
+        reference_boxes (Tensor): reference boxes
+        proposals (Tensor): boxes to be encoded
+        weights (Tensor[4]): the weights for ``(x, y, w, h)``
+    """
+
+    # perform some unpacking to make it JIT-fusion friendly
+    wx = weights[0]
+    wy = weights[1]
+    ww = weights[2]
+    wh = weights[3]
+
+    proposals_x1 = proposals[:, 0].unsqueeze(1)
+    proposals_y1 = proposals[:, 1].unsqueeze(1)
+    proposals_x2 = proposals[:, 2].unsqueeze(1)
+    proposals_y2 = proposals[:, 3].unsqueeze(1)
+
+    reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
+    reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
+    reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
+    reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
+
+    # implementation starts here
+    ex_widths = proposals_x2 - proposals_x1
+    ex_heights = proposals_y2 - proposals_y1
+    ex_ctr_x = proposals_x1 + 0.5 * ex_widths
+    ex_ctr_y = proposals_y1 + 0.5 * ex_heights
+
+    gt_widths = reference_boxes_x2 - reference_boxes_x1
+    gt_heights = reference_boxes_y2 - reference_boxes_y1
+    gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
+    gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
+
+    targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
+    targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
+    targets_dw = ww * torch.log(gt_widths / ex_widths)
+    targets_dh = wh * torch.log(gt_heights / ex_heights)
+
+    targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
+    return targets
+
+
+class BoxCoder:
+    """
+    This class encodes and decodes a set of bounding boxes into
+    the representation used for training the regressors.
+    """
+
+    def __init__(
+        self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
+    ) -> None:
+        """
+        Args:
+            weights (4-element tuple)
+            bbox_xform_clip (float)
+        """
+        self.weights = weights
+        self.bbox_xform_clip = bbox_xform_clip
+
+    def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
+        boxes_per_image = [len(b) for b in reference_boxes]
+        reference_boxes = torch.cat(reference_boxes, dim=0)
+        proposals = torch.cat(proposals, dim=0)
+        targets = self.encode_single(reference_boxes, proposals)
+        return targets.split(boxes_per_image, 0)
+
+    def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+        """
+        Encode a set of proposals with respect to some
+        reference boxes
+
+        Args:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+        """
+        dtype = reference_boxes.dtype
+        device = reference_boxes.device
+        weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
+        targets = encode_boxes(reference_boxes, proposals, weights)
+
+        return targets
+
+    def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
+        torch._assert(
+            isinstance(boxes, (list, tuple)),
+            "This function expects boxes of type list or tuple.",
+        )
+        torch._assert(
+            isinstance(rel_codes, torch.Tensor),
+            "This function expects rel_codes of type torch.Tensor.",
+        )
+        boxes_per_image = [b.size(0) for b in boxes]
+        concat_boxes = torch.cat(boxes, dim=0)
+        box_sum = 0
+        for val in boxes_per_image:
+            box_sum += val
+        if box_sum > 0:
+            rel_codes = rel_codes.reshape(box_sum, -1)
+        pred_boxes = self.decode_single(rel_codes, concat_boxes)
+        if box_sum > 0:
+            pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
+        return pred_boxes
+
+    def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Args:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+        """
+
+        boxes = boxes.to(rel_codes.dtype)
+
+        widths = boxes[:, 2] - boxes[:, 0]
+        heights = boxes[:, 3] - boxes[:, 1]
+        ctr_x = boxes[:, 0] + 0.5 * widths
+        ctr_y = boxes[:, 1] + 0.5 * heights
+
+        wx, wy, ww, wh = self.weights
+        dx = rel_codes[:, 0::4] / wx
+        dy = rel_codes[:, 1::4] / wy
+        dw = rel_codes[:, 2::4] / ww
+        dh = rel_codes[:, 3::4] / wh
+
+        # Prevent sending too large values into torch.exp()
+        dw = torch.clamp(dw, max=self.bbox_xform_clip)
+        dh = torch.clamp(dh, max=self.bbox_xform_clip)
+
+        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
+        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
+        pred_w = torch.exp(dw) * widths[:, None]
+        pred_h = torch.exp(dh) * heights[:, None]
+
+        # Distance from center to box's corner.
+        c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
+        c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
+
+        pred_boxes1 = pred_ctr_x - c_to_c_w
+        pred_boxes2 = pred_ctr_y - c_to_c_h
+        pred_boxes3 = pred_ctr_x + c_to_c_w
+        pred_boxes4 = pred_ctr_y + c_to_c_h
+        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
+        return pred_boxes
+
+
+class BoxLinearCoder:
+    """
+    The linear box-to-box transform defined in FCOS. The transformation is parameterized
+    by the distance from the center of (square) src box to 4 edges of the target box.
+    """
+
+    def __init__(self, normalize_by_size: bool = True) -> None:
+        """
+        Args:
+            normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
+        """
+        self.normalize_by_size = normalize_by_size
+
+    def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+        """
+        Encode a set of proposals with respect to some reference boxes
+
+        Args:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+
+        Returns:
+            Tensor: the encoded relative box offsets that can be used to
+            decode the boxes.
+
+        """
+
+        # get the center of reference_boxes
+        reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
+        reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
+
+        # get box regression transformation deltas
+        target_l = reference_boxes_ctr_x - proposals[..., 0]
+        target_t = reference_boxes_ctr_y - proposals[..., 1]
+        target_r = proposals[..., 2] - reference_boxes_ctr_x
+        target_b = proposals[..., 3] - reference_boxes_ctr_y
+
+        targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
+
+        if self.normalize_by_size:
+            reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
+            reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
+            reference_boxes_size = torch.stack(
+                (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
+            )
+            targets = targets / reference_boxes_size
+        return targets
+
+    def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Args:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+
+        Returns:
+            Tensor: the predicted boxes with the encoded relative box offsets.
+
+        .. note::
+            This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
+
+        """
+
+        boxes = boxes.to(dtype=rel_codes.dtype)
+
+        ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
+        ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
+
+        if self.normalize_by_size:
+            boxes_w = boxes[..., 2] - boxes[..., 0]
+            boxes_h = boxes[..., 3] - boxes[..., 1]
+
+            list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
+            rel_codes = rel_codes * list_box_size
+
+        pred_boxes1 = ctr_x - rel_codes[..., 0]
+        pred_boxes2 = ctr_y - rel_codes[..., 1]
+        pred_boxes3 = ctr_x + rel_codes[..., 2]
+        pred_boxes4 = ctr_y + rel_codes[..., 3]
+
+        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
+        return pred_boxes
+
+
+class Matcher:
+    """
+    This class assigns to each predicted "element" (e.g., a box) a ground-truth
+    element. Each predicted element will have exactly zero or one matches; each
+    ground-truth element may be assigned to zero or more predicted elements.
+
+    Matching is based on the MxN match_quality_matrix, that characterizes how well
+    each (ground-truth, predicted)-pair match. For example, if the elements are
+    boxes, the matrix may contain box IoU overlap values.
+
+    The matcher returns a tensor of size N containing the index of the ground-truth
+    element m that matches to prediction n. If there is no match, a negative value
+    is returned.
+    """
+
+    BELOW_LOW_THRESHOLD = -1
+    BETWEEN_THRESHOLDS = -2
+
+    __annotations__ = {
+        "BELOW_LOW_THRESHOLD": int,
+        "BETWEEN_THRESHOLDS": int,
+    }
+
+    def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
+        """
+        Args:
+            high_threshold (float): quality values greater than or equal to
+                this value are candidate matches.
+            low_threshold (float): a lower quality threshold used to stratify
+                matches into three levels:
+                1) matches >= high_threshold
+                2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
+                3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
+            allow_low_quality_matches (bool): if True, produce additional matches
+                for predictions that have only low-quality match candidates. See
+                set_low_quality_matches_ for more details.
+        """
+        self.BELOW_LOW_THRESHOLD = -1
+        self.BETWEEN_THRESHOLDS = -2
+        torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
+        self.high_threshold = high_threshold
+        self.low_threshold = low_threshold
+        self.allow_low_quality_matches = allow_low_quality_matches
+
+    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+        """
+        Args:
+            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
+            pairwise quality between M ground-truth elements and N predicted elements.
+
+        Returns:
+            matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
+            [0, M - 1] or a negative value indicating that prediction i could not
+            be matched.
+        """
+        if match_quality_matrix.numel() == 0:
+            # empty targets or proposals not supported during training
+            if match_quality_matrix.shape[0] == 0:
+                raise ValueError("No ground-truth boxes available for one of the images during training")
+            else:
+                raise ValueError("No proposal boxes available for one of the images during training")
+
+        # match_quality_matrix is M (gt) x N (predicted)
+        # Max over gt elements (dim 0) to find best gt candidate for each prediction
+        matched_vals, matches = match_quality_matrix.max(dim=0)
+        if self.allow_low_quality_matches:
+            all_matches = matches.clone()
+        else:
+            all_matches = None  # type: ignore[assignment]
+
+        # Assign candidate matches with low quality to negative (unassigned) values
+        below_low_threshold = matched_vals < self.low_threshold
+        between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
+        matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
+        matches[between_thresholds] = self.BETWEEN_THRESHOLDS
+
+        if self.allow_low_quality_matches:
+            if all_matches is None:
+                torch._assert(False, "all_matches should not be None")
+            else:
+                self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
+
+        return matches
+
+    def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
+        """
+        Produce additional matches for predictions that have only low-quality matches.
+        Specifically, for each ground-truth find the set of predictions that have
+        maximum overlap with it (including ties); for each prediction in that set, if
+        it is unmatched, then match it to the ground-truth with which it has the highest
+        quality value.
+        """
+        # For each gt, find the prediction with which it has the highest quality
+        highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
+        # Find the highest quality match available, even if it is low, including ties
+        gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
+        # Example gt_pred_pairs_of_highest_quality:
+        # (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]),
+        #  tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
+        # Each element in the first tensor is a gt index, and each element in second tensor is a prediction index
+        # Note how gt items 1, 2, 3, and 5 each have two ties
+
+        pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
+        matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
+
+
+class SSDMatcher(Matcher):
+    def __init__(self, threshold: float) -> None:
+        super().__init__(threshold, threshold, allow_low_quality_matches=False)
+
+    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+        matches = super().__call__(match_quality_matrix)
+
+        # For each gt, find the prediction with which it has the highest quality
+        _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
+        matches[highest_quality_pred_foreach_gt] = torch.arange(
+            highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
+        )
+
+        return matches
+
+
+def overwrite_eps(model: nn.Module, eps: float) -> None:
+    """
+    This method overwrites the default eps values of all the
+    FrozenBatchNorm2d layers of the model with the provided value.
+    This is necessary to address the BC-breaking change introduced
+    by the bug-fix at pytorch/vision#2933. The overwrite is applied
+    only when the pretrained weights are loaded to maintain compatibility
+    with previous versions.
+
+    Args:
+        model (nn.Module): The model on which we perform the overwrite.
+        eps (float): The new value of eps.
+    """
+    for module in model.modules():
+        if isinstance(module, FrozenBatchNorm2d):
+            module.eps = eps
+
+
+def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
+    """
+    This method retrieves the number of output channels of a specific model.
+
+    Args:
+        model (nn.Module): The model for which we estimate the out_channels.
+            It should return a single Tensor or an OrderedDict[Tensor].
+        size (Tuple[int, int]): The size (wxh) of the input.
+
+    Returns:
+        out_channels (List[int]): A list of the output channels of the model.
+    """
+    in_training = model.training
+    model.eval()
+
+    with torch.no_grad():
+        # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
+        device = next(model.parameters()).device
+        tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
+        features = model(tmp_img)
+        if isinstance(features, torch.Tensor):
+            features = OrderedDict([("0", features)])
+        out_channels = [x.size(1) for x in features.values()]
+
+    if in_training:
+        model.train()
+
+    return out_channels
+
+
+@torch.jit.unused
+def _fake_cast_onnx(v: Tensor) -> int:
+    return v  # type: ignore[return-value]
+
+
+def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
+    """
+    ONNX spec requires the k-value to be less than or equal to the number of inputs along
+    provided dim. Certain models use the number of elements along a particular axis instead of K
+    if K exceeds the number of elements along that axis. Previously, python's min() function was
+    used to determine whether to use the provided k-value or the specified dim axis value.
+
+    However, in cases where the model is being exported in tracing mode, python min() is
+    static causing the model to be traced incorrectly and eventually fail at the topk node.
+    In order to avoid this situation, in tracing mode, torch.min() is used instead.
+
+    Args:
+        input (Tensor): The original input tensor.
+        orig_kval (int): The provided k-value.
+        axis(int): Axis along which we retrieve the input size.
+
+    Returns:
+        min_kval (int): Appropriately selected k-value.
+    """
+    if not torch.jit.is_tracing():
+        return min(orig_kval, input.size(axis))
+    axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
+    min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
+    return _fake_cast_onnx(min_kval)
+
+
+def _box_loss(
+    type: str,
+    box_coder: BoxCoder,
+    anchors_per_image: Tensor,
+    matched_gt_boxes_per_image: Tensor,
+    bbox_regression_per_image: Tensor,
+    cnf: Optional[Dict[str, float]] = None,
+) -> Tensor:
+    torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
+
+    if type == "l1":
+        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+        return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
+    elif type == "smooth_l1":
+        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+        beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
+        return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
+    else:
+        bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
+        eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
+        if type == "ciou":
+            return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+        if type == "diou":
+            return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+        # otherwise giou
+        return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)

+ 268 - 0
libs/vision_libs/models/detection/anchor_utils.py

@@ -0,0 +1,268 @@
+import math
+from typing import List, Optional
+
+import torch
+from torch import nn, Tensor
+
+from .image_list import ImageList
+
+
+class AnchorGenerator(nn.Module):
+    """
+    Module that generates anchors for a set of feature maps and
+    image sizes.
+
+    The module support computing anchors at multiple sizes and aspect ratios
+    per feature map. This module assumes aspect ratio = height / width for
+    each anchor.
+
+    sizes and aspect_ratios should have the same number of elements, and it should
+    correspond to the number of feature maps.
+
+    sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
+    and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
+    per spatial location for feature map i.
+
+    Args:
+        sizes (Tuple[Tuple[int]]):
+        aspect_ratios (Tuple[Tuple[float]]):
+    """
+
+    __annotations__ = {
+        "cell_anchors": List[torch.Tensor],
+    }
+
+    def __init__(
+        self,
+        sizes=((128, 256, 512),),
+        aspect_ratios=((0.5, 1.0, 2.0),),
+    ):
+        super().__init__()
+
+        if not isinstance(sizes[0], (list, tuple)):
+            # TODO change this
+            sizes = tuple((s,) for s in sizes)
+        if not isinstance(aspect_ratios[0], (list, tuple)):
+            aspect_ratios = (aspect_ratios,) * len(sizes)
+
+        self.sizes = sizes
+        self.aspect_ratios = aspect_ratios
+        self.cell_anchors = [
+            self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios)
+        ]
+
+    # TODO: https://github.com/pytorch/pytorch/issues/26792
+    # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
+    # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
+    # This method assumes aspect ratio = height / width for an anchor.
+    def generate_anchors(
+        self,
+        scales: List[int],
+        aspect_ratios: List[float],
+        dtype: torch.dtype = torch.float32,
+        device: torch.device = torch.device("cpu"),
+    ) -> Tensor:
+        scales = torch.as_tensor(scales, dtype=dtype, device=device)
+        aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
+        h_ratios = torch.sqrt(aspect_ratios)
+        w_ratios = 1 / h_ratios
+
+        ws = (w_ratios[:, None] * scales[None, :]).view(-1)
+        hs = (h_ratios[:, None] * scales[None, :]).view(-1)
+
+        base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
+        return base_anchors.round()
+
+    def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
+        self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
+
+    def num_anchors_per_location(self) -> List[int]:
+        return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
+
+    # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
+    # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
+    def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
+        anchors = []
+        cell_anchors = self.cell_anchors
+        torch._assert(cell_anchors is not None, "cell_anchors should not be None")
+        torch._assert(
+            len(grid_sizes) == len(strides) == len(cell_anchors),
+            "Anchors should be Tuple[Tuple[int]] because each feature "
+            "map could potentially have different sizes and aspect ratios. "
+            "There needs to be a match between the number of "
+            "feature maps passed and the number of sizes / aspect ratios specified.",
+        )
+
+        for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
+            grid_height, grid_width = size
+            stride_height, stride_width = stride
+            device = base_anchors.device
+
+            # For output anchor, compute [x_center, y_center, x_center, y_center]
+            shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
+            shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
+            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
+            shift_x = shift_x.reshape(-1)
+            shift_y = shift_y.reshape(-1)
+            shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
+
+            # For every (base anchor, output anchor) pair,
+            # offset each zero-centered base anchor by the center of the output anchor.
+            anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
+
+        return anchors
+
+    def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
+        grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
+        image_size = image_list.tensors.shape[-2:]
+        dtype, device = feature_maps[0].dtype, feature_maps[0].device
+        strides = [
+            [
+                torch.empty((), dtype=torch.int64, device=device).fill_(image_size[0] // g[0]),
+                torch.empty((), dtype=torch.int64, device=device).fill_(image_size[1] // g[1]),
+            ]
+            for g in grid_sizes
+        ]
+        self.set_cell_anchors(dtype, device)
+        anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
+        anchors: List[List[torch.Tensor]] = []
+        for _ in range(len(image_list.image_sizes)):
+            anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
+            anchors.append(anchors_in_image)
+        anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
+        return anchors
+
+
+class DefaultBoxGenerator(nn.Module):
+    """
+    This module generates the default boxes of SSD for a set of feature maps and image sizes.
+
+    Args:
+        aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map.
+        min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation
+            of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
+        max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}`  of the default boxes used in the estimation
+            of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
+        scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
+            the ``min_ratio`` and ``max_ratio`` parameters.
+        steps (List[int]], optional): It's a hyper-parameter that affects the tiling of default boxes. If not provided
+            it will be estimated from the data.
+        clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
+            is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
+    """
+
+    def __init__(
+        self,
+        aspect_ratios: List[List[int]],
+        min_ratio: float = 0.15,
+        max_ratio: float = 0.9,
+        scales: Optional[List[float]] = None,
+        steps: Optional[List[int]] = None,
+        clip: bool = True,
+    ):
+        super().__init__()
+        if steps is not None and len(aspect_ratios) != len(steps):
+            raise ValueError("aspect_ratios and steps should have the same length")
+        self.aspect_ratios = aspect_ratios
+        self.steps = steps
+        self.clip = clip
+        num_outputs = len(aspect_ratios)
+
+        # Estimation of default boxes scales
+        if scales is None:
+            if num_outputs > 1:
+                range_ratio = max_ratio - min_ratio
+                self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)]
+                self.scales.append(1.0)
+            else:
+                self.scales = [min_ratio, max_ratio]
+        else:
+            self.scales = scales
+
+        self._wh_pairs = self._generate_wh_pairs(num_outputs)
+
+    def _generate_wh_pairs(
+        self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu")
+    ) -> List[Tensor]:
+        _wh_pairs: List[Tensor] = []
+        for k in range(num_outputs):
+            # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
+            s_k = self.scales[k]
+            s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
+            wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]]
+
+            # Adding 2 pairs for each aspect ratio of the feature map k
+            for ar in self.aspect_ratios[k]:
+                sq_ar = math.sqrt(ar)
+                w = self.scales[k] * sq_ar
+                h = self.scales[k] / sq_ar
+                wh_pairs.extend([[w, h], [h, w]])
+
+            _wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
+        return _wh_pairs
+
+    def num_anchors_per_location(self) -> List[int]:
+        # Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
+        return [2 + 2 * len(r) for r in self.aspect_ratios]
+
+    # Default Boxes calculation based on page 6 of SSD paper
+    def _grid_default_boxes(
+        self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32
+    ) -> Tensor:
+        default_boxes = []
+        for k, f_k in enumerate(grid_sizes):
+            # Now add the default boxes for each width-height pair
+            if self.steps is not None:
+                x_f_k = image_size[1] / self.steps[k]
+                y_f_k = image_size[0] / self.steps[k]
+            else:
+                y_f_k, x_f_k = f_k
+
+            shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype)
+            shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype)
+            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
+            shift_x = shift_x.reshape(-1)
+            shift_y = shift_y.reshape(-1)
+
+            shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2)
+            # Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
+            _wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k]
+            wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1)
+
+            default_box = torch.cat((shifts, wh_pairs), dim=1)
+
+            default_boxes.append(default_box)
+
+        return torch.cat(default_boxes, dim=0)
+
+    def __repr__(self) -> str:
+        s = (
+            f"{self.__class__.__name__}("
+            f"aspect_ratios={self.aspect_ratios}"
+            f", clip={self.clip}"
+            f", scales={self.scales}"
+            f", steps={self.steps}"
+            ")"
+        )
+        return s
+
+    def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
+        grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
+        image_size = image_list.tensors.shape[-2:]
+        dtype, device = feature_maps[0].dtype, feature_maps[0].device
+        default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
+        default_boxes = default_boxes.to(device)
+
+        dboxes = []
+        x_y_size = torch.tensor([image_size[1], image_size[0]], device=default_boxes.device)
+        for _ in image_list.image_sizes:
+            dboxes_in_image = default_boxes
+            dboxes_in_image = torch.cat(
+                [
+                    (dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
+                    (dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]) * x_y_size,
+                ],
+                -1,
+            )
+            dboxes.append(dboxes_in_image)
+        return dboxes

+ 244 - 0
libs/vision_libs/models/detection/backbone_utils.py

@@ -0,0 +1,244 @@
+import warnings
+from typing import Callable, Dict, List, Optional, Union
+
+from torch import nn, Tensor
+from torchvision.ops import misc as misc_nn_ops
+from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
+
+from .. import mobilenet, resnet
+from .._api import _get_enum_from_fn, WeightsEnum
+from .._utils import handle_legacy_interface, IntermediateLayerGetter
+
+
+class BackboneWithFPN(nn.Module):
+    """
+    Adds a FPN on top of a model.
+    Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
+    extract a submodel that returns the feature maps specified in return_layers.
+    The same limitations of IntermediateLayerGetter apply here.
+    Args:
+        backbone (nn.Module)
+        return_layers (Dict[name, new_name]): a dict containing the names
+            of the modules for which the activations will be returned as
+            the key of the dict, and the value of the dict is the name
+            of the returned activation (which the user can specify).
+        in_channels_list (List[int]): number of channels for each feature map
+            that is returned, in the order they are present in the OrderedDict
+        out_channels (int): number of channels in the FPN.
+        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
+    Attributes:
+        out_channels (int): the number of channels in the FPN
+    """
+
+    def __init__(
+        self,
+        backbone: nn.Module,
+        return_layers: Dict[str, str],
+        in_channels_list: List[int],
+        out_channels: int,
+        extra_blocks: Optional[ExtraFPNBlock] = None,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+
+        if extra_blocks is None:
+            extra_blocks = LastLevelMaxPool()
+
+        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+        self.fpn = FeaturePyramidNetwork(
+            in_channels_list=in_channels_list,
+            out_channels=out_channels,
+            extra_blocks=extra_blocks,
+            norm_layer=norm_layer,
+        )
+        self.out_channels = out_channels
+
+    def forward(self, x: Tensor) -> Dict[str, Tensor]:
+        x = self.body(x)
+        x = self.fpn(x)
+        return x
+
+
+@handle_legacy_interface(
+    weights=(
+        "pretrained",
+        lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
+    ),
+)
+def resnet_fpn_backbone(
+    *,
+    backbone_name: str,
+    weights: Optional[WeightsEnum],
+    norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
+    trainable_layers: int = 3,
+    returned_layers: Optional[List[int]] = None,
+    extra_blocks: Optional[ExtraFPNBlock] = None,
+) -> BackboneWithFPN:
+    """
+    Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
+
+    Examples::
+
+        >>> import torch
+        >>> from torchvision.models import ResNet50_Weights
+        >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
+        >>> backbone = resnet_fpn_backbone(backbone_name='resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3)
+        >>> # get some dummy image
+        >>> x = torch.rand(1,3,64,64)
+        >>> # compute the output
+        >>> output = backbone(x)
+        >>> print([(k, v.shape) for k, v in output.items()])
+        >>> # returns
+        >>>   [('0', torch.Size([1, 256, 16, 16])),
+        >>>    ('1', torch.Size([1, 256, 8, 8])),
+        >>>    ('2', torch.Size([1, 256, 4, 4])),
+        >>>    ('3', torch.Size([1, 256, 2, 2])),
+        >>>    ('pool', torch.Size([1, 256, 1, 1]))]
+
+    Args:
+        backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50',
+             'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
+        weights (WeightsEnum, optional): The pretrained weights for the model
+        norm_layer (callable): it is recommended to use the default value. For details visit:
+            (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
+        trainable_layers (int): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
+        returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
+            By default, all layers are returned.
+        extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
+            be performed. It is expected to take the fpn features, the original
+            features and the names of the original features as input, and returns
+            a new list of feature maps and their corresponding names. By
+            default, a ``LastLevelMaxPool`` is used.
+    """
+    backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
+    return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
+
+
+def _resnet_fpn_extractor(
+    backbone: resnet.ResNet,
+    trainable_layers: int,
+    returned_layers: Optional[List[int]] = None,
+    extra_blocks: Optional[ExtraFPNBlock] = None,
+    norm_layer: Optional[Callable[..., nn.Module]] = None,
+) -> BackboneWithFPN:
+
+    # select layers that won't be frozen
+    if trainable_layers < 0 or trainable_layers > 5:
+        raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
+    layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
+    if trainable_layers == 5:
+        layers_to_train.append("bn1")
+    for name, parameter in backbone.named_parameters():
+        if all([not name.startswith(layer) for layer in layers_to_train]):
+            parameter.requires_grad_(False)
+
+    if extra_blocks is None:
+        extra_blocks = LastLevelMaxPool()
+
+    if returned_layers is None:
+        returned_layers = [1, 2, 3, 4]
+    if min(returned_layers) <= 0 or max(returned_layers) >= 5:
+        raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
+    return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
+
+    in_channels_stage2 = backbone.inplanes // 8
+    in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
+    out_channels = 256
+    return BackboneWithFPN(
+        backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
+    )
+
+
+def _validate_trainable_layers(
+    is_trained: bool,
+    trainable_backbone_layers: Optional[int],
+    max_value: int,
+    default_value: int,
+) -> int:
+    # don't freeze any layers if pretrained model or backbone is not used
+    if not is_trained:
+        if trainable_backbone_layers is not None:
+            warnings.warn(
+                "Changing trainable_backbone_layers has no effect if "
+                "neither pretrained nor pretrained_backbone have been set to True, "
+                f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
+            )
+        trainable_backbone_layers = max_value
+
+    # by default freeze first blocks
+    if trainable_backbone_layers is None:
+        trainable_backbone_layers = default_value
+    if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
+        raise ValueError(
+            f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
+        )
+    return trainable_backbone_layers
+
+
+@handle_legacy_interface(
+    weights=(
+        "pretrained",
+        lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
+    ),
+)
+def mobilenet_backbone(
+    *,
+    backbone_name: str,
+    weights: Optional[WeightsEnum],
+    fpn: bool,
+    norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
+    trainable_layers: int = 2,
+    returned_layers: Optional[List[int]] = None,
+    extra_blocks: Optional[ExtraFPNBlock] = None,
+) -> nn.Module:
+    backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
+    return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)
+
+
+def _mobilenet_extractor(
+    backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
+    fpn: bool,
+    trainable_layers: int,
+    returned_layers: Optional[List[int]] = None,
+    extra_blocks: Optional[ExtraFPNBlock] = None,
+    norm_layer: Optional[Callable[..., nn.Module]] = None,
+) -> nn.Module:
+    backbone = backbone.features
+    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
+    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
+    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
+    num_stages = len(stage_indices)
+
+    # find the index of the layer from which we won't freeze
+    if trainable_layers < 0 or trainable_layers > num_stages:
+        raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
+    freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
+
+    for b in backbone[:freeze_before]:
+        for parameter in b.parameters():
+            parameter.requires_grad_(False)
+
+    out_channels = 256
+    if fpn:
+        if extra_blocks is None:
+            extra_blocks = LastLevelMaxPool()
+
+        if returned_layers is None:
+            returned_layers = [num_stages - 2, num_stages - 1]
+        if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
+            raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
+        return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
+
+        in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
+        return BackboneWithFPN(
+            backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
+        )
+    else:
+        m = nn.Sequential(
+            backbone,
+            # depthwise linear combination of channels to reduce their size
+            nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
+        )
+        m.out_channels = out_channels  # type: ignore[assignment]
+        return m

Algunos archivos no se mostraron porque demasiados archivos cambiaron en este cambio