eval-mAPJ.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #!/usr/bin/env python3
  2. """Evaluate mAPJ for LCNN, AFM, and Wireframe
  3. Usage:
  4. eval-mAPJ.py <path>...
  5. eval-mAPJ.py (-h | --help )
  6. Examples:
  7. python eval-mAPJ.py logs/*
  8. Arguments:
  9. <path> One or more directories that contain *.npz
  10. Options:
  11. -h --help Show this screen.
  12. """
  13. import os
  14. import re
  15. import glob
  16. import os.path as osp
  17. from collections import defaultdict
  18. import cv2
  19. import numpy as np
  20. import matplotlib.pyplot as plt
  21. from docopt import docopt
  22. from scipy.io import loadmat
  23. import lcnn.models
  24. from lcnn.metric import mAPJ, post_jheatmap
  25. GT = "data/wireframe/valid/*.npz"
  26. IM = "data/wireframe/valid-images/*.jpg"
  27. WF = "/data/wirebase/result/junc/2/17"
  28. AFM = "/data/wirebase/result/wireframe/afm/*.npz"
  29. DIST = [0.5, 1.0, 2.0]
  30. def evaluate_lcnn(im_list, gt_list, lcnn_list):
  31. # define result array to aggregate (n x 3) where 3 is (x, y, score)
  32. all_junc = np.zeros((0, 3))
  33. all_offset_junc = np.zeros((0, 3))
  34. # for each detected junction, which image they correspond to
  35. all_junc_ids = np.zeros(0, dtype=np.int32)
  36. # gt is a list since the variable gt number per image
  37. all_jc_gt = []
  38. for i, (lcnn_fn, gt_fn) in enumerate(zip(lcnn_list, gt_list)):
  39. with np.load(lcnn_fn) as npz:
  40. result = {name: arr for name, arr in npz.items()}
  41. jmap = result["jmap"]
  42. joff = result["joff"]
  43. with np.load(gt_fn) as npz:
  44. junc_gt = npz["junc"][:, :2]
  45. # for j in junc_gt:
  46. # plt.scatter(round(j[1]), round(j[0]), c="red")
  47. # for j in juncs_wf:
  48. # plt.scatter(round(j[1]), round(j[0]), c="blue")
  49. # plt.show()
  50. jun_c = post_jheatmap(jmap[0])
  51. all_junc = np.vstack((all_junc, jun_c))
  52. jun_o_c = post_jheatmap(jmap[0], offset=joff[0])
  53. all_offset_junc = np.vstack((all_offset_junc, jun_o_c))
  54. all_jc_gt.append(junc_gt)
  55. all_junc_ids = np.hstack((all_junc_ids, np.array([i] * len(jun_c))))
  56. # sometimes filter all and concat empty list will change dtype
  57. all_junc_ids = all_junc_ids.astype(np.int64)
  58. ap_jc = mAPJ(all_junc, all_jc_gt, DIST, all_junc_ids)
  59. ap_joc = mAPJ(all_offset_junc, all_jc_gt, DIST, all_junc_ids)
  60. print(f" {ap_jc:.1f} | {ap_joc:.1f}")
  61. def evaluate_wireframe(im_list, gt_list, juncs_wf):
  62. print("Compute WF mAP")
  63. juncs_wf = load_wf()
  64. all_junc = np.zeros((0, 3))
  65. all_junc_ids = np.zeros(0, dtype=np.int32)
  66. all_jc_gt = []
  67. for i, (im_fn, gt_fn, junc_wf) in enumerate(zip(im_list, gt_list, juncs_wf)):
  68. im = cv2.imread(im_fn)
  69. im = cv2.resize(im, (128, 128))
  70. with np.load(gt_fn) as npz:
  71. junc_gt = npz["junc"][:, :2]
  72. jun_c = sorted(junc_wf, key=lambda x: -x[2])[:1000]
  73. all_junc = np.vstack((all_junc, jun_c))
  74. all_jc_gt.append(junc_gt)
  75. all_junc_ids = np.hstack((all_junc_ids, np.array([i] * len(jun_c))))
  76. all_junc_ids = all_junc_ids.astype(np.int64)
  77. ap_jc = mAPJ(all_junc, all_jc_gt, DIST, all_junc_ids)
  78. print(f" {ap_jc:.1f}")
  79. def evaluate_afm(im_list, gt_list, afm):
  80. print("Compute AFM mAP")
  81. all_junc = np.zeros((0, 3))
  82. all_junc_ids = np.zeros(0, dtype=np.int32)
  83. all_jc_gt = []
  84. afm = glob.glob(AFM)
  85. afm.sort()
  86. for i, (im_fn, gt_fn, afm_fn) in enumerate(zip(im_list, gt_list, afm)):
  87. im = cv2.imread(im_fn)
  88. im = cv2.resize(im, (128, 128))
  89. with np.load(gt_fn) as npz:
  90. junc_gt = npz["junc"][:, :2]
  91. with np.load(afm_fn) as fafm:
  92. afm_line = fafm["lines"].reshape(-1, 2, 2)[:, :, ::-1]
  93. afm_score = -fafm["scores"]
  94. h = fafm["h"]
  95. w = fafm["w"]
  96. afm_line[:, :, 0] *= 128 / h
  97. afm_line[:, :, 1] *= 128 / w
  98. jun_c = []
  99. for line, score in zip(afm_line, afm_score):
  100. jun_c.append(list(line[0]) + [score])
  101. jun_c.append(list(line[1]) + [score])
  102. jun_c = np.array(jun_c)
  103. all_junc = np.vstack((all_junc, jun_c))
  104. all_jc_gt.append(junc_gt)
  105. all_junc_ids = np.hstack((all_junc_ids, np.array([i] * len(jun_c))))
  106. all_junc_ids = all_junc_ids.astype(np.int64)
  107. ap_jc = mAPJ(all_junc, all_jc_gt, DIST, all_junc_ids)
  108. print(f" {ap_jc:.1f}")
  109. def load_wf():
  110. pts = [defaultdict(int) for _ in range(500)]
  111. for thres in range(10):
  112. mats = sorted(glob.glob(f"{WF}/{thres}/*.mat"))
  113. for i, mat in enumerate(mats):
  114. img = cv2.imread(mat.replace(".mat", "_5.png"))
  115. juncs = loadmat(mat)["junctions"]
  116. if len(juncs) == 0:
  117. continue
  118. juncs[:, 0] *= 128 / img.shape[1]
  119. juncs[:, 1] *= 128 / img.shape[0]
  120. # juncs += 0.5
  121. for j in juncs:
  122. pts[i][tuple(j)] += 1
  123. pts = pts[: len(mats)]
  124. return [np.array([(k[1], k[0], v) for k, v in ipts.items()]) for ipts in pts]
  125. def main():
  126. args = docopt(__doc__)
  127. gt_list = sorted(glob.glob(GT))
  128. im_list = sorted(glob.glob(IM))
  129. for path in args["<path>"]:
  130. print("Evaluating", path)
  131. lcnn_list = sorted(glob.glob(osp.join(path, "*.npz")))
  132. evaluate_lcnn(im_list, gt_list, lcnn_list)
  133. if __name__ == "__main__":
  134. main()