소스 검색

Update demo.py

Yichao Zhou 6 년 전
부모
커밋
ce9895ef8d
2개의 변경된 파일60개의 추가작업 그리고 51개의 파일을 삭제
  1. 58 51
      demo.py
  2. 2 0
      lcnn/models/line_vectorizer.py

+ 58 - 51
demo.py

@@ -1,13 +1,13 @@
 #!/usr/bin/env python3
 """Process an image with the trained neural network
 Usage:
-    demo.py [options] <yaml-config> <checkpoint> <image>
+    demo.py [options] <yaml-config> <checkpoint> <images>...
     demo.py (-h | --help )
 
 Arguments:
    <yaml-config>                 Path to the yaml hyper-parameter file
    <checkpoint>                  Path to the checkpoint
-   <image>                       Path to the directory containing processed images
+   <images>                      Path to images
 
 Options:
    -h --help                     Show this screen.
@@ -83,55 +83,62 @@ def main():
     model = model.to(device)
     model.eval()
 
-    im = skimage.io.imread(args["<image>"])[:, :, :3]
-    im_resized = skimage.transform.resize(im, (512, 512)) * 255
-    image = (im_resized - M.image.mean) / M.image.stddev
-    image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
-    with torch.no_grad():
-        input_dict = {
-            "image": image.to(device),
-            "meta": [
-                {
-                    "junc": torch.zeros(1, 2).to(device),
-                    "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
-                    "Lpos": torch.zeros(2, 2, dtype=torch.uint8).to(device),
-                    "Lneg": torch.zeros(2, 2, dtype=torch.uint8).to(device),
-                }
-            ],
-            "target": {
-                "jmap": torch.zeros([1, 1, 128, 128]).to(device),
-                "joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
-            },
-            "do_evaluation": True,
-        }
-        H = model(input_dict)["preds"]
-
-    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
-    scores = H["score"][0].cpu().numpy()
-    for i in range(1, len(lines)):
-        if (lines[i] == lines[0]).all():
-            lines = lines[:i]
-            scores = scores[:i]
-            break
-
-    # postprocess lines to remove overlapped lines
-    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
-    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
-
-    plt.gca().set_axis_off()
-    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
-    plt.margins(0, 0)
-    plt.gca().xaxis.set_major_locator(plt.NullLocator())
-    plt.gca().yaxis.set_major_locator(plt.NullLocator())
-    for i, t in enumerate([0.95, 0.96, 0.97, 0.98, 0.99]):
-        for (a, b), s in zip(nlines, nscores):
-            if s < t:
-                continue
-            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
-            plt.scatter(a[1], a[0], **PLTOPTS)
-            plt.scatter(b[1], b[0], **PLTOPTS)
-        plt.imshow(im)
-        plt.show()
+    for imname in args["<images>"]:
+        print(f"Processing {imname}")
+        im = skimage.io.imread(imname)
+        if im.ndim == 2:
+            im = np.repeat(im[:, :, None], 3, 2)
+        im = im[:, :, :3]
+        im_resized = skimage.transform.resize(im, (512, 512)) * 255
+        image = (im_resized - M.image.mean) / M.image.stddev
+        image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
+        with torch.no_grad():
+            input_dict = {
+                "image": image.to(device),
+                "meta": [
+                    {
+                        "junc": torch.zeros(1, 2).to(device),
+                        "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
+                        "Lpos": torch.zeros(2, 2, dtype=torch.uint8).to(device),
+                        "Lneg": torch.zeros(2, 2, dtype=torch.uint8).to(device),
+                    }
+                ],
+                "target": {
+                    "jmap": torch.zeros([1, 1, 128, 128]).to(device),
+                    "joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
+                },
+                "do_evaluation": True,
+            }
+            H = model(input_dict)["preds"]
+
+        lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+        scores = H["score"][0].cpu().numpy()
+        for i in range(1, len(lines)):
+            if (lines[i] == lines[0]).all():
+                lines = lines[:i]
+                scores = scores[:i]
+                break
+
+        # postprocess lines to remove overlapped lines
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+        for i, t in enumerate([0.94, 0.95, 0.96, 0.97, 0.98, 0.99]):
+            plt.gca().set_axis_off()
+            plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+            plt.margins(0, 0)
+            for (a, b), s in zip(nlines, nscores):
+                if s < t:
+                    continue
+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+                plt.scatter(a[1], a[0], **PLTOPTS)
+                plt.scatter(b[1], b[0], **PLTOPTS)
+            plt.gca().xaxis.set_major_locator(plt.NullLocator())
+            plt.gca().yaxis.set_major_locator(plt.NullLocator())
+            plt.imshow(im)
+            plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"), bbox_inches="tight")
+            plt.show()
+            plt.close()
 
 
 if __name__ == "__main__":

+ 2 - 0
lcnn/models/line_vectorizer.py

@@ -162,6 +162,8 @@ class LineVectorizer(nn.Module):
                 K = min(int((jmap > M.eval_junc_thres).float().sum().item()), max_K)
             else:
                 K = min(int(N * 2 + 2), max_K)
+            if K < 2:
+                K = 2
             device = jmap.device
 
             # index: [N_TYPE, K]