kitti.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import csv
  2. import os
  3. from typing import Any, Callable, List, Optional, Tuple
  4. from PIL import Image
  5. from .utils import download_and_extract_archive
  6. from .vision import VisionDataset
  7. class Kitti(VisionDataset):
  8. """`KITTI <http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark>`_ Dataset.
  9. It corresponds to the "left color images of object" dataset, for object detection.
  10. Args:
  11. root (string): Root directory where images are downloaded to.
  12. Expects the following folder structure if download=False:
  13. .. code::
  14. <root>
  15. └── Kitti
  16. └─ raw
  17. ├── training
  18. | ├── image_2
  19. | └── label_2
  20. └── testing
  21. └── image_2
  22. train (bool, optional): Use ``train`` split if true, else ``test`` split.
  23. Defaults to ``train``.
  24. transform (callable, optional): A function/transform that takes in a PIL image
  25. and returns a transformed version. E.g, ``transforms.PILToTensor``
  26. target_transform (callable, optional): A function/transform that takes in the
  27. target and transforms it.
  28. transforms (callable, optional): A function/transform that takes input sample
  29. and its target as entry and returns a transformed version.
  30. download (bool, optional): If true, downloads the dataset from the internet and
  31. puts it in root directory. If dataset is already downloaded, it is not
  32. downloaded again.
  33. """
  34. data_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/"
  35. resources = [
  36. "data_object_image_2.zip",
  37. "data_object_label_2.zip",
  38. ]
  39. image_dir_name = "image_2"
  40. labels_dir_name = "label_2"
  41. def __init__(
  42. self,
  43. root: str,
  44. train: bool = True,
  45. transform: Optional[Callable] = None,
  46. target_transform: Optional[Callable] = None,
  47. transforms: Optional[Callable] = None,
  48. download: bool = False,
  49. ):
  50. super().__init__(
  51. root,
  52. transform=transform,
  53. target_transform=target_transform,
  54. transforms=transforms,
  55. )
  56. self.images = []
  57. self.targets = []
  58. self.train = train
  59. self._location = "training" if self.train else "testing"
  60. if download:
  61. self.download()
  62. if not self._check_exists():
  63. raise RuntimeError("Dataset not found. You may use download=True to download it.")
  64. image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
  65. if self.train:
  66. labels_dir = os.path.join(self._raw_folder, self._location, self.labels_dir_name)
  67. for img_file in os.listdir(image_dir):
  68. self.images.append(os.path.join(image_dir, img_file))
  69. if self.train:
  70. self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt"))
  71. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  72. """Get item at a given index.
  73. Args:
  74. index (int): Index
  75. Returns:
  76. tuple: (image, target), where
  77. target is a list of dictionaries with the following keys:
  78. - type: str
  79. - truncated: float
  80. - occluded: int
  81. - alpha: float
  82. - bbox: float[4]
  83. - dimensions: float[3]
  84. - locations: float[3]
  85. - rotation_y: float
  86. """
  87. image = Image.open(self.images[index])
  88. target = self._parse_target(index) if self.train else None
  89. if self.transforms:
  90. image, target = self.transforms(image, target)
  91. return image, target
  92. def _parse_target(self, index: int) -> List:
  93. target = []
  94. with open(self.targets[index]) as inp:
  95. content = csv.reader(inp, delimiter=" ")
  96. for line in content:
  97. target.append(
  98. {
  99. "type": line[0],
  100. "truncated": float(line[1]),
  101. "occluded": int(line[2]),
  102. "alpha": float(line[3]),
  103. "bbox": [float(x) for x in line[4:8]],
  104. "dimensions": [float(x) for x in line[8:11]],
  105. "location": [float(x) for x in line[11:14]],
  106. "rotation_y": float(line[14]),
  107. }
  108. )
  109. return target
  110. def __len__(self) -> int:
  111. return len(self.images)
  112. @property
  113. def _raw_folder(self) -> str:
  114. return os.path.join(self.root, self.__class__.__name__, "raw")
  115. def _check_exists(self) -> bool:
  116. """Check if the data directory exists."""
  117. folders = [self.image_dir_name]
  118. if self.train:
  119. folders.append(self.labels_dir_name)
  120. return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders)
  121. def download(self) -> None:
  122. """Download the KITTI data if it doesn't exist already."""
  123. if self._check_exists():
  124. return
  125. os.makedirs(self._raw_folder, exist_ok=True)
  126. # download files
  127. for fname in self.resources:
  128. download_and_extract_archive(
  129. url=f"{self.data_url}{fname}",
  130. download_root=self._raw_folder,
  131. filename=fname,
  132. )