caltech.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import os
  2. import os.path
  3. from typing import Any, Callable, List, Optional, Tuple, Union
  4. from PIL import Image
  5. from .utils import download_and_extract_archive, verify_str_arg
  6. from .vision import VisionDataset
  7. class Caltech101(VisionDataset):
  8. """`Caltech 101 <https://data.caltech.edu/records/20086>`_ Dataset.
  9. .. warning::
  10. This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
  11. Args:
  12. root (string): Root directory of dataset where directory
  13. ``caltech101`` exists or will be saved to if download is set to True.
  14. target_type (string or list, optional): Type of target to use, ``category`` or
  15. ``annotation``. Can also be a list to output a tuple with all specified
  16. target types. ``category`` represents the target class, and
  17. ``annotation`` is a list of points from a hand-generated outline.
  18. Defaults to ``category``.
  19. transform (callable, optional): A function/transform that takes in an PIL image
  20. and returns a transformed version. E.g, ``transforms.RandomCrop``
  21. target_transform (callable, optional): A function/transform that takes in the
  22. target and transforms it.
  23. download (bool, optional): If true, downloads the dataset from the internet and
  24. puts it in root directory. If dataset is already downloaded, it is not
  25. downloaded again.
  26. .. warning::
  27. To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
  28. """
  29. def __init__(
  30. self,
  31. root: str,
  32. target_type: Union[List[str], str] = "category",
  33. transform: Optional[Callable] = None,
  34. target_transform: Optional[Callable] = None,
  35. download: bool = False,
  36. ) -> None:
  37. super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
  38. os.makedirs(self.root, exist_ok=True)
  39. if isinstance(target_type, str):
  40. target_type = [target_type]
  41. self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type]
  42. if download:
  43. self.download()
  44. if not self._check_integrity():
  45. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  46. self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
  47. self.categories.remove("BACKGROUND_Google") # this is not a real class
  48. # For some reason, the category names in "101_ObjectCategories" and
  49. # "Annotations" do not always match. This is a manual map between the
  50. # two. Defaults to using same name, since most names are fine.
  51. name_map = {
  52. "Faces": "Faces_2",
  53. "Faces_easy": "Faces_3",
  54. "Motorbikes": "Motorbikes_16",
  55. "airplanes": "Airplanes_Side_2",
  56. }
  57. self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
  58. self.index: List[int] = []
  59. self.y = []
  60. for (i, c) in enumerate(self.categories):
  61. n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
  62. self.index.extend(range(1, n + 1))
  63. self.y.extend(n * [i])
  64. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  65. """
  66. Args:
  67. index (int): Index
  68. Returns:
  69. tuple: (image, target) where the type of target specified by target_type.
  70. """
  71. import scipy.io
  72. img = Image.open(
  73. os.path.join(
  74. self.root,
  75. "101_ObjectCategories",
  76. self.categories[self.y[index]],
  77. f"image_{self.index[index]:04d}.jpg",
  78. )
  79. )
  80. target: Any = []
  81. for t in self.target_type:
  82. if t == "category":
  83. target.append(self.y[index])
  84. elif t == "annotation":
  85. data = scipy.io.loadmat(
  86. os.path.join(
  87. self.root,
  88. "Annotations",
  89. self.annotation_categories[self.y[index]],
  90. f"annotation_{self.index[index]:04d}.mat",
  91. )
  92. )
  93. target.append(data["obj_contour"])
  94. target = tuple(target) if len(target) > 1 else target[0]
  95. if self.transform is not None:
  96. img = self.transform(img)
  97. if self.target_transform is not None:
  98. target = self.target_transform(target)
  99. return img, target
  100. def _check_integrity(self) -> bool:
  101. # can be more robust and check hash of files
  102. return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
  103. def __len__(self) -> int:
  104. return len(self.index)
  105. def download(self) -> None:
  106. if self._check_integrity():
  107. print("Files already downloaded and verified")
  108. return
  109. download_and_extract_archive(
  110. "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
  111. self.root,
  112. filename="101_ObjectCategories.tar.gz",
  113. md5="b224c7392d521a49829488ab0f1120d9",
  114. )
  115. download_and_extract_archive(
  116. "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
  117. self.root,
  118. filename="Annotations.tar",
  119. md5="6f83eeb1f24d99cab4eb377263132c91",
  120. )
  121. def extra_repr(self) -> str:
  122. return "Target type: {target_type}".format(**self.__dict__)
  123. class Caltech256(VisionDataset):
  124. """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
  125. Args:
  126. root (string): Root directory of dataset where directory
  127. ``caltech256`` exists or will be saved to if download is set to True.
  128. transform (callable, optional): A function/transform that takes in an PIL image
  129. and returns a transformed version. E.g, ``transforms.RandomCrop``
  130. target_transform (callable, optional): A function/transform that takes in the
  131. target and transforms it.
  132. download (bool, optional): If true, downloads the dataset from the internet and
  133. puts it in root directory. If dataset is already downloaded, it is not
  134. downloaded again.
  135. """
  136. def __init__(
  137. self,
  138. root: str,
  139. transform: Optional[Callable] = None,
  140. target_transform: Optional[Callable] = None,
  141. download: bool = False,
  142. ) -> None:
  143. super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
  144. os.makedirs(self.root, exist_ok=True)
  145. if download:
  146. self.download()
  147. if not self._check_integrity():
  148. raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
  149. self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
  150. self.index: List[int] = []
  151. self.y = []
  152. for (i, c) in enumerate(self.categories):
  153. n = len(
  154. [
  155. item
  156. for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
  157. if item.endswith(".jpg")
  158. ]
  159. )
  160. self.index.extend(range(1, n + 1))
  161. self.y.extend(n * [i])
  162. def __getitem__(self, index: int) -> Tuple[Any, Any]:
  163. """
  164. Args:
  165. index (int): Index
  166. Returns:
  167. tuple: (image, target) where target is index of the target class.
  168. """
  169. img = Image.open(
  170. os.path.join(
  171. self.root,
  172. "256_ObjectCategories",
  173. self.categories[self.y[index]],
  174. f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
  175. )
  176. )
  177. target = self.y[index]
  178. if self.transform is not None:
  179. img = self.transform(img)
  180. if self.target_transform is not None:
  181. target = self.target_transform(target)
  182. return img, target
  183. def _check_integrity(self) -> bool:
  184. # can be more robust and check hash of files
  185. return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
  186. def __len__(self) -> int:
  187. return len(self.index)
  188. def download(self) -> None:
  189. if self._check_integrity():
  190. print("Files already downloaded and verified")
  191. return
  192. download_and_extract_archive(
  193. "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
  194. self.root,
  195. filename="256_ObjectCategories.tar",
  196. md5="67b4f42ca05d46448c6bb8ecd2220f6d",
  197. )