utils.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import hashlib
  3. import json
  4. import os
  5. import random
  6. import subprocess
  7. import time
  8. import zipfile
  9. from multiprocessing.pool import ThreadPool
  10. from pathlib import Path
  11. from tarfile import is_tarfile
  12. import cv2
  13. import numpy as np
  14. from PIL import Image, ImageOps
  15. from ultralytics.nn.autobackend import check_class_names
  16. from ultralytics.utils import (
  17. DATASETS_DIR,
  18. LOGGER,
  19. NUM_THREADS,
  20. ROOT,
  21. SETTINGS_FILE,
  22. TQDM,
  23. clean_url,
  24. colorstr,
  25. emojis,
  26. is_dir_writeable,
  27. yaml_load,
  28. yaml_save,
  29. )
  30. from ultralytics.utils.checks import check_file, check_font, is_ascii
  31. from ultralytics.utils.downloads import download, safe_download, unzip_file
  32. from ultralytics.utils.ops import segments2boxes
  33. HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance."
  34. IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm", "heic"} # image suffixes
  35. VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
  36. PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
  37. FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
  38. def img2label_paths(img_paths):
  39. """Define label paths as a function of image paths."""
  40. sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
  41. return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
  42. def get_hash(paths):
  43. """Returns a single hash value of a list of paths (files or dirs)."""
  44. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  45. h = hashlib.sha256(str(size).encode()) # hash sizes
  46. h.update("".join(paths).encode()) # hash paths
  47. return h.hexdigest() # return hash
  48. def exif_size(img: Image.Image):
  49. """Returns exif-corrected PIL size."""
  50. s = img.size # (width, height)
  51. if img.format == "JPEG": # only support JPEG images
  52. try:
  53. if exif := img.getexif():
  54. rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
  55. if rotation in {6, 8}: # rotation 270 or 90
  56. s = s[1], s[0]
  57. except Exception:
  58. pass
  59. return s
  60. def verify_image(args):
  61. """Verify one image."""
  62. (im_file, cls), prefix = args
  63. # Number (found, corrupt), message
  64. nf, nc, msg = 0, 0, ""
  65. try:
  66. im = Image.open(im_file)
  67. im.verify() # PIL verify
  68. shape = exif_size(im) # image size
  69. shape = (shape[1], shape[0]) # hw
  70. assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
  71. assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
  72. if im.format.lower() in {"jpg", "jpeg"}:
  73. with open(im_file, "rb") as f:
  74. f.seek(-2, 2)
  75. if f.read() != b"\xff\xd9": # corrupt JPEG
  76. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
  77. msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
  78. nf = 1
  79. except Exception as e:
  80. nc = 1
  81. msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
  82. return (im_file, cls), nf, nc, msg
  83. def verify_image_label(args):
  84. """Verify one image-label pair."""
  85. im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
  86. # Number (missing, found, empty, corrupt), message, segments, keypoints
  87. nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
  88. try:
  89. # Verify images
  90. im = Image.open(im_file)
  91. im.verify() # PIL verify
  92. shape = exif_size(im) # image size
  93. shape = (shape[1], shape[0]) # hw
  94. assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
  95. assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
  96. if im.format.lower() in {"jpg", "jpeg"}:
  97. with open(im_file, "rb") as f:
  98. f.seek(-2, 2)
  99. if f.read() != b"\xff\xd9": # corrupt JPEG
  100. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
  101. msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
  102. # Verify labels
  103. if os.path.isfile(lb_file):
  104. nf = 1 # label found
  105. with open(lb_file) as f:
  106. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  107. if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
  108. classes = np.array([x[0] for x in lb], dtype=np.float32)
  109. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
  110. lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  111. lb = np.array(lb, dtype=np.float32)
  112. if nl := len(lb):
  113. if keypoint:
  114. assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
  115. points = lb[:, 5:].reshape(-1, ndim)[:, :2]
  116. else:
  117. assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
  118. points = lb[:, 1:]
  119. assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
  120. assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"
  121. # All labels
  122. max_cls = lb[:, 0].max() # max label count
  123. assert max_cls <= num_cls, (
  124. f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
  125. f"Possible class labels are 0-{num_cls - 1}"
  126. )
  127. _, i = np.unique(lb, axis=0, return_index=True)
  128. if len(i) < nl: # duplicate row check
  129. lb = lb[i] # remove duplicates
  130. if segments:
  131. segments = [segments[x] for x in i]
  132. msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
  133. else:
  134. ne = 1 # label empty
  135. lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
  136. else:
  137. nm = 1 # label missing
  138. lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32)
  139. if keypoint:
  140. keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
  141. if ndim == 2:
  142. kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
  143. keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
  144. lb = lb[:, :5]
  145. return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
  146. except Exception as e:
  147. nc = 1
  148. msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
  149. return [None, None, None, None, None, nm, nf, ne, nc, msg]
  150. def visualize_image_annotations(image_path, txt_path, label_map):
  151. """
  152. Visualizes YOLO annotations (bounding boxes and class labels) on an image.
  153. This function reads an image and its corresponding annotation file in YOLO format, then
  154. draws bounding boxes around detected objects and labels them with their respective class names.
  155. The bounding box colors are assigned based on the class ID, and the text color is dynamically
  156. adjusted for readability, depending on the background color's luminance.
  157. Args:
  158. image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL (e.g., .jpg, .png).
  159. txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object with:
  160. - class_id (int): The class index.
  161. - x_center (float): The X center of the bounding box (relative to image width).
  162. - y_center (float): The Y center of the bounding box (relative to image height).
  163. - width (float): The width of the bounding box (relative to image width).
  164. - height (float): The height of the bounding box (relative to image height).
  165. label_map (dict): A dictionary that maps class IDs (integers) to class labels (strings).
  166. Example:
  167. >>> label_map = {0: "cat", 1: "dog", 2: "bird"} # It should include all annotated classes details
  168. >>> visualize_image_annotations("path/to/image.jpg", "path/to/annotations.txt", label_map)
  169. """
  170. import matplotlib.pyplot as plt
  171. from ultralytics.utils.plotting import colors
  172. img = np.array(Image.open(image_path))
  173. img_height, img_width = img.shape[:2]
  174. annotations = []
  175. with open(txt_path) as file:
  176. for line in file:
  177. class_id, x_center, y_center, width, height = map(float, line.split())
  178. x = (x_center - width / 2) * img_width
  179. y = (y_center - height / 2) * img_height
  180. w = width * img_width
  181. h = height * img_height
  182. annotations.append((x, y, w, h, int(class_id)))
  183. fig, ax = plt.subplots(1) # Plot the image and annotations
  184. for x, y, w, h, label in annotations:
  185. color = tuple(c / 255 for c in colors(label, True)) # Get and normalize the RGB color
  186. rect = plt.Rectangle((x, y), w, h, linewidth=2, edgecolor=color, facecolor="none") # Create a rectangle
  187. ax.add_patch(rect)
  188. luminance = 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2] # Formula for luminance
  189. ax.text(x, y - 5, label_map[label], color="white" if luminance < 0.5 else "black", backgroundcolor=color)
  190. ax.imshow(img)
  191. plt.show()
  192. def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
  193. """
  194. Convert a list of polygons to a binary mask of the specified image size.
  195. Args:
  196. imgsz (tuple): The size of the image as (height, width).
  197. polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
  198. N is the number of polygons, and M is the number of points such that M % 2 = 0.
  199. color (int, optional): The color value to fill in the polygons on the mask. Defaults to 1.
  200. downsample_ratio (int, optional): Factor by which to downsample the mask. Defaults to 1.
  201. Returns:
  202. (np.ndarray): A binary mask of the specified image size with the polygons filled in.
  203. """
  204. mask = np.zeros(imgsz, dtype=np.uint8)
  205. polygons = np.asarray(polygons, dtype=np.int32)
  206. polygons = polygons.reshape((polygons.shape[0], -1, 2))
  207. cv2.fillPoly(mask, polygons, color=color)
  208. nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
  209. # Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1
  210. return cv2.resize(mask, (nw, nh))
  211. def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
  212. """
  213. Convert a list of polygons to a set of binary masks of the specified image size.
  214. Args:
  215. imgsz (tuple): The size of the image as (height, width).
  216. polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
  217. N is the number of polygons, and M is the number of points such that M % 2 = 0.
  218. color (int): The color value to fill in the polygons on the masks.
  219. downsample_ratio (int, optional): Factor by which to downsample each mask. Defaults to 1.
  220. Returns:
  221. (np.ndarray): A set of binary masks of the specified image size with the polygons filled in.
  222. """
  223. return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
  224. def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
  225. """Return a (640, 640) overlap mask."""
  226. masks = np.zeros(
  227. (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
  228. dtype=np.int32 if len(segments) > 255 else np.uint8,
  229. )
  230. areas = []
  231. ms = []
  232. for si in range(len(segments)):
  233. mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
  234. ms.append(mask.astype(masks.dtype))
  235. areas.append(mask.sum())
  236. areas = np.asarray(areas)
  237. index = np.argsort(-areas)
  238. ms = np.array(ms)[index]
  239. for i in range(len(segments)):
  240. mask = ms[i] * (i + 1)
  241. masks = masks + mask
  242. masks = np.clip(masks, a_min=0, a_max=i + 1)
  243. return masks, index
  244. def find_dataset_yaml(path: Path) -> Path:
  245. """
  246. Find and return the YAML file associated with a Detect, Segment or Pose dataset.
  247. This function searches for a YAML file at the root level of the provided directory first, and if not found, it
  248. performs a recursive search. It prefers YAML files that have the same stem as the provided path. An AssertionError
  249. is raised if no YAML file is found or if multiple YAML files are found.
  250. Args:
  251. path (Path): The directory path to search for the YAML file.
  252. Returns:
  253. (Path): The path of the found YAML file.
  254. """
  255. files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
  256. assert files, f"No YAML file found in '{path.resolve()}'"
  257. if len(files) > 1:
  258. files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
  259. assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}"
  260. return files[0]
  261. def check_det_dataset(dataset, autodownload=True):
  262. """
  263. Download, verify, and/or unzip a dataset if not found locally.
  264. This function checks the availability of a specified dataset, and if not found, it has the option to download and
  265. unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
  266. resolves paths related to the dataset.
  267. Args:
  268. dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
  269. autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True.
  270. Returns:
  271. (dict): Parsed dataset information and paths.
  272. """
  273. file = check_file(dataset)
  274. # Download (optional)
  275. extract_dir = ""
  276. if zipfile.is_zipfile(file) or is_tarfile(file):
  277. new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
  278. file = find_dataset_yaml(DATASETS_DIR / new_dir)
  279. extract_dir, autodownload = file.parent, False
  280. # Read YAML
  281. data = yaml_load(file, append_filename=True) # dictionary
  282. # Checks
  283. for k in "train", "val":
  284. if k not in data:
  285. if k != "val" or "validation" not in data:
  286. raise SyntaxError(
  287. emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
  288. )
  289. LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
  290. data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
  291. if "names" not in data and "nc" not in data:
  292. raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
  293. if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
  294. raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
  295. if "names" not in data:
  296. data["names"] = [f"class_{i}" for i in range(data["nc"])]
  297. else:
  298. data["nc"] = len(data["names"])
  299. data["names"] = check_class_names(data["names"])
  300. # Resolve paths
  301. path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
  302. if not path.is_absolute():
  303. path = (DATASETS_DIR / path).resolve()
  304. # Set paths
  305. data["path"] = path # download scripts
  306. for k in "train", "val", "test", "minival":
  307. if data.get(k): # prepend path
  308. if isinstance(data[k], str):
  309. x = (path / data[k]).resolve()
  310. if not x.exists() and data[k].startswith("../"):
  311. x = (path / data[k][3:]).resolve()
  312. data[k] = str(x)
  313. else:
  314. data[k] = [str((path / x).resolve()) for x in data[k]]
  315. # Parse YAML
  316. val, s = (data.get(x) for x in ("val", "download"))
  317. if val:
  318. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  319. if not all(x.exists() for x in val):
  320. name = clean_url(dataset) # dataset name with URL auth stripped
  321. m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'"
  322. if s and autodownload:
  323. LOGGER.warning(m)
  324. else:
  325. m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_FILE}'"
  326. raise FileNotFoundError(m)
  327. t = time.time()
  328. r = None # success
  329. if s.startswith("http") and s.endswith(".zip"): # URL
  330. safe_download(url=s, dir=DATASETS_DIR, delete=True)
  331. elif s.startswith("bash "): # bash script
  332. LOGGER.info(f"Running {s} ...")
  333. r = os.system(s)
  334. else: # python script
  335. exec(s, {"yaml": data})
  336. dt = f"({round(time.time() - t, 1)}s)"
  337. s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
  338. LOGGER.info(f"Dataset download {s}\n")
  339. check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
  340. return data # dictionary
  341. def check_cls_dataset(dataset, split=""):
  342. """
  343. Checks a classification dataset such as Imagenet.
  344. This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
  345. If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
  346. Args:
  347. dataset (str | Path): The name of the dataset.
  348. split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
  349. Returns:
  350. (dict): A dictionary containing the following keys:
  351. - 'train' (Path): The directory path containing the training set of the dataset.
  352. - 'val' (Path): The directory path containing the validation set of the dataset.
  353. - 'test' (Path): The directory path containing the test set of the dataset.
  354. - 'nc' (int): The number of classes in the dataset.
  355. - 'names' (dict): A dictionary of class names in the dataset.
  356. """
  357. # Download (optional if dataset=https://file.zip is passed directly)
  358. if str(dataset).startswith(("http:/", "https:/")):
  359. dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
  360. elif Path(dataset).suffix in {".zip", ".tar", ".gz"}:
  361. file = check_file(dataset)
  362. dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
  363. dataset = Path(dataset)
  364. data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
  365. if not data_dir.is_dir():
  366. LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...")
  367. t = time.time()
  368. if str(dataset) == "imagenet":
  369. subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
  370. else:
  371. url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{dataset}.zip"
  372. download(url, dir=data_dir.parent)
  373. s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
  374. LOGGER.info(s)
  375. train_set = data_dir / "train"
  376. val_set = (
  377. data_dir / "val"
  378. if (data_dir / "val").exists()
  379. else data_dir / "validation"
  380. if (data_dir / "validation").exists()
  381. else None
  382. ) # data/test or data/val
  383. test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
  384. if split == "val" and not val_set:
  385. LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
  386. elif split == "test" and not test_set:
  387. LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
  388. nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
  389. names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
  390. names = dict(enumerate(sorted(names)))
  391. # Print to console
  392. for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
  393. prefix = f"{colorstr(f'{k}:')} {v}..."
  394. if v is None:
  395. LOGGER.info(prefix)
  396. else:
  397. files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
  398. nf = len(files) # number of files
  399. nd = len({file.parent for file in files}) # number of directories
  400. if nf == 0:
  401. if k == "train":
  402. raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
  403. else:
  404. LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found")
  405. elif nd != nc:
  406. LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}")
  407. else:
  408. LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
  409. return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
  410. class HUBDatasetStats:
  411. """
  412. A class for generating HUB dataset JSON and `-hub` dataset directory.
  413. Args:
  414. path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'.
  415. task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
  416. autodownload (bool): Attempt to download dataset if not found locally. Default is False.
  417. Example:
  418. Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
  419. i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
  420. ```python
  421. from ultralytics.data.utils import HUBDatasetStats
  422. stats = HUBDatasetStats("path/to/coco8.zip", task="detect") # detect dataset
  423. stats = HUBDatasetStats("path/to/coco8-seg.zip", task="segment") # segment dataset
  424. stats = HUBDatasetStats("path/to/coco8-pose.zip", task="pose") # pose dataset
  425. stats = HUBDatasetStats("path/to/dota8.zip", task="obb") # OBB dataset
  426. stats = HUBDatasetStats("path/to/imagenet10.zip", task="classify") # classification dataset
  427. stats.get_json(save=True)
  428. stats.process_images()
  429. ```
  430. """
  431. def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
  432. """Initialize class."""
  433. path = Path(path).resolve()
  434. LOGGER.info(f"Starting HUB dataset checks for {path}....")
  435. self.task = task # detect, segment, pose, classify, obb
  436. if self.task == "classify":
  437. unzip_dir = unzip_file(path)
  438. data = check_cls_dataset(unzip_dir)
  439. data["path"] = unzip_dir
  440. else: # detect, segment, pose, obb
  441. _, data_dir, yaml_path = self._unzip(Path(path))
  442. try:
  443. # Load YAML with checks
  444. data = yaml_load(yaml_path)
  445. data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
  446. yaml_save(yaml_path, data)
  447. data = check_det_dataset(yaml_path, autodownload) # dict
  448. data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
  449. except Exception as e:
  450. raise Exception("error/HUB/dataset_stats/init") from e
  451. self.hub_dir = Path(f"{data['path']}-hub")
  452. self.im_dir = self.hub_dir / "images"
  453. self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
  454. self.data = data
  455. @staticmethod
  456. def _unzip(path):
  457. """Unzip data.zip."""
  458. if not str(path).endswith(".zip"): # path is data.yaml
  459. return False, None, path
  460. unzip_dir = unzip_file(path, path=path.parent)
  461. assert unzip_dir.is_dir(), (
  462. f"Error unzipping {path}, {unzip_dir} not found. path/to/abc.zip MUST unzip to path/to/abc/"
  463. )
  464. return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
  465. def _hub_ops(self, f):
  466. """Saves a compressed image for HUB previews."""
  467. compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
  468. def get_json(self, save=False, verbose=False):
  469. """Return dataset JSON for Ultralytics HUB."""
  470. def _round(labels):
  471. """Update labels to integer class and 4 decimal place floats."""
  472. if self.task == "detect":
  473. coordinates = labels["bboxes"]
  474. elif self.task in {"segment", "obb"}: # Segment and OBB use segments. OBB segments are normalized xyxyxyxy
  475. coordinates = [x.flatten() for x in labels["segments"]]
  476. elif self.task == "pose":
  477. n, nk, nd = labels["keypoints"].shape
  478. coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, nk * nd)), 1)
  479. else:
  480. raise ValueError(f"Undefined dataset task={self.task}.")
  481. zipped = zip(labels["cls"], coordinates)
  482. return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
  483. for split in "train", "val", "test":
  484. self.stats[split] = None # predefine
  485. path = self.data.get(split)
  486. # Check split
  487. if path is None: # no split
  488. continue
  489. files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
  490. if not files: # no images
  491. continue
  492. # Get dataset statistics
  493. if self.task == "classify":
  494. from torchvision.datasets import ImageFolder # scope for faster 'import ultralytics'
  495. dataset = ImageFolder(self.data[split])
  496. x = np.zeros(len(dataset.classes)).astype(int)
  497. for im in dataset.imgs:
  498. x[im[1]] += 1
  499. self.stats[split] = {
  500. "instance_stats": {"total": len(dataset), "per_class": x.tolist()},
  501. "image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
  502. "labels": [{Path(k).name: v} for k, v in dataset.imgs],
  503. }
  504. else:
  505. from ultralytics.data import YOLODataset
  506. dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
  507. x = np.array(
  508. [
  509. np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
  510. for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
  511. ]
  512. ) # shape(128x80)
  513. self.stats[split] = {
  514. "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
  515. "image_stats": {
  516. "total": len(dataset),
  517. "unlabelled": int(np.all(x == 0, 1).sum()),
  518. "per_class": (x > 0).sum(0).tolist(),
  519. },
  520. "labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
  521. }
  522. # Save, print and return
  523. if save:
  524. self.hub_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/
  525. stats_path = self.hub_dir / "stats.json"
  526. LOGGER.info(f"Saving {stats_path.resolve()}...")
  527. with open(stats_path, "w") as f:
  528. json.dump(self.stats, f) # save stats.json
  529. if verbose:
  530. LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
  531. return self.stats
  532. def process_images(self):
  533. """Compress images for Ultralytics HUB."""
  534. from ultralytics.data import YOLODataset # ClassificationDataset
  535. self.im_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/images/
  536. for split in "train", "val", "test":
  537. if self.data.get(split) is None:
  538. continue
  539. dataset = YOLODataset(img_path=self.data[split], data=self.data)
  540. with ThreadPool(NUM_THREADS) as pool:
  541. for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
  542. pass
  543. LOGGER.info(f"Done. All images saved to {self.im_dir}")
  544. return self.im_dir
  545. def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
  546. """
  547. Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python
  548. Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
  549. resized.
  550. Args:
  551. f (str): The path to the input image file.
  552. f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
  553. max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
  554. quality (int, optional): The image compression quality as a percentage. Default is 50%.
  555. Example:
  556. ```python
  557. from pathlib import Path
  558. from ultralytics.data.utils import compress_one_image
  559. for f in Path("path/to/dataset").rglob("*.jpg"):
  560. compress_one_image(f)
  561. ```
  562. """
  563. try: # use PIL
  564. im = Image.open(f)
  565. r = max_dim / max(im.height, im.width) # ratio
  566. if r < 1.0: # image too large
  567. im = im.resize((int(im.width * r), int(im.height * r)))
  568. im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
  569. except Exception as e: # use OpenCV
  570. LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}")
  571. im = cv2.imread(f)
  572. im_height, im_width = im.shape[:2]
  573. r = max_dim / max(im_height, im_width) # ratio
  574. if r < 1.0: # image too large
  575. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
  576. cv2.imwrite(str(f_new or f), im)
  577. def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
  578. """
  579. Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
  580. Args:
  581. path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'.
  582. weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
  583. annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
  584. Example:
  585. ```python
  586. from ultralytics.data.utils import autosplit
  587. autosplit()
  588. ```
  589. """
  590. path = Path(path) # images dir
  591. files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  592. n = len(files) # number of files
  593. random.seed(0) # for reproducibility
  594. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  595. txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
  596. for x in txt:
  597. if (path.parent / x).exists():
  598. (path.parent / x).unlink() # remove existing
  599. LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
  600. for i, img in TQDM(zip(indices, files), total=n):
  601. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  602. with open(path.parent / txt[i], "a") as f:
  603. f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
  604. def load_dataset_cache_file(path):
  605. """Load an Ultralytics *.cache dictionary from path."""
  606. import gc
  607. gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
  608. cache = np.load(str(path), allow_pickle=True).item() # load dict
  609. gc.enable()
  610. return cache
  611. def save_dataset_cache_file(prefix, path, x, version):
  612. """Save an Ultralytics dataset *.cache dictionary x to path."""
  613. x["version"] = version # add cache version
  614. if is_dir_writeable(path.parent):
  615. if path.exists():
  616. path.unlink() # remove *.cache file if exists
  617. np.save(str(path), x) # save cache for next time
  618. path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
  619. LOGGER.info(f"{prefix}New cache created: {path}")
  620. else:
  621. LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")