plot-sAP.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. #!/usr/bin/env python3
  2. import sys
  3. import glob
  4. import os.path as osp
  5. import cv2
  6. import numpy as np
  7. import scipy.io
  8. import matplotlib as mpl
  9. import numpy.linalg as LA
  10. import matplotlib.pyplot as plt
  11. try:
  12. sys.path.append(".")
  13. sys.path.append("..")
  14. import lcnn.utils
  15. import lcnn.metric
  16. except Exception:
  17. raise
  18. # Change the directory here
  19. PRED = "logs/190418-201834-f8934c6-lr4d10/npz/000312000/*.npz"
  20. GT = "data/wireframe/valid/*.npz"
  21. # PRED = "logs/190506-001532-york/*.npz"
  22. # GT = "data/york/valid/*.npz"
  23. WF = "/data/lcnn/wirebase/result/wireframe/wireframe_1_rerun-baseline_0.5_0.5/*"
  24. AFM = "/data/lcnn/wirebase/result/wireframe/afm/*.npz"
  25. mpl.rcParams.update({"font.size": 16})
  26. plt.rcParams["font.family"] = "Times New Roman"
  27. del mpl.font_manager.weight_dict["roman"]
  28. mpl.font_manager._rebuild()
  29. def wireframe_score(T=10):
  30. gts = glob.glob(GT)
  31. gts.sort()
  32. dirs = glob.glob(WF)
  33. dirs.sort(key=lambda x: -float(osp.split(x)[-1]))
  34. precision, recall = [], []
  35. for threshold in dirs:
  36. print("Processing", threshold)
  37. mat_files = glob.glob(osp.join(threshold, "*.mat"))
  38. mat_files.sort()
  39. tp, fp, total_gt = 0, 0, 0
  40. for i, (gt_name, matf) in enumerate(zip(gts, mat_files)):
  41. line_pred = scipy.io.loadmat(matf)["lines"].reshape(-1, 2, 2)
  42. img = cv2.imread(matf.replace(".mat", ".jpg"))
  43. line_pred[:, :, 0] *= 128 / img.shape[1]
  44. line_pred[:, :, 1] *= 128 / img.shape[0]
  45. line_pred = line_pred[:, :, ::-1]
  46. with np.load(gt_name) as fgt:
  47. line_gt = fgt["lpos"][:, :, :2]
  48. tp_, fp_ = lcnn.metric.msTPFP(line_pred, line_gt, T)
  49. tp += tp_.sum()
  50. fp += fp_.sum()
  51. total_gt += len(line_gt)
  52. recall.append(tp / total_gt)
  53. precision.append(tp / (tp + fp))
  54. recall = np.concatenate(([0.0], recall, [1.0]))
  55. precision = np.concatenate(([0.0], precision, [0.0]))
  56. for i in range(precision.size - 1, 0, -1):
  57. precision[i - 1] = max(precision[i - 1], precision[i])
  58. i = np.where(recall[1:] != recall[:-1])[0]
  59. ap = np.sum((recall[i + 1] - recall[i]) * precision[i + 1])
  60. plt.plot(
  61. np.maximum(0.005, recall[:-1]),
  62. precision[:-1],
  63. label="Wireframe",
  64. linewidth=3,
  65. c="C1",
  66. )
  67. print("Huang sAP:", ap)
  68. def line_score(threshold=10):
  69. preds = sorted(glob.glob(PRED))
  70. gts = sorted(glob.glob(GT))
  71. afm = sorted(glob.glob(AFM))
  72. lcnn_tp, lcnn_fp, lcnn_scores = [], [], []
  73. lsd_tp, lsd_fp, lsd_scores = [], [], []
  74. afm_tp, afm_fp, afm_scores = [], [], []
  75. n_gt = 0
  76. for pred_name, gt_name, afm_name in zip(preds, gts, afm):
  77. image = gt_name.replace("_label.npz", ".png")
  78. img = cv2.imread(image, 0)
  79. lsd = cv2.createLineSegmentDetector(cv2.LSD_REFINE_ADV)
  80. lsd_line, _, _, lsd_score = lsd.detect(img)
  81. lsd_line = lsd_line.reshape(-1, 2, 2)[:, :, ::-1]
  82. lsd_score = lsd_score.flatten()
  83. # print(lines.shape)
  84. # print(nfa.shape)
  85. with np.load(pred_name) as fpred:
  86. lcnn_line = fpred["lines"][:, :, :2]
  87. lcnn_score = fpred["score"]
  88. lcnn_line = lcnn_line[:, :, :2]
  89. with np.load(gt_name) as fgt:
  90. gt_line = fgt["lpos"][:, :, :2]
  91. with np.load(afm_name) as fafm:
  92. afm_line = fafm["lines"].reshape(-1, 2, 2)[:, :, ::-1]
  93. afm_score = -fafm["scores"]
  94. h = fafm["h"]
  95. w = fafm["w"]
  96. afm_line[:, :, 0] *= 128 / h
  97. afm_line[:, :, 1] *= 128 / w
  98. for i, ((a, b), s) in enumerate(zip(lcnn_line, lcnn_score)):
  99. if i > 0 and (lcnn_line[i] == lcnn_line[0]).all():
  100. lcnn_line = lcnn_line[:i]
  101. lcnn_score = lcnn_score[:i]
  102. break
  103. # plt.figure("LCNN")
  104. # for a, b in lcnn_line:
  105. # plt.plot([a[1], b[1]], [a[0], b[0]], linewidth=4)
  106. # plt.figure("GT")
  107. # for a, b in gt_line:
  108. # plt.plot([a[1], b[1]], [a[0], b[0]], linewidth=4)
  109. # plt.figure("LSD")
  110. # for a, b in lsd_line:
  111. # plt.plot([a[1], b[1]], [a[0], b[0]], linewidth=4)
  112. # plt.figure("AFM")
  113. # for a, b in afm_line:
  114. # plt.plot([a[1], b[1]], [a[0], b[0]], linewidth=4)
  115. # plt.show()
  116. tp, fp = lcnn.metric.msTPFP(lcnn_line, gt_line, threshold)
  117. lcnn_tp.append(tp)
  118. lcnn_fp.append(fp)
  119. lcnn_scores.append(lcnn_score)
  120. tp, fp = lcnn.metric.msTPFP(lsd_line, gt_line, threshold)
  121. lsd_tp.append(tp)
  122. lsd_fp.append(fp)
  123. lsd_scores.append(lsd_score)
  124. tp, fp = lcnn.metric.msTPFP(afm_line, gt_line, threshold)
  125. afm_tp.append(tp)
  126. afm_fp.append(fp)
  127. afm_scores.append(afm_score)
  128. n_gt += len(gt_line)
  129. lcnn_tp = np.concatenate(lcnn_tp)
  130. lcnn_fp = np.concatenate(lcnn_fp)
  131. lcnn_scores = np.concatenate(lcnn_scores)
  132. lcnn_index = np.argsort(-lcnn_scores)
  133. lcnn_tp = lcnn_tp[lcnn_index]
  134. lcnn_fp = lcnn_fp[lcnn_index]
  135. lcnn_tp = np.cumsum(lcnn_tp) / n_gt
  136. lcnn_fp = np.cumsum(lcnn_fp) / n_gt
  137. lsd_tp = np.concatenate(lsd_tp)
  138. lsd_fp = np.concatenate(lsd_fp)
  139. lsd_scores = np.concatenate(lsd_scores)
  140. lsd_index = np.argsort(-lsd_scores)
  141. lsd_tp = lsd_tp[lsd_index]
  142. lsd_fp = lsd_fp[lsd_index]
  143. lsd_tp = np.cumsum(lsd_tp) / n_gt
  144. lsd_fp = np.cumsum(lsd_fp) / n_gt
  145. afm_tp = np.concatenate(afm_tp)
  146. afm_fp = np.concatenate(afm_fp)
  147. afm_scores = np.concatenate(afm_scores)
  148. afm_index = np.argsort(-afm_scores)
  149. afm_tp = afm_tp[afm_index]
  150. afm_fp = afm_fp[afm_index]
  151. afm_tp = np.cumsum(afm_tp) / n_gt
  152. afm_fp = np.cumsum(afm_fp) / n_gt
  153. lcnn_re, lcnn_pr = lcnn_tp, lcnn_tp / (lcnn_tp + lcnn_fp)
  154. afm_re, afm_pr = afm_tp, afm_tp / (afm_tp + afm_fp)
  155. # lsd_re, lsd_pr = lsd_tp, lsd_tp / (lsd_tp + lsd_fp)
  156. T = 0.005
  157. plt.plot(afm_re[afm_re > T], afm_pr[afm_re > T], label="AFM", linewidth=3, c="C2")
  158. plt.plot(lcnn_re[lcnn_re > T], lcnn_pr[lcnn_re > T], label="L-CNN", linewidth=3, c="C3")
  159. # plt.plot(lsd_re, lsd_pr, label="LSD", linewidth=2)
  160. plt.grid(True)
  161. plt.axis([0.0, 1.0, 0.0, 1.0])
  162. plt.xticks(np.arange(0, 1.0, step=0.1))
  163. plt.yticks(np.arange(0, 1.0, step=0.1))
  164. plt.xlabel("Recall")
  165. plt.ylabel("Precision")
  166. plt.legend(loc="upper right")
  167. f_scores = np.linspace(0.2, 0.8, num=8)
  168. for f_score in f_scores:
  169. x = np.linspace(0.01, 1)
  170. y = f_score * x / (2 * x - f_score)
  171. l, = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.3)
  172. plt.annotate("f={0:0.1}".format(f_score), xy=(0.9, y[45] + 0.02), alpha=0.4)
  173. plt.title("PR Curve for sAP10")
  174. plt.savefig("sAP.pdf", format="pdf", bbox_inches="tight")
  175. plt.savefig("sAP.svg", format="svg", bbox_inches="tight")
  176. plt.show()
  177. print(
  178. f"Processing {PRED}:\n"
  179. + f" LSD sAP{threshold}: {lcnn.metric.ap(lsd_tp, lsd_fp)}\n"
  180. + f" AFM sAP{threshold}: {lcnn.metric.ap(afm_tp, afm_fp)}\n"
  181. + f" L-CNN sAP{threshold}: {lcnn.metric.ap(lcnn_tp, lcnn_fp)}"
  182. )
  183. cmap = plt.get_cmap("jet")
  184. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  185. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  186. sm.set_array([])
  187. def c(x):
  188. return sm.to_rgba(x)
  189. if __name__ == "__main__":
  190. plt.tight_layout()
  191. wireframe_score()
  192. line_score()