group_by_aspect_ratio.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import bisect
  2. import copy
  3. import math
  4. from collections import defaultdict
  5. from itertools import chain, repeat
  6. import numpy as np
  7. import torch
  8. import torch.utils.data
  9. import torchvision
  10. from PIL import Image
  11. from torch.utils.data.sampler import BatchSampler, Sampler
  12. from torch.utils.model_zoo import tqdm
  13. def _repeat_to_at_least(iterable, n):
  14. repeat_times = math.ceil(n / len(iterable))
  15. repeated = chain.from_iterable(repeat(iterable, repeat_times))
  16. return list(repeated)
  17. class GroupedBatchSampler(BatchSampler):
  18. """
  19. Wraps another sampler to yield a mini-batch of indices.
  20. It enforces that the batch only contain elements from the same group.
  21. It also tries to provide mini-batches which follows an ordering which is
  22. as close as possible to the ordering from the original sampler.
  23. Args:
  24. sampler (Sampler): Base sampler.
  25. group_ids (list[int]): If the sampler produces indices in range [0, N),
  26. `group_ids` must be a list of `N` ints which contains the group id of each sample.
  27. The group ids must be a continuous set of integers starting from
  28. 0, i.e. they must be in the range [0, num_groups).
  29. batch_size (int): Size of mini-batch.
  30. """
  31. def __init__(self, sampler, group_ids, batch_size):
  32. if not isinstance(sampler, Sampler):
  33. raise ValueError(f"sampler should be an instance of torch.utils.data.Sampler, but got sampler={sampler}")
  34. self.sampler = sampler
  35. self.group_ids = group_ids
  36. self.batch_size = batch_size
  37. def __iter__(self):
  38. buffer_per_group = defaultdict(list)
  39. samples_per_group = defaultdict(list)
  40. num_batches = 0
  41. for idx in self.sampler:
  42. group_id = self.group_ids[idx]
  43. buffer_per_group[group_id].append(idx)
  44. samples_per_group[group_id].append(idx)
  45. if len(buffer_per_group[group_id]) == self.batch_size:
  46. yield buffer_per_group[group_id]
  47. num_batches += 1
  48. del buffer_per_group[group_id]
  49. assert len(buffer_per_group[group_id]) < self.batch_size
  50. # now we have run out of elements that satisfy
  51. # the group criteria, let's return the remaining
  52. # elements so that the size of the sampler is
  53. # deterministic
  54. expected_num_batches = len(self)
  55. num_remaining = expected_num_batches - num_batches
  56. if num_remaining > 0:
  57. # for the remaining batches, take first the buffers with the largest number
  58. # of elements
  59. for group_id, _ in sorted(buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True):
  60. remaining = self.batch_size - len(buffer_per_group[group_id])
  61. samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
  62. buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
  63. assert len(buffer_per_group[group_id]) == self.batch_size
  64. yield buffer_per_group[group_id]
  65. num_remaining -= 1
  66. if num_remaining == 0:
  67. break
  68. assert num_remaining == 0
  69. def __len__(self):
  70. return len(self.sampler) // self.batch_size
  71. def _compute_aspect_ratios_slow(dataset, indices=None):
  72. print(
  73. "Your dataset doesn't support the fast path for "
  74. "computing the aspect ratios, so will iterate over "
  75. "the full dataset and load every image instead. "
  76. "This might take some time..."
  77. )
  78. if indices is None:
  79. indices = range(len(dataset))
  80. class SubsetSampler(Sampler):
  81. def __init__(self, indices):
  82. self.indices = indices
  83. def __iter__(self):
  84. return iter(self.indices)
  85. def __len__(self):
  86. return len(self.indices)
  87. sampler = SubsetSampler(indices)
  88. data_loader = torch.utils.data.DataLoader(
  89. dataset,
  90. batch_size=1,
  91. sampler=sampler,
  92. num_workers=14, # you might want to increase it for faster processing
  93. collate_fn=lambda x: x[0],
  94. )
  95. aspect_ratios = []
  96. with tqdm(total=len(dataset)) as pbar:
  97. for _i, (img, _) in enumerate(data_loader):
  98. pbar.update(1)
  99. height, width = img.shape[-2:]
  100. aspect_ratio = float(width) / float(height)
  101. aspect_ratios.append(aspect_ratio)
  102. return aspect_ratios
  103. def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
  104. if indices is None:
  105. indices = range(len(dataset))
  106. aspect_ratios = []
  107. for i in indices:
  108. height, width = dataset.get_height_and_width(i)
  109. aspect_ratio = float(width) / float(height)
  110. aspect_ratios.append(aspect_ratio)
  111. return aspect_ratios
  112. def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
  113. if indices is None:
  114. indices = range(len(dataset))
  115. aspect_ratios = []
  116. for i in indices:
  117. img_info = dataset.coco.imgs[dataset.ids[i]]
  118. aspect_ratio = float(img_info["width"]) / float(img_info["height"])
  119. aspect_ratios.append(aspect_ratio)
  120. return aspect_ratios
  121. def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
  122. if indices is None:
  123. indices = range(len(dataset))
  124. aspect_ratios = []
  125. for i in indices:
  126. # this doesn't load the data into memory, because PIL loads it lazily
  127. width, height = Image.open(dataset.images[i]).size
  128. aspect_ratio = float(width) / float(height)
  129. aspect_ratios.append(aspect_ratio)
  130. return aspect_ratios
  131. def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
  132. if indices is None:
  133. indices = range(len(dataset))
  134. ds_indices = [dataset.indices[i] for i in indices]
  135. return compute_aspect_ratios(dataset.dataset, ds_indices)
  136. def compute_aspect_ratios(dataset, indices=None):
  137. if hasattr(dataset, "get_height_and_width"):
  138. return _compute_aspect_ratios_custom_dataset(dataset, indices)
  139. if isinstance(dataset, torchvision.datasets.CocoDetection):
  140. return _compute_aspect_ratios_coco_dataset(dataset, indices)
  141. if isinstance(dataset, torchvision.datasets.VOCDetection):
  142. return _compute_aspect_ratios_voc_dataset(dataset, indices)
  143. if isinstance(dataset, torch.utils.data.Subset):
  144. return _compute_aspect_ratios_subset_dataset(dataset, indices)
  145. # slow path
  146. return _compute_aspect_ratios_slow(dataset, indices)
  147. def _quantize(x, bins):
  148. bins = copy.deepcopy(bins)
  149. bins = sorted(bins)
  150. quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
  151. return quantized
  152. def create_aspect_ratio_groups(dataset, k=0):
  153. aspect_ratios = compute_aspect_ratios(dataset)
  154. bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
  155. groups = _quantize(aspect_ratios, bins)
  156. # count number of elements per group
  157. counts = np.unique(groups, return_counts=True)[1]
  158. fbins = [0] + bins + [np.inf]
  159. print(f"Using {fbins} as bins for aspect ratio quantization")
  160. print(f"Count of instances per bin: {counts}")
  161. return groups