post.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #!/usr/bin/env python3
  2. """Post-processing the output of neural network
  3. Usage:
  4. post.py [options] <input-dir> <output-dir>
  5. post.py ( -h | --help )
  6. Examples:
  7. post.py logs/logname/npz/000336000 result/logname
  8. Arguments:
  9. input-dir Directory that stores the npz
  10. output-dir Output directory
  11. Options:
  12. -h --help Show this screen.
  13. --plot Generate images besides npz files
  14. --thresholds=<thresholds> A comma-separated list for thresholding
  15. [default: 0.006,0.010,0.015]
  16. """
  17. import glob
  18. import math
  19. import os
  20. import os.path as osp
  21. import sys
  22. import cv2
  23. import matplotlib as mpl
  24. import matplotlib.pyplot as plt
  25. import numpy as np
  26. from docopt import docopt
  27. from lcnn.postprocess import postprocess
  28. from lcnn.utils import parmap
  29. PLTOPTS = {"color": "#33FFFF", "s": 1.2, "edgecolors": "none", "zorder": 5}
  30. cmap = plt.get_cmap("jet")
  31. norm = mpl.colors.Normalize(vmin=0.92, vmax=1.02)
  32. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  33. sm.set_array([])
  34. def c(x):
  35. return sm.to_rgba(x)
  36. def imshow(im):
  37. plt.close()
  38. sizes = im.shape
  39. height = float(sizes[0])
  40. width = float(sizes[1])
  41. fig = plt.figure()
  42. fig.set_size_inches(width / height, 1, forward=False)
  43. ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
  44. ax.set_axis_off()
  45. fig.add_axes(ax)
  46. plt.xlim([-0.5, sizes[1] - 0.5])
  47. plt.ylim([sizes[0] - 0.5, -0.5])
  48. plt.imshow(im)
  49. def main():
  50. args = docopt(__doc__)
  51. files = sorted(glob.glob(osp.join(args["<input-dir>"], "*.npz")))
  52. inames = sorted(glob.glob("data/wireframe/valid-images/*.jpg"))
  53. gts = sorted(glob.glob("data/wireframe/valid/*.npz"))
  54. prefix = args["<output-dir>"]
  55. inputs = list(zip(files, inames, gts))
  56. thresholds = list(map(float, args["--thresholds"].split(",")))
  57. def handle(allname):
  58. fname, iname, gtname = allname
  59. print("Processing", fname)
  60. im = cv2.imread(iname)
  61. with np.load(fname) as f:
  62. lines = f["lines"]
  63. scores = f["score"]
  64. with np.load(gtname) as f:
  65. gtlines = f["lpos"][:, :, :2]
  66. gtlines[:, :, 0] *= im.shape[0] / 128
  67. gtlines[:, :, 1] *= im.shape[1] / 128
  68. for i in range(1, len(lines)):
  69. if (lines[i] == lines[0]).all():
  70. lines = lines[:i]
  71. scores = scores[:i]
  72. break
  73. lines[:, :, 0] *= im.shape[0] / 128
  74. lines[:, :, 1] *= im.shape[1] / 128
  75. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  76. for threshold in thresholds:
  77. nlines, nscores = postprocess(lines, scores, diag * threshold, 0, False)
  78. outdir = osp.join(prefix, f"{threshold:.3f}".replace(".", "_"))
  79. os.makedirs(outdir, exist_ok=True)
  80. npz_name = osp.join(outdir, osp.split(fname)[-1])
  81. if args["--plot"]:
  82. # plot gt
  83. imshow(im[:, :, ::-1])
  84. for (a, b) in gtlines:
  85. plt.plot([a[1], b[1]], [a[0], b[0]], c="orange", linewidth=0.5)
  86. plt.scatter(a[1], a[0], **PLTOPTS)
  87. plt.scatter(b[1], b[0], **PLTOPTS)
  88. plt.savefig(npz_name.replace(".npz", ".png"), dpi=500, bbox_inches=0)
  89. thres = [0.96, 0.97, 0.98, 0.99]
  90. for i, t in enumerate(thres):
  91. imshow(im[:, :, ::-1])
  92. for (a, b), s in zip(nlines[nscores > t], nscores[nscores > t]):
  93. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=0.5)
  94. plt.scatter(a[1], a[0], **PLTOPTS)
  95. plt.scatter(b[1], b[0], **PLTOPTS)
  96. plt.savefig(
  97. npz_name.replace(".npz", f"_{i}.png"), dpi=500, bbox_inches=0
  98. )
  99. nlines[:, :, 0] *= 128 / im.shape[0]
  100. nlines[:, :, 1] *= 128 / im.shape[1]
  101. np.savez_compressed(npz_name, lines=nlines, score=nscores)
  102. parmap(handle, inputs, 12)
  103. if __name__ == "__main__":
  104. main()