lm_0223.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import os
  2. import torch
  3. import numpy as np
  4. from PIL import Image
  5. import skimage.io
  6. import skimage.color
  7. from torchvision import transforms
  8. import shutil
  9. import matplotlib.pyplot as plt
  10. from models.line_detect.line_net import linenet_resnet50_fpn
  11. from models.line_detect.postprocess import postprocess
  12. from rtree import index
  13. import time
  14. import multiprocessing as mp
  15. mp.set_start_method('spawn', force=True)
  16. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  17. def load_model(model_path):
  18. model = linenet_resnet50_fpn().to(device)
  19. if os.path.exists(model_path):
  20. checkpoint = torch.load(model_path, map_location=device)
  21. model.load_state_dict(checkpoint['model_state_dict'])
  22. print(f"Loaded model from {model_path}")
  23. else:
  24. raise FileNotFoundError(f"No saved model found at {model_path}")
  25. model.eval()
  26. return model
  27. def preprocess_image(image_path):
  28. img = Image.open(image_path).convert("RGB")
  29. transform = transforms.ToTensor()
  30. img_tensor = transform(img)
  31. resized_img = skimage.transform.resize(
  32. img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
  33. )
  34. return torch.tensor(resized_img).permute(2, 0, 1)
  35. def save_plot(output_path: str):
  36. plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
  37. print(f"Saved plot to {output_path}")
  38. plt.close()
  39. def get_colors():
  40. """返回一组预定义的颜色列表"""
  41. return [
  42. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  43. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  44. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  45. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  46. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  47. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  48. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  49. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  50. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  51. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  52. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  53. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  54. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  55. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  56. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  57. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  58. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  59. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  60. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  61. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  62. ]
  63. def process_box(box, lines, scores):
  64. """处理单个边界框,找到最佳匹配的线段"""
  65. valid_lines = [] # 存储有效的线段
  66. valid_scores = [] # 存储有效的分数
  67. # print(f'score:{len(scores)}')
  68. for i in box:
  69. best_line = None
  70. max_length = 0.0
  71. # 遍历所有线段
  72. for j in range(lines.shape[1]):
  73. line_j = lines[0, j].cpu().numpy() / 128 * 512
  74. # 检查线段是否完全在box内
  75. if (all(line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  76. line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and line_j[0][0] <= i[3] and line_j[1][0] <= i[3])):
  77. # length = np.linalg.norm(line_j[0] - line_j[1])
  78. length = scores[j].item()
  79. # print(length)
  80. if length > max_length:
  81. best_line = line_j
  82. max_length = length
  83. if best_line is not None:
  84. valid_lines.append(best_line)
  85. valid_scores.append(max_length)
  86. else:
  87. valid_lines.append([[0.0, 0.0], [0.0, 0.0]])
  88. valid_scores.append(0.0)
  89. return valid_lines, valid_scores
  90. def box_line_optimized_parallel(pred):
  91. """并行处理边界框和线段的匹配"""
  92. lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
  93. scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
  94. boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # 所有box
  95. # num_processes = min(mp.cpu_count(), len(boxes)) # 使用可用的核心数
  96. # with mp.Pool(processes=num_processes) as pool:
  97. # results = pool.starmap(
  98. # process_box,
  99. # [(box, lines, scores) for box in boxes]
  100. # )
  101. results = process_box(boxes, lines, scores)
  102. # 更新预测结果
  103. filtered_pred = []
  104. for idx_box, (valid_lines, valid_scores) in enumerate(results):
  105. if valid_lines:
  106. pred[idx_box]['line'] = torch.tensor(valid_lines)
  107. pred[idx_box]['line_score'] = torch.tensor(valid_scores)
  108. filtered_pred.append(pred[idx_box])
  109. return filtered_pred
  110. def predict(image_path):
  111. start_time = time.time()
  112. model_path = r'D:\python\PycharmProjects\20250214\weight\linenet_wts\r50fpn_wts_e350\best.pth'
  113. model = load_model(model_path)
  114. img_tensor = preprocess_image(image_path)
  115. im = img_tensor.permute(1, 2, 0).cpu().numpy()
  116. with torch.no_grad():
  117. predictions = model([img_tensor.to(device)])
  118. t_start = time.time()
  119. filtered_pred = box_line_optimized_parallel(predictions)
  120. t_end = time.time()
  121. print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
  122. output_path_box = show_box(im, predictions, t_start)
  123. output_path_line = show_line(im, predictions, t_start)
  124. show_predict(im, filtered_pred, t_start)
  125. # combined_image_path = "combined_result.png"
  126. # combine_images(
  127. # [output_path_boxandline, output_path_box, output_path_line],
  128. # titles=["Box and Line", "Box", "Line"],
  129. # output_path=combined_image_path
  130. # )
  131. end_time = time.time()
  132. print(f'Total time: {end_time - start_time:.2f} seconds')
  133. def combine_images(image_paths: list, titles: list, output_path: str):
  134. """将多个图像合并为一张图片"""
  135. fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
  136. for ax, img_path, title in zip(axes, image_paths, titles):
  137. ax.imshow(plt.imread(img_path))
  138. ax.set_title(title)
  139. ax.axis("off")
  140. plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
  141. plt.close()
  142. def show_box(im, predictions, t_start):
  143. """绘制边界框并保存结果"""
  144. boxes = predictions[0]['boxes'].cpu().numpy()
  145. box_scores = predictions[0]['scores'].cpu().numpy()
  146. colors = get_colors()
  147. fig, ax = plt.subplots(figsize=(10, 10))
  148. ax.imshow(im)
  149. for idx, (box, score) in enumerate(zip(boxes, box_scores)):
  150. if score < 0.7:
  151. continue
  152. x0, y0, x1, y1 = box
  153. ax.add_patch(
  154. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
  155. t_end = time.time()
  156. print(f'show_box used: {t_end - t_start:.2f} seconds')
  157. plt.show()
  158. output_path = "temp_result_box.png"
  159. save_plot(output_path)
  160. return output_path
  161. def show_line(im, predictions, t_start):
  162. """绘制线段并保存结果"""
  163. lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  164. line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
  165. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  166. nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  167. fig, ax = plt.subplots(figsize=(10, 10))
  168. ax.imshow(im)
  169. for (a, b), s in zip(nlines, nscores):
  170. if s < 0.9:
  171. continue
  172. ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
  173. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  174. t_end = time.time()
  175. print(f'show_line used: {t_end - t_start:.2f} seconds')
  176. plt.show()
  177. output_path = "temp_result_line.png"
  178. save_plot(output_path)
  179. return output_path
  180. # def show_predict(im, filtered_pred, t_start):
  181. # colors = get_colors()
  182. # fig, ax = plt.subplots(figsize=(10, 10))
  183. # ax.imshow(im)
  184. # for idx, pred in enumerate(filtered_pred):
  185. # boxes = pred['boxes'].cpu().numpy()
  186. # box_scores = pred['scores'].cpu().numpy()
  187. # lines = pred['line'].cpu().numpy()
  188. # line_scores = pred['line_score'].cpu().numpy()
  189. # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  190. # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  191. # for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
  192. # if box_score < 0.7 or line_score < 0.9:
  193. # continue
  194. #
  195. # if line is None or len(line) == 0:
  196. # continue
  197. #
  198. # x0, y0, x1, y1 = box
  199. # a, b = line
  200. # color = colors[(idx + box_idx) % len(colors)] # 每个边界框分配一个唯一颜色
  201. # ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
  202. # ax.scatter(a[1], a[0], c=color, s=10)
  203. # ax.scatter(b[1], b[0], c=color, s=10)
  204. # ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
  205. # t_end = time.time()
  206. # print(f'show_predict used: {t_end - t_start:.2f} seconds')
  207. # plt.show()
  208. # output_path = "temp_result.png"
  209. # save_plot(output_path)
  210. # return output_path
  211. def show_predict(imgs, pred, t_start):
  212. col = [
  213. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  214. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  215. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  216. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  217. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  218. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  219. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  220. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  221. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  222. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  223. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  224. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  225. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  226. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  227. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  228. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  229. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  230. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  231. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  232. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  233. ]
  234. print(imgs.shape)
  235. # im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3]
  236. boxes = pred[0]['boxes'].cpu().numpy()
  237. box_scores = pred[0]['scores'].cpu().numpy()
  238. lines = pred[0]['line'].cpu().numpy()
  239. line_scores = pred[0]['line_score'].cpu().numpy()
  240. # 可视化预测结
  241. fig, ax = plt.subplots(figsize=(10, 10))
  242. ax.imshow(np.array(imgs))
  243. idx = 0
  244. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  245. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  246. x0, y0, x1, y1 = box
  247. # 框中无线的跳过
  248. if np.array_equal(line, tmp):
  249. continue
  250. a, b = line
  251. if box_score >= 0.7 or line_score >= 0.9:
  252. ax.add_patch(
  253. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  254. ax.scatter(a[1], a[0], c='#871F78', s=10)
  255. ax.scatter(b[1], b[0], c='#871F78', s=10)
  256. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  257. idx = idx + 1
  258. t_end = time.time()
  259. print(f'predict used:{t_end - t_start}')
  260. plt.show()
  261. if __name__ == "__main__":
  262. predict(r'C:\Users\m2337\Desktop\p\22.png')