123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534 |
- import os
- import time
- import torch
- import matplotlib.pyplot as plt
- import numpy as np
- from torchvision import transforms
- from models.wirenet.postprocess import postprocess
- from datetime import datetime
- def box_line(pred):
- '''
- :param pred: 预测结果
- :return:
- box与line一一对应
- {'box': [0.0, 34.23157501220703, 151.70858764648438, 125.10173797607422], 'line': array([[ 1.9720564, 81.73457 ],
- [ 1.9933801, 41.730167 ]], dtype=float32)}
- '''
- box_line = [[] for _ in range((len(pred) - 1))]
- for idx, box_ in enumerate(pred[0:-1]):
- box = box_['boxes'] # 是一个tensor
- line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
- score = pred[-1]['wires']['score'][idx]
- for i in box:
- aaa = {}
- aaa['box'] = i.tolist()
- aaa['line'] = []
- score_max = 0.0
- for j in range(len(line)):
- if (line[j][0][0] >= i[0] and line[j][1][0] >= i[0] and line[j][0][0] <= i[2] and
- line[j][1][0] <= i[2] and line[j][0][1] >= i[1] and line[j][1][1] >= i[1] and
- line[j][0][1] <= i[3] and line[j][1][1] <= i[3]):
- if score[j] > score_max:
- aaa['line'] = line[j]
- score_max = score[j]
- box_line[idx].append(aaa)
- def box_line_(pred):
- for idx, box_ in enumerate(pred[0:-1]):
- box = box_['boxes'] # 是一个tensor
- line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
- score = pred[-1]['wires']['score'][idx]
- line_ = []
- score_ = []
- for i in box:
- score_max = 0.0
- tmp = [[0.0, 0.0], [0.0, 0.0]]
- for j in range(len(line)):
- if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
- line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
- line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
- line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
- if score[j] > score_max:
- tmp = line[j]
- score_max = score[j]
- line_.append(tmp)
- score_.append(score_max)
- processed_list = torch.tensor(line_)
- pred[idx]['line'] = processed_list
- processed_s_list = torch.tensor(score_)
- pred[idx]['line_score'] = processed_s_list
- return pred
- # box与line匹配后画在一张图上,不设置阈值,直接画
- def show_(imgs, pred, epoch, writer):
- col = [
- '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
- '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
- '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
- '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
- '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
- '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
- '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
- '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
- '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
- '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
- '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
- '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
- '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
- '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
- '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
- '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
- '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
- '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
- '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
- '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
- ]
- # print(len(col))
- im = imgs[0].permute(1, 2, 0)
- boxes = pred[0]['boxes'].cpu().numpy()
- line = pred[0]['line'].cpu().numpy()
- # 可视化预测结
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(np.array(im))
- for idx, box in enumerate(boxes):
- x0, y0, x1, y1 = box
- ax.add_patch(
- plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
- for idx, (a, b) in enumerate(line):
- ax.scatter(a[1], a[0], c='#871F78', s=2)
- ax.scatter(b[1], b[0], c='#871F78', s=2)
- ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
- # 将Matplotlib图像转换为Tensor
- fig.canvas.draw()
- image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
- fig.canvas.get_width_height()[::-1] + (3,))
- plt.close()
- img2 = transforms.ToTensor()(image_from_plot)
- writer.add_image("all", img2, epoch)
- # box与line匹配后画在一张图上,设置阈值
- def show_predict(imgs, pred, t_start):
- col = [
- '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
- '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
- '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
- '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
- '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
- '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
- '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
- '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
- '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
- '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
- '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
- '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
- '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
- '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
- '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
- '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
- '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
- '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
- '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
- '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
- ]
- print(len(col))
- im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3]
- boxes = pred[0]['boxes'].cpu().numpy()
- box_scores = pred[0]['scores'].cpu().numpy()
- lines = pred[0]['line'].cpu().numpy()
- line_scores = pred[0]['line_score'].cpu().numpy()
- # 可视化预测结
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(np.array(im))
- idx = 0
- tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
- for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
- x0, y0, x1, y1 = box
- # 框中无线的跳过
- if np.array_equal(line, tmp):
- continue
- a, b = line
- if box_score >= 0.7 or line_score >= 0.9:
- ax.add_patch(
- plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
- ax.scatter(a[1], a[0], c='#871F78', s=10)
- ax.scatter(b[1], b[0], c='#871F78', s=10)
- ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
- idx = idx + 1
- t_end = time.time()
- print(f'predict used:{t_end - t_start}')
- plt.show()
- # 下面的都没有进行box与line的一一匹配
- # 只画线,设阈值
- def show_line(imgs, pred, t_start):
- im = imgs.permute(1, 2, 0)
- line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
- # print(pred[-1]['wires']['score'])
- line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
- print(f'lines num:{len(line)}')
- #
- # count = np.sum(line_score > 0.9)
- # print(f'draw line number:{count}')
- # 可视化预测结
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(np.array(im))
- for idx, (a, b) in enumerate(line):
- # if line_score[idx] < 0.7:
- # continue
- ax.scatter(a[1], a[0], c='#871F78', s=2)
- ax.scatter(b[1], b[0], c='#871F78', s=2)
- ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
- t_end = time.time()
- print(f'show_line used:{t_end - t_start}')
- plt.show()
- # show_line优化
- def show_line_optimized(imgs, pred, t_start):
- im = imgs.permute(1, 2, 0).cpu().numpy()
- line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
- line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- nlines, nscores = postprocess(line_data, line_scores, diag * 0.01, 0, False)
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(im)
- for i, t in enumerate([0.9]):
- for (a, b), s in zip(nlines, nscores):
- if s < t:
- continue
- ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
- ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
- t_end = time.time()
- print(f'show_line_optimized used:{t_end - t_start}')
- plt.show()
- # 只画框,设阈值
- def show_box(imgs, pred, t_start):
- col = [
- '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
- '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
- '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
- '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
- '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
- '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
- '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
- '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
- '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
- '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
- '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
- '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
- '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
- '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
- '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
- '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
- '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
- '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
- '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
- '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
- ]
- # print(len(col))
- im = imgs.permute(1, 2, 0)
- boxes = pred[0]['boxes'].cpu().numpy()
- box_scores = pred[0]['scores'].cpu().numpy()
- # 可视化预测结
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(np.array(im))
- for idx, box in enumerate(boxes):
- if box_scores[idx] < 0.7:
- continue
- x0, y0, x1, y1 = box
- ax.add_patch(
- plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
- t_end = time.time()
- print(f'show_box used:{t_end - t_start}')
- plt.show()
- # 将show_line与show_box合并,传入参数确定显示框还是线 都不显示,输出原图
- # def show_box_or_line(imgs, pred, show_line=False, show_box=False):
- # col = [
- # '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
- # '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
- # '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
- # '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
- # '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
- # '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
- # '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
- # '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
- # '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
- # '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
- # '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
- # '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
- # '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
- # '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
- # '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
- # '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
- # '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
- # '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
- # '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
- # '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
- # ]
- # # print(len(col))
- # im = imgs.permute(1, 2, 0)
- # boxes = pred[0]['boxes'].cpu().numpy()
- # line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
- #
- # # 可视化预测结
- # fig, ax = plt.subplots(figsize=(10, 10))
- # ax.imshow(np.array(im))
- #
- # if show_box:
- # for idx, box in enumerate(boxes):
- # x0, y0, x1, y1 = box
- # ax.add_patch(
- # plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
- #
- # if show_line:
- # for idx, (a, b) in enumerate(line):
- # ax.scatter(a[1], a[0], c='#871F78', s=2)
- # ax.scatter(b[1], b[0], c='#871F78', s=2)
- # ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
- #
- # plt.show()
- # 将show_line与show_box合并,传入参数确定显示框还是线 一起画
- def show_box_and_line(imgs, pred, show_line=False, show_box=False):
- col = [
- '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
- '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
- '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
- '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
- '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
- '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
- '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
- '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
- '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
- '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
- '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
- '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
- '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
- '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
- '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
- '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
- '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
- '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
- '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
- '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
- ]
- # print(len(col))
- im = imgs.permute(1, 2, 0)
- boxes = pred[0]['boxes'].cpu().numpy()
- line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
- # 可视化预测结
- fig, axs = plt.subplots(1, 2, figsize=(10, 10))
- if show_box:
- axs[0].imshow(np.array(im))
- for idx, box in enumerate(boxes):
- x0, y0, x1, y1 = box
- axs[0].add_patch(
- plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
- axs[0].set_title('Boxes')
- if show_line:
- axs[1].imshow(np.array(im))
- for idx, (a, b) in enumerate(line):
- axs[1].scatter(a[1], a[0], c='#871F78', s=2)
- axs[1].scatter(b[1], b[0], c='#871F78', s=2)
- axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
- axs[1].set_title('Lines')
- # 调整子图之间的距离,防止标题和标签重叠
- plt.tight_layout()
- plt.show()
- def set_thresholds(threshold):
- if isinstance(threshold, list):
- if len(threshold) != 2:
- raise ValueError("Threshold list must contain exactly two elements.")
- a, b = threshold
- elif isinstance(threshold, (int, float)):
- a = b = threshold
- else:
- raise TypeError("Threshold must be either a list of two numbers or a single number.")
- return a, b
- def color():
- return [
- '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
- '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
- '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
- '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
- '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
- '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
- '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
- '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
- '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
- '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
- '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
- '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
- '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
- '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
- '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
- '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
- '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
- '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
- '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
- '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
- ]
- def show_all(imgs, pred, threshold, save_path, show):
- col = color()
- box_th, line_th = set_thresholds(threshold)
- im = imgs.permute(1, 2, 0)
- boxes = pred[0]['boxes'].cpu().numpy()
- box_scores = pred[0]['scores'].cpu().numpy()
- line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
- line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
- fig, axs = plt.subplots(1, 3, figsize=(10, 10))
- axs[0].imshow(np.array(im))
- for idx, box in enumerate(boxes):
- if box_scores[idx] < box_th:
- continue
- x0, y0, x1, y1 = box
- axs[0].add_patch(
- plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
- axs[0].set_title('Boxes')
- axs[1].imshow(np.array(im))
- for idx, (a, b) in enumerate(line):
- if line_score[idx] < line_th:
- continue
- axs[1].scatter(a[1], a[0], c='#871F78', s=2)
- axs[1].scatter(b[1], b[0], c='#871F78', s=2)
- axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
- axs[1].set_title('Lines')
- axs[2].imshow(np.array(im))
- lines = pred[0]['line'].cpu().numpy()
- line_scores = pred[0]['line_score'].cpu().numpy()
- idx = 0
- tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
- for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
- x0, y0, x1, y1 = box
- # 框中无线的跳过
- if np.array_equal(line, tmp):
- continue
- a, b = line
- if box_score >= 0.7 or line_score >= 0.9:
- axs[2].add_patch(
- plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
- axs[2].scatter(a[1], a[0], c='#871F78', s=10)
- axs[2].scatter(b[1], b[0], c='#871F78', s=10)
- axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
- idx = idx + 1
- axs[2].set_title('Boxes and Lines')
- if save_path:
- save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png')
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- plt.savefig(save_path)
- print(f"Saved result image to {save_path}")
- if show:
- # 调整子图之间的距离,防止标题和标签重叠
- plt.tight_layout()
- plt.show()
- def show_box_or_line(imgs, pred, threshold, save_path = None, show_line=False, show_box=False):
- col = color()
- box_th, line_th = set_thresholds(threshold)
- im = imgs.permute(1, 2, 0)
- boxes = pred[0]['boxes'].cpu().numpy()
- box_scores = pred[0]['scores'].cpu().numpy()
- line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
- line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
- # 可视化预测结
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(np.array(im))
- if show_box:
- for idx, box in enumerate(boxes):
- if box_scores[idx] < box_th:
- continue
- x0, y0, x1, y1 = box
- ax.add_patch(
- plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
- if save_path:
- save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png')
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- plt.savefig(save_path)
- print(f"Saved result image to {save_path}")
- if show_line:
- for idx, (a, b) in enumerate(line):
- if line_score[idx] < line_th:
- continue
- ax.scatter(a[1], a[0], c='#871F78', s=2)
- ax.scatter(b[1], b[0], c='#871F78', s=2)
- ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
- if save_path:
- save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png')
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- plt.savefig(save_path)
- print(f"Saved result image to {save_path}")
- plt.show()
|