Yichao Zhou před 6 roky
rodič
revize
90432d88f5
6 změnil soubory, kde provedl 491 přidání a 484 odebrání
  1. 1 1
      README.md
  2. 3 0
      config/wireframe.yaml
  3. 467 478
      figs/PR-sAP10.svg
  4. 4 3
      lcnn/models/line_vectorizer.py
  5. 15 1
      misc/plot-sAP.py
  6. 1 1
      post.py

+ 1 - 1
README.md

@@ -23,7 +23,7 @@ The following table reports the performance metrics of several wireframe and lin
 | [LSD](https://ieeexplore.ieee.org/document/4731268/) |                 /                  |              52.0             |              61.0                |                /                 |              
 |  [AFM](https://github.com/cherubicXN/afm_cvpr2019)   |                24.4                |              69.5               |              77.2              |               23.3               |           
 | [Wireframe](https://github.com/huangkuns/wireframe)  |                5.1                 |              67.8               |              72.6              |               40.9               |              
-|                      **L-CNN**                       |              **62.9**              |            **83.0**             |            **81.6**            |             **59.3**             |               
+|                      **L-CNN**                       |              **62.9**              |            **82.8**             |            **81.2**            |             **59.3**             |               
 
 ### Precision-Recall Curves
 <p align="center">

+ 3 - 0
config/wireframe.yaml

@@ -56,6 +56,9 @@ model:
   use_slop: 0
   use_conv: 0
 
+  # junction threashold for evaluation (See #5)
+  eval_junc_thres: 0.008
+
 optim:
   name: Adam
   lr: 4.0e-4

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 467 - 478
figs/PR-sAP10.svg


+ 4 - 3
lcnn/models/line_vectorizer.py

@@ -156,11 +156,12 @@ class LineVectorizer(nn.Module):
             n_type = jmap.shape[0]
             jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
             joff = joff.reshape(n_type, 2, -1)
+            max_K = M.n_dyn_junc // n_type
             N = len(junc)
             if do_evaluation:
-                K = 440
+                K = min(int((jmap > M.eval_junc_thres).float().sum().item()), 2 * max_K)
             else:
-                K = min(int(N * 2 + 2), M.n_dyn_junc // n_type)
+                K = min(int(N * 2 + 2), max_K)
             device = jmap.device
 
             # index: [N_TYPE, K]
@@ -184,7 +185,7 @@ class LineVectorizer(nn.Module):
             match[cost > 1.5 * 1.5] = N
             match = match.flatten()
 
-            _ = torch.arange(len(match), device=device)
+            _ = torch.arange(n_type * K, device=device)
             u, v = torch.meshgrid(_, _)
             u, v = u.flatten(), v.flatten()
             up, vp = match[u], match[v]

+ 15 - 1
misc/plot-sAP.py

@@ -21,6 +21,7 @@ except Exception:
 
 # Change the directory here
 PRED = "logs/190418-201834-f8934c6-lr4d10/npz/000312000/*.npz"
+PRED = "post/jmap_0008/*.npz"
 GT = "data/wireframe/valid/*.npz"
 # PRED = "logs/190506-001532-york/*.npz"
 # GT = "data/york/valid/*.npz"
@@ -69,6 +70,11 @@ def wireframe_score(T=10):
     i = np.where(recall[1:] != recall[:-1])[0]
     ap = np.sum((recall[i + 1] - recall[i]) * precision[i + 1])
 
+    np.savez(
+        "/data/lcnn/results/sAP/wireframe.npz",
+        x=np.maximum(0.005, recall[:-1]),
+        y=precision[:-1],
+    )
     plt.plot(
         np.maximum(0.005, recall[:-1]),
         precision[:-1],
@@ -183,7 +189,15 @@ def line_score(threshold=10):
 
     T = 0.005
     plt.plot(afm_re[afm_re > T], afm_pr[afm_re > T], label="AFM", linewidth=3, c="C2")
-    plt.plot(lcnn_re[lcnn_re > T], lcnn_pr[lcnn_re > T], label="L-CNN", linewidth=3, c="C3")
+    plt.plot(
+        lcnn_re[lcnn_re > T], lcnn_pr[lcnn_re > T], label="L-CNN", linewidth=3, c="C3"
+    )
+    np.savez(
+        "/data/lcnn/results/sAP/afm.npz", x=afm_re[afm_re > T], y=afm_pr[afm_re > T]
+    )
+    np.savez(
+        "/data/lcnn/results/sAP/lcnn.npz", x=lcnn_re[lcnn_re > T], y=lcnn_pr[lcnn_re > T]
+    )
     # plt.plot(lsd_re, lsd_pr, label="LSD", linewidth=2)
 
     plt.grid(True)

+ 1 - 1
post.py

@@ -183,7 +183,7 @@ def main():
                     plt.scatter(b[1], b[0], **PLTOPTS)
                 plt.savefig(npz_name.replace(".npz", ".png"), dpi=500, bbox_inches=0)
 
-                thres = [0.97, 0.98, 0.99]
+                thres = [0.96, 0.97, 0.98, 0.99]
                 for i, t in enumerate(thres):
                     imshow(im[:, :, ::-1])
                     for (a, b), s in zip(nlines[nscores > t], nscores[nscores > t]):

Některé soubory nejsou zobrazeny, neboť je v těchto rozdílových datech změněno mnoho souborů