split_dota.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import itertools
  3. from glob import glob
  4. from math import ceil
  5. from pathlib import Path
  6. import cv2
  7. import numpy as np
  8. from PIL import Image
  9. from tqdm import tqdm
  10. from ultralytics.data.utils import exif_size, img2label_paths
  11. from ultralytics.utils.checks import check_requirements
  12. def bbox_iof(polygon1, bbox2, eps=1e-6):
  13. """
  14. Calculate Intersection over Foreground (IoF) between polygons and bounding boxes.
  15. Args:
  16. polygon1 (np.ndarray): Polygon coordinates, shape (n, 8).
  17. bbox2 (np.ndarray): Bounding boxes, shape (n, 4).
  18. eps (float, optional): Small value to prevent division by zero. Defaults to 1e-6.
  19. Returns:
  20. (np.ndarray): IoF scores, shape (n, 1) or (n, m) if bbox2 is (m, 4).
  21. Note:
  22. Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4].
  23. Bounding box format: [x_min, y_min, x_max, y_max].
  24. """
  25. check_requirements("shapely")
  26. from shapely.geometry import Polygon
  27. polygon1 = polygon1.reshape(-1, 4, 2)
  28. lt_point = np.min(polygon1, axis=-2) # left-top
  29. rb_point = np.max(polygon1, axis=-2) # right-bottom
  30. bbox1 = np.concatenate([lt_point, rb_point], axis=-1)
  31. lt = np.maximum(bbox1[:, None, :2], bbox2[..., :2])
  32. rb = np.minimum(bbox1[:, None, 2:], bbox2[..., 2:])
  33. wh = np.clip(rb - lt, 0, np.inf)
  34. h_overlaps = wh[..., 0] * wh[..., 1]
  35. left, top, right, bottom = (bbox2[..., i] for i in range(4))
  36. polygon2 = np.stack([left, top, right, top, right, bottom, left, bottom], axis=-1).reshape(-1, 4, 2)
  37. sg_polys1 = [Polygon(p) for p in polygon1]
  38. sg_polys2 = [Polygon(p) for p in polygon2]
  39. overlaps = np.zeros(h_overlaps.shape)
  40. for p in zip(*np.nonzero(h_overlaps)):
  41. overlaps[p] = sg_polys1[p[0]].intersection(sg_polys2[p[-1]]).area
  42. unions = np.array([p.area for p in sg_polys1], dtype=np.float32)
  43. unions = unions[..., None]
  44. unions = np.clip(unions, eps, np.inf)
  45. outputs = overlaps / unions
  46. if outputs.ndim == 1:
  47. outputs = outputs[..., None]
  48. return outputs
  49. def load_yolo_dota(data_root, split="train"):
  50. """
  51. Load DOTA dataset.
  52. Args:
  53. data_root (str): Data root.
  54. split (str): The split data set, could be `train` or `val`.
  55. Notes:
  56. The directory structure assumed for the DOTA dataset:
  57. - data_root
  58. - images
  59. - train
  60. - val
  61. - labels
  62. - train
  63. - val
  64. """
  65. assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}."
  66. im_dir = Path(data_root) / "images" / split
  67. assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
  68. im_files = glob(str(Path(data_root) / "images" / split / "*"))
  69. lb_files = img2label_paths(im_files)
  70. annos = []
  71. for im_file, lb_file in zip(im_files, lb_files):
  72. w, h = exif_size(Image.open(im_file))
  73. with open(lb_file) as f:
  74. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  75. lb = np.array(lb, dtype=np.float32)
  76. annos.append(dict(ori_size=(h, w), label=lb, filepath=im_file))
  77. return annos
  78. def get_windows(im_size, crop_sizes=(1024,), gaps=(200,), im_rate_thr=0.6, eps=0.01):
  79. """
  80. Get the coordinates of windows.
  81. Args:
  82. im_size (tuple): Original image size, (h, w).
  83. crop_sizes (List(int)): Crop size of windows.
  84. gaps (List(int)): Gap between crops.
  85. im_rate_thr (float): Threshold of windows areas divided by image ares.
  86. eps (float): Epsilon value for math operations.
  87. """
  88. h, w = im_size
  89. windows = []
  90. for crop_size, gap in zip(crop_sizes, gaps):
  91. assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]"
  92. step = crop_size - gap
  93. xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
  94. xs = [step * i for i in range(xn)]
  95. if len(xs) > 1 and xs[-1] + crop_size > w:
  96. xs[-1] = w - crop_size
  97. yn = 1 if h <= crop_size else ceil((h - crop_size) / step + 1)
  98. ys = [step * i for i in range(yn)]
  99. if len(ys) > 1 and ys[-1] + crop_size > h:
  100. ys[-1] = h - crop_size
  101. start = np.array(list(itertools.product(xs, ys)), dtype=np.int64)
  102. stop = start + crop_size
  103. windows.append(np.concatenate([start, stop], axis=1))
  104. windows = np.concatenate(windows, axis=0)
  105. im_in_wins = windows.copy()
  106. im_in_wins[:, 0::2] = np.clip(im_in_wins[:, 0::2], 0, w)
  107. im_in_wins[:, 1::2] = np.clip(im_in_wins[:, 1::2], 0, h)
  108. im_areas = (im_in_wins[:, 2] - im_in_wins[:, 0]) * (im_in_wins[:, 3] - im_in_wins[:, 1])
  109. win_areas = (windows[:, 2] - windows[:, 0]) * (windows[:, 3] - windows[:, 1])
  110. im_rates = im_areas / win_areas
  111. if not (im_rates > im_rate_thr).any():
  112. max_rate = im_rates.max()
  113. im_rates[abs(im_rates - max_rate) < eps] = 1
  114. return windows[im_rates > im_rate_thr]
  115. def get_window_obj(anno, windows, iof_thr=0.7):
  116. """Get objects for each window."""
  117. h, w = anno["ori_size"]
  118. label = anno["label"]
  119. if len(label):
  120. label[:, 1::2] *= w
  121. label[:, 2::2] *= h
  122. iofs = bbox_iof(label[:, 1:], windows)
  123. # Unnormalized and misaligned coordinates
  124. return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] # window_anns
  125. else:
  126. return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] # window_anns
  127. def crop_and_save(anno, windows, window_objs, im_dir, lb_dir, allow_background_images=True):
  128. """
  129. Crop images and save new labels.
  130. Args:
  131. anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys.
  132. windows (list): A list of windows coordinates.
  133. window_objs (list): A list of labels inside each window.
  134. im_dir (str): The output directory path of images.
  135. lb_dir (str): The output directory path of labels.
  136. allow_background_images (bool): Whether to include background images without labels.
  137. Notes:
  138. The directory structure assumed for the DOTA dataset:
  139. - data_root
  140. - images
  141. - train
  142. - val
  143. - labels
  144. - train
  145. - val
  146. """
  147. im = cv2.imread(anno["filepath"])
  148. name = Path(anno["filepath"]).stem
  149. for i, window in enumerate(windows):
  150. x_start, y_start, x_stop, y_stop = window.tolist()
  151. new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
  152. patch_im = im[y_start:y_stop, x_start:x_stop]
  153. ph, pw = patch_im.shape[:2]
  154. label = window_objs[i]
  155. if len(label) or allow_background_images:
  156. cv2.imwrite(str(Path(im_dir) / f"{new_name}.jpg"), patch_im)
  157. if len(label):
  158. label[:, 1::2] -= x_start
  159. label[:, 2::2] -= y_start
  160. label[:, 1::2] /= pw
  161. label[:, 2::2] /= ph
  162. with open(Path(lb_dir) / f"{new_name}.txt", "w") as f:
  163. for lb in label:
  164. formatted_coords = [f"{coord:.6g}" for coord in lb[1:]]
  165. f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")
  166. def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=(1024,), gaps=(200,)):
  167. """
  168. Split both images and labels.
  169. Notes:
  170. The directory structure assumed for the DOTA dataset:
  171. - data_root
  172. - images
  173. - split
  174. - labels
  175. - split
  176. and the output directory structure is:
  177. - save_dir
  178. - images
  179. - split
  180. - labels
  181. - split
  182. """
  183. im_dir = Path(save_dir) / "images" / split
  184. im_dir.mkdir(parents=True, exist_ok=True)
  185. lb_dir = Path(save_dir) / "labels" / split
  186. lb_dir.mkdir(parents=True, exist_ok=True)
  187. annos = load_yolo_dota(data_root, split=split)
  188. for anno in tqdm(annos, total=len(annos), desc=split):
  189. windows = get_windows(anno["ori_size"], crop_sizes, gaps)
  190. window_objs = get_window_obj(anno, windows)
  191. crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
  192. def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)):
  193. """
  194. Split train and val set of DOTA.
  195. Notes:
  196. The directory structure assumed for the DOTA dataset:
  197. - data_root
  198. - images
  199. - train
  200. - val
  201. - labels
  202. - train
  203. - val
  204. and the output directory structure is:
  205. - save_dir
  206. - images
  207. - train
  208. - val
  209. - labels
  210. - train
  211. - val
  212. """
  213. crop_sizes, gaps = [], []
  214. for r in rates:
  215. crop_sizes.append(int(crop_size / r))
  216. gaps.append(int(gap / r))
  217. for split in ["train", "val"]:
  218. split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)
  219. def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)):
  220. """
  221. Split test set of DOTA, labels are not included within this set.
  222. Notes:
  223. The directory structure assumed for the DOTA dataset:
  224. - data_root
  225. - images
  226. - test
  227. and the output directory structure is:
  228. - save_dir
  229. - images
  230. - test
  231. """
  232. crop_sizes, gaps = [], []
  233. for r in rates:
  234. crop_sizes.append(int(crop_size / r))
  235. gaps.append(int(gap / r))
  236. save_dir = Path(save_dir) / "images" / "test"
  237. save_dir.mkdir(parents=True, exist_ok=True)
  238. im_dir = Path(data_root) / "images" / "test"
  239. assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
  240. im_files = glob(str(im_dir / "*"))
  241. for im_file in tqdm(im_files, total=len(im_files), desc="test"):
  242. w, h = exif_size(Image.open(im_file))
  243. windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
  244. im = cv2.imread(im_file)
  245. name = Path(im_file).stem
  246. for window in windows:
  247. x_start, y_start, x_stop, y_stop = window.tolist()
  248. new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
  249. patch_im = im[y_start:y_stop, x_start:x_stop]
  250. cv2.imwrite(str(save_dir / f"{new_name}.jpg"), patch_im)
  251. if __name__ == "__main__":
  252. split_trainval(data_root="DOTAv2", save_dir="DOTAv2-split")
  253. split_test(data_root="DOTAv2", save_dir="DOTAv2-split")