demo.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #!/usr/bin/env python3
  2. """Process an image with the trained neural network
  3. Usage:
  4. demo.py [options] <yaml-config> <checkpoint> <images>...
  5. demo.py (-h | --help )
  6. Arguments:
  7. <yaml-config> Path to the yaml hyper-parameter file
  8. <checkpoint> Path to the checkpoint
  9. <images> Path to images
  10. Options:
  11. -h --help Show this screen.
  12. -d --devices <devices> Comma seperated GPU devices [default: 0]
  13. """
  14. import os
  15. import os.path as osp
  16. import pprint
  17. import random
  18. import matplotlib as mpl
  19. import matplotlib.pyplot as plt
  20. import numpy as np
  21. import skimage.io
  22. import skimage.transform
  23. import torch
  24. import yaml
  25. from docopt import docopt
  26. import lcnn
  27. from lcnn.config import C, M
  28. from lcnn.models.line_vectorizer import LineVectorizer
  29. from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
  30. from lcnn.postprocess import postprocess
  31. from lcnn.utils import recursive_to
  32. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  33. cmap = plt.get_cmap("jet")
  34. norm = mpl.colors.Normalize(vmin=0.9, vmax=1.0)
  35. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  36. sm.set_array([])
  37. def c(x):
  38. return sm.to_rgba(x)
  39. def main():
  40. args = docopt(__doc__)
  41. config_file = args["<yaml-config>"] or "config/wireframe.yaml"
  42. C.update(C.from_yaml(filename=config_file))
  43. M.update(C.model)
  44. pprint.pprint(C, indent=4)
  45. random.seed(0)
  46. np.random.seed(0)
  47. torch.manual_seed(0)
  48. device_name = "cpu"
  49. os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
  50. if torch.cuda.is_available():
  51. device_name = "cuda"
  52. torch.backends.cudnn.deterministic = True
  53. torch.cuda.manual_seed(0)
  54. print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  55. else:
  56. print("CUDA is not available")
  57. device = torch.device(device_name)
  58. checkpoint = torch.load(args["<checkpoint>"], map_location=device)
  59. # Load model
  60. model = lcnn.models.hg(
  61. depth=M.depth,
  62. head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
  63. num_stacks=M.num_stacks,
  64. num_blocks=M.num_blocks,
  65. num_classes=sum(sum(M.head_size, [])),
  66. )
  67. model = MultitaskLearner(model)
  68. model = LineVectorizer(model)
  69. model.load_state_dict(checkpoint["model_state_dict"])
  70. model = model.to(device)
  71. model.eval()
  72. for imname in args["<images>"]:
  73. print(f"Processing {imname}")
  74. im = skimage.io.imread(imname)
  75. if im.ndim == 2:
  76. im = np.repeat(im[:, :, None], 3, 2)
  77. im = im[:, :, :3]
  78. im_resized = skimage.transform.resize(im, (512, 512)) * 255
  79. image = (im_resized - M.image.mean) / M.image.stddev
  80. image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
  81. with torch.no_grad():
  82. input_dict = {
  83. "image": image.to(device),
  84. "meta": [
  85. {
  86. "junc": torch.zeros(1, 2).to(device),
  87. "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
  88. "Lpos": torch.zeros(2, 2, dtype=torch.uint8).to(device),
  89. "Lneg": torch.zeros(2, 2, dtype=torch.uint8).to(device),
  90. }
  91. ],
  92. "target": {
  93. "jmap": torch.zeros([1, 1, 128, 128]).to(device),
  94. "joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
  95. },
  96. "mode": "testing",
  97. }
  98. H = model(input_dict)["preds"]
  99. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  100. scores = H["score"][0].cpu().numpy()
  101. for i in range(1, len(lines)):
  102. if (lines[i] == lines[0]).all():
  103. lines = lines[:i]
  104. scores = scores[:i]
  105. break
  106. # postprocess lines to remove overlapped lines
  107. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  108. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  109. for i, t in enumerate([0.94, 0.95, 0.96, 0.97, 0.98, 0.99]):
  110. plt.gca().set_axis_off()
  111. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  112. plt.margins(0, 0)
  113. for (a, b), s in zip(nlines, nscores):
  114. if s < t:
  115. continue
  116. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  117. plt.scatter(a[1], a[0], **PLTOPTS)
  118. plt.scatter(b[1], b[0], **PLTOPTS)
  119. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  120. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  121. plt.imshow(im)
  122. plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"), bbox_inches="tight")
  123. plt.show()
  124. plt.close()
  125. if __name__ == "__main__":
  126. main()