123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233 |
- # import multiprocessing
- #
- #
- # def worker(num):
- # """线程调用的函数"""
- # print(f"Worker: {num}")
- # return
- #
- #
- # if __name__ == '__main__':
- # jobs = []
- # for i in range(5):
- # p = multiprocessing.Process(target=worker, args=(i,))
- # jobs.append(p)
- # p.start()
- #
- # import sys
- # print(sys.version)
- from typing import List
- from fastapi import FastAPI, File, UploadFile, HTTPException
- from fastapi.responses import FileResponse
- import os
- import torch
- import numpy as np
- from PIL import Image
- import skimage.io
- import skimage.color
- from torchvision import transforms
- import shutil
- import matplotlib.pyplot as plt
- from models.line_detect.line_net import linenet_resnet50_fpn
- from models.line_detect.postprocess import postprocess
- import time
- import multiprocessing as mp
- from fastapi.middleware.cors import CORSMiddleware
- # 初始化 FastAPI
- app = FastAPI()
- # 添加 CORS 中间件
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"], # 允许所有源
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # 设置多进程启动方式为 'spawn'
- mp.set_start_method('spawn', force=True)
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- def load_model(model_path):
- """加载模型并返回模型实例"""
- model = linenet_resnet50_fpn().to(device)
- if os.path.exists(model_path):
- checkpoint = torch.load(model_path, map_location=device)
- model.load_state_dict(checkpoint['model_state_dict'])
- print(f"Loaded model from {model_path}")
- else:
- raise FileNotFoundError(f"No saved model found at {model_path}")
- model.eval()
- return model
- def preprocess_image(image_path):
- """预处理上传的图片"""
- img = Image.open(image_path).convert("RGB")
- transform = transforms.ToTensor()
- img_tensor = transform(img)
- resized_img = skimage.transform.resize(
- img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
- )
- return torch.tensor(resized_img).permute(2, 0, 1)
- def save_plot(output_path: str):
- """保存图像并关闭绘图"""
- plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
- print(f"Saved plot to {output_path}")
- plt.close()
- def get_colors():
- """返回一组预定义的颜色列表"""
- return [
- '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
- '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
- '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
- '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
- '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
- '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
- '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
- '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
- '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
- '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
- '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
- '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
- '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
- '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
- '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
- '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
- '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
- '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
- '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
- '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
- ]
- def show_line(im, lines, line_scores, diag, t_start):
- """绘制线段并保存结果"""
- nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(im)
- for (a, b), s in zip(nlines, nscores):
- if s < 0.9:
- continue
- ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
- ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
- t_end = time.time()
- print(f'show_line used: {t_end - t_start:.2f} seconds')
- output_path = "temp_result_line.png"
- save_plot(output_path)
- return output_path
- def show_box(im, boxes, box_scores, colors, t_start):
- """绘制边界框并保存结果"""
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(im)
- for idx, (box, score) in enumerate(zip(boxes, box_scores)):
- if score < 0.7:
- continue
- x0, y0, x1, y1 = box
- ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
- t_end = time.time()
- print(f'show_box used: {t_end - t_start:.2f} seconds')
- output_path = "temp_result_box.png"
- save_plot(output_path)
- return output_path
- def show_predict(im, boxes, box_scores, lines, line_scores, diag, colors, t_start):
- """绘制边界框和线段并保存结果"""
- nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(im)
- for box, line, box_score, line_score, color in zip(boxes, nlines, box_scores, nscores, colors):
- if box_score > 0.7 and line_score > 0.9:
- x0, y0, x1, y1 = box
- a, b = line
- ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
- ax.scatter(a[1], a[0], c='#871F78', s=10)
- ax.scatter(b[1], b[0], c='#871F78', s=10)
- ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
- t_end = time.time()
- print(f'show_predict used: {t_end - t_start:.2f} seconds')
- output_path = "temp_result.png"
- save_plot(output_path)
- return output_path
- @app.get("/")
- def read_root():
- """返回前端页面"""
- return FileResponse("static/index.html")
- @app.post("/predict")
- @app.post("/predict/")
- async def predict(file: UploadFile = File(...)):
- try:
- start_time = time.time()
- # 保存上传文件
- os.makedirs("uploaded_images", exist_ok=True)
- image_path = f"uploaded_images/{file.filename}"
- with open(image_path, "wb") as f:
- shutil.copyfileobj(file.file, f)
- # 加载模型
- # model_path = "/home/dieu/PycharmProjects/MultiVisionModels/logs/pth/resnet50_best_e100.pth"
- model_path = r'D:\python\PycharmProjects\20250214\weight\merged_model_weights.pth'
- model = load_model(model_path)
- # 预处理图片
- img_tensor = preprocess_image(image_path)
- im = img_tensor.permute(1, 2, 0).cpu().numpy()
- # 模型推理
- with torch.no_grad():
- predictions = model([img_tensor.to(device)])
- # 提取预测结果
- boxes = predictions[0]['boxes'].cpu().numpy()
- box_scores = predictions[0]['scores'].cpu().numpy()
- lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
- line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- colors = get_colors()
- # 绘制图像
- t_start = time.time()
- output_path_box = show_box(im, boxes, box_scores, colors, t_start)
- output_path_line = show_line(im, lines, line_scores, diag, t_start)
- output_path_boxandline = show_predict(im, boxes, box_scores, lines, line_scores, diag, colors, t_start)
- # 合并图像
- combined_image_path = "combined_result.png"
- combine_images(
- [output_path_boxandline, output_path_box, output_path_line],
- titles=["Box and Line", "Box", "Line"],
- output_path=combined_image_path
- )
- end_time = time.time()
- print(f'Total time: {end_time - start_time:.2f} seconds')
- # 返回合成的图片
- return FileResponse(combined_image_path, media_type="image/png", filename="combined_result.png")
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
- def combine_images(image_paths: List[str], titles: List[str], output_path: str):
- """将多个图像合并为一张图片"""
- fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
- for ax, img_path, title in zip(axes, image_paths, titles):
- ax.imshow(plt.imread(img_path))
- ax.set_title(title)
- ax.axis("off")
- plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
- plt.close()
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=808)
|