eval-APH.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #!/usr/bin/env python3
  2. """Evaluate APH for LCNN
  3. Usage:
  4. eval-APH.py <src> <dst>
  5. eval-APH.py (-h | --help )
  6. Examples:
  7. python eval-APH.py logs/*
  8. Arguments:
  9. <src> Source directory that stores preprocessed npz
  10. <dst> Temporary output directory
  11. Options:
  12. -h --help Show this screen.
  13. """
  14. import os
  15. import glob
  16. import os.path as osp
  17. import subprocess
  18. import numpy as np
  19. import scipy.io as sio
  20. import matplotlib as mpl
  21. import matplotlib.pyplot as plt
  22. from scipy import interpolate
  23. from docopt import docopt
  24. mpl.rcParams.update({"font.size": 18})
  25. plt.rcParams["font.family"] = "Times New Roman"
  26. del mpl.font_manager.weight_dict["roman"]
  27. mpl.font_manager._rebuild()
  28. image_path = "data/wireframe/valid-images/"
  29. line_gt_path = "data/wireframe/valid/"
  30. output_size = 128
  31. def main():
  32. args = docopt(__doc__)
  33. src_dir = args["<src>"]
  34. tar_dir = args["<dst>"]
  35. output_file = osp.join(tar_dir, "result.mat")
  36. target_dir = osp.join(tar_dir, "mat")
  37. os.makedirs(target_dir, exist_ok=True)
  38. print(f"intermediate matlab results will be saved at: {target_dir}")
  39. file_list = glob.glob(osp.join(src_dir, "*.npz"))
  40. thresh = [0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.97, 0.99, 0.995, 0.999, 0.9995, 0.9999]
  41. for t in thresh:
  42. for fname in file_list:
  43. name = fname.split("/")[-1].split(".")[0]
  44. mat_name = name + ".mat"
  45. npz = np.load(fname)
  46. lines = npz["lines"].reshape(-1, 4)
  47. scores = npz["score"]
  48. for j in range(len(scores) - 1):
  49. if scores[j + 1] == scores[0]:
  50. lines = lines[: j + 1]
  51. scores = scores[: j + 1]
  52. break
  53. idx = np.where(scores > t)[0]
  54. os.makedirs(osp.join(target_dir, str(t)), exist_ok=True)
  55. sio.savemat(osp.join(target_dir, str(t), mat_name), {"lines": lines[idx]})
  56. cmd = "matlab -nodisplay -nodesktop "
  57. cmd += '-r "dbstop if error; '
  58. cmd += "eval_release('{:s}', '{:s}', '{:s}', '{:s}', {:d}); quit;\"".format(
  59. image_path, line_gt_path, output_file, target_dir, output_size
  60. )
  61. print("Running:\n{}".format(cmd))
  62. os.environ["MATLABPATH"] = "matlab/"
  63. subprocess.call(cmd, shell=True)
  64. mat = sio.loadmat(output_file)
  65. tps = mat["sumtp"]
  66. fps = mat["sumfp"]
  67. N = mat["sumgt"]
  68. rcs = sorted(list((tps / N)[:, 0]))
  69. prs = sorted(list((tps / np.maximum(tps + fps, 1e-9))[:, 0]))[::-1]
  70. print(
  71. "f measure is: ",
  72. (2 * np.array(prs) * np.array(rcs) / (np.array(prs) + np.array(rcs))).max(),
  73. )
  74. recall = np.concatenate(([0.0], rcs, [1.0]))
  75. precision = np.concatenate(([0.0], prs, [0.0]))
  76. for i in range(precision.size - 1, 0, -1):
  77. precision[i - 1] = max(precision[i - 1], precision[i])
  78. i = np.where(recall[1:] != recall[:-1])[0]
  79. print("AP is: ", np.sum((recall[i + 1] - recall[i]) * precision[i + 1]))
  80. f = interpolate.interp1d(rcs, prs, kind="cubic", bounds_error=False)
  81. x = np.arange(0, 1, 0.01) * rcs[-1]
  82. y = f(x)
  83. plt.plot(x, y, linewidth=3, label="L-CNN")
  84. f_scores = np.linspace(0.2, 0.8, num=8)
  85. for f_score in f_scores:
  86. x = np.linspace(0.01, 1)
  87. y = f_score * x / (2 * x - f_score)
  88. l, = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.3)
  89. plt.annotate("f={0:0.1}".format(f_score), xy=(0.9, y[45] + 0.02), alpha=0.4)
  90. plt.grid(True)
  91. plt.axis([0.0, 1.0, 0.0, 1.0])
  92. plt.xticks(np.arange(0, 1.0, step=0.1))
  93. plt.xlabel("Recall")
  94. plt.ylabel("Precision")
  95. plt.yticks(np.arange(0, 1.0, step=0.1))
  96. plt.legend(loc=3)
  97. plt.title("PR Curve for APH")
  98. plt.savefig("apH.pdf", format="pdf", bbox_inches="tight")
  99. plt.savefig("apH.svg", format="svg", bbox_inches="tight")
  100. plt.show()
  101. if __name__ == "__main__":
  102. plt.tight_layout()
  103. main()