predict.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #!/usr/bin/env python3
  2. """Process an image with the trained neural network
  3. Usage:
  4. demo.py [options] <yaml-config> <checkpoint> <images>...
  5. demo.py (-h | --help )
  6. Arguments:
  7. <yaml-config> Path to the yaml hyper-parameter file
  8. <checkpoint> Path to the checkpoint
  9. <images> Path to images
  10. Options:
  11. -h --help Show this screen.
  12. -d --devices <devices> Comma seperated GPU devices [default: 0]
  13. """
  14. # 终端运行 python ./predict.py -d 0 config/wireframe.yaml <path-to-pretrained-pth> <path-to-image>
  15. import os
  16. import os.path as osp
  17. import pprint
  18. import random
  19. import matplotlib as mpl
  20. import matplotlib.pyplot as plt
  21. import numpy as np
  22. import skimage.io
  23. import skimage.transform
  24. import torch
  25. import yaml
  26. from docopt import docopt
  27. import lcnn
  28. from lcnn.config import C, M
  29. from lcnn.models.line_vectorizer import LineVectorizer
  30. from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
  31. from lcnn.postprocess import postprocess
  32. from lcnn.utils import recursive_to
  33. from torchvision.utils import draw_bounding_boxes
  34. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  35. cmap = plt.get_cmap("jet")
  36. # # get_cmap 函数是 Matplotlib 中用于获取色彩映射对象的关键函数。它可以接受色彩映射的名称作为参数,返回相应的色彩映射对象。
  37. norm = mpl.colors.Normalize(vmin=0.9, vmax=1.0)
  38. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) # 颜色映射的颜色条
  39. sm.set_array([])
  40. def c(x):
  41. return sm.to_rgba(x)
  42. def main():
  43. args = docopt(__doc__)
  44. config = {
  45. # 数据集配置
  46. 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录
  47. 'config_file': 'config/wireframe.yaml', # 配置文件路径
  48. # GPU配置
  49. 'devices': '0', # 使用的GPU设备
  50. 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet
  51. }
  52. # 更新配置
  53. C.update(C.from_yaml(filename=config['config_file']))
  54. M.update(C.model)
  55. random.seed(0)
  56. np.random.seed(0)
  57. torch.manual_seed(0)
  58. # 设备配置
  59. device_name = "cpu"
  60. os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
  61. if torch.cuda.is_available():
  62. device_name = "cuda"
  63. torch.backends.cudnn.deterministic = True
  64. torch.cuda.manual_seed(0)
  65. print("Let's use", torch.cuda.device_count(), "GPU(s)!")
  66. else:
  67. print("CUDA is not available")
  68. device = torch.device(device_name)
  69. checkpoint = torch.load(args["<checkpoint>"], map_location=device) ##############
  70. # Load model # backbone
  71. model = lcnn.models.fasterrcnn_resnet50(
  72. # num_stacks=M.num_stacks,
  73. num_classes=sum(sum(M.head_size, [])),
  74. )
  75. # model = lcnn.models.hg(
  76. # depth=M.depth,
  77. # head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
  78. # num_stacks=M.num_stacks,
  79. # num_blocks=M.num_blocks,
  80. # num_classes=sum(sum(M.head_size, [])),
  81. # )
  82. model = MultitaskLearner(model)
  83. model = LineVectorizer(model)
  84. model.load_state_dict(checkpoint["model_state_dict"])
  85. model = model.to(device)
  86. model.eval()
  87. for imname in args["<images>"]:
  88. print(f"Processing {imname}")
  89. im = skimage.io.imread(imname)
  90. if im.ndim == 2:
  91. im = np.repeat(im[:, :, None], 3, 2)
  92. im = im[:, :, :3]
  93. im_resized = skimage.transform.resize(im, (512, 512)) * 255
  94. image = (im_resized - M.image.mean) / M.image.stddev
  95. image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
  96. with torch.no_grad():
  97. input_dict = {
  98. "image": image.to(device),
  99. "meta": [
  100. {
  101. "junc_coords": torch.zeros(1, 2).to(device),
  102. "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
  103. "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
  104. "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
  105. }
  106. ],
  107. "target": {
  108. "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
  109. "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
  110. },
  111. "target_b": None,
  112. "mode": "testing",
  113. }
  114. result = model(input_dict)
  115. # print(result)
  116. H = result["preds"]
  117. boxed_image = draw_bounding_boxes((image[0] * 255).to(torch.uint8), result["box"][0]["boxes"],
  118. colors="yellow", width=1)
  119. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  120. scores = H["score"][0].cpu().numpy()
  121. for i in range(1, len(lines)):
  122. if (lines[i] == lines[0]).all():
  123. lines = lines[:i]
  124. scores = scores[:i]
  125. break
  126. # postprocess lines to remove overlapped lines
  127. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  128. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  129. for i, t in enumerate([0.5, 0.95]):
  130. plt.gca().set_axis_off()
  131. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  132. plt.margins(0, 0)
  133. for (a, b), s in zip(nlines, nscores):
  134. if s < t:
  135. continue
  136. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  137. plt.scatter(a[1], a[0], **PLTOPTS)
  138. plt.scatter(b[1], b[0], **PLTOPTS)
  139. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  140. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  141. plt.imshow(im)
  142. plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"), bbox_inches="tight")
  143. plt.show()
  144. plt.close()
  145. if __name__ == "__main__":
  146. main()