# import time # import torch # from PIL import Image # from torchvision import transforms # from skimage.transform import resize import time import cv2 import skimage import os import torch from PIL import Image import matplotlib.pyplot as plt import matplotlib as mpl import numpy as np # from models.line_detect.line_net import linenet_resnet50_fpn from torchvision import transforms from models.wirenet.postprocess import postprocess from rtree import index from datetime import datetime device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def box_line_(imgs, pred): # 默认置信度 im = imgs.permute(1, 2, 0).cpu().numpy() lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 scores = pred[-1]['wires']['score'].cpu().numpy()[0] # print(f'111:{len(lines)}') for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 line, score = postprocess(lines, scores, diag * 0.01, 0, False) # print(f'333:{len(lines)}') for idx, box_ in enumerate(pred[0:-2]): box = box_['boxes'] # 是一个tensor line_ = [] score_ = [] for i in box: score_max = 0.0 tmp = [[0.0, 0.0], [0.0, 0.0]] for j in range(len(line)): if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and line[j][0][0] <= i[3] and line[j][1][0] <= i[3]): if score[j] > score_max: tmp = line[j] score_max = score[j] line_.append(tmp) score_.append(score_max) processed_list = torch.tensor(np.array(line_)) pred[idx]['line'] = processed_list processed_s_list = torch.tensor(score_) pred[idx]['line_score'] = processed_s_list return pred def set_thresholds(threshold): if isinstance(threshold, list): if len(threshold) != 2: raise ValueError("Threshold list must contain exactly two elements.") a, b = threshold elif isinstance(threshold, (int, float)): a = b = threshold else: raise TypeError("Threshold must be either a list of two numbers or a single number.") return a, b def color(): return [ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5', '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3', '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b', '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173', '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc', '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6', '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32', '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4', '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4', '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d', '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9', '#bfbfbf', '#969696', '#737373', '#525252', '#252525', '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c', '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' ] def show_all(imgs, pred, threshold, save_path): col = color() box_th, line_th = set_thresholds(threshold) im = imgs.permute(1, 2, 0) boxes = pred[0]['boxes'].cpu().numpy() box_scores = pred[0]['scores'].cpu().numpy() lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 scores = pred[-1]['wires']['score'].cpu().numpy()[0] for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 line, line_score = postprocess(lines, scores, diag * 0.01, 0, False) fig, axs = plt.subplots(1, 3, figsize=(10, 10)) axs[0].imshow(np.array(im)) for idx, box in enumerate(boxes): if box_scores[idx] < box_th: continue x0, y0, x1, y1 = box axs[0].add_patch( plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) axs[0].set_title('Boxes') axs[1].imshow(np.array(im)) for idx, (a, b) in enumerate(line): if line_score[idx] < line_th: continue axs[1].scatter(a[1], a[0], c='#871F78', s=2) axs[1].scatter(b[1], b[0], c='#871F78', s=2) axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) axs[1].set_title('Lines') axs[2].imshow(np.array(im)) lines = pred[0]['line'].cpu().numpy() line_scores = pred[0]['line_score'].cpu().numpy() idx = 0 tmp = np.array([[0.0, 0.0], [0.0, 0.0]]) for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores): x0, y0, x1, y1 = box # 框中无线的跳过 if np.array_equal(line, tmp): continue a, b = line if box_score >= 0.7 or line_score >= 0.9: axs[2].add_patch( plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) axs[2].scatter(a[1], a[0], c='#871F78', s=10) axs[2].scatter(b[1], b[0], c='#871F78', s=10) axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1) idx = idx + 1 axs[2].set_title('Boxes and Lines') if save_path: save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png') os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path) print(f"Saved result image to {save_path}") # if show: # 调整子图之间的距离,防止标题和标签重叠 plt.tight_layout() plt.show() def show_box_or_line(imgs, pred, threshold, save_path=None, show_line=False, show_box=False): col = color() box_th, line_th = set_thresholds(threshold) im = imgs.permute(1, 2, 0) # 可视化预测结 fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(np.array(im)) if show_box: boxes = pred[0]['boxes'].cpu().numpy() box_scores = pred[0]['scores'].cpu().numpy() for idx, box in enumerate(boxes): if box_scores[idx] < box_th: continue x0, y0, x1, y1 = box ax.add_patch( plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) if save_path: save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png') os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path) print(f"Saved result image to {save_path}") if show_line: lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 scores = pred[-1]['wires']['score'].cpu().numpy()[0] for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 line, line_score = postprocess(lines, scores, diag * 0.01, 0, False) for idx, (a, b) in enumerate(line): if line_score[idx] < line_th: continue ax.scatter(a[1], a[0], c='#871F78', s=2) ax.scatter(b[1], b[0], c='#871F78', s=2) ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) if save_path: save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png') os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path) print(f"Saved result image to {save_path}") plt.show() def show_predict(imgs, pred, threshold, t_start): col = color() box_th, line_th = set_thresholds(threshold) im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3] boxes = pred[0]['boxes'].cpu().numpy() box_scores = pred[0]['scores'].cpu().numpy() lines = pred[0]['line'].cpu().numpy() scores = pred[0]['line_score'].cpu().numpy() for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 line1, line_score1 = postprocess(lines, scores, diag * 0.01, 0, False) # 可视化预测结 fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(np.array(im)) idx = 0 tmp = np.array([[0.0, 0.0], [0.0, 0.0]]) for box, line, box_score, line_score in zip(boxes, line1, box_scores, line_score1): x0, y0, x1, y1 = box # 框中无线的跳过 if np.array_equal(line, tmp): continue a, b = line if box_score >= box_th or line_score >= line_th: ax.add_patch( plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) ax.scatter(a[1], a[0], c='#871F78', s=10) ax.scatter(b[1], b[0], c='#871F78', s=10) ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1) idx = idx + 1 t_end = time.time() print(f'predict used:{t_end - t_start}') plt.show() class Predict: def __init__(self, pt_path, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False): """ 初始化预测器。 参数: pt_path: 模型权重文件路径。 model: 模型定义(未加载权重)。 img: 输入图像(路径或 PIL 图像对象)。 type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。 threshold: 阈值,用于过滤预测结果。 save_path: 保存结果的路径(可选)。 show: 是否显示结果。 device: 运行设备(默认 'cuda')。 """ self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.model = model self.pt_path = pt_path self.img = self.load_image(img) self.type = type self.threshold = threshold self.save_path = save_path self.show_line = show_line self.show_box = show_box def load_best_model(self, model, save_path, device): if os.path.exists(save_path): checkpoint = torch.load(save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) # if optimizer is not None: # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # epoch = checkpoint['epoch'] # loss = checkpoint['loss'] # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}") else: print(f"No saved model found at {save_path}") return model def load_image(self, img): """加载图像""" if isinstance(img, str): img = Image.open(img).convert("RGB") return img def preprocess_image(self, img): """预处理图像""" transform = transforms.ToTensor() img_tensor = transform(img) # [3, H, W] # 调整大小为 512x512 t_start = time.time() im = img_tensor.permute(1, 2, 0) # [H, W, 3] # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3) if im.shape != (512, 512, 3): im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR) img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512] t_end = time.time() print(f"Image preprocessing used: {t_end - t_start:.4f} seconds") return img_ def predict(self): """执行预测""" model = self.load_best_model(self.model, self.pt_path, device) model.eval() # 预处理图像 img_ = self.preprocess_image(self.img) # 模型推理 with torch.no_grad(): predictions = model([img_.to(self.device)]) print("Model predictions completed.") # 后处理 t_start = time.time() pred = box_line_(img_, predictions) # 线框匹配 t_end = time.time() print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds") # 根据类型显示或保存结果 if self.type == 0: show_all(img_, pred, self.threshold, save_path=self.save_path) elif self.type == 1: show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True) elif self.type == 2: show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True) elif self.type == 3: show_predict(img_, pred, self.threshold, t_start) def run(self): """运行预测流程""" self.predict() class Predict1: def __init__(self, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False): """ 初始化预测器。 参数: pt_path: 模型权重文件路径。 model: 模型定义(未加载权重)。 img: 输入图像(路径或 PIL 图像对象)。 type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。 threshold: 阈值,用于过滤预测结果。 save_path: 保存结果的路径(可选)。 show: 是否显示结果。 device: 运行设备(默认 'cuda')。 """ self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.model = model self.img = self.load_image(img) self.type = type self.threshold = threshold self.save_path = save_path self.show_line = show_line self.show_box = show_box def load_best_model(self, model, save_path, device): if os.path.exists(save_path): checkpoint = torch.load(save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) # if optimizer is not None: # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # epoch = checkpoint['epoch'] # loss = checkpoint['loss'] # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}") else: print(f"No saved model found at {save_path}") return model def load_image(self, img): """加载图像""" if isinstance(img, str): img = Image.open(img).convert("RGB") return img def preprocess_image(self, img): """预处理图像""" transform = transforms.ToTensor() img_tensor = transform(img) # [3, H, W] # 调整大小为 512x512 t_start = time.time() im = img_tensor.permute(1, 2, 0) # [H, W, 3] # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3) if im.shape != (512, 512, 3): im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR) img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512] t_end = time.time() print(f"Image preprocessing used: {t_end - t_start:.4f} seconds") return img_ def predict(self): """执行预测""" # model = self.load_best_model(self.model, self.pt_path, device) model = self.model model.eval() # 预处理图像 img_ = self.preprocess_image(self.img) # 模型推理 with torch.no_grad(): predictions = model([img_.to(self.device)]) print("Model predictions completed.") # 根据类型显示或保存结果 if self.type == 0: # 后处理 t_start = time.time() pred = box_line_(img_, predictions) # 线框匹配 t_end = time.time() print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds") show_all(img_, pred, self.threshold, save_path=self.save_path) elif self.type == 1: show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True) elif self.type == 2: show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True) elif self.type == 3: # 后处理 t_start = time.time() pred = box_line_(img_, predictions) # 线框匹配 t_end = time.time() print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds") show_predict(img_, pred, self.threshold, t_start) def run(self): """运行预测流程""" self.predict()