# from fastapi import FastAPI, File, UploadFile, HTTPException # from fastapi.responses import FileResponse # from fastapi.staticfiles import StaticFiles import os import torch import numpy as np from PIL import Image import skimage.io import skimage.color from torchvision import transforms import shutil import matplotlib.pyplot as plt from models.line_detect.line_net import linenet_resnet50_fpn from models.wirenet.postprocess import postprocess from rtree import index import time import multiprocessing as mp # from fastapi.middleware.cors import CORSMiddleware # 初始化 FastAPI # app = FastAPI() # 添加 CORS 中间件 # app.add_middleware( # CORSMiddleware, # allow_origins=["*"], # 允许所有源 # allow_credentials=True, # allow_methods=["*"], # allow_headers=["*"], # ) # 设置多进程启动方式为 'spawn' mp.set_start_method('spawn', force=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_model(model_path): """加载模型并返回模型实例""" model = linenet_resnet50_fpn().to(device) if os.path.exists(model_path): checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print(f"Loaded model from {model_path}") else: raise FileNotFoundError(f"No saved model found at {model_path}") model.eval() return model def preprocess_image(image_path): """预处理上传的图片""" img = Image.open(image_path).convert("RGB") transform = transforms.ToTensor() img_tensor = transform(img) resized_img = skimage.transform.resize( img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512) ) return torch.tensor(resized_img).permute(2, 0, 1),img def save_plot(output_path: str): """保存图像并关闭绘图""" plt.savefig(output_path, bbox_inches='tight', pad_inches=0) print(f"Saved plot to {output_path}") plt.close() def get_colors(): """返回一组预定义的颜色列表""" return [ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5', '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3', '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b', '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173', '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc', '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6', '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32', '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4', '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4', '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d', '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9', '#bfbfbf', '#969696', '#737373', '#525252', '#252525', '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c', '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' ] def process_box(box, lines, scores): """处理单个边界框,找到最佳匹配的线段""" valid_lines = [] # 存储有效的线段 valid_scores = [] # 存储有效的分数 # print(f'score:{len(scores)}') for i in box: best_line = None max_length = 0.0 # 遍历所有线段 for j in range(lines.shape[1]): # line_j = lines[0, j].cpu().numpy() / 128 * 512 line_j = lines[0, j].cpu().numpy() # 检查线段是否完全在box内 if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and line_j[0][0] <= i[3] and line_j[1][0] <= i[3]): # 计算线段长度 length = np.linalg.norm(line_j[0] - line_j[1]) # length = scores[j].cpu().numpy() # print(length) if length > max_length: best_line = line_j max_length = length # 如果找到有效的线段,则添加到结果中 if best_line is not None: valid_lines.append(best_line) valid_scores.append(max_length) # 使用线段长度作为分数 else: valid_lines.append([[0.0,0.0],[0.0,0.0]]) valid_scores.append(0.0) # 使用线段置信度作为分数 # print(f'valid_lines:{valid_lines}') # print(f'valid_scores:{valid_scores}') return valid_lines, valid_scores def box_line_optimized_parallel(pred): """并行处理边界框和线段的匹配""" lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2] scores = pred[-1]['wires']['score'][0] # 假设形状为[2500] boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # 所有box num_processes = min(mp.cpu_count(), len(boxes)) # 使用可用的核心数 with mp.Pool(processes=num_processes) as pool: results = pool.starmap( process_box, [(box, lines, scores) for box in boxes] ) # 更新预测结果 filtered_pred = [] for idx_box, (valid_lines, valid_scores) in enumerate(results): if valid_lines: pred[idx_box]['line'] = torch.tensor(valid_lines) pred[idx_box]['line_score'] = torch.tensor(valid_scores) filtered_pred.append(pred[idx_box]) return filtered_pred def predict(image_path): start_time = time.time() # 保存上传文件 # os.makedirs("uploaded_images", exist_ok=True) # image_path = f"{file.filename}" # with open(image_path, "wb") as f: # shutil.copyfileobj(file.file, f) # 加载模型 model_path = r'D:\python\PycharmProjects\20250214\weight\linenet_wts\r50fpn_wts_e350\best.pth' model = load_model(model_path) # 预处理图片 img_tensor,img = preprocess_image(image_path) # W = img.shape[0] im = img_tensor.permute(1, 2, 0).cpu().numpy() # 模型推理 with torch.no_grad(): predictions = model([img_tensor.to(device)]) lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0] diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) predictions[-1]['wires']['lines_'] = torch.from_numpy(nlines).float().cuda() predictions[-1]['wires']['score_'] = torch.from_numpy(nscores).float().cuda() print(predictions) # 匹配线段和边界框 t_start = time.time() filtered_pred = box_line_optimized_parallel(predictions) t_end = time.time() print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds') # 绘制图像 output_path_box = show_box(im, predictions, t_start) output_path_line = show_line(im, predictions, t_start) output_path_boxandline = show_predict(im, filtered_pred, t_start) # 合并图像 combined_image_path = "combined_result.png" combine_images( [output_path_boxandline, output_path_box, output_path_line], titles=["Box and Line", "Box", "Line"], output_path=combined_image_path ) end_time = time.time() print(f'Total time: {end_time - start_time:.2f} seconds') # 获取线段数据并添加详细的调试信息 lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * np.array([1328, 2112]) print(f"Initial lines shape: {lines.shape}") print(f"Initial lines data type: {lines.dtype}") # 确保数据是正确的形状 if len(lines.shape) != 3: # 如果不是 (N, 2, 2) 形状 if len(lines.shape) == 2 and lines.shape[1] == 4: # 如果是 (N, 4) 形状,重塑为 (N, 2, 2) lines = lines.reshape(-1, 2, 2) else: print(f"Warning: Unexpected lines shape: {lines.shape}") print(f"After reshape - lines shape: {lines.shape}") # 确保每个点是 [x, y] 格式 formatted_lines = [] for line in lines: start_point = np.array([line[0][0], line[0][1]]) end_point = np.array([line[1][0], line[1][1]]) formatted_lines.append([start_point, end_point]) formatted_lines = np.array(formatted_lines) print(f"Final formatted_lines shape: {formatted_lines.shape}") print(f"Sample formatted line: {formatted_lines[0] if len(formatted_lines) > 0 else 'No lines'}") # 确保返回的是三维数组:[lines_array] result = [formatted_lines] print(f"Final result type: {type(result)}") print(f"Final result[0] shape: {result[0].shape}") return result def combine_images(image_paths: list, titles: list, output_path: str): """将多个图像合并为一张图片""" fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5)) for ax, img_path, title in zip(axes, image_paths, titles): ax.imshow(plt.imread(img_path)) ax.set_title(title) ax.axis("off") plt.savefig(output_path, bbox_inches="tight", pad_inches=0) plt.close() def show_box(im, predictions, t_start): """绘制边界框并保存结果""" boxes = predictions[0]['boxes'].cpu().numpy() box_scores = predictions[0]['scores'].cpu().numpy() colors = get_colors() fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for idx, (box, score) in enumerate(zip(boxes, box_scores)): if score < 0.7: continue x0, y0, x1, y1 = box ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1)) t_end = time.time() print(f'show_box used: {t_end - t_start:.2f} seconds') output_path = "temp_result_box.png" save_plot(output_path) return output_path def show_line(im, predictions, t_start): """绘制线段并保存结果""" lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0] diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for (a, b), s in zip(nlines, nscores): if s < 0.9: continue ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2) ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) t_end = time.time() print(f'show_line used: {t_end - t_start:.2f} seconds') output_path = "temp_result_line.png" save_plot(output_path) return output_path def show_predict(im, filtered_pred, t_start): """绘制匹配后的边界框和线段并保存结果""" colors = get_colors() fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for idx, pred in enumerate(filtered_pred): boxes = pred['boxes'].cpu().numpy() box_scores = pred['scores'].cpu().numpy() lines = pred['line'].cpu().numpy() line_scores = pred['line_score'].cpu().numpy() diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)): if box_score < 0.7 or line_score < 0.9: continue # 如果线段为空(即没有找到有效线段),跳过绘制 if line is None or len(line) == 0: continue x0, y0, x1, y1 = box a, b = line color = colors[(idx + box_idx) % len(colors)] # 每个边界框分配一个唯一颜色 ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1)) ax.scatter(a[1], a[0], c=color, s=10) ax.scatter(b[1], b[0], c=color, s=10) ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1) t_end = time.time() print(f'show_predict used: {t_end - t_start:.2f} seconds') output_path = "temp_result.png" save_plot(output_path) return output_path if __name__ == "__main__": lines = predict(r'C:\Users\m2337\Desktop\p\9.jpg') print(f'lines:{lines}')