|
|
@@ -1,13 +1,13 @@
|
|
|
#!/usr/bin/env python3
|
|
|
"""Process an image with the trained neural network
|
|
|
Usage:
|
|
|
- demo.py [options] <yaml-config> <checkpoint> <image>
|
|
|
+ demo.py [options] <yaml-config> <checkpoint> <images>...
|
|
|
demo.py (-h | --help )
|
|
|
|
|
|
Arguments:
|
|
|
<yaml-config> Path to the yaml hyper-parameter file
|
|
|
<checkpoint> Path to the checkpoint
|
|
|
- <image> Path to the directory containing processed images
|
|
|
+ <images> Path to images
|
|
|
|
|
|
Options:
|
|
|
-h --help Show this screen.
|
|
|
@@ -83,55 +83,62 @@ def main():
|
|
|
model = model.to(device)
|
|
|
model.eval()
|
|
|
|
|
|
- im = skimage.io.imread(args["<image>"])[:, :, :3]
|
|
|
- im_resized = skimage.transform.resize(im, (512, 512)) * 255
|
|
|
- image = (im_resized - M.image.mean) / M.image.stddev
|
|
|
- image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
|
|
|
- with torch.no_grad():
|
|
|
- input_dict = {
|
|
|
- "image": image.to(device),
|
|
|
- "meta": [
|
|
|
- {
|
|
|
- "junc": torch.zeros(1, 2).to(device),
|
|
|
- "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
|
|
|
- "Lpos": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
- "Lneg": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
- }
|
|
|
- ],
|
|
|
- "target": {
|
|
|
- "jmap": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
- "joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
- },
|
|
|
- "do_evaluation": True,
|
|
|
- }
|
|
|
- H = model(input_dict)["preds"]
|
|
|
-
|
|
|
- lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
|
|
|
- scores = H["score"][0].cpu().numpy()
|
|
|
- for i in range(1, len(lines)):
|
|
|
- if (lines[i] == lines[0]).all():
|
|
|
- lines = lines[:i]
|
|
|
- scores = scores[:i]
|
|
|
- break
|
|
|
-
|
|
|
- # postprocess lines to remove overlapped lines
|
|
|
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
- nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
|
|
|
-
|
|
|
- plt.gca().set_axis_off()
|
|
|
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
|
- plt.margins(0, 0)
|
|
|
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
|
|
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
|
|
- for i, t in enumerate([0.95, 0.96, 0.97, 0.98, 0.99]):
|
|
|
- for (a, b), s in zip(nlines, nscores):
|
|
|
- if s < t:
|
|
|
- continue
|
|
|
- plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
|
|
|
- plt.scatter(a[1], a[0], **PLTOPTS)
|
|
|
- plt.scatter(b[1], b[0], **PLTOPTS)
|
|
|
- plt.imshow(im)
|
|
|
- plt.show()
|
|
|
+ for imname in args["<images>"]:
|
|
|
+ print(f"Processing {imname}")
|
|
|
+ im = skimage.io.imread(imname)
|
|
|
+ if im.ndim == 2:
|
|
|
+ im = np.repeat(im[:, :, None], 3, 2)
|
|
|
+ im = im[:, :, :3]
|
|
|
+ im_resized = skimage.transform.resize(im, (512, 512)) * 255
|
|
|
+ image = (im_resized - M.image.mean) / M.image.stddev
|
|
|
+ image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
|
|
|
+ with torch.no_grad():
|
|
|
+ input_dict = {
|
|
|
+ "image": image.to(device),
|
|
|
+ "meta": [
|
|
|
+ {
|
|
|
+ "junc": torch.zeros(1, 2).to(device),
|
|
|
+ "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
|
|
|
+ "Lpos": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
+ "Lneg": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "target": {
|
|
|
+ "jmap": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
+ "joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
+ },
|
|
|
+ "do_evaluation": True,
|
|
|
+ }
|
|
|
+ H = model(input_dict)["preds"]
|
|
|
+
|
|
|
+ lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
|
|
|
+ scores = H["score"][0].cpu().numpy()
|
|
|
+ for i in range(1, len(lines)):
|
|
|
+ if (lines[i] == lines[0]).all():
|
|
|
+ lines = lines[:i]
|
|
|
+ scores = scores[:i]
|
|
|
+ break
|
|
|
+
|
|
|
+ # postprocess lines to remove overlapped lines
|
|
|
+ diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
+ nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
|
|
|
+
|
|
|
+ for i, t in enumerate([0.94, 0.95, 0.96, 0.97, 0.98, 0.99]):
|
|
|
+ plt.gca().set_axis_off()
|
|
|
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
|
+ plt.margins(0, 0)
|
|
|
+ for (a, b), s in zip(nlines, nscores):
|
|
|
+ if s < t:
|
|
|
+ continue
|
|
|
+ plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
|
|
|
+ plt.scatter(a[1], a[0], **PLTOPTS)
|
|
|
+ plt.scatter(b[1], b[0], **PLTOPTS)
|
|
|
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
|
|
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
|
|
+ plt.imshow(im)
|
|
|
+ plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"), bbox_inches="tight")
|
|
|
+ plt.show()
|
|
|
+ plt.close()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|