utils.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import datetime
  2. import errno
  3. import os
  4. import time
  5. from collections import defaultdict, deque
  6. import torch
  7. import torch.distributed as dist
  8. from torch.utils.data.dataloader import default_collate
  9. class SmoothedValue:
  10. """Track a series of values and provide access to smoothed values over a
  11. window or the global series average.
  12. """
  13. def __init__(self, window_size=20, fmt=None):
  14. if fmt is None:
  15. fmt = "{median:.4f} ({global_avg:.4f})"
  16. self.deque = deque(maxlen=window_size)
  17. self.total = 0.0
  18. self.count = 0
  19. self.fmt = fmt
  20. def update(self, value, n=1):
  21. self.deque.append(value)
  22. self.count += n
  23. self.total += value * n
  24. def synchronize_between_processes(self):
  25. """
  26. Warning: does not synchronize the deque!
  27. """
  28. if not is_dist_avail_and_initialized():
  29. return
  30. t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
  31. dist.barrier()
  32. dist.all_reduce(t)
  33. t = t.tolist()
  34. self.count = int(t[0])
  35. self.total = t[1]
  36. @property
  37. def median(self):
  38. d = torch.tensor(list(self.deque))
  39. return d.median().item()
  40. @property
  41. def avg(self):
  42. d = torch.tensor(list(self.deque), dtype=torch.float32)
  43. return d.mean().item()
  44. @property
  45. def global_avg(self):
  46. return self.total / self.count
  47. @property
  48. def max(self):
  49. return max(self.deque)
  50. @property
  51. def value(self):
  52. return self.deque[-1]
  53. def __str__(self):
  54. return self.fmt.format(
  55. median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
  56. )
  57. def all_gather(data):
  58. """
  59. Run all_gather on arbitrary picklable data (not necessarily tensors)
  60. Args:
  61. data: any picklable object
  62. Returns:
  63. list[data]: list of data gathered from each rank
  64. """
  65. world_size = get_world_size()
  66. if world_size == 1:
  67. return [data]
  68. data_list = [None] * world_size
  69. dist.all_gather_object(data_list, data)
  70. return data_list
  71. def reduce_dict(input_dict, average=True):
  72. """
  73. Args:
  74. input_dict (dict): all the values will be reduced
  75. average (bool): whether to do average or sum
  76. Reduce the values in the dictionary from all processes so that all processes
  77. have the averaged results. Returns a dict with the same fields as
  78. input_dict, after reduction.
  79. """
  80. world_size = get_world_size()
  81. if world_size < 2:
  82. return input_dict
  83. with torch.inference_mode():
  84. names = []
  85. values = []
  86. # sort the keys so that they are consistent across processes
  87. for k in sorted(input_dict.keys()):
  88. names.append(k)
  89. values.append(input_dict[k])
  90. values = torch.stack(values, dim=0)
  91. dist.all_reduce(values)
  92. if average:
  93. values /= world_size
  94. reduced_dict = {k: v for k, v in zip(names, values)}
  95. return reduced_dict
  96. class MetricLogger:
  97. def __init__(self, delimiter="\t"):
  98. self.meters = defaultdict(SmoothedValue)
  99. self.delimiter = delimiter
  100. def update(self, **kwargs):
  101. for k, v in kwargs.items():
  102. if isinstance(v, torch.Tensor):
  103. v = v.item()
  104. assert isinstance(v, (float, int))
  105. self.meters[k].update(v)
  106. def __getattr__(self, attr):
  107. if attr in self.meters:
  108. return self.meters[attr]
  109. if attr in self.__dict__:
  110. return self.__dict__[attr]
  111. raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
  112. def __str__(self):
  113. loss_str = []
  114. for name, meter in self.meters.items():
  115. loss_str.append(f"{name}: {str(meter)}")
  116. return self.delimiter.join(loss_str)
  117. def synchronize_between_processes(self):
  118. for meter in self.meters.values():
  119. meter.synchronize_between_processes()
  120. def add_meter(self, name, meter):
  121. self.meters[name] = meter
  122. def log_every(self, iterable, print_freq, header=None):
  123. i = 0
  124. if not header:
  125. header = ""
  126. start_time = time.time()
  127. end = time.time()
  128. iter_time = SmoothedValue(fmt="{avg:.4f}")
  129. data_time = SmoothedValue(fmt="{avg:.4f}")
  130. space_fmt = ":" + str(len(str(len(iterable)))) + "d"
  131. if torch.cuda.is_available():
  132. log_msg = self.delimiter.join(
  133. [
  134. header,
  135. "[{0" + space_fmt + "}/{1}]",
  136. "eta: {eta}",
  137. "{meters}",
  138. "time: {time}",
  139. "data: {data}",
  140. "max mem: {memory:.0f}",
  141. ]
  142. )
  143. else:
  144. log_msg = self.delimiter.join(
  145. [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
  146. )
  147. MB = 1024.0 * 1024.0
  148. for obj in iterable:
  149. data_time.update(time.time() - end)
  150. yield obj
  151. iter_time.update(time.time() - end)
  152. if i % print_freq == 0 or i == len(iterable) - 1:
  153. eta_seconds = iter_time.global_avg * (len(iterable) - i)
  154. eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
  155. if torch.cuda.is_available():
  156. print(
  157. log_msg.format(
  158. i,
  159. len(iterable),
  160. eta=eta_string,
  161. meters=str(self),
  162. time=str(iter_time),
  163. data=str(data_time),
  164. memory=torch.cuda.max_memory_allocated() / MB,
  165. )
  166. )
  167. else:
  168. print(
  169. log_msg.format(
  170. i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
  171. )
  172. )
  173. i += 1
  174. end = time.time()
  175. total_time = time.time() - start_time
  176. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  177. print(f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")
  178. def collate_fn(batch):
  179. # print(f'batch:{len(batch)}')
  180. return tuple(zip(*batch))
  181. def collate_fn_wirepoint(batch):
  182. # print(f'batch[0]:{batch[0]}')
  183. # for b in batch:
  184. # b[1]["wires"]= [b[1]["wires"]]
  185. # default_collate([b[1] for b in batch]),
  186. # batch[0][1]["wires"] = [batch[0][1]["wires"]]
  187. batch=tuple(zip(*batch))
  188. # print(f'batch_post:{batch}')
  189. return batch
  190. def mkdir(path):
  191. try:
  192. os.makedirs(path)
  193. except OSError as e:
  194. if e.errno != errno.EEXIST:
  195. raise
  196. def setup_for_distributed(is_master):
  197. """
  198. This function disables printing when not in master process
  199. """
  200. import builtins as __builtin__
  201. builtin_print = __builtin__.print
  202. def print(*args, **kwargs):
  203. force = kwargs.pop("force", False)
  204. if is_master or force:
  205. builtin_print(*args, **kwargs)
  206. __builtin__.print = print
  207. def is_dist_avail_and_initialized():
  208. if not dist.is_available():
  209. return False
  210. if not dist.is_initialized():
  211. return False
  212. return True
  213. def get_world_size():
  214. if not is_dist_avail_and_initialized():
  215. return 1
  216. return dist.get_world_size()
  217. def get_rank():
  218. if not is_dist_avail_and_initialized():
  219. return 0
  220. return dist.get_rank()
  221. def is_main_process():
  222. return get_rank() == 0
  223. def save_on_master(*args, **kwargs):
  224. if is_main_process():
  225. torch.save(*args, **kwargs)
  226. def init_distributed_mode(args):
  227. if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
  228. args.rank = int(os.environ["RANK"])
  229. args.world_size = int(os.environ["WORLD_SIZE"])
  230. args.gpu = int(os.environ["LOCAL_RANK"])
  231. elif "SLURM_PROCID" in os.environ:
  232. args.rank = int(os.environ["SLURM_PROCID"])
  233. args.gpu = args.rank % torch.cuda.device_count()
  234. else:
  235. print("Not using distributed mode")
  236. args.distributed = False
  237. return
  238. args.distributed = True
  239. torch.cuda.set_device(args.gpu)
  240. args.dist_backend = "nccl"
  241. print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
  242. torch.distributed.init_process_group(
  243. backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
  244. )
  245. torch.distributed.barrier()
  246. setup_for_distributed(args.rank == 0)