_dataset_wrapper.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. # type: ignore
  2. from __future__ import annotations
  3. import collections.abc
  4. import contextlib
  5. from collections import defaultdict
  6. from copy import copy
  7. import torch
  8. from torchvision import datasets, tv_tensors
  9. from torchvision.transforms.v2 import functional as F
  10. __all__ = ["wrap_dataset_for_transforms_v2"]
  11. def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
  12. """Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`.
  13. Example:
  14. >>> dataset = torchvision.datasets.CocoDetection(...)
  15. >>> dataset = wrap_dataset_for_transforms_v2(dataset)
  16. .. note::
  17. For now, only the most popular datasets are supported. Furthermore, the wrapper only supports dataset
  18. configurations that are fully supported by ``torchvision.transforms.v2``. If you encounter an error prompting you
  19. to raise an issue to ``torchvision`` for a dataset or configuration that you need, please do so.
  20. The dataset samples are wrapped according to the description below.
  21. Special cases:
  22. * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
  23. returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
  24. ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.tv_tensors``.
  25. The original keys are preserved. If ``target_keys`` is omitted, returns only the values for the
  26. ``"image_id"``, ``"boxes"``, and ``"labels"``.
  27. * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
  28. the target and wrap the data in the corresponding ``torchvision.tv_tensors``. The original keys are
  29. preserved. If ``target_keys`` is omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
  30. * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
  31. coordinate format and wrapped into a :class:`~torchvision.tv_tensors.BoundingBoxes` tv_tensor.
  32. * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a
  33. dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
  34. in the corresponding ``torchvision.tv_tensors``. The original keys are preserved. If ``target_keys`` is
  35. omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
  36. * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
  37. :class:`~torchvision.tv_tensors.Mask` tv_tensor.
  38. * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
  39. :class:`~torchvision.tv_tensors.Mask` tv_tensor. The target for ``target_type="instance"`` is *replaced* by
  40. a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.tv_tensors.Mask` tv_tensor) and
  41. ``"labels"``.
  42. * :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY``
  43. coordinate format and wrapped into a :class:`~torchvision.tv_tensors.BoundingBoxes` tv_tensor.
  44. Image classification datasets
  45. This wrapper is a no-op for image classification datasets, since they were already fully supported by
  46. :mod:`torchvision.transforms` and thus no change is needed for :mod:`torchvision.transforms.v2`.
  47. Segmentation datasets
  48. Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation`, return a two-tuple of
  49. :class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
  50. segmentation mask into a :class:`~torchvision.tv_tensors.Mask` (second item).
  51. Video classification datasets
  52. Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics`, return a three-tuple containing a
  53. :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
  54. :class:`~torchvision.tv_tensors.Video` while leaving the other items as is.
  55. .. note::
  56. Only datasets constructed with ``output_format="TCHW"`` are supported, since the alternative
  57. ``output_format="THWC"`` is not supported by :mod:`torchvision.transforms.v2`.
  58. Args:
  59. dataset: the dataset instance to wrap for compatibility with transforms v2.
  60. target_keys: Target keys to return in case the target is a dictionary. If ``None`` (default), selected keys are
  61. specific to the dataset. If ``"all"``, returns the full target. Can also be a collection of strings for
  62. fine grained access. Currently only supported for :class:`~torchvision.datasets.CocoDetection`,
  63. :class:`~torchvision.datasets.VOCDetection`, :class:`~torchvision.datasets.Kitti`, and
  64. :class:`~torchvision.datasets.WIDERFace`. See above for details.
  65. """
  66. if not (
  67. target_keys is None
  68. or target_keys == "all"
  69. or (isinstance(target_keys, collections.abc.Collection) and all(isinstance(key, str) for key in target_keys))
  70. ):
  71. raise ValueError(
  72. f"`target_keys` can be None, 'all', or a collection of strings denoting the keys to be returned, "
  73. f"but got {target_keys}"
  74. )
  75. # Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
  76. # "WrappedImageNet" at runtime that doubly inherits from VisionDatasetTVTensorWrapper (see below) as well as the
  77. # original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
  78. # while we can still inject everything that we need.
  79. wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetTVTensorWrapper, type(dataset)), {})
  80. # Since VisionDatasetTVTensorWrapper comes before ImageNet in the MRO, calling the class hits
  81. # VisionDatasetTVTensorWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
  82. # ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
  83. # have the existing instance as attribute on the new object.
  84. return wrapped_dataset_cls(dataset, target_keys)
  85. class WrapperFactories(dict):
  86. def register(self, dataset_cls):
  87. def decorator(wrapper_factory):
  88. self[dataset_cls] = wrapper_factory
  89. return wrapper_factory
  90. return decorator
  91. # We need this two-stage design, i.e. a wrapper factory producing the actual wrapper, since some wrappers depend on the
  92. # dataset instance rather than just the class, since they require the user defined instance attributes. Thus, we can
  93. # provide a wrapping from the dataset class to the factory here, but can only instantiate the wrapper at runtime when
  94. # we have access to the dataset instance.
  95. WRAPPER_FACTORIES = WrapperFactories()
  96. class VisionDatasetTVTensorWrapper:
  97. def __init__(self, dataset, target_keys):
  98. dataset_cls = type(dataset)
  99. if not isinstance(dataset, datasets.VisionDataset):
  100. raise TypeError(
  101. f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
  102. f"but got a '{dataset_cls.__name__}' instead.\n"
  103. f"For an example of how to perform the wrapping for custom datasets, see\n\n"
  104. "https://pytorch.org/vision/main/auto_examples/plot_tv_tensors.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
  105. )
  106. for cls in dataset_cls.mro():
  107. if cls in WRAPPER_FACTORIES:
  108. wrapper_factory = WRAPPER_FACTORIES[cls]
  109. if target_keys is not None and cls not in {
  110. datasets.CocoDetection,
  111. datasets.VOCDetection,
  112. datasets.Kitti,
  113. datasets.WIDERFace,
  114. }:
  115. raise ValueError(
  116. f"`target_keys` is currently only supported for `CocoDetection`, `VOCDetection`, `Kitti`, "
  117. f"and `WIDERFace`, but got {cls.__name__}."
  118. )
  119. break
  120. elif cls is datasets.VisionDataset:
  121. # TODO: If we have documentation on how to do that, put a link in the error message.
  122. msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
  123. if dataset_cls in datasets.__dict__.values():
  124. msg = (
  125. f"{msg} If an automated wrapper for this dataset would be useful for you, "
  126. f"please open an issue at https://github.com/pytorch/vision/issues."
  127. )
  128. raise TypeError(msg)
  129. self._dataset = dataset
  130. self._target_keys = target_keys
  131. self._wrapper = wrapper_factory(dataset, target_keys)
  132. # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
  133. # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
  134. # `transforms`
  135. # https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
  136. # some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
  137. # disable all three here to be able to extract the untransformed sample to wrap.
  138. self.transform, dataset.transform = dataset.transform, None
  139. self.target_transform, dataset.target_transform = dataset.target_transform, None
  140. self.transforms, dataset.transforms = dataset.transforms, None
  141. def __getattr__(self, item):
  142. with contextlib.suppress(AttributeError):
  143. return object.__getattribute__(self, item)
  144. return getattr(self._dataset, item)
  145. def __getitem__(self, idx):
  146. # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
  147. # of this class
  148. sample = self._dataset[idx]
  149. sample = self._wrapper(idx, sample)
  150. # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
  151. # or joint (`transforms`), we can access the full functionality through `transforms`
  152. if self.transforms is not None:
  153. sample = self.transforms(*sample)
  154. return sample
  155. def __len__(self):
  156. return len(self._dataset)
  157. # TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs.
  158. def __reduce__(self):
  159. # __reduce__ gets called when we try to pickle the dataset.
  160. # In a DataLoader with spawn context, this gets called `num_workers` times from the main process.
  161. # We have to reset the [target_]transform[s] attributes of the dataset
  162. # to their original values, because we previously set them to None in __init__().
  163. dataset = copy(self._dataset)
  164. dataset.transform = self.transform
  165. dataset.transforms = self.transforms
  166. dataset.target_transform = self.target_transform
  167. return wrap_dataset_for_transforms_v2, (dataset, self._target_keys)
  168. def raise_not_supported(description):
  169. raise RuntimeError(
  170. f"{description} is currently not supported by this wrapper. "
  171. f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues."
  172. )
  173. def identity(item):
  174. return item
  175. def identity_wrapper_factory(dataset, target_keys):
  176. def wrapper(idx, sample):
  177. return sample
  178. return wrapper
  179. def pil_image_to_mask(pil_image):
  180. return tv_tensors.Mask(pil_image)
  181. def parse_target_keys(target_keys, *, available, default):
  182. if target_keys is None:
  183. target_keys = default
  184. if target_keys == "all":
  185. target_keys = available
  186. else:
  187. target_keys = set(target_keys)
  188. extra = target_keys - available
  189. if extra:
  190. raise ValueError(f"Target keys {sorted(extra)} are not available")
  191. return target_keys
  192. def list_of_dicts_to_dict_of_lists(list_of_dicts):
  193. dict_of_lists = defaultdict(list)
  194. for dct in list_of_dicts:
  195. for key, value in dct.items():
  196. dict_of_lists[key].append(value)
  197. return dict(dict_of_lists)
  198. def wrap_target_by_type(target, *, target_types, type_wrappers):
  199. if not isinstance(target, (tuple, list)):
  200. target = [target]
  201. wrapped_target = tuple(
  202. type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target)
  203. )
  204. if len(wrapped_target) == 1:
  205. wrapped_target = wrapped_target[0]
  206. return wrapped_target
  207. def classification_wrapper_factory(dataset, target_keys):
  208. return identity_wrapper_factory(dataset, target_keys)
  209. for dataset_cls in [
  210. datasets.Caltech256,
  211. datasets.CIFAR10,
  212. datasets.CIFAR100,
  213. datasets.ImageNet,
  214. datasets.MNIST,
  215. datasets.FashionMNIST,
  216. datasets.GTSRB,
  217. datasets.DatasetFolder,
  218. datasets.ImageFolder,
  219. datasets.Imagenette,
  220. ]:
  221. WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)
  222. def segmentation_wrapper_factory(dataset, target_keys):
  223. def wrapper(idx, sample):
  224. image, mask = sample
  225. return image, pil_image_to_mask(mask)
  226. return wrapper
  227. for dataset_cls in [
  228. datasets.VOCSegmentation,
  229. ]:
  230. WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory)
  231. def video_classification_wrapper_factory(dataset, target_keys):
  232. if dataset.video_clips.output_format == "THWC":
  233. raise RuntimeError(
  234. f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, "
  235. f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
  236. )
  237. def wrapper(idx, sample):
  238. video, audio, label = sample
  239. video = tv_tensors.Video(video)
  240. return video, audio, label
  241. return wrapper
  242. for dataset_cls in [
  243. datasets.HMDB51,
  244. datasets.Kinetics,
  245. datasets.UCF101,
  246. ]:
  247. WRAPPER_FACTORIES.register(dataset_cls)(video_classification_wrapper_factory)
  248. @WRAPPER_FACTORIES.register(datasets.Caltech101)
  249. def caltech101_wrapper_factory(dataset, target_keys):
  250. if "annotation" in dataset.target_type:
  251. raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`")
  252. return classification_wrapper_factory(dataset, target_keys)
  253. @WRAPPER_FACTORIES.register(datasets.CocoDetection)
  254. def coco_dectection_wrapper_factory(dataset, target_keys):
  255. target_keys = parse_target_keys(
  256. target_keys,
  257. available={
  258. # native
  259. "segmentation",
  260. "area",
  261. "iscrowd",
  262. "image_id",
  263. "bbox",
  264. "category_id",
  265. # added by the wrapper
  266. "boxes",
  267. "masks",
  268. "labels",
  269. },
  270. default={"image_id", "boxes", "labels"},
  271. )
  272. def segmentation_to_mask(segmentation, *, canvas_size):
  273. from pycocotools import mask
  274. segmentation = (
  275. mask.frPyObjects(segmentation, *canvas_size)
  276. if isinstance(segmentation, dict)
  277. else mask.merge(mask.frPyObjects(segmentation, *canvas_size))
  278. )
  279. return torch.from_numpy(mask.decode(segmentation))
  280. def wrapper(idx, sample):
  281. image_id = dataset.ids[idx]
  282. image, target = sample
  283. if not target:
  284. return image, dict(image_id=image_id)
  285. canvas_size = tuple(F.get_size(image))
  286. batched_target = list_of_dicts_to_dict_of_lists(target)
  287. target = {}
  288. if "image_id" in target_keys:
  289. target["image_id"] = image_id
  290. if "boxes" in target_keys:
  291. target["boxes"] = F.convert_bounding_box_format(
  292. tv_tensors.BoundingBoxes(
  293. batched_target["bbox"],
  294. format=tv_tensors.BoundingBoxFormat.XYWH,
  295. canvas_size=canvas_size,
  296. ),
  297. new_format=tv_tensors.BoundingBoxFormat.XYXY,
  298. )
  299. if "masks" in target_keys:
  300. target["masks"] = tv_tensors.Mask(
  301. torch.stack(
  302. [
  303. segmentation_to_mask(segmentation, canvas_size=canvas_size)
  304. for segmentation in batched_target["segmentation"]
  305. ]
  306. ),
  307. )
  308. if "labels" in target_keys:
  309. target["labels"] = torch.tensor(batched_target["category_id"])
  310. for target_key in target_keys - {"image_id", "boxes", "masks", "labels"}:
  311. target[target_key] = batched_target[target_key]
  312. return image, target
  313. return wrapper
  314. WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory)
  315. VOC_DETECTION_CATEGORIES = [
  316. "__background__",
  317. "aeroplane",
  318. "bicycle",
  319. "bird",
  320. "boat",
  321. "bottle",
  322. "bus",
  323. "car",
  324. "cat",
  325. "chair",
  326. "cow",
  327. "diningtable",
  328. "dog",
  329. "horse",
  330. "motorbike",
  331. "person",
  332. "pottedplant",
  333. "sheep",
  334. "sofa",
  335. "train",
  336. "tvmonitor",
  337. ]
  338. VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES))))
  339. @WRAPPER_FACTORIES.register(datasets.VOCDetection)
  340. def voc_detection_wrapper_factory(dataset, target_keys):
  341. target_keys = parse_target_keys(
  342. target_keys,
  343. available={
  344. # native
  345. "annotation",
  346. # added by the wrapper
  347. "boxes",
  348. "labels",
  349. },
  350. default={"boxes", "labels"},
  351. )
  352. def wrapper(idx, sample):
  353. image, target = sample
  354. batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
  355. if "annotation" not in target_keys:
  356. target = {}
  357. if "boxes" in target_keys:
  358. target["boxes"] = tv_tensors.BoundingBoxes(
  359. [
  360. [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
  361. for bndbox in batched_instances["bndbox"]
  362. ],
  363. format=tv_tensors.BoundingBoxFormat.XYXY,
  364. canvas_size=(image.height, image.width),
  365. )
  366. if "labels" in target_keys:
  367. target["labels"] = torch.tensor(
  368. [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]]
  369. )
  370. return image, target
  371. return wrapper
  372. @WRAPPER_FACTORIES.register(datasets.SBDataset)
  373. def sbd_wrapper(dataset, target_keys):
  374. if dataset.mode == "boundaries":
  375. raise_not_supported("SBDataset with mode='boundaries'")
  376. return segmentation_wrapper_factory(dataset, target_keys)
  377. @WRAPPER_FACTORIES.register(datasets.CelebA)
  378. def celeba_wrapper_factory(dataset, target_keys):
  379. if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
  380. raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")
  381. def wrapper(idx, sample):
  382. image, target = sample
  383. target = wrap_target_by_type(
  384. target,
  385. target_types=dataset.target_type,
  386. type_wrappers={
  387. "bbox": lambda item: F.convert_bounding_box_format(
  388. tv_tensors.BoundingBoxes(
  389. item,
  390. format=tv_tensors.BoundingBoxFormat.XYWH,
  391. canvas_size=(image.height, image.width),
  392. ),
  393. new_format=tv_tensors.BoundingBoxFormat.XYXY,
  394. ),
  395. },
  396. )
  397. return image, target
  398. return wrapper
  399. KITTI_CATEGORIES = ["Car", "Van", "Truck", "Pedestrian", "Person_sitting", "Cyclist", "Tram", "Misc", "DontCare"]
  400. KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES))))
  401. @WRAPPER_FACTORIES.register(datasets.Kitti)
  402. def kitti_wrapper_factory(dataset, target_keys):
  403. target_keys = parse_target_keys(
  404. target_keys,
  405. available={
  406. # native
  407. "type",
  408. "truncated",
  409. "occluded",
  410. "alpha",
  411. "bbox",
  412. "dimensions",
  413. "location",
  414. "rotation_y",
  415. # added by the wrapper
  416. "boxes",
  417. "labels",
  418. },
  419. default={"boxes", "labels"},
  420. )
  421. def wrapper(idx, sample):
  422. image, target = sample
  423. if target is None:
  424. return image, target
  425. batched_target = list_of_dicts_to_dict_of_lists(target)
  426. target = {}
  427. if "boxes" in target_keys:
  428. target["boxes"] = tv_tensors.BoundingBoxes(
  429. batched_target["bbox"],
  430. format=tv_tensors.BoundingBoxFormat.XYXY,
  431. canvas_size=(image.height, image.width),
  432. )
  433. if "labels" in target_keys:
  434. target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in batched_target["type"]])
  435. for target_key in target_keys - {"boxes", "labels"}:
  436. target[target_key] = batched_target[target_key]
  437. return image, target
  438. return wrapper
  439. @WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
  440. def oxford_iiit_pet_wrapper_factor(dataset, target_keys):
  441. def wrapper(idx, sample):
  442. image, target = sample
  443. if target is not None:
  444. target = wrap_target_by_type(
  445. target,
  446. target_types=dataset._target_types,
  447. type_wrappers={
  448. "segmentation": pil_image_to_mask,
  449. },
  450. )
  451. return image, target
  452. return wrapper
  453. @WRAPPER_FACTORIES.register(datasets.Cityscapes)
  454. def cityscapes_wrapper_factory(dataset, target_keys):
  455. if any(target_type in dataset.target_type for target_type in ["polygon", "color"]):
  456. raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`")
  457. def instance_segmentation_wrapper(mask):
  458. # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21
  459. data = pil_image_to_mask(mask)
  460. masks = []
  461. labels = []
  462. for id in data.unique():
  463. masks.append(data == id)
  464. label = id
  465. if label >= 1_000:
  466. label //= 1_000
  467. labels.append(label)
  468. return dict(masks=tv_tensors.Mask(torch.stack(masks)), labels=torch.stack(labels))
  469. def wrapper(idx, sample):
  470. image, target = sample
  471. target = wrap_target_by_type(
  472. target,
  473. target_types=dataset.target_type,
  474. type_wrappers={
  475. "instance": instance_segmentation_wrapper,
  476. "semantic": pil_image_to_mask,
  477. },
  478. )
  479. return image, target
  480. return wrapper
  481. @WRAPPER_FACTORIES.register(datasets.WIDERFace)
  482. def widerface_wrapper(dataset, target_keys):
  483. target_keys = parse_target_keys(
  484. target_keys,
  485. available={
  486. "bbox",
  487. "blur",
  488. "expression",
  489. "illumination",
  490. "occlusion",
  491. "pose",
  492. "invalid",
  493. },
  494. default="all",
  495. )
  496. def wrapper(idx, sample):
  497. image, target = sample
  498. if target is None:
  499. return image, target
  500. target = {key: target[key] for key in target_keys}
  501. if "bbox" in target_keys:
  502. target["bbox"] = F.convert_bounding_box_format(
  503. tv_tensors.BoundingBoxes(
  504. target["bbox"], format=tv_tensors.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width)
  505. ),
  506. new_format=tv_tensors.BoundingBoxFormat.XYXY,
  507. )
  508. return image, target
  509. return wrapper