utils.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import math
  2. import os.path as osp
  3. import multiprocessing
  4. from timeit import default_timer as timer
  5. import numpy as np
  6. import torch
  7. import matplotlib.pyplot as plt
  8. class benchmark(object):
  9. def __init__(self, msg, enable=True, fmt="%0.3g"):
  10. self.msg = msg
  11. self.fmt = fmt
  12. self.enable = enable
  13. def __enter__(self):
  14. if self.enable:
  15. self.start = timer()
  16. return self
  17. def __exit__(self, *args):
  18. if self.enable:
  19. t = timer() - self.start
  20. print(("%s : " + self.fmt + " seconds") % (self.msg, t))
  21. self.time = t
  22. def quiver(x, y, ax):
  23. ax.set_xlim(0, x.shape[1])
  24. ax.set_ylim(x.shape[0], 0)
  25. ax.quiver(
  26. x,
  27. y,
  28. units="xy",
  29. angles="xy",
  30. scale_units="xy",
  31. scale=1,
  32. minlength=0.01,
  33. width=0.1,
  34. color="b",
  35. )
  36. def recursive_to(input, device):
  37. if isinstance(input, torch.Tensor):
  38. return input.to(device)
  39. if isinstance(input, dict):
  40. for name in input:
  41. if isinstance(input[name], torch.Tensor):
  42. input[name] = input[name].to(device)
  43. return input
  44. if isinstance(input, list):
  45. for i, item in enumerate(input):
  46. input[i] = recursive_to(item, device)
  47. return input
  48. assert False
  49. def np_softmax(x, axis=0):
  50. """Compute softmax values for each sets of scores in x."""
  51. e_x = np.exp(x - np.max(x))
  52. return e_x / e_x.sum(axis=axis, keepdims=True)
  53. def argsort2d(arr):
  54. return np.dstack(np.unravel_index(np.argsort(arr.ravel()), arr.shape))[0]
  55. def __parallel_handle(f, q_in, q_out):
  56. while True:
  57. i, x = q_in.get()
  58. if i is None:
  59. break
  60. q_out.put((i, f(x)))
  61. def parmap(f, X, nprocs=multiprocessing.cpu_count(), progress_bar=lambda x: x):
  62. if nprocs == 0:
  63. nprocs = multiprocessing.cpu_count()
  64. q_in = multiprocessing.Queue(1)
  65. q_out = multiprocessing.Queue()
  66. proc = [
  67. multiprocessing.Process(target=__parallel_handle, args=(f, q_in, q_out))
  68. for _ in range(nprocs)
  69. ]
  70. for p in proc:
  71. p.daemon = True
  72. p.start()
  73. try:
  74. sent = [q_in.put((i, x)) for i, x in enumerate(X)]
  75. [q_in.put((None, None)) for _ in range(nprocs)]
  76. res = [q_out.get() for _ in progress_bar(range(len(sent)))]
  77. [p.join() for p in proc]
  78. except KeyboardInterrupt:
  79. q_in.close()
  80. q_out.close()
  81. raise
  82. return [x for i, x in sorted(res)]