# import multiprocessing # # # def worker(num): # """线程调用的函数""" # print(f"Worker: {num}") # return # # # if __name__ == '__main__': # jobs = [] # for i in range(5): # p = multiprocessing.Process(target=worker, args=(i,)) # jobs.append(p) # p.start() # # import sys # print(sys.version) from typing import List from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import FileResponse import os import torch import numpy as np from PIL import Image import skimage.io import skimage.color from torchvision import transforms import shutil import matplotlib.pyplot as plt from models.line_detect.line_net import linenet_resnet50_fpn from models.line_detect.postprocess import postprocess import time import multiprocessing as mp from fastapi.middleware.cors import CORSMiddleware # 初始化 FastAPI app = FastAPI() # 添加 CORS 中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], # 允许所有源 allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 设置多进程启动方式为 'spawn' mp.set_start_method('spawn', force=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_model(model_path): """加载模型并返回模型实例""" model = linenet_resnet50_fpn().to(device) if os.path.exists(model_path): checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print(f"Loaded model from {model_path}") else: raise FileNotFoundError(f"No saved model found at {model_path}") model.eval() return model def preprocess_image(image_path): """预处理上传的图片""" img = Image.open(image_path).convert("RGB") transform = transforms.ToTensor() img_tensor = transform(img) resized_img = skimage.transform.resize( img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512) ) return torch.tensor(resized_img).permute(2, 0, 1) def save_plot(output_path: str): """保存图像并关闭绘图""" plt.savefig(output_path, bbox_inches='tight', pad_inches=0) print(f"Saved plot to {output_path}") plt.close() def get_colors(): """返回一组预定义的颜色列表""" return [ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5', '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3', '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b', '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173', '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc', '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6', '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32', '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4', '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4', '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d', '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9', '#bfbfbf', '#969696', '#737373', '#525252', '#252525', '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c', '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' ] def show_line(im, lines, line_scores, diag, t_start): """绘制线段并保存结果""" nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for (a, b), s in zip(nlines, nscores): if s < 0.9: continue ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2) ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) t_end = time.time() print(f'show_line used: {t_end - t_start:.2f} seconds') output_path = "temp_result_line.png" save_plot(output_path) return output_path def show_box(im, boxes, box_scores, colors, t_start): """绘制边界框并保存结果""" fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for idx, (box, score) in enumerate(zip(boxes, box_scores)): if score < 0.7: continue x0, y0, x1, y1 = box ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1)) t_end = time.time() print(f'show_box used: {t_end - t_start:.2f} seconds') output_path = "temp_result_box.png" save_plot(output_path) return output_path def show_predict(im, boxes, box_scores, lines, line_scores, diag, colors, t_start): """绘制边界框和线段并保存结果""" nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for box, line, box_score, line_score, color in zip(boxes, nlines, box_scores, nscores, colors): if box_score > 0.7 and line_score > 0.9: x0, y0, x1, y1 = box a, b = line ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1)) ax.scatter(a[1], a[0], c='#871F78', s=10) ax.scatter(b[1], b[0], c='#871F78', s=10) ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1) t_end = time.time() print(f'show_predict used: {t_end - t_start:.2f} seconds') output_path = "temp_result.png" save_plot(output_path) return output_path @app.get("/") def read_root(): """返回前端页面""" return FileResponse("static/index.html") @app.post("/predict") @app.post("/predict/") async def predict(file: UploadFile = File(...)): try: start_time = time.time() # 保存上传文件 os.makedirs("uploaded_images", exist_ok=True) image_path = f"uploaded_images/{file.filename}" with open(image_path, "wb") as f: shutil.copyfileobj(file.file, f) # 加载模型 # model_path = "/home/dieu/PycharmProjects/MultiVisionModels/logs/pth/resnet50_best_e100.pth" model_path = r'D:\python\PycharmProjects\20250214\weight\merged_model_weights.pth' model = load_model(model_path) # 预处理图片 img_tensor = preprocess_image(image_path) im = img_tensor.permute(1, 2, 0).cpu().numpy() # 模型推理 with torch.no_grad(): predictions = model([img_tensor.to(device)]) # 提取预测结果 boxes = predictions[0]['boxes'].cpu().numpy() box_scores = predictions[0]['scores'].cpu().numpy() lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0] diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 colors = get_colors() # 绘制图像 t_start = time.time() output_path_box = show_box(im, boxes, box_scores, colors, t_start) output_path_line = show_line(im, lines, line_scores, diag, t_start) output_path_boxandline = show_predict(im, boxes, box_scores, lines, line_scores, diag, colors, t_start) # 合并图像 combined_image_path = "combined_result.png" combine_images( [output_path_boxandline, output_path_box, output_path_line], titles=["Box and Line", "Box", "Line"], output_path=combined_image_path ) end_time = time.time() print(f'Total time: {end_time - start_time:.2f} seconds') # 返回合成的图片 return FileResponse(combined_image_path, media_type="image/png", filename="combined_result.png") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) def combine_images(image_paths: List[str], titles: List[str], output_path: str): """将多个图像合并为一张图片""" fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5)) for ax, img_path, title in zip(axes, image_paths, titles): ax.imshow(plt.imread(img_path)) ax.set_title(title) ax.axis("off") plt.savefig(output_path, bbox_inches="tight", pad_inches=0) plt.close() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=808)