import time import skimage from models.line_detect.postprocess import show_predict import os import torch from PIL import Image import matplotlib.pyplot as plt import numpy as np from models.line_detect.line_net import linenet_resnet50_fpn from torchvision import transforms from rtree import index import multiprocessing as mp # 设置多进程启动方式为 'spawn' mp.set_start_method('spawn', force=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'device:{device}') def load_best_model(model, save_path, device): if os.path.exists(save_path): checkpoint = torch.load(save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}") else: print(f"No saved model found at {save_path}") return model def process_box(box, lines, scores): valid_lines = [] # 存储有效的线段 valid_scores = [] # 存储有效的分数 for i in box: best_line = None max_length = 0.0 # 遍历所有线段 for j in range(lines.shape[1]): line_j = lines[0, j].cpu().numpy() / 128 * 512 # 检查线段是否完全在box内 if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and line_j[0][0] <= i[3] and line_j[1][0] <= i[3]): # 计算线段长度 length = np.linalg.norm(line_j[0] - line_j[1]) if length > max_length: best_line = line_j max_length = length # 如果找到有效的线段,则添加到结果中 if best_line is not None: valid_lines.append(best_line) valid_scores.append(max_length) # 使用线段长度作为分数 return valid_lines, valid_scores def box_line_optimized_parallel(pred): # 提取所有线段和分数 lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2] scores = pred[-1]['wires']['score'][0] # 假设形状为[2500] # 初始化存储结果的列表 filtered_pred = [] # 使用多进程并行处理每个box boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # 所有box num_processes = min(mp.cpu_count(), len(boxes)) # 使用可用的核心数 with mp.Pool(processes=num_processes) as pool: results = pool.starmap( process_box, [(box, lines, scores) for box in boxes] ) # 更新预测结果 for idx_box, (valid_lines, valid_scores) in enumerate(results): if valid_lines: pred[idx_box]['line'] = torch.tensor(valid_lines) pred[idx_box]['line_score'] = torch.tensor(valid_scores) filtered_pred.append(pred[idx_box]) return filtered_pred def predict(pt_path, model, img): model = load_best_model(model, pt_path, device) model.eval() if isinstance(img, str): img = Image.open(img).convert("RGB") transform = transforms.ToTensor() img_tensor = transform(img) im = img_tensor.permute(1, 2, 0) # [512, 512, 3] im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3) img_tensor = torch.tensor(im_resized).permute(2, 0, 1) with torch.no_grad(): t_start = time.time() predictions = model([img_tensor.to(device)]) t_end = time.time() print(f'Prediction used: {t_end - t_start:.4f} seconds') boxes = predictions[0]['boxes'].shape lines = predictions[-1]['wires']['lines'].shape lines_scores = predictions[-1]['wires']['score'].shape print(f'Predictions - boxes: {boxes}, lines: {lines}, lines_scores: {lines_scores}') t_start = time.time() pred = box_line_optimized_parallel(predictions) t_end = time.time() print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds') # 检查 pred 是否为空 if not pred: print("No valid predictions found. Skipping visualization.") return # 只绘制有效的线段 show_predict(img_tensor, pred, t_start) if __name__ == '__main__': t_start = time.time() print(f'Start to predict: {t_start}') model = linenet_resnet50_fpn().to(device) pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth' img_path = r'C:\Users\m2337\Desktop\p\49.jpg' predict(pt_path, model, img_path) t_end = time.time() print(f'Total prediction time: {t_end - t_start:.4f} seconds')