aaa.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # import multiprocessing
  2. #
  3. #
  4. # def worker(num):
  5. # """线程调用的函数"""
  6. # print(f"Worker: {num}")
  7. # return
  8. #
  9. #
  10. # if __name__ == '__main__':
  11. # jobs = []
  12. # for i in range(5):
  13. # p = multiprocessing.Process(target=worker, args=(i,))
  14. # jobs.append(p)
  15. # p.start()
  16. #
  17. # import sys
  18. # print(sys.version)
  19. from typing import List
  20. from fastapi import FastAPI, File, UploadFile, HTTPException
  21. from fastapi.responses import FileResponse
  22. import os
  23. import torch
  24. import numpy as np
  25. from PIL import Image
  26. import skimage.io
  27. import skimage.color
  28. from torchvision import transforms
  29. import shutil
  30. import matplotlib.pyplot as plt
  31. from models.line_detect.line_net import linenet_resnet50_fpn
  32. from models.line_detect.postprocess import postprocess
  33. import time
  34. import multiprocessing as mp
  35. from fastapi.middleware.cors import CORSMiddleware
  36. # 初始化 FastAPI
  37. app = FastAPI()
  38. # 添加 CORS 中间件
  39. app.add_middleware(
  40. CORSMiddleware,
  41. allow_origins=["*"], # 允许所有源
  42. allow_credentials=True,
  43. allow_methods=["*"],
  44. allow_headers=["*"],
  45. )
  46. # 设置多进程启动方式为 'spawn'
  47. mp.set_start_method('spawn', force=True)
  48. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  49. def load_model(model_path):
  50. """加载模型并返回模型实例"""
  51. model = linenet_resnet50_fpn().to(device)
  52. if os.path.exists(model_path):
  53. checkpoint = torch.load(model_path, map_location=device)
  54. model.load_state_dict(checkpoint['model_state_dict'])
  55. print(f"Loaded model from {model_path}")
  56. else:
  57. raise FileNotFoundError(f"No saved model found at {model_path}")
  58. model.eval()
  59. return model
  60. def preprocess_image(image_path):
  61. """预处理上传的图片"""
  62. img = Image.open(image_path).convert("RGB")
  63. transform = transforms.ToTensor()
  64. img_tensor = transform(img)
  65. resized_img = skimage.transform.resize(
  66. img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
  67. )
  68. return torch.tensor(resized_img).permute(2, 0, 1)
  69. def save_plot(output_path: str):
  70. """保存图像并关闭绘图"""
  71. plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
  72. print(f"Saved plot to {output_path}")
  73. plt.close()
  74. def get_colors():
  75. """返回一组预定义的颜色列表"""
  76. return [
  77. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  78. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  79. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  80. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  81. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  82. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  83. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  84. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  85. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  86. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  87. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  88. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  89. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  90. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  91. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  92. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  93. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  94. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  95. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  96. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  97. ]
  98. def show_line(im, lines, line_scores, diag, t_start):
  99. """绘制线段并保存结果"""
  100. nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  101. fig, ax = plt.subplots(figsize=(10, 10))
  102. ax.imshow(im)
  103. for (a, b), s in zip(nlines, nscores):
  104. if s < 0.9:
  105. continue
  106. ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
  107. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  108. t_end = time.time()
  109. print(f'show_line used: {t_end - t_start:.2f} seconds')
  110. output_path = "temp_result_line.png"
  111. save_plot(output_path)
  112. return output_path
  113. def show_box(im, boxes, box_scores, colors, t_start):
  114. """绘制边界框并保存结果"""
  115. fig, ax = plt.subplots(figsize=(10, 10))
  116. ax.imshow(im)
  117. for idx, (box, score) in enumerate(zip(boxes, box_scores)):
  118. if score < 0.7:
  119. continue
  120. x0, y0, x1, y1 = box
  121. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
  122. t_end = time.time()
  123. print(f'show_box used: {t_end - t_start:.2f} seconds')
  124. output_path = "temp_result_box.png"
  125. save_plot(output_path)
  126. return output_path
  127. def show_predict(im, boxes, box_scores, lines, line_scores, diag, colors, t_start):
  128. """绘制边界框和线段并保存结果"""
  129. nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  130. fig, ax = plt.subplots(figsize=(10, 10))
  131. ax.imshow(im)
  132. for box, line, box_score, line_score, color in zip(boxes, nlines, box_scores, nscores, colors):
  133. if box_score > 0.7 and line_score > 0.9:
  134. x0, y0, x1, y1 = box
  135. a, b = line
  136. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
  137. ax.scatter(a[1], a[0], c='#871F78', s=10)
  138. ax.scatter(b[1], b[0], c='#871F78', s=10)
  139. ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
  140. t_end = time.time()
  141. print(f'show_predict used: {t_end - t_start:.2f} seconds')
  142. output_path = "temp_result.png"
  143. save_plot(output_path)
  144. return output_path
  145. @app.get("/")
  146. def read_root():
  147. """返回前端页面"""
  148. return FileResponse("static/index.html")
  149. @app.post("/predict")
  150. @app.post("/predict/")
  151. async def predict(file: UploadFile = File(...)):
  152. try:
  153. start_time = time.time()
  154. # 保存上传文件
  155. os.makedirs("uploaded_images", exist_ok=True)
  156. image_path = f"uploaded_images/{file.filename}"
  157. with open(image_path, "wb") as f:
  158. shutil.copyfileobj(file.file, f)
  159. # 加载模型
  160. # model_path = "/home/dieu/PycharmProjects/MultiVisionModels/logs/pth/resnet50_best_e100.pth"
  161. model_path = r'D:\python\PycharmProjects\20250214\weight\merged_model_weights.pth'
  162. model = load_model(model_path)
  163. # 预处理图片
  164. img_tensor = preprocess_image(image_path)
  165. im = img_tensor.permute(1, 2, 0).cpu().numpy()
  166. # 模型推理
  167. with torch.no_grad():
  168. predictions = model([img_tensor.to(device)])
  169. # 提取预测结果
  170. boxes = predictions[0]['boxes'].cpu().numpy()
  171. box_scores = predictions[0]['scores'].cpu().numpy()
  172. lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  173. line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
  174. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  175. colors = get_colors()
  176. # 绘制图像
  177. t_start = time.time()
  178. output_path_box = show_box(im, boxes, box_scores, colors, t_start)
  179. output_path_line = show_line(im, lines, line_scores, diag, t_start)
  180. output_path_boxandline = show_predict(im, boxes, box_scores, lines, line_scores, diag, colors, t_start)
  181. # 合并图像
  182. combined_image_path = "combined_result.png"
  183. combine_images(
  184. [output_path_boxandline, output_path_box, output_path_line],
  185. titles=["Box and Line", "Box", "Line"],
  186. output_path=combined_image_path
  187. )
  188. end_time = time.time()
  189. print(f'Total time: {end_time - start_time:.2f} seconds')
  190. # 返回合成的图片
  191. return FileResponse(combined_image_path, media_type="image/png", filename="combined_result.png")
  192. except Exception as e:
  193. raise HTTPException(status_code=500, detail=str(e))
  194. def combine_images(image_paths: List[str], titles: List[str], output_path: str):
  195. """将多个图像合并为一张图片"""
  196. fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
  197. for ax, img_path, title in zip(axes, image_paths, titles):
  198. ax.imshow(plt.imread(img_path))
  199. ax.set_title(title)
  200. ax.axis("off")
  201. plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
  202. plt.close()
  203. if __name__ == "__main__":
  204. import uvicorn
  205. uvicorn.run(app, host="0.0.0.0", port=808)