#!/usr/bin/env python3 """Process an image with the trained neural network Usage: demo.py [options] ... demo.py (-h | --help ) Arguments: Path to the yaml hyper-parameter file Path to the checkpoint Path to images Options: -h --help Show this screen. -d --devices Comma seperated GPU devices [default: 0] """ # 终端运行 python ./predict.py -d 0 config/wireframe.yaml import os import os.path as osp import pprint import random import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import skimage.io import skimage.transform import torch import yaml from docopt import docopt import lcnn from lcnn.config import C, M from lcnn.models.line_vectorizer import LineVectorizer from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner from lcnn.postprocess import postprocess from lcnn.utils import recursive_to from torchvision.utils import draw_bounding_boxes PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} cmap = plt.get_cmap("jet") # # get_cmap 函数是 Matplotlib 中用于获取色彩映射对象的关键函数。它可以接受色彩映射的名称作为参数,返回相应的色彩映射对象。 norm = mpl.colors.Normalize(vmin=0.9, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) # 颜色映射的颜色条 sm.set_array([]) def c(x): return sm.to_rgba(x) def main(): args = docopt(__doc__) config = { # 数据集配置 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录 'config_file': 'config/wireframe.yaml', # 配置文件路径 # GPU配置 'devices': '0', # 使用的GPU设备 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet } # 更新配置 C.update(C.from_yaml(filename=config['config_file'])) M.update(C.model) random.seed(0) np.random.seed(0) torch.manual_seed(0) # 设备配置 device_name = "cpu" os.environ["CUDA_VISIBLE_DEVICES"] = config['devices'] if torch.cuda.is_available(): device_name = "cuda" torch.backends.cudnn.deterministic = True torch.cuda.manual_seed(0) print("Let's use", torch.cuda.device_count(), "GPU(s)!") else: print("CUDA is not available") device = torch.device(device_name) checkpoint = torch.load(args[""], map_location=device) ############## # Load model # backbone model = lcnn.models.fasterrcnn_resnet50( # num_stacks=M.num_stacks, num_classes=sum(sum(M.head_size, [])), ) # model = lcnn.models.hg( # depth=M.depth, # head=lambda c_in, c_out: MultitaskHead(c_in, c_out), # num_stacks=M.num_stacks, # num_blocks=M.num_blocks, # num_classes=sum(sum(M.head_size, [])), # ) model = MultitaskLearner(model) model = LineVectorizer(model) model.load_state_dict(checkpoint["model_state_dict"]) model = model.to(device) model.eval() for imname in args[""]: print(f"Processing {imname}") im = skimage.io.imread(imname) if im.ndim == 2: im = np.repeat(im[:, :, None], 3, 2) im = im[:, :, :3] im_resized = skimage.transform.resize(im, (512, 512)) * 255 image = (im_resized - M.image.mean) / M.image.stddev image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float() with torch.no_grad(): input_dict = { "image": image.to(device), "meta": [ { "junc_coords": torch.zeros(1, 2).to(device), "jtyp": torch.zeros(1, dtype=torch.uint8).to(device), "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device), "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device), } ], "target": { "junc_map": torch.zeros([1, 1, 128, 128]).to(device), "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device), }, "target_b": None, "mode": "testing", } result = model(input_dict) # print(result) H = result["preds"] boxed_image = draw_bounding_boxes((image[0] * 255).to(torch.uint8), result["box"][0]["boxes"], colors="yellow", width=1) lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2] scores = H["score"][0].cpu().numpy() for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break # postprocess lines to remove overlapped lines diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) for i, t in enumerate([0.5, 0.95]): plt.gca().set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) for (a, b), s in zip(nlines, nscores): if s < t: continue plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) plt.scatter(a[1], a[0], **PLTOPTS) plt.scatter(b[1], b[0], **PLTOPTS) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.imshow(im) plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"), bbox_inches="tight") plt.show() plt.close() if __name__ == "__main__": main()