|
@@ -1,8 +1,12 @@
|
|
|
+import time
|
|
|
+
|
|
|
import torch
|
|
|
import matplotlib.pyplot as plt
|
|
|
import numpy as np
|
|
|
from torchvision import transforms
|
|
|
|
|
|
+from models.wirenet.postprocess import postprocess
|
|
|
+
|
|
|
|
|
|
def box_line(pred):
|
|
|
'''
|
|
@@ -34,27 +38,33 @@ def box_line(pred):
|
|
|
|
|
|
|
|
|
def box_line_(pred):
|
|
|
- '''
|
|
|
- 形式同pred
|
|
|
- '''
|
|
|
for idx, box_ in enumerate(pred[0:-1]):
|
|
|
box = box_['boxes'] # 是一个tensor
|
|
|
line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
|
|
|
score = pred[-1]['wires']['score'][idx]
|
|
|
line_ = []
|
|
|
+ score_ = []
|
|
|
+
|
|
|
for i in box:
|
|
|
score_max = 0.0
|
|
|
tmp = [[0.0, 0.0], [0.0, 0.0]]
|
|
|
+
|
|
|
for j in range(len(line)):
|
|
|
- if (line[j][0][0] >= i[0] and line[j][1][0] >= i[0] and line[j][0][0] <= i[2] and
|
|
|
- line[j][1][0] <= i[2] and line[j][0][1] >= i[1] and line[j][1][1] >= i[1] and
|
|
|
- line[j][0][1] <= i[3] and line[j][1][1] <= i[3]):
|
|
|
+ if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
|
|
|
+ line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
|
|
|
+ line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
|
|
|
+ line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
|
|
|
+
|
|
|
if score[j] > score_max:
|
|
|
tmp = line[j]
|
|
|
score_max = score[j]
|
|
|
line_.append(tmp)
|
|
|
+ score_.append(score_max)
|
|
|
processed_list = torch.tensor(line_)
|
|
|
pred[idx]['line'] = processed_list
|
|
|
+
|
|
|
+ processed_s_list = torch.tensor(score_)
|
|
|
+ pred[idx]['line_score'] = processed_s_list
|
|
|
return pred
|
|
|
|
|
|
|
|
@@ -81,7 +91,7 @@ def show_(imgs, pred, epoch, writer):
|
|
|
'#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
|
|
|
'#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
|
|
|
]
|
|
|
- print(len(col))
|
|
|
+ # print(len(col))
|
|
|
im = imgs[0].permute(1, 2, 0)
|
|
|
boxes = pred[0]['boxes'].cpu().numpy()
|
|
|
line = pred[0]['line'].cpu().numpy()
|
|
@@ -96,9 +106,9 @@ def show_(imgs, pred, epoch, writer):
|
|
|
plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
|
|
|
|
for idx, (a, b) in enumerate(line):
|
|
|
- ax.scatter(a[0], a[1], c=col[99 - idx], s=2)
|
|
|
- ax.scatter(b[0], b[1], c=col[99 - idx], s=2)
|
|
|
- ax.plot([a[0], b[0]], [a[1], b[1]], c=col[idx], linewidth=1)
|
|
|
+ ax.scatter(a[1], a[0], c='#871F78', s=2)
|
|
|
+ ax.scatter(b[1], b[0], c='#871F78', s=2)
|
|
|
+ ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
|
|
|
|
|
|
# 将Matplotlib图像转换为Tensor
|
|
|
fig.canvas.draw()
|
|
@@ -109,3 +119,53 @@ def show_(imgs, pred, epoch, writer):
|
|
|
|
|
|
writer.add_image("all", img2, epoch)
|
|
|
|
|
|
+
|
|
|
+def show_predict(imgs, pred, t_start):
|
|
|
+ col = [
|
|
|
+ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
|
|
|
+ '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
|
|
|
+ '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
|
|
|
+ '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
|
|
|
+ '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
|
|
|
+ '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
|
|
|
+ '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
|
|
|
+ '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
|
|
|
+ '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
|
|
|
+ '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
|
|
|
+ '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
|
|
|
+ '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
|
|
|
+ '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
|
|
|
+ '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
|
|
|
+ '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
|
|
|
+ '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
|
|
|
+ '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
|
|
|
+ '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
|
|
|
+ '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
|
|
|
+ '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
|
|
|
+ ]
|
|
|
+ print(len(col))
|
|
|
+ im = imgs.permute(1, 2, 0)
|
|
|
+ boxes = pred[0]['boxes'].cpu().numpy()
|
|
|
+ box_scores = pred[0]['scores'].cpu().numpy()
|
|
|
+ lines = pred[0]['line'].cpu().numpy()
|
|
|
+ line_scores = pred[0]['line_score'].cpu().numpy()
|
|
|
+
|
|
|
+ # 可视化预测结
|
|
|
+ fig, ax = plt.subplots(figsize=(10, 10))
|
|
|
+ ax.imshow(np.array(im))
|
|
|
+ idx = 0
|
|
|
+
|
|
|
+ for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
|
|
|
+ x0, y0, x1, y1 = box
|
|
|
+ a, b = line
|
|
|
+ if box_score > 0.7 and line_score > 0.9:
|
|
|
+ ax.add_patch(
|
|
|
+ plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
|
+ ax.scatter(a[1], a[0], c='#871F78', s=2)
|
|
|
+ ax.scatter(b[1], b[0], c='#871F78', s=2)
|
|
|
+ ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
|
|
|
+ idx = idx + 1
|
|
|
+ t_end = time.time()
|
|
|
+ print(f'predict used:{t_end - t_start}')
|
|
|
+
|
|
|
+ plt.show()
|