Jelajahi Sumber

More comments on the data format. Closes #2

Yichao Zhou 7 tahun lalu
induk
melakukan
7aee92c1c8
2 mengubah file dengan 25 tambahan dan 16 penghapusan
  1. 11 9
      dataset/wireframe.py
  2. 14 7
      lcnn/datasets.py

+ 11 - 9
dataset/wireframe.py

@@ -91,7 +91,6 @@ def save_heatmap(prefix, image, lines):
             vint0, vint1 = to_int(v0[:2] / 2), to_int(v1[:2] / 2)
             rr, cc, value = skimage.draw.line_aa(*vint0, *vint1)
             lneg.append([v0, v1, i0, i1, np.average(np.minimum(value, llmap[rr, cc]))])
-            # assert np.sum((v0 - v1) ** 2) > 0.01
 
     assert len(lneg) != 0
     lneg.sort(key=lambda l: -l[-1])
@@ -115,17 +114,20 @@ def save_heatmap(prefix, image, lines):
     #     plt.plot([junc[i0][1], junc[i1][1]], [junc[i0][0], junc[i1][0]])
     # plt.show()
 
+    # For junc, lpos, and lneg that stores the junction coordinates, the last
+    # dimension is (y, x, t), where t represents the type of that junction.  In
+    # the wireframe dataset, t is always zero.
     np.savez_compressed(
         f"{prefix}_label.npz",
         aspect_ratio=image.shape[1] / image.shape[0],
-        jmap=jmap,  # [J, H, W]
-        joff=joff,  # [J, 2, H, W]
-        lmap=lmap,  # [H, W]
-        junc=junc,  # [Na, 3]
-        Lpos=Lpos,  # [M, 2]
-        Lneg=Lneg,  # [M, 2]
-        lpos=lpos,  # [Np, 2, 3]   (y, x, t) for the last dim
-        lneg=lneg,  # [Nn, 2, 3]
+        jmap=jmap,  # [J, H, W]    Junction heat map
+        joff=joff,  # [J, 2, H, W] Junction offset within each pixel
+        lmap=lmap,  # [H, W]       Line heat map with anti-aliasing
+        junc=junc,  # [Na, 3]      Junction coordinate
+        Lpos=Lpos,  # [M, 2]       Positive lines represented with junction indices
+        Lneg=Lneg,  # [M, 2]       Negative lines represented with junction indices
+        lpos=lpos,  # [Np, 2, 3]   Positive lines represented with junction coordinates
+        lneg=lneg,  # [Nn, 2, 3]   Negative lines represented with junction coordinates
     )
     cv2.imwrite(f"{prefix}.png", image)
 

+ 14 - 7
lcnn/datasets.py

@@ -1,12 +1,12 @@
-import os
 import glob
 import json
 import math
+import os
 import random
 
 import numpy as np
-import torch
 import numpy.linalg as LA
+import torch
 from skimage import io
 from torch.utils.data import Dataset
 from torch.utils.data.dataloader import default_collate
@@ -15,11 +15,7 @@ from lcnn.config import M
 
 
 class WireframeDataset(Dataset):
-    def __init__(
-        self,
-        rootdir,
-        split,
-    ):
+    def __init__(self, rootdir, split):
         self.rootdir = rootdir
         filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
         filelist.sort()
@@ -39,6 +35,17 @@ class WireframeDataset(Dataset):
         image = (image - M.image.mean) / M.image.stddev
         image = np.rollaxis(image, 2).copy()
 
+        # npz["jmap"]: [J, H, W]    Junction heat map
+        # npz["joff"]: [J, 2, H, W] Junction offset within each pixel
+        # npz["lmap"]: [H, W]       Line heat map with anti-aliasing
+        # npz["junc"]: [Na, 3]      Junction coordinates
+        # npz["Lpos"]: [M, 2]       Positive lines represented with junction indices
+        # npz["Lneg"]: [M, 2]       Negative lines represented with junction indices
+        # npz["lpos"]: [Np, 2, 3]   Positive lines represented with junction coordinates
+        # npz["lneg"]: [Nn, 2, 3]   Negative lines represented with junction coordinates
+        #
+        # For junc, lpos, and lneg that stores the junction coordinates, the last
+        # dimension is (y, x, t), where t represents the type of that junction.
         with np.load(self.filelist[idx]) as npz:
             target = {
                 name: torch.from_numpy(npz[name]).float()