boxline.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # from fastapi import FastAPI, File, UploadFile, HTTPException
  2. # from fastapi.responses import FileResponse
  3. # from fastapi.staticfiles import StaticFiles
  4. import os
  5. import torch
  6. import numpy as np
  7. from PIL import Image
  8. import skimage.io
  9. import skimage.color
  10. from torchvision import transforms
  11. import shutil
  12. import matplotlib.pyplot as plt
  13. from models.line_detect.line_net import linenet_resnet50_fpn
  14. from models.wirenet.postprocess import postprocess
  15. from rtree import index
  16. import time
  17. import multiprocessing as mp
  18. # from fastapi.middleware.cors import CORSMiddleware
  19. # 初始化 FastAPI
  20. # app = FastAPI()
  21. # 添加 CORS 中间件
  22. # app.add_middleware(
  23. # CORSMiddleware,
  24. # allow_origins=["*"], # 允许所有源
  25. # allow_credentials=True,
  26. # allow_methods=["*"],
  27. # allow_headers=["*"],
  28. # )
  29. # 设置多进程启动方式为 'spawn'
  30. mp.set_start_method('spawn', force=True)
  31. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  32. def load_model(model_path):
  33. """加载模型并返回模型实例"""
  34. model = linenet_resnet50_fpn().to(device)
  35. if os.path.exists(model_path):
  36. checkpoint = torch.load(model_path, map_location=device)
  37. model.load_state_dict(checkpoint['model_state_dict'])
  38. print(f"Loaded model from {model_path}")
  39. else:
  40. raise FileNotFoundError(f"No saved model found at {model_path}")
  41. model.eval()
  42. return model
  43. def preprocess_image(image_path):
  44. """预处理上传的图片"""
  45. img = Image.open(image_path).convert("RGB")
  46. transform = transforms.ToTensor()
  47. img_tensor = transform(img)
  48. resized_img = skimage.transform.resize(
  49. img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
  50. )
  51. return torch.tensor(resized_img).permute(2, 0, 1),img
  52. def save_plot(output_path: str):
  53. """保存图像并关闭绘图"""
  54. plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
  55. print(f"Saved plot to {output_path}")
  56. plt.close()
  57. def get_colors():
  58. """返回一组预定义的颜色列表"""
  59. return [
  60. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  61. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  62. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  63. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  64. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  65. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  66. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  67. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  68. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  69. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  70. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  71. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  72. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  73. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  74. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  75. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  76. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  77. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  78. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  79. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  80. ]
  81. def process_box(box, lines, scores):
  82. """处理单个边界框,找到最佳匹配的线段"""
  83. valid_lines = [] # 存储有效的线段
  84. valid_scores = [] # 存储有效的分数
  85. # print(f'score:{len(scores)}')
  86. for i in box:
  87. best_line = None
  88. max_length = 0.0
  89. # 遍历所有线段
  90. for j in range(lines.shape[1]):
  91. # line_j = lines[0, j].cpu().numpy() / 128 * 512
  92. line_j = lines[0, j].cpu().numpy()
  93. # 检查线段是否完全在box内
  94. if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
  95. line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  96. line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
  97. line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
  98. # 计算线段长度
  99. length = np.linalg.norm(line_j[0] - line_j[1])
  100. # length = scores[j].cpu().numpy()
  101. # print(length)
  102. if length > max_length:
  103. best_line = line_j
  104. max_length = length
  105. # 如果找到有效的线段,则添加到结果中
  106. if best_line is not None:
  107. valid_lines.append(best_line)
  108. valid_scores.append(max_length) # 使用线段长度作为分数
  109. else:
  110. valid_lines.append([[0.0,0.0],[0.0,0.0]])
  111. valid_scores.append(0.0) # 使用线段置信度作为分数
  112. # print(f'valid_lines:{valid_lines}')
  113. # print(f'valid_scores:{valid_scores}')
  114. return valid_lines, valid_scores
  115. def box_line_optimized_parallel(pred):
  116. """并行处理边界框和线段的匹配"""
  117. lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
  118. scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
  119. boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # 所有box
  120. num_processes = min(mp.cpu_count(), len(boxes)) # 使用可用的核心数
  121. with mp.Pool(processes=num_processes) as pool:
  122. results = pool.starmap(
  123. process_box,
  124. [(box, lines, scores) for box in boxes]
  125. )
  126. # 更新预测结果
  127. filtered_pred = []
  128. for idx_box, (valid_lines, valid_scores) in enumerate(results):
  129. if valid_lines:
  130. pred[idx_box]['line'] = torch.tensor(valid_lines)
  131. pred[idx_box]['line_score'] = torch.tensor(valid_scores)
  132. filtered_pred.append(pred[idx_box])
  133. return filtered_pred
  134. def predict(image_path):
  135. start_time = time.time()
  136. # 保存上传文件
  137. # os.makedirs("uploaded_images", exist_ok=True)
  138. # image_path = f"{file.filename}"
  139. # with open(image_path, "wb") as f:
  140. # shutil.copyfileobj(file.file, f)
  141. # 加载模型
  142. model_path = r'D:\python\PycharmProjects\20250214\weight\linenet_wts\r50fpn_wts_e350\best.pth'
  143. model = load_model(model_path)
  144. # 预处理图片
  145. img_tensor,img = preprocess_image(image_path)
  146. # W = img.shape[0]
  147. im = img_tensor.permute(1, 2, 0).cpu().numpy()
  148. # 模型推理
  149. with torch.no_grad():
  150. predictions = model([img_tensor.to(device)])
  151. lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  152. line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
  153. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  154. nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  155. predictions[-1]['wires']['lines_'] = torch.from_numpy(nlines).float().cuda()
  156. predictions[-1]['wires']['score_'] = torch.from_numpy(nscores).float().cuda()
  157. print(predictions)
  158. # 匹配线段和边界框
  159. t_start = time.time()
  160. filtered_pred = box_line_optimized_parallel(predictions)
  161. t_end = time.time()
  162. print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
  163. # 绘制图像
  164. output_path_box = show_box(im, predictions, t_start)
  165. output_path_line = show_line(im, predictions, t_start)
  166. output_path_boxandline = show_predict(im, filtered_pred, t_start)
  167. # 合并图像
  168. combined_image_path = "combined_result.png"
  169. combine_images(
  170. [output_path_boxandline, output_path_box, output_path_line],
  171. titles=["Box and Line", "Box", "Line"],
  172. output_path=combined_image_path
  173. )
  174. end_time = time.time()
  175. print(f'Total time: {end_time - start_time:.2f} seconds')
  176. # 获取线段数据并添加详细的调试信息
  177. lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * np.array([1328, 2112])
  178. print(f"Initial lines shape: {lines.shape}")
  179. print(f"Initial lines data type: {lines.dtype}")
  180. # 确保数据是正确的形状
  181. if len(lines.shape) != 3: # 如果不是 (N, 2, 2) 形状
  182. if len(lines.shape) == 2 and lines.shape[1] == 4:
  183. # 如果是 (N, 4) 形状,重塑为 (N, 2, 2)
  184. lines = lines.reshape(-1, 2, 2)
  185. else:
  186. print(f"Warning: Unexpected lines shape: {lines.shape}")
  187. print(f"After reshape - lines shape: {lines.shape}")
  188. # 确保每个点是 [x, y] 格式
  189. formatted_lines = []
  190. for line in lines:
  191. start_point = np.array([line[0][0], line[0][1]])
  192. end_point = np.array([line[1][0], line[1][1]])
  193. formatted_lines.append([start_point, end_point])
  194. formatted_lines = np.array(formatted_lines)
  195. print(f"Final formatted_lines shape: {formatted_lines.shape}")
  196. print(f"Sample formatted line: {formatted_lines[0] if len(formatted_lines) > 0 else 'No lines'}")
  197. # 确保返回的是三维数组:[lines_array]
  198. result = [formatted_lines]
  199. print(f"Final result type: {type(result)}")
  200. print(f"Final result[0] shape: {result[0].shape}")
  201. return result
  202. def combine_images(image_paths: list, titles: list, output_path: str):
  203. """将多个图像合并为一张图片"""
  204. fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
  205. for ax, img_path, title in zip(axes, image_paths, titles):
  206. ax.imshow(plt.imread(img_path))
  207. ax.set_title(title)
  208. ax.axis("off")
  209. plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
  210. plt.close()
  211. def show_box(im, predictions, t_start):
  212. """绘制边界框并保存结果"""
  213. boxes = predictions[0]['boxes'].cpu().numpy()
  214. box_scores = predictions[0]['scores'].cpu().numpy()
  215. colors = get_colors()
  216. fig, ax = plt.subplots(figsize=(10, 10))
  217. ax.imshow(im)
  218. for idx, (box, score) in enumerate(zip(boxes, box_scores)):
  219. if score < 0.7:
  220. continue
  221. x0, y0, x1, y1 = box
  222. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
  223. t_end = time.time()
  224. print(f'show_box used: {t_end - t_start:.2f} seconds')
  225. output_path = "temp_result_box.png"
  226. save_plot(output_path)
  227. return output_path
  228. def show_line(im, predictions, t_start):
  229. """绘制线段并保存结果"""
  230. lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  231. line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
  232. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  233. nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  234. fig, ax = plt.subplots(figsize=(10, 10))
  235. ax.imshow(im)
  236. for (a, b), s in zip(nlines, nscores):
  237. if s < 0.9:
  238. continue
  239. ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
  240. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  241. t_end = time.time()
  242. print(f'show_line used: {t_end - t_start:.2f} seconds')
  243. output_path = "temp_result_line.png"
  244. save_plot(output_path)
  245. return output_path
  246. def show_predict(im, filtered_pred, t_start):
  247. """绘制匹配后的边界框和线段并保存结果"""
  248. colors = get_colors()
  249. fig, ax = plt.subplots(figsize=(10, 10))
  250. ax.imshow(im)
  251. for idx, pred in enumerate(filtered_pred):
  252. boxes = pred['boxes'].cpu().numpy()
  253. box_scores = pred['scores'].cpu().numpy()
  254. lines = pred['line'].cpu().numpy()
  255. line_scores = pred['line_score'].cpu().numpy()
  256. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  257. nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  258. for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
  259. if box_score < 0.7 or line_score < 0.9:
  260. continue
  261. # 如果线段为空(即没有找到有效线段),跳过绘制
  262. if line is None or len(line) == 0:
  263. continue
  264. x0, y0, x1, y1 = box
  265. a, b = line
  266. color = colors[(idx + box_idx) % len(colors)] # 每个边界框分配一个唯一颜色
  267. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
  268. ax.scatter(a[1], a[0], c=color, s=10)
  269. ax.scatter(b[1], b[0], c=color, s=10)
  270. ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
  271. t_end = time.time()
  272. print(f'show_predict used: {t_end - t_start:.2f} seconds')
  273. output_path = "temp_result.png"
  274. save_plot(output_path)
  275. return output_path
  276. if __name__ == "__main__":
  277. lines = predict(r'C:\Users\m2337\Desktop\p\9.jpg')
  278. print(f'lines:{lines}')