predict_0221.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # zjf的是将线段长度作为得分,这里是将最大置信度作为得分
  2. # from fastapi import FastAPI, File, UploadFile, HTTPException
  3. # from fastapi.responses import FileResponse
  4. # from fastapi.staticfiles import StaticFiles
  5. import os
  6. import torch
  7. import numpy as np
  8. from PIL import Image
  9. import skimage.io
  10. import skimage.color
  11. from torchvision import transforms
  12. import shutil
  13. import matplotlib.pyplot as plt
  14. from models.line_detect.line_net import linenet_resnet50_fpn
  15. from models.line_detect.postprocess import postprocess
  16. from rtree import index
  17. import time
  18. import multiprocessing as mp
  19. # from fastapi.middleware.cors import CORSMiddleware
  20. # 初始化 FastAPI
  21. # app = FastAPI()
  22. # 添加 CORS 中间件
  23. # app.add_middleware(
  24. # CORSMiddleware,
  25. # allow_origins=["*"], # 允许所有源
  26. # allow_credentials=True,
  27. # allow_methods=["*"],
  28. # allow_headers=["*"],
  29. # )
  30. # 设置多进程启动方式为 'spawn'
  31. mp.set_start_method('spawn', force=True)
  32. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  33. def load_model(model_path):
  34. """加载模型并返回模型实例"""
  35. model = linenet_resnet50_fpn().to(device)
  36. if os.path.exists(model_path):
  37. checkpoint = torch.load(model_path, map_location=device)
  38. model.load_state_dict(checkpoint['model_state_dict'])
  39. print(f"Loaded model from {model_path}")
  40. else:
  41. raise FileNotFoundError(f"No saved model found at {model_path}")
  42. model.eval()
  43. return model
  44. def preprocess_image(image_path):
  45. """预处理上传的图片"""
  46. img = Image.open(image_path).convert("RGB")
  47. transform = transforms.ToTensor()
  48. img_tensor = transform(img)
  49. resized_img = skimage.transform.resize(
  50. img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
  51. )
  52. return torch.tensor(resized_img).permute(2, 0, 1)
  53. def save_plot(output_path: str):
  54. """保存图像并关闭绘图"""
  55. plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
  56. print(f"Saved plot to {output_path}")
  57. plt.close()
  58. def get_colors():
  59. """返回一组预定义的颜色列表"""
  60. return [
  61. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  62. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  63. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  64. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  65. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  66. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  67. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  68. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  69. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  70. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  71. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  72. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  73. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  74. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  75. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  76. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  77. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  78. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  79. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  80. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  81. ]
  82. def process_box(box, lines, scores):
  83. """处理单个边界框,找到最佳匹配的线段"""
  84. valid_lines = [] # 存储有效的线段
  85. valid_scores = [] # 存储有效的分数
  86. # print(f'score:{len(scores)}')
  87. for i in box:
  88. best_line = None
  89. max_length = 0.0
  90. # 遍历所有线段
  91. for j in range(lines.shape[1]):
  92. line_j = lines[0, j].cpu().numpy() / 128 * 512
  93. # 检查线段是否完全在box内
  94. if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
  95. line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  96. line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
  97. line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
  98. # length = np.linalg.norm(line_j[0] - line_j[1])
  99. length = scores[j].item()
  100. # print(length)
  101. if length > max_length:
  102. best_line = line_j
  103. max_length = length
  104. if best_line is not None:
  105. valid_lines.append(best_line)
  106. valid_scores.append(max_length)
  107. else:
  108. valid_lines.append([[0.0,0.0],[0.0,0.0]])
  109. valid_scores.append(0.0) # 使用线段置信度作为分数
  110. # print(f'valid_lines:{valid_lines}')
  111. # print(f'valid_scores:{valid_scores}')
  112. return valid_lines, valid_scores
  113. def box_line_optimized_parallel(pred):
  114. """并行处理边界框和线段的匹配"""
  115. lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
  116. scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
  117. boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # 所有box
  118. num_processes = min(mp.cpu_count(), len(boxes)) # 使用可用的核心数
  119. with mp.Pool(processes=num_processes) as pool:
  120. results = pool.starmap(
  121. process_box,
  122. [(box, lines, scores) for box in boxes]
  123. )
  124. # 更新预测结果
  125. filtered_pred = []
  126. for idx_box, (valid_lines, valid_scores) in enumerate(results):
  127. if valid_lines:
  128. pred[idx_box]['line'] = torch.tensor(valid_lines)
  129. pred[idx_box]['line_score'] = torch.tensor(valid_scores)
  130. filtered_pred.append(pred[idx_box])
  131. return filtered_pred
  132. def predict(image_path):
  133. start_time = time.time()
  134. # 保存上传文件
  135. # os.makedirs("uploaded_images", exist_ok=True)
  136. # image_path = f"{file.filename}"
  137. # with open(image_path, "wb") as f:
  138. # shutil.copyfileobj(file.file, f)
  139. # 加载模型
  140. model_path = "/data/lm/resnet50_best_e280.pth"
  141. model = load_model(model_path)
  142. # 预处理图片
  143. img_tensor = preprocess_image(image_path)
  144. im = img_tensor.permute(1, 2, 0).cpu().numpy()
  145. # 模型推理
  146. with torch.no_grad():
  147. predictions = model([img_tensor.to(device)])
  148. # 匹配线段和边界框
  149. t_start = time.time()
  150. filtered_pred = box_line_optimized_parallel(predictions)
  151. t_end = time.time()
  152. print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
  153. # 绘制图像
  154. output_path_box = show_box(im, predictions, t_start)
  155. output_path_line = show_line(im, predictions, t_start)
  156. output_path_boxandline = show_predict(im, filtered_pred, t_start)
  157. # 合并图像
  158. combined_image_path = "combined_result.png"
  159. combine_images(
  160. [output_path_boxandline, output_path_box, output_path_line],
  161. titles=["Box and Line", "Box", "Line"],
  162. output_path=combined_image_path
  163. )
  164. end_time = time.time()
  165. print(f'Total time: {end_time - start_time:.2f} seconds')
  166. # # 返回合成的图片
  167. # return FileResponse(combined_image_path, media_type="image/png", filename="combined_result.png")
  168. # except Exception as e:
  169. # raise HTTPException(status_code=500, detail=str(e))
  170. def combine_images(image_paths: list, titles: list, output_path: str):
  171. """将多个图像合并为一张图片"""
  172. fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
  173. for ax, img_path, title in zip(axes, image_paths, titles):
  174. ax.imshow(plt.imread(img_path))
  175. ax.set_title(title)
  176. ax.axis("off")
  177. plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
  178. plt.close()
  179. def show_box(im, predictions, t_start):
  180. """绘制边界框并保存结果"""
  181. boxes = predictions[0]['boxes'].cpu().numpy()
  182. box_scores = predictions[0]['scores'].cpu().numpy()
  183. colors = get_colors()
  184. fig, ax = plt.subplots(figsize=(10, 10))
  185. ax.imshow(im)
  186. for idx, (box, score) in enumerate(zip(boxes, box_scores)):
  187. if score < 0.7:
  188. continue
  189. x0, y0, x1, y1 = box
  190. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
  191. t_end = time.time()
  192. print(f'show_box used: {t_end - t_start:.2f} seconds')
  193. output_path = "temp_result_box.png"
  194. save_plot(output_path)
  195. return output_path
  196. def show_line(im, predictions, t_start):
  197. """绘制线段并保存结果"""
  198. lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  199. line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
  200. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  201. nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  202. fig, ax = plt.subplots(figsize=(10, 10))
  203. ax.imshow(im)
  204. for (a, b), s in zip(nlines, nscores):
  205. if s < 0.9:
  206. continue
  207. ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
  208. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  209. t_end = time.time()
  210. print(f'show_line used: {t_end - t_start:.2f} seconds')
  211. output_path = "temp_result_line.png"
  212. save_plot(output_path)
  213. return output_path
  214. def show_predict(im, filtered_pred, t_start):
  215. """绘制匹配后的边界框和线段并保存结果"""
  216. colors = get_colors()
  217. fig, ax = plt.subplots(figsize=(10, 10))
  218. ax.imshow(im)
  219. for idx, pred in enumerate(filtered_pred):
  220. boxes = pred['boxes'].cpu().numpy()
  221. box_scores = pred['scores'].cpu().numpy()
  222. lines = pred['line'].cpu().numpy()
  223. line_scores = pred['line_score'].cpu().numpy()
  224. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  225. nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  226. for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
  227. if box_score < 0.7 or line_score < 0.9:
  228. continue
  229. # 如果线段为空(即没有找到有效线段),跳过绘制
  230. if line is None or len(line) == 0:
  231. continue
  232. x0, y0, x1, y1 = box
  233. a, b = line
  234. color = colors[(idx + box_idx) % len(colors)] # 每个边界框分配一个唯一颜色
  235. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
  236. ax.scatter(a[1], a[0], c=color, s=10)
  237. ax.scatter(b[1], b[0], c=color, s=10)
  238. ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
  239. t_end = time.time()
  240. print(f'show_predict used: {t_end - t_start:.2f} seconds')
  241. output_path = "temp_result.png"
  242. save_plot(output_path)
  243. return output_path
  244. if __name__ == "__main__":
  245. predict('/data/lm/wirenet_1000/images/train/00031643_1.png')