# import time # import torch # from PIL import Image # from torchvision import transforms # from skimage.transform import resize import time import cv2 import skimage import os import torch from PIL import Image import matplotlib.pyplot as plt import numpy as np from torchvision import transforms from models.wirenet.postprocess import postprocess from rtree import index from datetime import datetime device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def box_line_(imgs, pred): # 默认置信度 im = imgs.permute(1, 2, 0).cpu().numpy() lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 scores = pred[-1]['wires']['score'].cpu().numpy()[0] # print(f'111:{len(lines)}') for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 line, score = postprocess(lines, scores, diag * 0.01, 0, False) # print(f'333:{len(lines)}') for idx, box_ in enumerate(pred[0:-2]): box = box_['boxes'] # 是一个tensor line_ = [] score_ = [] for i in box: score_max = 0.0 tmp = [[0.0, 0.0], [0.0, 0.0]] for j in range(len(line)): if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and line[j][0][0] <= i[3] and line[j][1][0] <= i[3]): if score[j] > score_max: tmp = line[j] score_max = score[j] line_.append(tmp) score_.append(score_max) processed_list = torch.tensor(np.array(line_)) pred[idx]['line'] = processed_list processed_s_list = torch.tensor(score_) pred[idx]['line_score'] = processed_s_list return pred def set_thresholds(threshold): if isinstance(threshold, list): if len(threshold) != 2: raise ValueError("Threshold list must contain exactly two elements.") a, b = threshold elif isinstance(threshold, (int, float)): a = b = threshold else: raise TypeError("Threshold must be either a list of two numbers or a single number.") return a, b def color(): return [ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5', '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3', '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b', '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173', '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc', '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6', '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32', '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4', '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4', '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d', '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9', '#bfbfbf', '#969696', '#737373', '#525252', '#252525', '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c', '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' ] def show_all(imgs, pred, threshold, save_path): col = color() box_th, line_th = set_thresholds(threshold) im = imgs.permute(1, 2, 0) boxes = pred[0]['boxes'].cpu().numpy() box_scores = pred[0]['scores'].cpu().numpy() lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 scores = pred[-1]['wires']['score'].cpu().numpy()[0] for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 line, line_score = postprocess(lines, scores, diag * 0.01, 0, False) fig, axs = plt.subplots(1, 3, figsize=(10, 10)) axs[0].imshow(np.array(im)) for idx, box in enumerate(boxes): if box_scores[idx] < box_th: continue x0, y0, x1, y1 = box axs[0].add_patch( plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) axs[0].set_title('Boxes') axs[1].imshow(np.array(im)) for idx, (a, b) in enumerate(line): if line_score[idx] < line_th: continue axs[1].scatter(a[1], a[0], c='#871F78', s=2) axs[1].scatter(b[1], b[0], c='#871F78', s=2) axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) axs[1].set_title('Lines') axs[2].imshow(np.array(im)) lines = pred[0]['line'].cpu().numpy() line_scores = pred[0]['line_score'].cpu().numpy() idx = 0 tmp = np.array([[0.0, 0.0], [0.0, 0.0]]) for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores): x0, y0, x1, y1 = box # 框中无线的跳过 if np.array_equal(line, tmp): continue a, b = line if box_score >= 0.7 or line_score >= 0.9: axs[2].add_patch( plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) axs[2].scatter(a[1], a[0], c='#871F78', s=10) axs[2].scatter(b[1], b[0], c='#871F78', s=10) axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1) idx = idx + 1 axs[2].set_title('Boxes and Lines') if save_path: save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png') os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path) print(f"Saved result image to {save_path}") # if show: # 调整子图之间的距离,防止标题和标签重叠 plt.tight_layout() plt.show() def show_box_or_line(imgs, pred, threshold, save_path=None, show_line=False, show_box=False): col = color() box_th, line_th = set_thresholds(threshold) im = imgs.permute(1, 2, 0) # 可视化预测结 fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(np.array(im)) if show_box: boxes = pred[0]['boxes'].cpu().numpy() box_scores = pred[0]['scores'].cpu().numpy() for idx, box in enumerate(boxes): if box_scores[idx] < box_th: continue x0, y0, x1, y1 = box ax.add_patch( plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) if save_path: save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png') os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path) print(f"Saved result image to {save_path}") if show_line: lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 scores = pred[-1]['wires']['score'].cpu().numpy()[0] for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 line, line_score = postprocess(lines, scores, diag * 0.01, 0, False) for idx, (a, b) in enumerate(line): if line_score[idx] < line_th: continue ax.scatter(a[1], a[0], c='#871F78', s=2) ax.scatter(b[1], b[0], c='#871F78', s=2) ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) if save_path: save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png') os.makedirs(os.path.dirname(save_path), exist_ok=True) plt.savefig(save_path) print(f"Saved result image to {save_path}") plt.show() def show_predict(imgs, pred, threshold, t_start): col = color() box_th, line_th = set_thresholds(threshold) im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3] boxes = pred[0]['boxes'].cpu().numpy() box_scores = pred[0]['scores'].cpu().numpy() lines = pred[0]['line'].cpu().numpy() scores = pred[0]['line_score'].cpu().numpy() for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 line1, line_score1 = postprocess(lines, scores, diag * 0.01, 0, False) # 可视化预测结 fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(np.array(im)) idx = 0 tmp = np.array([[0.0, 0.0], [0.0, 0.0]]) for box, line, box_score, line_score in zip(boxes, line1, box_scores, line_score1): x0, y0, x1, y1 = box # 框中无线的跳过 if np.array_equal(line, tmp): continue a, b = line if box_score >= box_th or line_score >= line_th: ax.add_patch( plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) ax.scatter(a[1], a[0], c='#871F78', s=10) ax.scatter(b[1], b[0], c='#871F78', s=10) ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1) idx = idx + 1 t_end = time.time() print(f'predict used:{t_end - t_start}') plt.show() class Predict: def __init__(self, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False): """ 初始化预测器。 参数: pt_path: 模型权重文件路径。 model: 模型定义(未加载权重)。 img: 输入图像(路径或 PIL 图像对象)。 type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。 threshold: 阈值,用于过滤预测结果。 save_path: 保存结果的路径(可选)。 show: 是否显示结果。 device: 运行设备(默认 'cuda')。 """ # self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.model = model self.device = next(model.parameters()).device # self.pt_path = pt_path self.img = self.load_image(img) self.type = type self.threshold = threshold self.save_path = save_path self.show_line = show_line self.show_box = show_box def load_best_model(self, model, save_path, device): if os.path.exists(save_path): checkpoint = torch.load(save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) # if optimizer is not None: # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # epoch = checkpoint['epoch'] # loss = checkpoint['loss'] # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}") else: print(f"No saved model found at {save_path}") return model def load_image(self, img): """加载图像""" if isinstance(img, str): img = Image.open(img).convert("RGB") return img def preprocess_image(self, img): """预处理图像""" transform = transforms.ToTensor() img_tensor = transform(img) # [3, H, W] # 调整大小为 512x512 t_start = time.time() im = img_tensor.permute(1, 2, 0) # [H, W, 3] # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3) if im.shape != (512, 512, 3): im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR) img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512] t_end = time.time() print(f"Image preprocessing used: {t_end - t_start:.4f} seconds") return img_ def predict(self): """执行预测""" # model = self.load_best_model(self.model, self.pt_path, device) # # model.eval() # 预处理图像 img_ = self.preprocess_image(self.img) # 模型推理 with torch.no_grad(): predictions =self.model([img_.to(self.device)]) print("Model predictions completed.") # 后处理 t_start = time.time() pred = box_line_(img_, predictions) # 线框匹配 t_end = time.time() print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds") # 根据类型显示或保存结果 if self.type == 0: show_all(img_, pred, self.threshold, save_path=self.save_path) elif self.type == 1: show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True) elif self.type == 2: show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True) elif self.type == 3: show_predict(img_, pred, self.threshold, t_start) def run(self): """运行预测流程""" self.predict() class Predict1: def __init__(self, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False): """ 初始化预测器。 参数: pt_path: 模型权重文件路径。 model: 模型定义(未加载权重)。 img: 输入图像(路径或 PIL 图像对象)。 type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。 threshold: 阈值,用于过滤预测结果。 save_path: 保存结果的路径(可选)。 show: 是否显示结果。 device: 运行设备(默认 'cuda')。 """ self.device = torch.device(device if torch.cuda.is_available() else 'cpu') self.model = model self.img = self.load_image(img) self.type = type self.threshold = threshold self.save_path = save_path self.show_line = show_line self.show_box = show_box def load_best_model(self, model, save_path, device): if os.path.exists(save_path): checkpoint = torch.load(save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) # if optimizer is not None: # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # epoch = checkpoint['epoch'] # loss = checkpoint['loss'] # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}") else: print(f"No saved model found at {save_path}") return model def load_image(self, img): """加载图像""" if isinstance(img, str): img = Image.open(img).convert("RGB") return img def preprocess_image(self, img): """预处理图像""" transform = transforms.ToTensor() img_tensor = transform(img) # [3, H, W] # 调整大小为 512x512 t_start = time.time() im = img_tensor.permute(1, 2, 0) # [H, W, 3] # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3) if im.shape != (512, 512, 3): im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR) img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512] t_end = time.time() print(f"Image preprocessing used: {t_end - t_start:.4f} seconds") return img_ def predict(self): """执行预测""" # model = self.load_best_model(self.model, self.pt_path, device) model = self.model model.eval() # 预处理图像 img_ = self.preprocess_image(self.img) # 模型推理 with torch.no_grad(): predictions = model([img_.to(self.device)]) print("Model predictions completed.") # 根据类型显示或保存结果 if self.type == 0: # 后处理 t_start = time.time() pred = box_line_(img_, predictions) # 线框匹配 t_end = time.time() print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds") show_all(img_, pred, self.threshold, save_path=self.save_path) elif self.type == 1: show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True) elif self.type == 2: show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True) elif self.type == 3: # 后处理 t_start = time.time() pred = box_line_(img_, predictions) # 线框匹配 t_end = time.time() print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds") show_predict(img_, pred, self.threshold, t_start) def run(self): """运行预测流程""" self.predict()