process.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #!/usr/bin/env python3
  2. """Process a dataset with the trained neural network
  3. Usage:
  4. train.py [options] <yaml-config> <checkpoint> <image-dir> <output-dir>
  5. train.py (-h | --help )
  6. Arguments:
  7. <yaml-config> Path to the yaml hyper-parameter file
  8. <checkpoint> Path to the checkpoint
  9. <image-dir> Path to the directory containing processed images
  10. <output-dir> Path to the output directory
  11. Options:
  12. -h --help Show this screen.
  13. -d --devices <devices> Comma seperated GPU devices [default: 0]
  14. --plot Plot the result
  15. """
  16. import os
  17. import sys
  18. import shlex
  19. import pprint
  20. import random
  21. import os.path as osp
  22. import threading
  23. import subprocess
  24. import numpy as np
  25. import torch
  26. import matplotlib as mpl
  27. import skimage.io
  28. import matplotlib.pyplot as plt
  29. from docopt import docopt
  30. import lcnn
  31. from lcnn.utils import recursive_to
  32. from lcnn.config import C, M
  33. from lcnn.datasets import WireframeDataset, collate
  34. from lcnn.models.line_vectorizer import LineVectorizer
  35. from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
  36. def main():
  37. args = docopt(__doc__)
  38. config_file = args["<yaml-config>"] or "config/wireframe.yaml"
  39. C.update(C.from_yaml(filename=config_file))
  40. M.update(C.model)
  41. pprint.pprint(C, indent=4)
  42. random.seed(0)
  43. np.random.seed(0)
  44. torch.manual_seed(0)
  45. device_name = "cpu"
  46. os.environ["CUDA_VISIBLE_DEVICES"] = args["--devices"]
  47. if torch.cuda.is_available():
  48. device_name = "cuda"
  49. torch.backends.cudnn.deterministic = True
  50. torch.cuda.manual_seed(0)
  51. print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  52. else:
  53. print("CUDA is not available")
  54. device = torch.device(device_name)
  55. if M.backbone == "stacked_hourglass":
  56. model = lcnn.models.hg(
  57. depth=M.depth,
  58. head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
  59. num_stacks=M.num_stacks,
  60. num_blocks=M.num_blocks,
  61. num_classes=sum(sum(M.head_size, [])),
  62. )
  63. else:
  64. raise NotImplementedError
  65. checkpoint = torch.load(args["<checkpoint>"])
  66. model = MultitaskLearner(model)
  67. model = LineVectorizer(model)
  68. model.load_state_dict(checkpoint["model_state_dict"])
  69. model = model.to(device)
  70. model.eval()
  71. loader = torch.utils.data.DataLoader(
  72. WireframeDataset(args["<image-dir>"], split="valid"),
  73. shuffle=False,
  74. batch_size=M.batch_size,
  75. collate_fn=collate,
  76. num_workers=C.io.num_workers,
  77. pin_memory=True,
  78. )
  79. os.makedirs(args["<output-dir>"], exist_ok=True)
  80. for batch_idx, (image, meta, target) in enumerate(loader):
  81. with torch.no_grad():
  82. input_dict = {
  83. "image": recursive_to(image, device),
  84. "meta": recursive_to(meta, device),
  85. "target": recursive_to(target, device),
  86. "do_evaluation": True,
  87. }
  88. H = model(input_dict)["heatmaps"]
  89. for i in range(M.batch_size):
  90. index = batch_idx * M.batch_size + i
  91. np.savez(
  92. osp.join(args["<output-dir>"], f"{index:06}.npz"),
  93. **{k: v[i].cpu().numpy() for k, v in H.items()},
  94. )
  95. if not args["--plot"]:
  96. continue
  97. im = image[i].cpu().numpy().transpose(1, 2, 0)
  98. im = im * M.image.stddev + M.image.mean
  99. lines = H["lines"][i].cpu().numpy() * 4
  100. scores = H["score"][i].cpu().numpy()
  101. if len(lines) > 0 and not (lines[0] == 0).all():
  102. for i, ((a, b), s) in enumerate(zip(lines, scores)):
  103. if i > 0 and (lines[i] == lines[0]).all():
  104. break
  105. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4)
  106. plt.show()
  107. cmap = plt.get_cmap("jet")
  108. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  109. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  110. sm.set_array([])
  111. def c(x):
  112. return sm.to_rgba(x)
  113. if __name__ == "__main__":
  114. main()