| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- #!/usr/bin/env python3
- """Post-processing the output of neural network
- Usage:
- post.py [options] <input-dir> <output-dir>
- post.py ( -h | --help )
- Examples:
- post.py logs/logname/npz/000336000 result/logname
- Arguments:
- input-dir Directory that stores the npz
- output-dir Output directory
- Options:
- -h --help Show this screen.
- --plot Generate images besides npz files
- --thresholds=<thresholds> A comma-separated list for thresholding
- [default: 0.006,0.010,0.015]
- """
- import glob
- import math
- import os
- import os.path as osp
- import sys
- import cv2
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- import numpy as np
- from docopt import docopt
- from lcnn.postprocess import postprocess
- from lcnn.utils import parmap
- PLTOPTS = {"color": "#33FFFF", "s": 1.2, "edgecolors": "none", "zorder": 5}
- cmap = plt.get_cmap("jet")
- norm = mpl.colors.Normalize(vmin=0.92, vmax=1.02)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
- sm.set_array([])
- def c(x):
- return sm.to_rgba(x)
- def imshow(im):
- plt.close()
- sizes = im.shape
- height = float(sizes[0])
- width = float(sizes[1])
- fig = plt.figure()
- fig.set_size_inches(width / height, 1, forward=False)
- ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
- ax.set_axis_off()
- fig.add_axes(ax)
- plt.xlim([-0.5, sizes[1] - 0.5])
- plt.ylim([sizes[0] - 0.5, -0.5])
- plt.imshow(im)
- def main():
- args = docopt(__doc__)
- files = sorted(glob.glob(osp.join(args["<input-dir>"], "*.npz")))
- inames = sorted(glob.glob("data/wireframe/valid-images/*.jpg"))
- gts = sorted(glob.glob("data/wireframe/valid/*.npz"))
- prefix = args["<output-dir>"]
- inputs = list(zip(files, inames, gts))
- thresholds = list(map(float, args["--thresholds"].split(",")))
- def handle(allname):
- fname, iname, gtname = allname
- print("Processing", fname)
- im = cv2.imread(iname)
- with np.load(fname) as f:
- lines = f["lines"]
- scores = f["score"]
- with np.load(gtname) as f:
- gtlines = f["lpos"][:, :, :2]
- gtlines[:, :, 0] *= im.shape[0] / 128
- gtlines[:, :, 1] *= im.shape[1] / 128
- for i in range(1, len(lines)):
- if (lines[i] == lines[0]).all():
- lines = lines[:i]
- scores = scores[:i]
- break
- lines[:, :, 0] *= im.shape[0] / 128
- lines[:, :, 1] *= im.shape[1] / 128
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- for threshold in thresholds:
- nlines, nscores = postprocess(lines, scores, diag * threshold, 0, False)
- outdir = osp.join(prefix, f"{threshold:.3f}".replace(".", "_"))
- os.makedirs(outdir, exist_ok=True)
- npz_name = osp.join(outdir, osp.split(fname)[-1])
- if args["--plot"]:
- # plot gt
- imshow(im[:, :, ::-1])
- for (a, b) in gtlines:
- plt.plot([a[1], b[1]], [a[0], b[0]], c="orange", linewidth=0.5)
- plt.scatter(a[1], a[0], **PLTOPTS)
- plt.scatter(b[1], b[0], **PLTOPTS)
- plt.savefig(npz_name.replace(".npz", ".png"), dpi=500, bbox_inches=0)
- thres = [0.96, 0.97, 0.98, 0.99]
- for i, t in enumerate(thres):
- imshow(im[:, :, ::-1])
- for (a, b), s in zip(nlines[nscores > t], nscores[nscores > t]):
- plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=0.5)
- plt.scatter(a[1], a[0], **PLTOPTS)
- plt.scatter(b[1], b[0], **PLTOPTS)
- plt.savefig(
- npz_name.replace(".npz", f"_{i}.png"), dpi=500, bbox_inches=0
- )
- nlines[:, :, 0] *= 128 / im.shape[0]
- nlines[:, :, 1] *= 128 / im.shape[1]
- np.savez_compressed(npz_name, lines=nlines, score=nscores)
- parmap(handle, inputs, 12)
- if __name__ == "__main__":
- main()
|