demo.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. #!/usr/bin/env python3
  2. """Process an image with the trained neural network
  3. Usage:
  4. demo.py [options] <yaml-config> <checkpoint> <image>
  5. demo.py (-h | --help )
  6. Arguments:
  7. <yaml-config> Path to the yaml hyper-parameter file
  8. <checkpoint> Path to the checkpoint
  9. <image> Path to the directory containing processed images
  10. Options:
  11. -h --help Show this screen.
  12. -d --devices <devices> Comma seperated GPU devices [default: 0]
  13. """
  14. import os
  15. import os.path as osp
  16. import pprint
  17. import random
  18. import matplotlib as mpl
  19. import matplotlib.pyplot as plt
  20. import numpy as np
  21. import skimage.io
  22. import skimage.transform
  23. import torch
  24. import yaml
  25. from docopt import docopt
  26. import lcnn
  27. from lcnn.config import C, M
  28. from lcnn.models.line_vectorizer import LineVectorizer
  29. from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
  30. from lcnn.postprocess import postprocess
  31. from lcnn.utils import recursive_to
  32. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  33. cmap = plt.get_cmap("jet")
  34. norm = mpl.colors.Normalize(vmin=0.9, vmax=1.0)
  35. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  36. sm.set_array([])
  37. def c(x):
  38. return sm.to_rgba(x)
  39. def main():
  40. args = docopt(__doc__)
  41. config_file = args["<yaml-config>"] or "config/wireframe.yaml"
  42. C.update(C.from_yaml(filename=config_file))
  43. M.update(C.model)
  44. pprint.pprint(C, indent=4)
  45. random.seed(0)
  46. np.random.seed(0)
  47. torch.manual_seed(0)
  48. device_name = "cpu"
  49. os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
  50. if torch.cuda.is_available():
  51. device_name = "cuda"
  52. torch.backends.cudnn.deterministic = True
  53. torch.cuda.manual_seed(0)
  54. print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  55. else:
  56. print("CUDA is not available")
  57. device = torch.device(device_name)
  58. checkpoint = torch.load(args["<checkpoint>"], map_location=device)
  59. # Load model
  60. model = lcnn.models.hg(
  61. depth=M.depth,
  62. head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
  63. num_stacks=M.num_stacks,
  64. num_blocks=M.num_blocks,
  65. num_classes=sum(sum(M.head_size, [])),
  66. )
  67. model = MultitaskLearner(model)
  68. model = LineVectorizer(model)
  69. model.load_state_dict(checkpoint["model_state_dict"])
  70. model = model.to(device)
  71. model.eval()
  72. im = skimage.io.imread(args["<image>"])[:, :, :3]
  73. im_resized = skimage.transform.resize(im, (512, 512)) * 255
  74. image = (im_resized - M.image.mean) / M.image.stddev
  75. image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
  76. with torch.no_grad():
  77. input_dict = {
  78. "image": image.to(device),
  79. "meta": [
  80. {
  81. "junc": torch.zeros(1, 2).to(device),
  82. "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
  83. "Lpos": torch.zeros(2, 2, dtype=torch.uint8).to(device),
  84. "Lneg": torch.zeros(2, 2, dtype=torch.uint8).to(device),
  85. }
  86. ],
  87. "target": {
  88. "jmap": torch.zeros([1, 1, 128, 128]).to(device),
  89. "joff": torch.zeros([1, 1, 2, 128, 128]).to(device),
  90. },
  91. "do_evaluation": True,
  92. }
  93. H = model(input_dict)["preds"]
  94. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  95. scores = H["score"][0].cpu().numpy()
  96. for i in range(1, len(lines)):
  97. if (lines[i] == lines[0]).all():
  98. lines = lines[:i]
  99. scores = scores[:i]
  100. break
  101. # postprocess lines to remove overlapped lines
  102. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  103. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  104. plt.gca().set_axis_off()
  105. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  106. plt.margins(0, 0)
  107. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  108. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  109. for i, t in enumerate([0.95, 0.96, 0.97, 0.98, 0.99]):
  110. for (a, b), s in zip(nlines, nscores):
  111. if s < t:
  112. continue
  113. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  114. plt.scatter(a[1], a[0], **PLTOPTS)
  115. plt.scatter(b[1], b[0], **PLTOPTS)
  116. plt.imshow(im)
  117. plt.show()
  118. if __name__ == "__main__":
  119. main()