eval-mAPJ.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. #!/usr/bin/env python3
  2. """Evaluate mAPJ for LCNN, AFM, and Wireframe
  3. Usage:
  4. eval-mAPJ.py <path>...
  5. eval-mAPJ.py (-h | --help )
  6. Examples:
  7. python eval-mAPJ.py logs/*
  8. Arguments:
  9. <path> One or more directories that contain *.npz
  10. Options:
  11. -h --help Show this screen.
  12. """
  13. import glob
  14. import os
  15. import os.path as osp
  16. import re
  17. from collections import defaultdict
  18. import cv2
  19. import matplotlib.pyplot as plt
  20. import numpy as np
  21. from docopt import docopt
  22. from scipy.io import loadmat
  23. import lcnn.models
  24. from lcnn.metric import mAPJ, post_jheatmap
  25. GT = "data/york/valid/*.npz"
  26. IM = "data/wireframe/valid-images/*.jpg"
  27. WF = "/data/lcnn/wirebase/result/junc/2/17"
  28. WF = "/data/lcnn/wirebase/result/wireframe/wireframe_2_rerun-baseline_0.5_0.5_york"
  29. AFM = "/data/lcnn/wirebase/result/wireframe/afm/*.npz"
  30. AFM = "/data/lcnn/logs/york-afm/*.npz"
  31. DIST = [0.5, 1.0, 2.0]
  32. def evaluate_lcnn(im_list, gt_list, lcnn_list):
  33. # define result array to aggregate (n x 3) where 3 is (x, y, score)
  34. all_junc = np.zeros((0, 3))
  35. all_offset_junc = np.zeros((0, 3))
  36. # for each detected junction, which image they correspond to
  37. all_junc_ids = np.zeros(0, dtype=np.int32)
  38. # gt is a list since the variable gt number per image
  39. all_jc_gt = []
  40. for i, (lcnn_fn, gt_fn) in enumerate(zip(lcnn_list, gt_list)):
  41. with np.load(lcnn_fn) as npz:
  42. result = {name: arr for name, arr in npz.items()}
  43. jmap = result["jmap"]
  44. joff = result["joff"]
  45. with np.load(gt_fn) as npz:
  46. junc_gt = npz["junc"][:, :2]
  47. # for j in junc_gt:
  48. # plt.scatter(round(j[1]), round(j[0]), c="red")
  49. # for j in juncs_wf:
  50. # plt.scatter(round(j[1]), round(j[0]), c="blue")
  51. # plt.show()
  52. jun_c = post_jheatmap(jmap[0])
  53. all_junc = np.vstack((all_junc, jun_c))
  54. jun_o_c = post_jheatmap(jmap[0], offset=joff[0])
  55. all_offset_junc = np.vstack((all_offset_junc, jun_o_c))
  56. all_jc_gt.append(junc_gt)
  57. all_junc_ids = np.hstack((all_junc_ids, np.array([i] * len(jun_c))))
  58. # sometimes filter all and concat empty list will change dtype
  59. all_junc_ids = all_junc_ids.astype(np.int64)
  60. ap_jc = mAPJ(all_junc, all_jc_gt, DIST, all_junc_ids)
  61. ap_joc = mAPJ(all_offset_junc, all_jc_gt, DIST, all_junc_ids)
  62. print(f" {ap_jc:.1f} | {ap_joc:.1f}")
  63. def evaluate_wireframe(im_list, gt_list):
  64. print("Compute WF mAP")
  65. juncs_wf = load_wf()
  66. all_junc = np.zeros((0, 3))
  67. all_junc_ids = np.zeros(0, dtype=np.int32)
  68. all_jc_gt = []
  69. for i, (im_fn, gt_fn, junc_wf) in enumerate(zip(im_list, gt_list, juncs_wf)):
  70. im = cv2.imread(im_fn)
  71. im = cv2.resize(im, (128, 128))
  72. with np.load(gt_fn) as npz:
  73. junc_gt = npz["junc"][:, :2]
  74. jun_c = sorted(junc_wf, key=lambda x: -x[2])[:1000]
  75. all_junc = np.vstack((all_junc, jun_c))
  76. all_jc_gt.append(junc_gt)
  77. all_junc_ids = np.hstack((all_junc_ids, np.array([i] * len(jun_c))))
  78. all_junc_ids = all_junc_ids.astype(np.int64)
  79. ap_jc = mAPJ(all_junc, all_jc_gt, DIST, all_junc_ids)
  80. print(f" {ap_jc:.1f}")
  81. def evaluate_afm(im_list, gt_list):
  82. print("Compute AFM mAP")
  83. all_junc = np.zeros((0, 3))
  84. all_junc_ids = np.zeros(0, dtype=np.int32)
  85. all_jc_gt = []
  86. afm = glob.glob(AFM)
  87. afm.sort()
  88. for i, (im_fn, gt_fn, afm_fn) in enumerate(zip(im_list, gt_list, afm)):
  89. im = cv2.imread(im_fn)
  90. im = cv2.resize(im, (128, 128))
  91. with np.load(gt_fn) as npz:
  92. junc_gt = npz["junc"][:, :2]
  93. with np.load(afm_fn) as fafm:
  94. afm_line = fafm["lines"]
  95. afm_score = fafm["score"]
  96. jun_c = []
  97. # plt.imshow(im)
  98. for line, score in zip(afm_line, afm_score):
  99. jun_c.append(list(line[0]) + [score])
  100. jun_c.append(list(line[1]) + [score])
  101. # plt.plot([line[0][1], line[1][1]], [line[0][0], line[1][0]], c="blue")
  102. # for line in gt_line:
  103. # plt.plot([line[0][1], line[1][1]], [line[0][0], line[1][0]], c="red")
  104. # plt.show()
  105. jun_c = np.array(jun_c)
  106. all_junc = np.vstack((all_junc, jun_c))
  107. all_jc_gt.append(junc_gt)
  108. all_junc_ids = np.hstack((all_junc_ids, np.array([i] * len(jun_c))))
  109. all_junc_ids = all_junc_ids.astype(np.int64)
  110. ap_jc = mAPJ(all_junc, all_jc_gt, DIST, all_junc_ids)
  111. print(f" {ap_jc:.1f}")
  112. def load_wf():
  113. pts = [defaultdict(int) for _ in range(102)]
  114. for thres in sorted(
  115. [
  116. 0.001,
  117. 0.01,
  118. 0.1,
  119. 0.5,
  120. 2,
  121. 6,
  122. 10,
  123. 20,
  124. 30,
  125. 50,
  126. 80,
  127. 100,
  128. 150,
  129. 200,
  130. 400,
  131. 800,
  132. 3200,
  133. 1600,
  134. 6400,
  135. 51200,
  136. 12800,
  137. 25600,
  138. 102400,
  139. 204800,
  140. ]
  141. ):
  142. mats = sorted(glob.glob(f"{WF}/{thres}/*.mat"))
  143. for i, mat in enumerate(mats):
  144. juncs = loadmat(mat)["lines"].reshape(-1, 2)
  145. if len(juncs) == 0:
  146. continue
  147. juncs *= 128 / 512
  148. for j in juncs:
  149. pts[i][tuple(j)] += 1
  150. pts = pts[: len(mats)]
  151. return [np.array([(k[1], k[0], v) for k, v in ipts.items()]) for ipts in pts]
  152. def main():
  153. # args = docopt(__doc__)
  154. gt_list = sorted(glob.glob(GT))
  155. im_list = sorted(glob.glob(IM))
  156. evaluate_afm(im_list, gt_list)
  157. # evaluate_wireframe(im_list, gt_list)
  158. # for path in args["<path>"]:
  159. # print("Evaluating", path)
  160. # lcnn_list = sorted(glob.glob(osp.join(path, "*.npz")))
  161. # evaluate_lcnn(im_list, gt_list, lcnn_list)
  162. if __name__ == "__main__":
  163. main()