# zjf的是将线段长度作为得分,这里是将最大置信度作为得分 # from fastapi import FastAPI, File, UploadFile, HTTPException # from fastapi.responses import FileResponse # from fastapi.staticfiles import StaticFiles import os import torch import numpy as np from PIL import Image import skimage.io import skimage.color from torchvision import transforms import shutil import matplotlib.pyplot as plt from models.line_detect.line_net import linenet_resnet50_fpn from models.line_detect.postprocess import postprocess from rtree import index import time import multiprocessing as mp # from fastapi.middleware.cors import CORSMiddleware # 初始化 FastAPI # app = FastAPI() # 添加 CORS 中间件 # app.add_middleware( # CORSMiddleware, # allow_origins=["*"], # 允许所有源 # allow_credentials=True, # allow_methods=["*"], # allow_headers=["*"], # ) # 设置多进程启动方式为 'spawn' mp.set_start_method('spawn', force=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_model(model_path): """加载模型并返回模型实例""" model = linenet_resnet50_fpn().to(device) if os.path.exists(model_path): checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print(f"Loaded model from {model_path}") else: raise FileNotFoundError(f"No saved model found at {model_path}") model.eval() return model def preprocess_image(image_path): """预处理上传的图片""" img = Image.open(image_path).convert("RGB") transform = transforms.ToTensor() img_tensor = transform(img) resized_img = skimage.transform.resize( img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512) ) return torch.tensor(resized_img).permute(2, 0, 1) def save_plot(output_path: str): """保存图像并关闭绘图""" plt.savefig(output_path, bbox_inches='tight', pad_inches=0) print(f"Saved plot to {output_path}") plt.close() def get_colors(): """返回一组预定义的颜色列表""" return [ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5', '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3', '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b', '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173', '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc', '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6', '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32', '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4', '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4', '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d', '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9', '#bfbfbf', '#969696', '#737373', '#525252', '#252525', '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c', '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' ] def process_box(box, lines, scores): """处理单个边界框,找到最佳匹配的线段""" valid_lines = [] # 存储有效的线段 valid_scores = [] # 存储有效的分数 # print(f'score:{len(scores)}') for i in box: best_line = None max_length = 0.0 # 遍历所有线段 for j in range(lines.shape[1]): line_j = lines[0, j].cpu().numpy() / 128 * 512 # 检查线段是否完全在box内 if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and line_j[0][0] <= i[3] and line_j[1][0] <= i[3]): # length = np.linalg.norm(line_j[0] - line_j[1]) length = scores[j].item() # print(length) if length > max_length: best_line = line_j max_length = length if best_line is not None: valid_lines.append(best_line) valid_scores.append(max_length) else: valid_lines.append([[0.0,0.0],[0.0,0.0]]) valid_scores.append(0.0) # 使用线段置信度作为分数 # print(f'valid_lines:{valid_lines}') # print(f'valid_scores:{valid_scores}') return valid_lines, valid_scores def box_line_optimized_parallel(pred): """并行处理边界框和线段的匹配""" lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2] scores = pred[-1]['wires']['score'][0] # 假设形状为[2500] boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # 所有box num_processes = min(mp.cpu_count(), len(boxes)) # 使用可用的核心数 with mp.Pool(processes=num_processes) as pool: results = pool.starmap( process_box, [(box, lines, scores) for box in boxes] ) # 更新预测结果 filtered_pred = [] for idx_box, (valid_lines, valid_scores) in enumerate(results): if valid_lines: pred[idx_box]['line'] = torch.tensor(valid_lines) pred[idx_box]['line_score'] = torch.tensor(valid_scores) filtered_pred.append(pred[idx_box]) return filtered_pred def predict(image_path): start_time = time.time() # 保存上传文件 # os.makedirs("uploaded_images", exist_ok=True) # image_path = f"{file.filename}" # with open(image_path, "wb") as f: # shutil.copyfileobj(file.file, f) # 加载模型 model_path = "/data/lm/resnet50_best_e280.pth" model = load_model(model_path) # 预处理图片 img_tensor = preprocess_image(image_path) im = img_tensor.permute(1, 2, 0).cpu().numpy() # 模型推理 with torch.no_grad(): predictions = model([img_tensor.to(device)]) # 匹配线段和边界框 t_start = time.time() filtered_pred = box_line_optimized_parallel(predictions) t_end = time.time() print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds') # 绘制图像 output_path_box = show_box(im, predictions, t_start) output_path_line = show_line(im, predictions, t_start) output_path_boxandline = show_predict(im, filtered_pred, t_start) # 合并图像 combined_image_path = "combined_result.png" combine_images( [output_path_boxandline, output_path_box, output_path_line], titles=["Box and Line", "Box", "Line"], output_path=combined_image_path ) end_time = time.time() print(f'Total time: {end_time - start_time:.2f} seconds') # # 返回合成的图片 # return FileResponse(combined_image_path, media_type="image/png", filename="combined_result.png") # except Exception as e: # raise HTTPException(status_code=500, detail=str(e)) def combine_images(image_paths: list, titles: list, output_path: str): """将多个图像合并为一张图片""" fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5)) for ax, img_path, title in zip(axes, image_paths, titles): ax.imshow(plt.imread(img_path)) ax.set_title(title) ax.axis("off") plt.savefig(output_path, bbox_inches="tight", pad_inches=0) plt.close() def show_box(im, predictions, t_start): """绘制边界框并保存结果""" boxes = predictions[0]['boxes'].cpu().numpy() box_scores = predictions[0]['scores'].cpu().numpy() colors = get_colors() fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for idx, (box, score) in enumerate(zip(boxes, box_scores)): if score < 0.7: continue x0, y0, x1, y1 = box ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1)) t_end = time.time() print(f'show_box used: {t_end - t_start:.2f} seconds') output_path = "temp_result_box.png" save_plot(output_path) return output_path def show_line(im, predictions, t_start): """绘制线段并保存结果""" lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0] diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for (a, b), s in zip(nlines, nscores): if s < 0.9: continue ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2) ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) t_end = time.time() print(f'show_line used: {t_end - t_start:.2f} seconds') output_path = "temp_result_line.png" save_plot(output_path) return output_path def show_predict(im, filtered_pred, t_start): """绘制匹配后的边界框和线段并保存结果""" colors = get_colors() fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for idx, pred in enumerate(filtered_pred): boxes = pred['boxes'].cpu().numpy() box_scores = pred['scores'].cpu().numpy() lines = pred['line'].cpu().numpy() line_scores = pred['line_score'].cpu().numpy() diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)): if box_score < 0.7 or line_score < 0.9: continue # 如果线段为空(即没有找到有效线段),跳过绘制 if line is None or len(line) == 0: continue x0, y0, x1, y1 = box a, b = line color = colors[(idx + box_idx) % len(colors)] # 每个边界框分配一个唯一颜色 ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1)) ax.scatter(a[1], a[0], c=color, s=10) ax.scatter(b[1], b[0], c=color, s=10) ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1) t_end = time.time() print(f'show_predict used: {t_end - t_start:.2f} seconds') output_path = "temp_result.png" save_plot(output_path) return output_path if __name__ == "__main__": predict('/data/lm/wirenet_1000/images/train/00031643_1.png')