|
@@ -1,3 +1,4 @@
|
|
|
|
+import os
|
|
import time
|
|
import time
|
|
|
|
|
|
import torch
|
|
import torch
|
|
@@ -7,6 +8,8 @@ from torchvision import transforms
|
|
|
|
|
|
from models.wirenet.postprocess import postprocess
|
|
from models.wirenet.postprocess import postprocess
|
|
|
|
|
|
|
|
+from datetime import datetime
|
|
|
|
+
|
|
|
|
|
|
def box_line(pred):
|
|
def box_line(pred):
|
|
'''
|
|
'''
|
|
@@ -188,7 +191,8 @@ def show_line(imgs, pred, t_start):
|
|
|
|
|
|
diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
|
|
line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
|
|
-
|
|
|
|
|
|
+ print(f'lines num:{len(line)}')
|
|
|
|
+ #
|
|
# count = np.sum(line_score > 0.9)
|
|
# count = np.sum(line_score > 0.9)
|
|
# print(f'draw line number:{count}')
|
|
# print(f'draw line number:{count}')
|
|
|
|
|
|
@@ -197,8 +201,8 @@ def show_line(imgs, pred, t_start):
|
|
ax.imshow(np.array(im))
|
|
ax.imshow(np.array(im))
|
|
|
|
|
|
for idx, (a, b) in enumerate(line):
|
|
for idx, (a, b) in enumerate(line):
|
|
- if line_score[idx] < 0.9:
|
|
|
|
- continue
|
|
|
|
|
|
+ # if line_score[idx] < 0.7:
|
|
|
|
+ # continue
|
|
ax.scatter(a[1], a[0], c='#871F78', s=2)
|
|
ax.scatter(a[1], a[0], c='#871F78', s=2)
|
|
ax.scatter(b[1], b[0], c='#871F78', s=2)
|
|
ax.scatter(b[1], b[0], c='#871F78', s=2)
|
|
ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
|
|
ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
|
|
@@ -281,7 +285,55 @@ def show_box(imgs, pred, t_start):
|
|
|
|
|
|
|
|
|
|
# 将show_line与show_box合并,传入参数确定显示框还是线 都不显示,输出原图
|
|
# 将show_line与show_box合并,传入参数确定显示框还是线 都不显示,输出原图
|
|
-def show_box_or_line(imgs, pred, show_line=False, show_box=False):
|
|
|
|
|
|
+# def show_box_or_line(imgs, pred, show_line=False, show_box=False):
|
|
|
|
+# col = [
|
|
|
|
+# '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
|
|
|
|
+# '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
|
|
|
|
+# '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
|
|
|
|
+# '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
|
|
|
|
+# '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
|
|
|
|
+# '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
|
|
|
|
+# '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
|
|
|
|
+# '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
|
|
|
|
+# '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
|
|
|
|
+# '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
|
|
|
|
+# '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
|
|
|
|
+# '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
|
|
|
|
+# '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
|
|
|
|
+# '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
|
|
|
|
+# '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
|
|
|
|
+# '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
|
|
|
|
+# '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
|
|
|
|
+# '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
|
|
|
|
+# '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
|
|
|
|
+# '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
|
|
|
|
+# ]
|
|
|
|
+# # print(len(col))
|
|
|
|
+# im = imgs.permute(1, 2, 0)
|
|
|
|
+# boxes = pred[0]['boxes'].cpu().numpy()
|
|
|
|
+# line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
|
|
|
|
+#
|
|
|
|
+# # 可视化预测结
|
|
|
|
+# fig, ax = plt.subplots(figsize=(10, 10))
|
|
|
|
+# ax.imshow(np.array(im))
|
|
|
|
+#
|
|
|
|
+# if show_box:
|
|
|
|
+# for idx, box in enumerate(boxes):
|
|
|
|
+# x0, y0, x1, y1 = box
|
|
|
|
+# ax.add_patch(
|
|
|
|
+# plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
|
|
+#
|
|
|
|
+# if show_line:
|
|
|
|
+# for idx, (a, b) in enumerate(line):
|
|
|
|
+# ax.scatter(a[1], a[0], c='#871F78', s=2)
|
|
|
|
+# ax.scatter(b[1], b[0], c='#871F78', s=2)
|
|
|
|
+# ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
|
|
|
|
+#
|
|
|
|
+# plt.show()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# 将show_line与show_box合并,传入参数确定显示框还是线 一起画
|
|
|
|
+def show_box_and_line(imgs, pred, show_line=False, show_box=False):
|
|
col = [
|
|
col = [
|
|
'#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
|
|
'#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
|
|
'#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
|
|
'#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
|
|
@@ -310,27 +362,43 @@ def show_box_or_line(imgs, pred, show_line=False, show_box=False):
|
|
line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
|
|
line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
|
|
|
|
|
|
# 可视化预测结
|
|
# 可视化预测结
|
|
- fig, ax = plt.subplots(figsize=(10, 10))
|
|
|
|
- ax.imshow(np.array(im))
|
|
|
|
|
|
+ fig, axs = plt.subplots(1, 2, figsize=(10, 10))
|
|
|
|
|
|
if show_box:
|
|
if show_box:
|
|
|
|
+ axs[0].imshow(np.array(im))
|
|
for idx, box in enumerate(boxes):
|
|
for idx, box in enumerate(boxes):
|
|
x0, y0, x1, y1 = box
|
|
x0, y0, x1, y1 = box
|
|
- ax.add_patch(
|
|
|
|
|
|
+ axs[0].add_patch(
|
|
plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
|
|
+ axs[0].set_title('Boxes')
|
|
|
|
|
|
if show_line:
|
|
if show_line:
|
|
|
|
+ axs[1].imshow(np.array(im))
|
|
for idx, (a, b) in enumerate(line):
|
|
for idx, (a, b) in enumerate(line):
|
|
- ax.scatter(a[1], a[0], c='#871F78', s=2)
|
|
|
|
- ax.scatter(b[1], b[0], c='#871F78', s=2)
|
|
|
|
- ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
|
|
|
|
|
|
+ axs[1].scatter(a[1], a[0], c='#871F78', s=2)
|
|
|
|
+ axs[1].scatter(b[1], b[0], c='#871F78', s=2)
|
|
|
|
+ axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
|
|
|
|
+ axs[1].set_title('Lines')
|
|
|
|
|
|
|
|
+ # 调整子图之间的距离,防止标题和标签重叠
|
|
|
|
+ plt.tight_layout()
|
|
plt.show()
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
-# 将show_line与show_box合并,传入参数确定显示框还是线 一起画
|
|
|
|
-def show_box_and_line(imgs, pred, show_line=False, show_box=False):
|
|
|
|
- col = [
|
|
|
|
|
|
+def set_thresholds(threshold):
|
|
|
|
+ if isinstance(threshold, list):
|
|
|
|
+ if len(threshold) != 2:
|
|
|
|
+ raise ValueError("Threshold list must contain exactly two elements.")
|
|
|
|
+ a, b = threshold
|
|
|
|
+ elif isinstance(threshold, (int, float)):
|
|
|
|
+ a = b = threshold
|
|
|
|
+ else:
|
|
|
|
+ raise TypeError("Threshold must be either a list of two numbers or a single number.")
|
|
|
|
+
|
|
|
|
+ return a, b
|
|
|
|
+
|
|
|
|
+def color():
|
|
|
|
+ return [
|
|
'#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
|
|
'#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
|
|
'#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
|
|
'#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
|
|
'#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
|
|
'#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
|
|
@@ -352,30 +420,115 @@ def show_box_and_line(imgs, pred, show_line=False, show_box=False):
|
|
'#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
|
|
'#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
|
|
'#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
|
|
'#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
|
|
]
|
|
]
|
|
- # print(len(col))
|
|
|
|
|
|
+
|
|
|
|
+def show_all(imgs, pred, threshold, save_path, show):
|
|
|
|
+ col = color()
|
|
|
|
+ box_th, line_th = set_thresholds(threshold)
|
|
im = imgs.permute(1, 2, 0)
|
|
im = imgs.permute(1, 2, 0)
|
|
|
|
+
|
|
boxes = pred[0]['boxes'].cpu().numpy()
|
|
boxes = pred[0]['boxes'].cpu().numpy()
|
|
|
|
+ box_scores = pred[0]['scores'].cpu().numpy()
|
|
line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
|
|
line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
|
|
|
|
+ line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
|
|
|
|
+
|
|
|
|
+ diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
|
+ line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
|
|
|
|
+
|
|
|
|
+ fig, axs = plt.subplots(1, 3, figsize=(10, 10))
|
|
|
|
+
|
|
|
|
+ axs[0].imshow(np.array(im))
|
|
|
|
+ for idx, box in enumerate(boxes):
|
|
|
|
+ if box_scores[idx] < box_th:
|
|
|
|
+ continue
|
|
|
|
+ x0, y0, x1, y1 = box
|
|
|
|
+ axs[0].add_patch(
|
|
|
|
+ plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
|
|
+ axs[0].set_title('Boxes')
|
|
|
|
+
|
|
|
|
+ axs[1].imshow(np.array(im))
|
|
|
|
+ for idx, (a, b) in enumerate(line):
|
|
|
|
+ if line_score[idx] < line_th:
|
|
|
|
+ continue
|
|
|
|
+ axs[1].scatter(a[1], a[0], c='#871F78', s=2)
|
|
|
|
+ axs[1].scatter(b[1], b[0], c='#871F78', s=2)
|
|
|
|
+ axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
|
|
|
|
+ axs[1].set_title('Lines')
|
|
|
|
+
|
|
|
|
+ axs[2].imshow(np.array(im))
|
|
|
|
+ lines = pred[0]['line'].cpu().numpy()
|
|
|
|
+ line_scores = pred[0]['line_score'].cpu().numpy()
|
|
|
|
+ idx = 0
|
|
|
|
+ tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
|
|
|
|
+ for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
|
|
|
|
+ x0, y0, x1, y1 = box
|
|
|
|
+ # 框中无线的跳过
|
|
|
|
+ if np.array_equal(line, tmp):
|
|
|
|
+ continue
|
|
|
|
+ a, b = line
|
|
|
|
+ if box_score >= 0.7 or line_score >= 0.9:
|
|
|
|
+ axs[2].add_patch(
|
|
|
|
+ plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
|
|
+ axs[2].scatter(a[1], a[0], c='#871F78', s=10)
|
|
|
|
+ axs[2].scatter(b[1], b[0], c='#871F78', s=10)
|
|
|
|
+ axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
|
|
|
|
+ idx = idx + 1
|
|
|
|
+ axs[2].set_title('Boxes and Lines')
|
|
|
|
+
|
|
|
|
+ if save_path:
|
|
|
|
+ save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png')
|
|
|
|
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
|
+
|
|
|
|
+ plt.savefig(save_path)
|
|
|
|
+ print(f"Saved result image to {save_path}")
|
|
|
|
+
|
|
|
|
+ if show:
|
|
|
|
+ # 调整子图之间的距离,防止标题和标签重叠
|
|
|
|
+ plt.tight_layout()
|
|
|
|
+ plt.show()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def show_box_or_line(imgs, pred, threshold, save_path = None, show_line=False, show_box=False):
|
|
|
|
+ col = color()
|
|
|
|
+ box_th, line_th = set_thresholds(threshold)
|
|
|
|
+ im = imgs.permute(1, 2, 0)
|
|
|
|
+
|
|
|
|
+ boxes = pred[0]['boxes'].cpu().numpy()
|
|
|
|
+ box_scores = pred[0]['scores'].cpu().numpy()
|
|
|
|
+ line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
|
|
|
|
+ line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
|
|
|
|
|
|
# 可视化预测结
|
|
# 可视化预测结
|
|
- fig, axs = plt.subplots(1, 2, figsize=(10, 10))
|
|
|
|
|
|
+ fig, ax = plt.subplots(figsize=(10, 10))
|
|
|
|
+ ax.imshow(np.array(im))
|
|
|
|
|
|
if show_box:
|
|
if show_box:
|
|
- axs[0].imshow(np.array(im))
|
|
|
|
for idx, box in enumerate(boxes):
|
|
for idx, box in enumerate(boxes):
|
|
|
|
+ if box_scores[idx] < box_th:
|
|
|
|
+ continue
|
|
x0, y0, x1, y1 = box
|
|
x0, y0, x1, y1 = box
|
|
- axs[0].add_patch(
|
|
|
|
|
|
+ ax.add_patch(
|
|
plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
- axs[0].set_title('Boxes')
|
|
|
|
|
|
+ if save_path:
|
|
|
|
+ save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png')
|
|
|
|
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
|
+
|
|
|
|
+ plt.savefig(save_path)
|
|
|
|
+ print(f"Saved result image to {save_path}")
|
|
|
|
+
|
|
|
|
|
|
if show_line:
|
|
if show_line:
|
|
- axs[1].imshow(np.array(im))
|
|
|
|
for idx, (a, b) in enumerate(line):
|
|
for idx, (a, b) in enumerate(line):
|
|
- axs[1].scatter(a[1], a[0], c='#871F78', s=2)
|
|
|
|
- axs[1].scatter(b[1], b[0], c='#871F78', s=2)
|
|
|
|
- axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
|
|
|
|
- axs[1].set_title('Lines')
|
|
|
|
|
|
+ if line_score[idx] < line_th:
|
|
|
|
+ continue
|
|
|
|
+ ax.scatter(a[1], a[0], c='#871F78', s=2)
|
|
|
|
+ ax.scatter(b[1], b[0], c='#871F78', s=2)
|
|
|
|
+ ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
|
|
|
|
+ if save_path:
|
|
|
|
+ save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png')
|
|
|
|
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
|
|
|
- # 调整子图之间的距离,防止标题和标签重叠
|
|
|
|
- plt.tight_layout()
|
|
|
|
- plt.show()
|
|
|
|
|
|
+ plt.savefig(save_path)
|
|
|
|
+ print(f"Saved result image to {save_path}")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ plt.show()
|