import os import torch import numpy as np from PIL import Image import skimage.io import skimage.color from torchvision import transforms import shutil import matplotlib.pyplot as plt from models.line_detect.line_net import linenet_resnet50_fpn from models.line_detect.postprocess import postprocess from rtree import index import time import multiprocessing as mp mp.set_start_method('spawn', force=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_model(model_path): model = linenet_resnet50_fpn().to(device) if os.path.exists(model_path): checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print(f"Loaded model from {model_path}") else: raise FileNotFoundError(f"No saved model found at {model_path}") model.eval() return model def preprocess_image(image_path): img = Image.open(image_path).convert("RGB") transform = transforms.ToTensor() img_tensor = transform(img) resized_img = skimage.transform.resize( img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512) ) return torch.tensor(resized_img).permute(2, 0, 1) def save_plot(output_path: str): plt.savefig(output_path, bbox_inches='tight', pad_inches=0) print(f"Saved plot to {output_path}") plt.close() def get_colors(): """返回一组预定义的颜色列表""" return [ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5', '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3', '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b', '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173', '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc', '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6', '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32', '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4', '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4', '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d', '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9', '#bfbfbf', '#969696', '#737373', '#525252', '#252525', '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c', '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' ] def process_box(box, lines, scores): """处理单个边界框,找到最佳匹配的线段""" valid_lines = [] # 存储有效的线段 valid_scores = [] # 存储有效的分数 # print(f'score:{len(scores)}') for i in box: best_line = None max_length = 0.0 # 遍历所有线段 for j in range(lines.shape[1]): line_j = lines[0, j].cpu().numpy() / 128 * 512 # 检查线段是否完全在box内 if (all(line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and line_j[0][0] <= i[3] and line_j[1][0] <= i[3])): # length = np.linalg.norm(line_j[0] - line_j[1]) length = scores[j].item() # print(length) if length > max_length: best_line = line_j max_length = length if best_line is not None: valid_lines.append(best_line) valid_scores.append(max_length) else: valid_lines.append([[0.0, 0.0], [0.0, 0.0]]) valid_scores.append(0.0) return valid_lines, valid_scores def box_line_optimized_parallel(pred): """并行处理边界框和线段的匹配""" lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2] scores = pred[-1]['wires']['score'][0] # 假设形状为[2500] boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # 所有box # num_processes = min(mp.cpu_count(), len(boxes)) # 使用可用的核心数 # with mp.Pool(processes=num_processes) as pool: # results = pool.starmap( # process_box, # [(box, lines, scores) for box in boxes] # ) results = process_box(boxes, lines, scores) # 更新预测结果 filtered_pred = [] for idx_box, (valid_lines, valid_scores) in enumerate(results): if valid_lines: pred[idx_box]['line'] = torch.tensor(valid_lines) pred[idx_box]['line_score'] = torch.tensor(valid_scores) filtered_pred.append(pred[idx_box]) return filtered_pred def predict(image_path): start_time = time.time() model_path = r'D:\python\PycharmProjects\20250214\weight\linenet_wts\r50fpn_wts_e350\best.pth' model = load_model(model_path) img_tensor = preprocess_image(image_path) im = img_tensor.permute(1, 2, 0).cpu().numpy() with torch.no_grad(): predictions = model([img_tensor.to(device)]) t_start = time.time() filtered_pred = box_line_optimized_parallel(predictions) t_end = time.time() print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds') output_path_box = show_box(im, predictions, t_start) output_path_line = show_line(im, predictions, t_start) show_predict(im, filtered_pred, t_start) # combined_image_path = "combined_result.png" # combine_images( # [output_path_boxandline, output_path_box, output_path_line], # titles=["Box and Line", "Box", "Line"], # output_path=combined_image_path # ) end_time = time.time() print(f'Total time: {end_time - start_time:.2f} seconds') def combine_images(image_paths: list, titles: list, output_path: str): """将多个图像合并为一张图片""" fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5)) for ax, img_path, title in zip(axes, image_paths, titles): ax.imshow(plt.imread(img_path)) ax.set_title(title) ax.axis("off") plt.savefig(output_path, bbox_inches="tight", pad_inches=0) plt.close() def show_box(im, predictions, t_start): """绘制边界框并保存结果""" boxes = predictions[0]['boxes'].cpu().numpy() box_scores = predictions[0]['scores'].cpu().numpy() colors = get_colors() fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for idx, (box, score) in enumerate(zip(boxes, box_scores)): if score < 0.7: continue x0, y0, x1, y1 = box ax.add_patch( plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1)) t_end = time.time() print(f'show_box used: {t_end - t_start:.2f} seconds') plt.show() output_path = "temp_result_box.png" save_plot(output_path) return output_path def show_line(im, predictions, t_start): """绘制线段并保存结果""" lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0] diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for (a, b), s in zip(nlines, nscores): if s < 0.9: continue ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2) ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) t_end = time.time() print(f'show_line used: {t_end - t_start:.2f} seconds') plt.show() output_path = "temp_result_line.png" save_plot(output_path) return output_path # def show_predict(im, filtered_pred, t_start): # colors = get_colors() # fig, ax = plt.subplots(figsize=(10, 10)) # ax.imshow(im) # for idx, pred in enumerate(filtered_pred): # boxes = pred['boxes'].cpu().numpy() # box_scores = pred['scores'].cpu().numpy() # lines = pred['line'].cpu().numpy() # line_scores = pred['line_score'].cpu().numpy() # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) # for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)): # if box_score < 0.7 or line_score < 0.9: # continue # # if line is None or len(line) == 0: # continue # # x0, y0, x1, y1 = box # a, b = line # color = colors[(idx + box_idx) % len(colors)] # 每个边界框分配一个唯一颜色 # ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1)) # ax.scatter(a[1], a[0], c=color, s=10) # ax.scatter(b[1], b[0], c=color, s=10) # ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1) # t_end = time.time() # print(f'show_predict used: {t_end - t_start:.2f} seconds') # plt.show() # output_path = "temp_result.png" # save_plot(output_path) # return output_path def show_predict(imgs, pred, t_start): col = [ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5', '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3', '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b', '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173', '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc', '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6', '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32', '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4', '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4', '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d', '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9', '#bfbfbf', '#969696', '#737373', '#525252', '#252525', '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c', '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' ] print(imgs.shape) # im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3] boxes = pred[0]['boxes'].cpu().numpy() box_scores = pred[0]['scores'].cpu().numpy() lines = pred[0]['line'].cpu().numpy() line_scores = pred[0]['line_score'].cpu().numpy() # 可视化预测结 fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(np.array(imgs)) idx = 0 tmp = np.array([[0.0, 0.0], [0.0, 0.0]]) for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores): x0, y0, x1, y1 = box # 框中无线的跳过 if np.array_equal(line, tmp): continue a, b = line if box_score >= 0.7 or line_score >= 0.9: ax.add_patch( plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) ax.scatter(a[1], a[0], c='#871F78', s=10) ax.scatter(b[1], b[0], c='#871F78', s=10) ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1) idx = idx + 1 t_end = time.time() print(f'predict used:{t_end - t_start}') plt.show() if __name__ == "__main__": predict(r'C:\Users\m2337\Desktop\p\22.png')