|
@@ -0,0 +1,173 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+"""Process an image with the trained neural network
|
|
|
+Usage:
|
|
|
+ demo.py [options] <yaml-config> <checkpoint> <images>...
|
|
|
+ demo.py (-h | --help )
|
|
|
+
|
|
|
+Arguments:
|
|
|
+ <yaml-config> Path to the yaml hyper-parameter file
|
|
|
+ <checkpoint> Path to the checkpoint
|
|
|
+ <images> Path to images
|
|
|
+
|
|
|
+Options:
|
|
|
+ -h --help Show this screen.
|
|
|
+ -d --devices <devices> Comma seperated GPU devices [default: 0]
|
|
|
+"""
|
|
|
+
|
|
|
+# 终端运行 python ./predict.py -d 0 config/wireframe.yaml <path-to-pretrained-pth> <path-to-image>
|
|
|
+
|
|
|
+import os
|
|
|
+import os.path as osp
|
|
|
+import pprint
|
|
|
+import random
|
|
|
+
|
|
|
+import matplotlib as mpl
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import numpy as np
|
|
|
+import skimage.io
|
|
|
+import skimage.transform
|
|
|
+import torch
|
|
|
+import yaml
|
|
|
+from docopt import docopt
|
|
|
+
|
|
|
+import lcnn
|
|
|
+from lcnn.config import C, M
|
|
|
+from lcnn.models.line_vectorizer import LineVectorizer
|
|
|
+from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner
|
|
|
+from lcnn.postprocess import postprocess
|
|
|
+from lcnn.utils import recursive_to
|
|
|
+from torchvision.utils import draw_bounding_boxes
|
|
|
+
|
|
|
+PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
|
|
|
+cmap = plt.get_cmap("jet")
|
|
|
+# # get_cmap 函数是 Matplotlib 中用于获取色彩映射对象的关键函数。它可以接受色彩映射的名称作为参数,返回相应的色彩映射对象。
|
|
|
+norm = mpl.colors.Normalize(vmin=0.9, vmax=1.0)
|
|
|
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) # 颜色映射的颜色条
|
|
|
+sm.set_array([])
|
|
|
+
|
|
|
+
|
|
|
+def c(x):
|
|
|
+ return sm.to_rgba(x)
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ args = docopt(__doc__)
|
|
|
+ config = {
|
|
|
+ # 数据集配置
|
|
|
+ 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录
|
|
|
+ 'config_file': 'config/wireframe.yaml', # 配置文件路径
|
|
|
+
|
|
|
+ # GPU配置
|
|
|
+ 'devices': '0', # 使用的GPU设备
|
|
|
+ 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ # 更新配置
|
|
|
+ C.update(C.from_yaml(filename=config['config_file']))
|
|
|
+ M.update(C.model)
|
|
|
+
|
|
|
+ random.seed(0)
|
|
|
+ np.random.seed(0)
|
|
|
+ torch.manual_seed(0)
|
|
|
+
|
|
|
+ # 设备配置
|
|
|
+ device_name = "cpu"
|
|
|
+ os.environ["CUDA_VISIBLE_DEVICES"] = config['devices']
|
|
|
+
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ device_name = "cuda"
|
|
|
+ torch.backends.cudnn.deterministic = True
|
|
|
+ torch.cuda.manual_seed(0)
|
|
|
+ print("Let's use", torch.cuda.device_count(), "GPU(s)!")
|
|
|
+ else:
|
|
|
+ print("CUDA is not available")
|
|
|
+
|
|
|
+ device = torch.device(device_name)
|
|
|
+
|
|
|
+ checkpoint = torch.load(args["<checkpoint>"], map_location=device) ##############
|
|
|
+
|
|
|
+ # Load model # backbone
|
|
|
+ model = lcnn.models.fasterrcnn_resnet50(
|
|
|
+ # num_stacks=M.num_stacks,
|
|
|
+ num_classes=sum(sum(M.head_size, [])),
|
|
|
+ )
|
|
|
+ # model = lcnn.models.hg(
|
|
|
+ # depth=M.depth,
|
|
|
+ # head=lambda c_in, c_out: MultitaskHead(c_in, c_out),
|
|
|
+ # num_stacks=M.num_stacks,
|
|
|
+ # num_blocks=M.num_blocks,
|
|
|
+ # num_classes=sum(sum(M.head_size, [])),
|
|
|
+ # )
|
|
|
+ model = MultitaskLearner(model)
|
|
|
+ model = LineVectorizer(model)
|
|
|
+ model.load_state_dict(checkpoint["model_state_dict"])
|
|
|
+ model = model.to(device)
|
|
|
+ model.eval()
|
|
|
+
|
|
|
+ for imname in args["<images>"]:
|
|
|
+ print(f"Processing {imname}")
|
|
|
+ im = skimage.io.imread(imname)
|
|
|
+ if im.ndim == 2:
|
|
|
+ im = np.repeat(im[:, :, None], 3, 2)
|
|
|
+ im = im[:, :, :3]
|
|
|
+ im_resized = skimage.transform.resize(im, (512, 512)) * 255
|
|
|
+ image = (im_resized - M.image.mean) / M.image.stddev
|
|
|
+ image = torch.from_numpy(np.rollaxis(image, 2)[None].copy()).float()
|
|
|
+ with torch.no_grad():
|
|
|
+ input_dict = {
|
|
|
+ "image": image.to(device),
|
|
|
+ "meta": [
|
|
|
+ {
|
|
|
+ "junc_coords": torch.zeros(1, 2).to(device),
|
|
|
+ "jtyp": torch.zeros(1, dtype=torch.uint8).to(device),
|
|
|
+ "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
+ "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device),
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "target": {
|
|
|
+ "junc_map": torch.zeros([1, 1, 128, 128]).to(device),
|
|
|
+ "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device),
|
|
|
+ },
|
|
|
+ "target_b": None,
|
|
|
+ "mode": "testing",
|
|
|
+ }
|
|
|
+ result = model(input_dict)
|
|
|
+ # print(result)
|
|
|
+ H = result["preds"]
|
|
|
+
|
|
|
+ boxed_image = draw_bounding_boxes((image[0] * 255).to(torch.uint8), result["box"][0]["boxes"],
|
|
|
+ colors="yellow", width=1)
|
|
|
+
|
|
|
+ lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
|
|
|
+ scores = H["score"][0].cpu().numpy()
|
|
|
+ for i in range(1, len(lines)):
|
|
|
+ if (lines[i] == lines[0]).all():
|
|
|
+ lines = lines[:i]
|
|
|
+ scores = scores[:i]
|
|
|
+ break
|
|
|
+
|
|
|
+ # postprocess lines to remove overlapped lines
|
|
|
+ diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
+ nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
|
|
|
+
|
|
|
+ for i, t in enumerate([0.5, 0.95]):
|
|
|
+ plt.gca().set_axis_off()
|
|
|
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
|
+ plt.margins(0, 0)
|
|
|
+ for (a, b), s in zip(nlines, nscores):
|
|
|
+ if s < t:
|
|
|
+ continue
|
|
|
+ plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
|
|
|
+ plt.scatter(a[1], a[0], **PLTOPTS)
|
|
|
+ plt.scatter(b[1], b[0], **PLTOPTS)
|
|
|
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
|
|
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
|
|
+ plt.imshow(im)
|
|
|
+ plt.savefig(imname.replace(".png", f"-{t:.02f}.svg"), bbox_inches="tight")
|
|
|
+ plt.show()
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|