post.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. #!/usr/bin/env python3
  2. """Post-processing the output of neural network
  3. Usage:
  4. post.py [options] <input-dir> <output-dir>
  5. post.py ( -h | --help )
  6. Examples:
  7. post.py logs/logname/npz/000336000 result/logname
  8. Arguments:
  9. input-dir Directory that stores the npz
  10. output-dir Output directory
  11. Options:
  12. -h --help Show this screen.
  13. --plot Generate images besides npz files
  14. --thresholds=<thresholds> A comma-separated list for thresholding
  15. [default: 0.006,0.010,0.015]
  16. """
  17. import os
  18. import sys
  19. import glob
  20. import math
  21. import os.path as osp
  22. import cv2
  23. import numpy as np
  24. import matplotlib as mpl
  25. import matplotlib.pyplot as plt
  26. from docopt import docopt
  27. from lcnn.utils import parmap
  28. cmap = plt.get_cmap("jet")
  29. norm = mpl.colors.Normalize(vmin=0.92, vmax=1.02)
  30. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  31. sm.set_array([])
  32. def c(x):
  33. return sm.to_rgba(x)
  34. def imshow(im):
  35. plt.close()
  36. sizes = im.shape
  37. height = float(sizes[0])
  38. width = float(sizes[1])
  39. fig = plt.figure()
  40. fig.set_size_inches(width / height, 1, forward=False)
  41. ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
  42. ax.set_axis_off()
  43. fig.add_axes(ax)
  44. plt.xlim([-0.5, sizes[1] - 0.5])
  45. plt.ylim([sizes[0] - 0.5, -0.5])
  46. plt.imshow(im)
  47. def pline(x1, y1, x2, y2, x, y):
  48. px = x2 - x1
  49. py = y2 - y1
  50. dd = px * px + py * py
  51. u = ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
  52. dx = x1 + u * px - x
  53. dy = y1 + u * py - y
  54. return dx * dx + dy * dy
  55. def psegment(x1, y1, x2, y2, x, y):
  56. px = x2 - x1
  57. py = y2 - y1
  58. dd = px * px + py * py
  59. u = max(min(((x - x1) * px + (y - y1) * py) / float(dd), 1), 0)
  60. dx = x1 + u * px - x
  61. dy = y1 + u * py - y
  62. return dx * dx + dy * dy
  63. def plambda(x1, y1, x2, y2, x, y):
  64. px = x2 - x1
  65. py = y2 - y1
  66. dd = px * px + py * py
  67. return ((x - x1) * px + (y - y1) * py) / max(1e-9, float(dd))
  68. def process(lines, scores, threshold=0.01, tol=1e9, do_clip=False):
  69. nlines, nscores = [], []
  70. for (p, q), score in zip(lines, scores):
  71. start, end = 0, 1
  72. for a, b in nlines:
  73. if (
  74. min(
  75. max(pline(*p, *q, *a), pline(*p, *q, *b)),
  76. max(pline(*a, *b, *p), pline(*a, *b, *q)),
  77. )
  78. > threshold ** 2
  79. ):
  80. continue
  81. lambda_a = plambda(*p, *q, *a)
  82. lambda_b = plambda(*p, *q, *b)
  83. if lambda_a > lambda_b:
  84. lambda_a, lambda_b = lambda_b, lambda_a
  85. lambda_a -= tol
  86. lambda_b += tol
  87. # case 1: skip (if not do_clip)
  88. if start < lambda_a and lambda_b < end:
  89. continue
  90. # not intersect
  91. if lambda_b < start or lambda_a > end:
  92. continue
  93. # cover
  94. if lambda_a <= start and end <= lambda_b:
  95. start = 10
  96. break
  97. # case 2 & 3:
  98. if lambda_a <= start and start <= lambda_b:
  99. start = lambda_b
  100. if lambda_a <= end and end <= lambda_b:
  101. end = lambda_a
  102. if start >= end:
  103. break
  104. if start >= end:
  105. continue
  106. nlines.append(np.array([p + (q - p) * start, p + (q - p) * end]))
  107. nscores.append(score)
  108. return np.array(nlines), np.array(nscores)
  109. def main():
  110. args = docopt(__doc__)
  111. files = sorted(glob.glob(osp.join(args["<input-dir>"], "*.npz")))
  112. inames = sorted(glob.glob("data/wireframe/valid-images/*.jpg"))
  113. gts = sorted(glob.glob("data/wireframe/valid/*.npz"))
  114. prefix = args["<output-dir>"]
  115. inputs = list(zip(files, inames, gts))
  116. thresholds = list(map(float, args["--thresholds"].split(",")))
  117. def handle(allname):
  118. fname, iname, gtname = allname
  119. print("Processing", fname)
  120. im = cv2.imread(iname)
  121. with np.load(fname) as f:
  122. lines = f["lines"]
  123. scores = f["score"]
  124. with np.load(gtname) as f:
  125. gtlines = f["lpos"][:, :, :2]
  126. gtlines[:, :, 0] *= im.shape[0] / 128
  127. gtlines[:, :, 1] *= im.shape[1] / 128
  128. for i in range(1, len(lines)):
  129. if (lines[i] == lines[0]).all():
  130. lines = lines[:i]
  131. scores = scores[:i]
  132. break
  133. lines[:, :, 0] *= im.shape[0] / 128
  134. lines[:, :, 1] *= im.shape[1] / 128
  135. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  136. for threshold in thresholds:
  137. nlines, nscores = process(lines, scores, diag * threshold, 0, False)
  138. outdir = osp.join(prefix, f"{threshold:.3f}".replace(".", "_"))
  139. os.makedirs(outdir, exist_ok=True)
  140. npz_name = osp.join(outdir, osp.split(fname)[-1])
  141. PLTOPTS = {"color": "#33FFFF", "s": 1.2, "edgecolors": "none", "zorder": 5}
  142. if args["--plot"]:
  143. # plot gt
  144. imshow(im[:, :, ::-1])
  145. for (a, b) in gtlines:
  146. plt.plot([a[1], b[1]], [a[0], b[0]], c="orange", linewidth=0.5)
  147. plt.scatter(a[1], a[0], **PLTOPTS)
  148. plt.scatter(b[1], b[0], **PLTOPTS)
  149. plt.savefig(npz_name.replace(".npz", ".png"), dpi=500, bbox_inches=0)
  150. thres = [0.96, 0.97, 0.98, 0.99]
  151. for i, t in enumerate(thres):
  152. imshow(im[:, :, ::-1])
  153. for (a, b), s in zip(nlines[nscores > t], nscores[nscores > t]):
  154. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=0.5)
  155. plt.scatter(a[1], a[0], **PLTOPTS)
  156. plt.scatter(b[1], b[0], **PLTOPTS)
  157. plt.savefig(
  158. npz_name.replace(".npz", f"_{i}.png"), dpi=500, bbox_inches=0
  159. )
  160. nlines[:, :, 0] *= 128 / im.shape[0]
  161. nlines[:, :, 1] *= 128 / im.shape[1]
  162. np.savez_compressed(npz_name, lines=nlines, score=nscores)
  163. parmap(handle, inputs, 12)
  164. if __name__ == "__main__":
  165. main()