|
|
@@ -1,3 +1,14 @@
|
|
|
+# import time
|
|
|
+# import torch
|
|
|
+# from PIL import Image
|
|
|
+# from torchvision import transforms
|
|
|
+# from skimage.transform import resize
|
|
|
+
|
|
|
+import time
|
|
|
+
|
|
|
+import cv2
|
|
|
+import skimage
|
|
|
+
|
|
|
import os
|
|
|
|
|
|
import torch
|
|
|
@@ -5,126 +16,356 @@ from PIL import Image
|
|
|
import matplotlib.pyplot as plt
|
|
|
import matplotlib as mpl
|
|
|
import numpy as np
|
|
|
-from models.line_detect.line_net import linenet_resnet50_fpn
|
|
|
+# from models.line_detect.line_net import linenet_resnet50_fpn
|
|
|
from torchvision import transforms
|
|
|
|
|
|
from models.wirenet.postprocess import postprocess
|
|
|
+from rtree import index
|
|
|
+
|
|
|
+from datetime import datetime
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
|
-def load_best_model(model, save_path, device):
|
|
|
- if os.path.exists(save_path):
|
|
|
- checkpoint = torch.load(save_path, map_location=device)
|
|
|
- model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
- # if optimizer is not None:
|
|
|
- # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
- epoch = checkpoint['epoch']
|
|
|
- loss = checkpoint['loss']
|
|
|
- print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
|
|
|
+def box_line_(imgs, pred): # 默认置信度
|
|
|
+ im = imgs.permute(1, 2, 0).cpu().numpy()
|
|
|
+ lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
|
|
|
+ scores = pred[-1]['wires']['score'].cpu().numpy()[0]
|
|
|
+
|
|
|
+ # print(f'111:{len(lines)}')
|
|
|
+ for i in range(1, len(lines)):
|
|
|
+ if (lines[i] == lines[0]).all():
|
|
|
+ lines = lines[:i]
|
|
|
+ scores = scores[:i]
|
|
|
+ break
|
|
|
+ diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
+ line, score = postprocess(lines, scores, diag * 0.01, 0, False)
|
|
|
+ # print(f'333:{len(lines)}')
|
|
|
+ for idx, box_ in enumerate(pred[0:-1]):
|
|
|
+ box = box_['boxes'] # 是一个tensor
|
|
|
+
|
|
|
+ line_ = []
|
|
|
+ score_ = []
|
|
|
+
|
|
|
+ for i in box:
|
|
|
+ score_max = 0.0
|
|
|
+ tmp = [[0.0, 0.0], [0.0, 0.0]]
|
|
|
+
|
|
|
+ for j in range(len(line)):
|
|
|
+ if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
|
|
|
+ line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
|
|
|
+ line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
|
|
|
+ line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
|
|
|
+
|
|
|
+ if score[j] > score_max:
|
|
|
+ tmp = line[j]
|
|
|
+ score_max = score[j]
|
|
|
+ line_.append(tmp)
|
|
|
+ score_.append(score_max)
|
|
|
+ processed_list = torch.tensor(np.array(line_))
|
|
|
+ pred[idx]['line'] = processed_list
|
|
|
+
|
|
|
+ processed_s_list = torch.tensor(score_)
|
|
|
+ pred[idx]['line_score'] = processed_s_list
|
|
|
+ return pred
|
|
|
+
|
|
|
+
|
|
|
+def set_thresholds(threshold):
|
|
|
+ if isinstance(threshold, list):
|
|
|
+ if len(threshold) != 2:
|
|
|
+ raise ValueError("Threshold list must contain exactly two elements.")
|
|
|
+ a, b = threshold
|
|
|
+ elif isinstance(threshold, (int, float)):
|
|
|
+ a = b = threshold
|
|
|
else:
|
|
|
- print(f"No saved model found at {save_path}")
|
|
|
- return model
|
|
|
+ raise TypeError("Threshold must be either a list of two numbers or a single number.")
|
|
|
+
|
|
|
+ return a, b
|
|
|
+
|
|
|
+
|
|
|
+def color():
|
|
|
+ return [
|
|
|
+ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
|
|
|
+ '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
|
|
|
+ '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
|
|
|
+ '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
|
|
|
+ '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
|
|
|
+ '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
|
|
|
+ '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
|
|
|
+ '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
|
|
|
+ '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
|
|
|
+ '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
|
|
|
+ '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
|
|
|
+ '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
|
|
|
+ '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
|
|
|
+ '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
|
|
|
+ '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
|
|
|
+ '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
|
|
|
+ '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
|
|
|
+ '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
|
|
|
+ '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
|
|
|
+ '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
|
|
|
+ ]
|
|
|
|
|
|
|
|
|
-cmap = plt.get_cmap("jet")
|
|
|
-norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
|
|
|
-sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
|
-sm.set_array([])
|
|
|
+def show_all(imgs, pred, threshold, save_path):
|
|
|
+ col = color()
|
|
|
+ box_th, line_th = set_thresholds(threshold)
|
|
|
+ im = imgs.permute(1, 2, 0)
|
|
|
+
|
|
|
+ boxes = pred[0]['boxes'].cpu().numpy()
|
|
|
+ box_scores = pred[0]['scores'].cpu().numpy()
|
|
|
+ lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
|
|
|
+ scores = pred[-1]['wires']['score'].cpu().numpy()[0]
|
|
|
+
|
|
|
+ for i in range(1, len(lines)):
|
|
|
+ if (lines[i] == lines[0]).all():
|
|
|
+ lines = lines[:i]
|
|
|
+ scores = scores[:i]
|
|
|
+ break
|
|
|
+ diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
+ line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
|
|
|
|
|
|
+ fig, axs = plt.subplots(1, 3, figsize=(10, 10))
|
|
|
|
|
|
-def c(x):
|
|
|
- return sm.to_rgba(x)
|
|
|
+ axs[0].imshow(np.array(im))
|
|
|
+ for idx, box in enumerate(boxes):
|
|
|
+ if box_scores[idx] < box_th:
|
|
|
+ continue
|
|
|
+ x0, y0, x1, y1 = box
|
|
|
+ axs[0].add_patch(
|
|
|
+ plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
|
+ axs[0].set_title('Boxes')
|
|
|
|
|
|
+ axs[1].imshow(np.array(im))
|
|
|
+ for idx, (a, b) in enumerate(line):
|
|
|
+ if line_score[idx] < line_th:
|
|
|
+ continue
|
|
|
+ axs[1].scatter(a[1], a[0], c='#871F78', s=2)
|
|
|
+ axs[1].scatter(b[1], b[0], c='#871F78', s=2)
|
|
|
+ axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
|
|
|
+ axs[1].set_title('Lines')
|
|
|
|
|
|
-def imshow(im):
|
|
|
- plt.close()
|
|
|
+ axs[2].imshow(np.array(im))
|
|
|
+ lines = pred[0]['line'].cpu().numpy()
|
|
|
+ line_scores = pred[0]['line_score'].cpu().numpy()
|
|
|
+ idx = 0
|
|
|
+ tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
|
|
|
+ for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
|
|
|
+ x0, y0, x1, y1 = box
|
|
|
+ # 框中无线的跳过
|
|
|
+ if np.array_equal(line, tmp):
|
|
|
+ continue
|
|
|
+ a, b = line
|
|
|
+ if box_score >= 0.7 or line_score >= 0.9:
|
|
|
+ axs[2].add_patch(
|
|
|
+ plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
|
+ axs[2].scatter(a[1], a[0], c='#871F78', s=10)
|
|
|
+ axs[2].scatter(b[1], b[0], c='#871F78', s=10)
|
|
|
+ axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
|
|
|
+ idx = idx + 1
|
|
|
+ axs[2].set_title('Boxes and Lines')
|
|
|
+
|
|
|
+ if save_path:
|
|
|
+ save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png')
|
|
|
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
+
|
|
|
+ plt.savefig(save_path)
|
|
|
+ print(f"Saved result image to {save_path}")
|
|
|
+
|
|
|
+ # if show:
|
|
|
+ # 调整子图之间的距离,防止标题和标签重叠
|
|
|
plt.tight_layout()
|
|
|
- plt.imshow(im)
|
|
|
- plt.colorbar(sm, fraction=0.046)
|
|
|
- plt.xlim([0, im.shape[0]])
|
|
|
- plt.ylim([im.shape[0], 0])
|
|
|
+ plt.show()
|
|
|
|
|
|
|
|
|
-def show_line(img, pred):
|
|
|
- im = img.permute(1, 2, 0)
|
|
|
+def show_box_or_line(imgs, pred, threshold, save_path=None, show_line=False, show_box=False):
|
|
|
+ col = color()
|
|
|
+ box_th, line_th = set_thresholds(threshold)
|
|
|
+ im = imgs.permute(1, 2, 0)
|
|
|
|
|
|
- # 创建图形和坐标轴
|
|
|
+ # 可视化预测结
|
|
|
fig, ax = plt.subplots(figsize=(10, 10))
|
|
|
- # 绘制原始图像
|
|
|
ax.imshow(np.array(im))
|
|
|
- # 绘制边界框
|
|
|
- boxes = pred[0]['boxes'].cpu().numpy()
|
|
|
- boxes_scores = pred[0]['scores'].cpu().numpy()
|
|
|
|
|
|
- # for box in boxes:
|
|
|
- # x0, y0, x1, y1 = box
|
|
|
- # rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1)
|
|
|
- # ax.add_patch(rect) # 将矩形添加到 Axes 对象上
|
|
|
+ if show_box:
|
|
|
+ boxes = pred[0]['boxes'].cpu().numpy()
|
|
|
+ box_scores = pred[0]['scores'].cpu().numpy()
|
|
|
+ for idx, box in enumerate(boxes):
|
|
|
+ if box_scores[idx] < box_th:
|
|
|
+ continue
|
|
|
+ x0, y0, x1, y1 = box
|
|
|
+ ax.add_patch(
|
|
|
+ plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
|
+ if save_path:
|
|
|
+ save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png')
|
|
|
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
+
|
|
|
+ plt.savefig(save_path)
|
|
|
+ print(f"Saved result image to {save_path}")
|
|
|
+
|
|
|
+ if show_line:
|
|
|
+ lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
|
|
|
+ scores = pred[-1]['wires']['score'].cpu().numpy()[0]
|
|
|
+
|
|
|
+ for i in range(1, len(lines)):
|
|
|
+ if (lines[i] == lines[0]).all():
|
|
|
+ lines = lines[:i]
|
|
|
+ scores = scores[:i]
|
|
|
+ break
|
|
|
+ diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
+ line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
|
|
|
+
|
|
|
+ for idx, (a, b) in enumerate(line):
|
|
|
+ if line_score[idx] < line_th:
|
|
|
+ continue
|
|
|
+ ax.scatter(a[1], a[0], c='#871F78', s=2)
|
|
|
+ ax.scatter(b[1], b[0], c='#871F78', s=2)
|
|
|
+ ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
|
|
|
+ if save_path:
|
|
|
+ save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png')
|
|
|
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
+
|
|
|
+ plt.savefig(save_path)
|
|
|
+ print(f"Saved result image to {save_path}")
|
|
|
+
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+
|
|
|
+def show_predict(imgs, pred, threshold, t_start):
|
|
|
+ col = color()
|
|
|
+ box_th, line_th = set_thresholds(threshold)
|
|
|
+ im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3]
|
|
|
+ boxes = pred[0]['boxes'].cpu().numpy()
|
|
|
+ box_scores = pred[0]['scores'].cpu().numpy()
|
|
|
+ lines = pred[0]['line'].cpu().numpy()
|
|
|
+ scores = pred[0]['line_score'].cpu().numpy()
|
|
|
|
|
|
- for b, s in zip(boxes, boxes_scores):
|
|
|
- # print(f'box:{b}, box_score:{s}')
|
|
|
- if s < 0.7:
|
|
|
- continue
|
|
|
- x0, y0, x1, y1 = b
|
|
|
- rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1)
|
|
|
- ax.add_patch(rect) # 将矩形添加到 Axes 对象上
|
|
|
-
|
|
|
- PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
|
|
|
- H = pred[-1]['wires']
|
|
|
- lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
|
|
|
- scores = H["score"][0].cpu().numpy()
|
|
|
for i in range(1, len(lines)):
|
|
|
if (lines[i] == lines[0]).all():
|
|
|
lines = lines[:i]
|
|
|
scores = scores[:i]
|
|
|
break
|
|
|
-
|
|
|
- # 后处理线条以去除重叠的线条
|
|
|
diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
- nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
|
|
|
+ line1, line_score1 = postprocess(lines, scores, diag * 0.01, 0, False)
|
|
|
+
|
|
|
+ # 可视化预测结
|
|
|
+ fig, ax = plt.subplots(figsize=(10, 10))
|
|
|
+ ax.imshow(np.array(im))
|
|
|
+ idx = 0
|
|
|
+
|
|
|
+ tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
|
|
|
+
|
|
|
+ for box, line, box_score, line_score in zip(boxes, line1, box_scores, line_score1):
|
|
|
+ x0, y0, x1, y1 = box
|
|
|
+ # 框中无线的跳过
|
|
|
+ if np.array_equal(line, tmp):
|
|
|
+ continue
|
|
|
+ a, b = line
|
|
|
+ if box_score >= box_th or line_score >= line_th:
|
|
|
+ ax.add_patch(
|
|
|
+ plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
|
+ ax.scatter(a[1], a[0], c='#871F78', s=10)
|
|
|
+ ax.scatter(b[1], b[0], c='#871F78', s=10)
|
|
|
+ ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
|
|
|
+ idx = idx + 1
|
|
|
+ t_end = time.time()
|
|
|
+ print(f'predict used:{t_end - t_start}')
|
|
|
|
|
|
- # 根据分数绘制线条
|
|
|
- for i, t in enumerate([0.9]):
|
|
|
- for (a, b), s in zip(nlines, nscores):
|
|
|
- if s < t:
|
|
|
- continue
|
|
|
- ax.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) # 在 Axes 上绘制线条
|
|
|
- ax.scatter(a[1], a[0], **PLTOPTS) # 在 Axes 上绘制散点
|
|
|
- ax.scatter(b[1], b[0], **PLTOPTS) # 在 Axes 上绘制散点
|
|
|
-
|
|
|
- # 隐藏坐标轴
|
|
|
- ax.set_axis_off()
|
|
|
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
|
|
- plt.margins(0, 0)
|
|
|
- ax.xaxis.set_major_locator(plt.NullLocator())
|
|
|
- ax.yaxis.set_major_locator(plt.NullLocator())
|
|
|
-
|
|
|
- # 显示图像
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
-def predict(pt_path, model, img):
|
|
|
- model = load_best_model(model, pt_path, device)
|
|
|
+class Predict:
|
|
|
+ def __init__(self, pt_path, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
|
|
|
+ """
|
|
|
+ 初始化预测器。
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ pt_path: 模型权重文件路径。
|
|
|
+ model: 模型定义(未加载权重)。
|
|
|
+ img: 输入图像(路径或 PIL 图像对象)。
|
|
|
+ type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。
|
|
|
+ threshold: 阈值,用于过滤预测结果。
|
|
|
+ save_path: 保存结果的路径(可选)。
|
|
|
+ show: 是否显示结果。
|
|
|
+ device: 运行设备(默认 'cuda')。
|
|
|
+ """
|
|
|
+ self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
|
|
+ self.model = model
|
|
|
+ self.pt_path = pt_path
|
|
|
+ self.img = self.load_image(img)
|
|
|
+ self.type = type
|
|
|
+ self.threshold = threshold
|
|
|
+ self.save_path = save_path
|
|
|
+ self.show_line = show_line
|
|
|
+ self.show_box = show_box
|
|
|
+
|
|
|
+ def load_best_model(self, model, save_path, device):
|
|
|
+ if os.path.exists(save_path):
|
|
|
+ checkpoint = torch.load(save_path, map_location=device)
|
|
|
+ model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
+ # if optimizer is not None:
|
|
|
+ # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
+ # epoch = checkpoint['epoch']
|
|
|
+ # loss = checkpoint['loss']
|
|
|
+ # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
|
|
|
+ else:
|
|
|
+ print(f"No saved model found at {save_path}")
|
|
|
+ return model
|
|
|
+
|
|
|
+ def load_image(self, img):
|
|
|
+ """加载图像"""
|
|
|
+ if isinstance(img, str):
|
|
|
+ img = Image.open(img).convert("RGB")
|
|
|
+ return img
|
|
|
+
|
|
|
+ def preprocess_image(self, img):
|
|
|
+ """预处理图像"""
|
|
|
+ transform = transforms.ToTensor()
|
|
|
+ img_tensor = transform(img) # [3, H, W]
|
|
|
+
|
|
|
+ # 调整大小为 512x512
|
|
|
+ t_start = time.time()
|
|
|
+ im = img_tensor.permute(1, 2, 0) # [H, W, 3]
|
|
|
+ # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
|
|
|
+ if im.shape != (512, 512, 3):
|
|
|
+ im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
|
|
|
+ img_ = torch.tensor(im_resized).permute(2, 0, 1) # [3, 512, 512]
|
|
|
+ t_end = time.time()
|
|
|
+ print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
|
|
|
+
|
|
|
+ return img_
|
|
|
|
|
|
- model.eval()
|
|
|
+ def predict(self):
|
|
|
+ """执行预测"""
|
|
|
+ model = self.load_best_model(self.model, self.pt_path, device)
|
|
|
|
|
|
- if isinstance(img, str):
|
|
|
- img = Image.open(img).convert("RGB")
|
|
|
+ model.eval()
|
|
|
|
|
|
- transform = transforms.ToTensor()
|
|
|
- img_tensor = transform(img)
|
|
|
+ # 预处理图像
|
|
|
+ img_ = self.preprocess_image(self.img)
|
|
|
|
|
|
- with torch.no_grad():
|
|
|
- predictions = model([img_tensor])
|
|
|
- print(predictions[0])
|
|
|
+ # 模型推理
|
|
|
+ with torch.no_grad():
|
|
|
+ predictions = model([img_.to(self.device)])
|
|
|
+ print("Model predictions completed.")
|
|
|
|
|
|
- show_line(img_tensor, predictions)
|
|
|
+ # 后处理
|
|
|
+ t_start = time.time()
|
|
|
+ pred = box_line_(img_, predictions) # 线框匹配
|
|
|
+ t_end = time.time()
|
|
|
+ print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
|
|
|
|
|
|
+ # 根据类型显示或保存结果
|
|
|
+ if self.type == 0:
|
|
|
+ show_all(img_, pred, self.threshold, save_path=self.save_path)
|
|
|
+ elif self.type == 1:
|
|
|
+ show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True)
|
|
|
+ elif self.type == 2:
|
|
|
+ show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True)
|
|
|
+ elif self.type == 3:
|
|
|
+ show_predict(img_, pred, self.threshold, t_start)
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
- model = linenet_resnet50_fpn()
|
|
|
- pt_path = r'D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\linenet_wts\resnet50_best_e8.pth'
|
|
|
- # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png' # 工件图
|
|
|
- img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png' # wireframe图
|
|
|
- predict(pt_path, model, img_path)
|
|
|
+ def run(self):
|
|
|
+ """运行预测流程"""
|
|
|
+ self.predict()
|