process.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. #!/usr/bin/env python3
  2. """Process a dataset with the trained neural network
  3. Usage:
  4. process.py [options] <yaml-config> <checkpoint> <image-dir> <output-dir>
  5. process.py (-h | --help )
  6. Arguments:
  7. <yaml-config> Path to the yaml hyper-parameter file
  8. <checkpoint> Path to the checkpoint
  9. <image-dir> Path to the directory containing processed images
  10. <output-dir> Path to the output directory
  11. Options:
  12. -h --help Show this screen.
  13. -d --devices <devices> Comma seperated GPU devices [default: 0]
  14. --plot Plot the result
  15. """
  16. import os
  17. import sys
  18. import shlex
  19. import pprint
  20. import random
  21. import os.path as osp
  22. import threading
  23. import subprocess
  24. import yaml
  25. import numpy as np
  26. import torch
  27. import matplotlib as mpl
  28. import skimage.io
  29. import matplotlib.pyplot as plt
  30. from docopt import docopt
  31. import lcnn
  32. from lcnn.utils import recursive_to
  33. from lcnn.config import C, M
  34. from lcnn.datasets import WireframeDataset, collate
  35. from lcnn.models.line_vectorizer import LineVectorizer
  36. from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
  37. def main():
  38. args = docopt(__doc__)
  39. config_file = args["<yaml-config>"] or "config/wireframe.yaml"
  40. C.update(C.from_yaml(filename=config_file))
  41. M.update(C.model)
  42. pprint.pprint(C, indent=4)
  43. random.seed(0)
  44. np.random.seed(0)
  45. torch.manual_seed(0)
  46. device_name = "cpu"
  47. os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
  48. if torch.cuda.is_available():
  49. device_name = "cuda"
  50. torch.backends.cudnn.deterministic = True
  51. torch.cuda.manual_seed(0)
  52. print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  53. else:
  54. print("CUDA is not available")
  55. device = torch.device(device_name)
  56. if M.backbone == "stacked_hourglass":
  57. model = lcnn.models.hg(
  58. depth=M.depth,
  59. head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
  60. num_stacks=M.num_stacks,
  61. num_blocks=M.num_blocks,
  62. num_classes=sum(sum(M.head_size, [])),
  63. )
  64. else:
  65. raise NotImplementedError
  66. checkpoint = torch.load(args["<checkpoint>"])
  67. model = MultitaskLearner(model)
  68. model = LineVectorizer(model)
  69. model.load_state_dict(checkpoint["model_state_dict"])
  70. model = model.to(device)
  71. model.eval()
  72. loader = torch.utils.data.DataLoader(
  73. WireframeDataset(args["<image-dir>"], split="valid"),
  74. shuffle=False,
  75. batch_size=M.batch_size,
  76. collate_fn=collate,
  77. num_workers=C.io.num_workers if os.name != "nt" else 0,
  78. pin_memory=True,
  79. )
  80. os.makedirs(args["<output-dir>"], exist_ok=True)
  81. for batch_idx, (image, meta, target) in enumerate(loader):
  82. with torch.no_grad():
  83. input_dict = {
  84. "image": recursive_to(image, device),
  85. "meta": recursive_to(meta, device),
  86. "target": recursive_to(target, device),
  87. "mode": "validation",
  88. }
  89. H = model(input_dict)["preds"]
  90. for i in range(M.batch_size):
  91. index = batch_idx * M.batch_size + i
  92. np.savez(
  93. osp.join(args["<output-dir>"], f"{index:06}.npz"),
  94. **{k: v[i].cpu().numpy() for k, v in H.items()},
  95. )
  96. if not args["--plot"]:
  97. continue
  98. im = image[i].cpu().numpy().transpose(1, 2, 0)
  99. im = im * M.image.stddev + M.image.mean
  100. lines = H["lines"][i].cpu().numpy() * 4
  101. scores = H["score"][i].cpu().numpy()
  102. if len(lines) > 0 and not (lines[0] == 0).all():
  103. for i, ((a, b), s) in enumerate(zip(lines, scores)):
  104. if i > 0 and (lines[i] == lines[0]).all():
  105. break
  106. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
  107. plt.show()
  108. cmap = plt.get_cmap("jet")
  109. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  110. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  111. sm.set_array([])
  112. def c(x):
  113. return sm.to_rgba(x)
  114. if __name__ == "__main__":
  115. main()