123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- #!/usr/bin/env python3
- """Process an image with the trained neural network
- Usage:
- demo.py [options] <yaml-config> <checkpoint> <images>...
- demo.py (-h | --help )
- Arguments:
- <yaml-config> Path to the yaml hyper-parameter file
- <checkpoint> Path to the checkpoint
- <images> Path to images
- Options:
- -h --help Show this screen.
- -d --devices <devices> Comma seperated GPU devices [default: 0]
- """
- # 终端运行 python ./predict.py -d 0 config/wireframe.yaml <path-to-pretrained-pth> <path-to-image>
- import os
- import os.path as osp
- import pprint
- import random
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- import numpy as np
- import skimage.io
- import skimage.transform
- import torch
- import yaml
- from docopt import docopt
- import lcnn
- from lcnn.config import C, M
- from lcnn.models.line_vectorizer import LineVectorizer
- from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
- from lcnn.postprocess import postprocess
- from lcnn.utils import recursive_to
- from torchvision.utils import draw_bounding_boxes
- PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
- cmap = plt.get_cmap("jet")
- # # get_cmap 函数是 Matplotlib 中用于获取色彩映射对象的关键函数。它可以接受色彩映射的名称作为参数,返回相应的色彩映射对象。
- norm = mpl.colors.Normalize(vmin=0.9, vmax=1.0)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) # 颜色映射的颜色条
- sm.set_array([])
- def c(x):
- return sm.to_rgba(x)
- def main():
- args = docopt(__doc__)
- config = {
- # 数据集配置
- 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录
- 'config_file': 'config/wireframe.yaml', # 配置文件路径
- # GPU配置
- 'devices': '0', # 使用的GPU设备
- 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet
- }
- # 更新配置
- C.update(C.from_yaml(filename=config['config_file']))
- M.update(C.model)
- random.seed(0)
- np.random.seed(0)
- torch.manual_seed(0)
- # 设备配置
- device_name = "cpu"
- os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
- if torch.cuda.is_available():
- device_name = "cuda"
- torch.backends.cudnn.deterministic = True
- torch.cuda.manual_seed(0)
- print("Let's use", torch.cuda.device_count(), "GPU(s)!")
- else:
- print("CUDA is not available")
- device = torch.device(device_name)
- checkpoint = torch.load(args["<checkpoint>"], map_location=device) ##############
- # Load model # backbone
- model = lcnn.models.fasterrcnn_resnet50(
- # num_stacks=M.num_stacks,
- num_classes=sum(sum(M.head_size, [])),
- )
- # model = lcnn.models.hg(
- # depth=M.depth,
- # head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
- # num_stacks=M.num_stacks,
- # num_blocks=M.num_blocks,
- # num_classes=sum(sum(M.head_size, [])),
- # )
- model = MultitaskLearner(model)
- model = LineVectorizer(model)
- model.load_state_dict(checkpoint["model_state_dict"])
- model = model.to(device)
- model.eval()
- for imname in args["<images>"]:
- print(f"Processing {imname}")
- im = skimage.io.imread(imname)
- if im.ndim == 2:
- im = np.repeat(im[:, :, None], 3, 2)
- im = im[:, :, :3]
- im_resized = skimage.transform.resize(im, (512, 512)) * 255
- image = (im_resized - M.image.mean) / M.image.stddev
- image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
- with torch.no_grad():
- input_dict = {
- "image": image.to(device),
- "meta": [
- {
- "junc_coords": torch.zeros(1, 2).to(device),
- "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
- "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
- "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
- }
- ],
- "target": {
- "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
- "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
- },
- "target_b": None,
- "mode": "testing",
- }
- result = model(input_dict)
- # print(result)
- H = result["preds"]
- boxed_image = draw_bounding_boxes((image[0] * 255).to(torch.uint8), result["box"][0]["boxes"],
- colors="yellow", width=1)
- lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
- scores = H["score"][0].cpu().numpy()
- for i in range(1, len(lines)):
- if (lines[i] == lines[0]).all():
- lines = lines[:i]
- scores = scores[:i]
- break
- # postprocess lines to remove overlapped lines
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
- for i, t in enumerate([0.5, 0.95]):
- plt.gca().set_axis_off()
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
- plt.margins(0, 0)
- for (a, b), s in zip(nlines, nscores):
- if s < t:
- continue
- plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
- plt.scatter(a[1], a[0], **PLTOPTS)
- plt.scatter(b[1], b[0], **PLTOPTS)
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
- plt.imshow(im)
- plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"), bbox_inches="tight")
- plt.show()
- plt.close()
- if __name__ == "__main__":
- main()
|