plot-sAP.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. #!/usr/bin/env python3
  2. import sys
  3. import glob
  4. import os.path as osp
  5. import cv2
  6. import numpy as np
  7. import scipy.io
  8. import matplotlib as mpl
  9. import numpy.linalg as LA
  10. import matplotlib.pyplot as plt
  11. try:
  12. sys.path.append(".")
  13. sys.path.append("..")
  14. import lcnn.utils
  15. import lcnn.metric
  16. except Exception:
  17. raise
  18. # Change the directory here
  19. PRED = "logs/190418-201834-f8934c6-lr4d10/npz/000312000/*.npz"
  20. PRED = "post/jmap_0008/*.npz"
  21. GT = "data/wireframe/valid/*.npz"
  22. # PRED = "logs/190506-001532-york/*.npz"
  23. # GT = "data/york/valid/*.npz"
  24. WF = "/data/lcnn/wirebase/result/wireframe/wireframe_1_rerun-baseline_0.5_0.5/*"
  25. AFM = "/data/lcnn/wirebase/result/wireframe/afm/*.npz"
  26. mpl.rcParams.update({"font.size": 16})
  27. plt.rcParams["font.family"] = "Times New Roman"
  28. del mpl.font_manager.weight_dict["roman"]
  29. mpl.font_manager._rebuild()
  30. def wireframe_score(T=10):
  31. gts = glob.glob(GT)
  32. gts.sort()
  33. dirs = glob.glob(WF)
  34. dirs.sort(key=lambda x: -float(osp.split(x)[-1]))
  35. precision, recall = [], []
  36. for threshold in dirs:
  37. print("Processing", threshold)
  38. mat_files = glob.glob(osp.join(threshold, "*.mat"))
  39. mat_files.sort()
  40. tp, fp, total_gt = 0, 0, 0
  41. for i, (gt_name, matf) in enumerate(zip(gts, mat_files)):
  42. line_pred = scipy.io.loadmat(matf)["lines"].reshape(-1, 2, 2)
  43. img = cv2.imread(matf.replace(".mat", ".jpg"))
  44. line_pred[:, :, 0] *= 128 / img.shape[1]
  45. line_pred[:, :, 1] *= 128 / img.shape[0]
  46. line_pred = line_pred[:, :, ::-1]
  47. with np.load(gt_name) as fgt:
  48. line_gt = fgt["lpos"][:, :, :2]
  49. tp_, fp_ = lcnn.metric.msTPFP(line_pred, line_gt, T)
  50. tp += tp_.sum()
  51. fp += fp_.sum()
  52. total_gt += len(line_gt)
  53. recall.append(tp / total_gt)
  54. precision.append(tp / (tp + fp))
  55. recall = np.concatenate(([0.0], recall, [1.0]))
  56. precision = np.concatenate(([0.0], precision, [0.0]))
  57. for i in range(precision.size - 1, 0, -1):
  58. precision[i - 1] = max(precision[i - 1], precision[i])
  59. i = np.where(recall[1:] != recall[:-1])[0]
  60. ap = np.sum((recall[i + 1] - recall[i]) * precision[i + 1])
  61. np.savez(
  62. "/data/lcnn/results/sAP/wireframe.npz",
  63. x=np.maximum(0.005, recall[:-1]),
  64. y=precision[:-1],
  65. )
  66. plt.plot(
  67. np.maximum(0.005, recall[:-1]),
  68. precision[:-1],
  69. label="Wireframe",
  70. linewidth=3,
  71. c="C1",
  72. )
  73. print("Huang sAP:", ap)
  74. def line_score(threshold=10):
  75. preds = sorted(glob.glob(PRED))
  76. gts = sorted(glob.glob(GT))
  77. afm = sorted(glob.glob(AFM))
  78. lcnn_tp, lcnn_fp, lcnn_scores = [], [], []
  79. lsd_tp, lsd_fp, lsd_scores = [], [], []
  80. afm_tp, afm_fp, afm_scores = [], [], []
  81. n_gt = 0
  82. for pred_name, gt_name, afm_name in zip(preds, gts, afm):
  83. image = gt_name.replace("_label.npz", ".png")
  84. img = cv2.imread(image, 0)
  85. lsd = cv2.createLineSegmentDetector(cv2.LSD_REFINE_ADV)
  86. lsd_line, _, _, lsd_score = lsd.detect(img)
  87. lsd_line = lsd_line.reshape(-1, 2, 2)[:, :, ::-1]
  88. lsd_score = lsd_score.flatten()
  89. # print(lines.shape)
  90. # print(nfa.shape)
  91. with np.load(pred_name) as fpred:
  92. lcnn_line = fpred["lines"][:, :, :2]
  93. lcnn_score = fpred["score"]
  94. lcnn_line = lcnn_line[:, :, :2]
  95. with np.load(gt_name) as fgt:
  96. gt_line = fgt["lpos"][:, :, :2]
  97. with np.load(afm_name) as fafm:
  98. afm_line = fafm["lines"].reshape(-1, 2, 2)[:, :, ::-1]
  99. afm_score = -fafm["scores"]
  100. h = fafm["h"]
  101. w = fafm["w"]
  102. afm_line[:, :, 0] *= 128 / h
  103. afm_line[:, :, 1] *= 128 / w
  104. for i, ((a, b), s) in enumerate(zip(lcnn_line, lcnn_score)):
  105. if i > 0 and (lcnn_line[i] == lcnn_line[0]).all():
  106. lcnn_line = lcnn_line[:i]
  107. lcnn_score = lcnn_score[:i]
  108. break
  109. # plt.figure("LCNN")
  110. # for a, b in lcnn_line:
  111. # plt.plot([a[1], b[1]], [a[0], b[0]], linewidth=4)
  112. # plt.figure("GT")
  113. # for a, b in gt_line:
  114. # plt.plot([a[1], b[1]], [a[0], b[0]], linewidth=4)
  115. # plt.figure("LSD")
  116. # for a, b in lsd_line:
  117. # plt.plot([a[1], b[1]], [a[0], b[0]], linewidth=4)
  118. # plt.figure("AFM")
  119. # for a, b in afm_line:
  120. # plt.plot([a[1], b[1]], [a[0], b[0]], linewidth=4)
  121. # plt.show()
  122. tp, fp = lcnn.metric.msTPFP(lcnn_line, gt_line, threshold)
  123. lcnn_tp.append(tp)
  124. lcnn_fp.append(fp)
  125. lcnn_scores.append(lcnn_score)
  126. tp, fp = lcnn.metric.msTPFP(lsd_line, gt_line, threshold)
  127. lsd_tp.append(tp)
  128. lsd_fp.append(fp)
  129. lsd_scores.append(lsd_score)
  130. tp, fp = lcnn.metric.msTPFP(afm_line, gt_line, threshold)
  131. afm_tp.append(tp)
  132. afm_fp.append(fp)
  133. afm_scores.append(afm_score)
  134. n_gt += len(gt_line)
  135. lcnn_tp = np.concatenate(lcnn_tp)
  136. lcnn_fp = np.concatenate(lcnn_fp)
  137. lcnn_scores = np.concatenate(lcnn_scores)
  138. lcnn_index = np.argsort(-lcnn_scores)
  139. lcnn_tp = lcnn_tp[lcnn_index]
  140. lcnn_fp = lcnn_fp[lcnn_index]
  141. lcnn_tp = np.cumsum(lcnn_tp) / n_gt
  142. lcnn_fp = np.cumsum(lcnn_fp) / n_gt
  143. lsd_tp = np.concatenate(lsd_tp)
  144. lsd_fp = np.concatenate(lsd_fp)
  145. lsd_scores = np.concatenate(lsd_scores)
  146. lsd_index = np.argsort(-lsd_scores)
  147. lsd_tp = lsd_tp[lsd_index]
  148. lsd_fp = lsd_fp[lsd_index]
  149. lsd_tp = np.cumsum(lsd_tp) / n_gt
  150. lsd_fp = np.cumsum(lsd_fp) / n_gt
  151. afm_tp = np.concatenate(afm_tp)
  152. afm_fp = np.concatenate(afm_fp)
  153. afm_scores = np.concatenate(afm_scores)
  154. afm_index = np.argsort(-afm_scores)
  155. afm_tp = afm_tp[afm_index]
  156. afm_fp = afm_fp[afm_index]
  157. afm_tp = np.cumsum(afm_tp) / n_gt
  158. afm_fp = np.cumsum(afm_fp) / n_gt
  159. lcnn_re, lcnn_pr = lcnn_tp, lcnn_tp / (lcnn_tp + lcnn_fp)
  160. afm_re, afm_pr = afm_tp, afm_tp / (afm_tp + afm_fp)
  161. # lsd_re, lsd_pr = lsd_tp, lsd_tp / (lsd_tp + lsd_fp)
  162. T = 0.005
  163. plt.plot(afm_re[afm_re > T], afm_pr[afm_re > T], label="AFM", linewidth=3, c="C2")
  164. plt.plot(
  165. lcnn_re[lcnn_re > T], lcnn_pr[lcnn_re > T], label="L-CNN", linewidth=3, c="C3"
  166. )
  167. np.savez(
  168. "/data/lcnn/results/sAP/afm.npz", x=afm_re[afm_re > T], y=afm_pr[afm_re > T]
  169. )
  170. np.savez(
  171. "/data/lcnn/results/sAP/lcnn.npz", x=lcnn_re[lcnn_re > T], y=lcnn_pr[lcnn_re > T]
  172. )
  173. # plt.plot(lsd_re, lsd_pr, label="LSD", linewidth=2)
  174. plt.grid(True)
  175. plt.axis([0.0, 1.0, 0.0, 1.0])
  176. plt.xticks(np.arange(0, 1.0, step=0.1))
  177. plt.yticks(np.arange(0, 1.0, step=0.1))
  178. plt.xlabel("Recall")
  179. plt.ylabel("Precision")
  180. plt.legend(loc="upper right")
  181. f_scores = np.linspace(0.2, 0.8, num=8)
  182. for f_score in f_scores:
  183. x = np.linspace(0.01, 1)
  184. y = f_score * x / (2 * x - f_score)
  185. l, = plt.plot(x[y >= 0], y[y >= 0], color="green", alpha=0.3)
  186. plt.annotate("f={0:0.1}".format(f_score), xy=(0.9, y[45] + 0.02), alpha=0.4)
  187. plt.title("PR Curve for sAP10")
  188. plt.savefig("sAP.pdf", format="pdf", bbox_inches="tight")
  189. plt.savefig("sAP.svg", format="svg", bbox_inches="tight")
  190. plt.show()
  191. print(
  192. f"Processing {PRED}:\n"
  193. + f" LSD sAP{threshold}: {lcnn.metric.ap(lsd_tp, lsd_fp)}\n"
  194. + f" AFM sAP{threshold}: {lcnn.metric.ap(afm_tp, afm_fp)}\n"
  195. + f" L-CNN sAP{threshold}: {lcnn.metric.ap(lcnn_tp, lcnn_fp)}"
  196. )
  197. cmap = plt.get_cmap("jet")
  198. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  199. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  200. sm.set_array([])
  201. def c(x):
  202. return sm.to_rgba(x)
  203. if __name__ == "__main__":
  204. plt.tight_layout()
  205. wireframe_score()
  206. line_score()