predict_0226.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. import time
  2. import skimage
  3. # from models.line_detect.postprocess import show_predict, show_box, show_box_or_line, show_box_and_line, \
  4. # show_line_optimized, show_line, show_all
  5. import os
  6. import torch
  7. from PIL import Image
  8. import matplotlib.pyplot as plt
  9. import matplotlib as mpl
  10. import numpy as np
  11. from models.line_detect.line_net import linenet_resnet50_fpn
  12. from torchvision import transforms
  13. # from models.wirenet.postprocess import postprocess
  14. from models.wirenet.postprocess import postprocess
  15. from rtree import index
  16. from datetime import datetime
  17. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  18. def load_best_model(model, save_path, device):
  19. if os.path.exists(save_path):
  20. checkpoint = torch.load(save_path, map_location=device)
  21. model.load_state_dict(checkpoint['model_state_dict'])
  22. # if optimizer is not None:
  23. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  24. epoch = checkpoint['epoch']
  25. loss = checkpoint['loss']
  26. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  27. else:
  28. print(f"No saved model found at {save_path}")
  29. return model
  30. def box_line_(imgs, pred, length=False): # 默认置信度
  31. im = imgs.permute(1, 2, 0).cpu().numpy()
  32. line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  33. line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  34. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  35. line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  36. for idx, box_ in enumerate(pred[0:-1]):
  37. box = box_['boxes'] # 是一个tensor
  38. # line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
  39. # score = pred[-1]['wires']['score'][idx]
  40. # diag = (512 ** 2 + 512 ** 2) ** 0.5
  41. # line, score = postprocess(line, score, diag * 0.01, 0, False)
  42. line_ = []
  43. score_ = []
  44. for i in box:
  45. score_max = 0.0
  46. tmp = [[0.0, 0.0], [0.0, 0.0]]
  47. for j in range(len(line)):
  48. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  49. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  50. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  51. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  52. # # 计算线段长度
  53. # length = np.linalg.norm(line[j][0] - line[j][1])
  54. # if length > score_max:
  55. # tmp = line[j]
  56. # score_max = score[j]
  57. if score[j] > score_max:
  58. tmp = line[j]
  59. score_max = score[j]
  60. line_.append(tmp)
  61. score_.append(score_max)
  62. processed_list = torch.tensor(line_)
  63. pred[idx]['line'] = processed_list
  64. processed_s_list = torch.tensor(score_)
  65. pred[idx]['line_score'] = processed_s_list
  66. return pred
  67. def box_line_optimized(pred):
  68. # 创建R-tree索引
  69. idx = index.Index()
  70. # 将所有线段添加到R-tree中
  71. lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
  72. scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
  73. # 提取并处理所有线段
  74. for idx_line in range(lines.shape[1]): # 遍历2500条线段
  75. line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512 # 转换为numpy数组并调整比例
  76. x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
  77. y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
  78. x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
  79. y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
  80. idx.insert(idx_line, (max(0, x_min - 256), max(0, y_min - 256), min(512, x_max + 256), min(512, y_max + 256)))
  81. for idx_box, box_ in enumerate(pred[0:-1]):
  82. box = box_['boxes'].cpu().numpy() # 确保将张量转换为numpy数组
  83. line_ = []
  84. score_ = []
  85. for i in box:
  86. score_max = 0.0
  87. tmp = [[0.0, 0.0], [0.0, 0.0]]
  88. # 获取与当前box可能相交的所有线段
  89. possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
  90. for j in possible_matches:
  91. line_j = lines[0, j].cpu().numpy() / 128 * 512
  92. if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and # 注意这里交换了x和y
  93. line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  94. line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
  95. line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
  96. if scores[j] > score_max:
  97. tmp = line_j
  98. score_max = scores[j]
  99. line_.append(tmp)
  100. score_.append(score_max)
  101. processed_list = torch.tensor(line_)
  102. pred[idx_box]['line'] = processed_list
  103. processed_s_list = torch.tensor(score_)
  104. pred[idx_box]['line_score'] = processed_s_list
  105. return pred
  106. def set_thresholds(threshold):
  107. if isinstance(threshold, list):
  108. if len(threshold) != 2:
  109. raise ValueError("Threshold list must contain exactly two elements.")
  110. a, b = threshold
  111. elif isinstance(threshold, (int, float)):
  112. a = b = threshold
  113. else:
  114. raise TypeError("Threshold must be either a list of two numbers or a single number.")
  115. return a, b
  116. def color():
  117. return [
  118. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  119. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  120. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  121. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  122. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  123. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  124. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  125. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  126. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  127. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  128. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  129. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  130. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  131. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  132. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  133. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  134. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  135. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  136. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  137. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  138. ]
  139. def show_all(imgs, pred, threshold, save_path, show):
  140. col = color()
  141. box_th, line_th = set_thresholds(threshold)
  142. im = imgs.permute(1, 2, 0)
  143. boxes = pred[0]['boxes'].cpu().numpy()
  144. box_scores = pred[0]['scores'].cpu().numpy()
  145. line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  146. line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
  147. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  148. line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
  149. fig, axs = plt.subplots(1, 3, figsize=(10, 10))
  150. axs[0].imshow(np.array(im))
  151. for idx, box in enumerate(boxes):
  152. if box_scores[idx] < box_th:
  153. continue
  154. x0, y0, x1, y1 = box
  155. axs[0].add_patch(
  156. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  157. axs[0].set_title('Boxes')
  158. axs[1].imshow(np.array(im))
  159. for idx, (a, b) in enumerate(line):
  160. if line_score[idx] < line_th:
  161. continue
  162. axs[1].scatter(a[1], a[0], c='#871F78', s=2)
  163. axs[1].scatter(b[1], b[0], c='#871F78', s=2)
  164. axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  165. axs[1].set_title('Lines')
  166. axs[2].imshow(np.array(im))
  167. lines = pred[0]['line'].cpu().numpy()
  168. line_scores = pred[0]['line_score'].cpu().numpy()
  169. idx = 0
  170. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  171. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  172. x0, y0, x1, y1 = box
  173. # 框中无线的跳过
  174. if np.array_equal(line, tmp):
  175. continue
  176. a, b = line
  177. if box_score >= 0.7 or line_score >= 0.9:
  178. axs[2].add_patch(
  179. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  180. axs[2].scatter(a[1], a[0], c='#871F78', s=10)
  181. axs[2].scatter(b[1], b[0], c='#871F78', s=10)
  182. axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  183. idx = idx + 1
  184. axs[2].set_title('Boxes and Lines')
  185. if save_path:
  186. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png')
  187. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  188. plt.savefig(save_path)
  189. print(f"Saved result image to {save_path}")
  190. if show:
  191. # 调整子图之间的距离,防止标题和标签重叠
  192. plt.tight_layout()
  193. plt.show()
  194. def show_box_or_line(imgs, pred, threshold, save_path = None, show_line=False, show_box=False):
  195. col = color()
  196. box_th, line_th = set_thresholds(threshold)
  197. im = imgs.permute(1, 2, 0)
  198. boxes = pred[0]['boxes'].cpu().numpy()
  199. box_scores = pred[0]['scores'].cpu().numpy()
  200. line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  201. line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
  202. # 可视化预测结
  203. fig, ax = plt.subplots(figsize=(10, 10))
  204. ax.imshow(np.array(im))
  205. if show_box:
  206. for idx, box in enumerate(boxes):
  207. if box_scores[idx] < box_th:
  208. continue
  209. x0, y0, x1, y1 = box
  210. ax.add_patch(
  211. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  212. if save_path:
  213. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png')
  214. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  215. plt.savefig(save_path)
  216. print(f"Saved result image to {save_path}")
  217. if show_line:
  218. for idx, (a, b) in enumerate(line):
  219. if line_score[idx] < line_th:
  220. continue
  221. ax.scatter(a[1], a[0], c='#871F78', s=2)
  222. ax.scatter(b[1], b[0], c='#871F78', s=2)
  223. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  224. if save_path:
  225. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png')
  226. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  227. plt.savefig(save_path)
  228. print(f"Saved result image to {save_path}")
  229. plt.show()
  230. def show_predict(imgs, pred, threshold, t_start):
  231. col = color()
  232. box_th, line_th = set_thresholds(threshold)
  233. im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3]
  234. boxes = pred[0]['boxes'].cpu().numpy()
  235. box_scores = pred[0]['scores'].cpu().numpy()
  236. lines = pred[0]['line'].cpu().numpy()
  237. line_scores = pred[0]['line_score'].cpu().numpy()
  238. # 可视化预测结
  239. fig, ax = plt.subplots(figsize=(10, 10))
  240. ax.imshow(np.array(im))
  241. idx = 0
  242. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  243. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  244. x0, y0, x1, y1 = box
  245. # 框中无线的跳过
  246. if np.array_equal(line, tmp):
  247. continue
  248. a, b = line
  249. if box_score >= box_th or line_score >= line_th:
  250. ax.add_patch(
  251. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  252. ax.scatter(a[1], a[0], c='#871F78', s=10)
  253. ax.scatter(b[1], b[0], c='#871F78', s=10)
  254. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  255. idx = idx + 1
  256. t_end = time.time()
  257. print(f'predict used:{t_end - t_start}')
  258. plt.show()
  259. def predict(pt_path, model, img, type=0, threshold=0.5, save_path=None, show=False):
  260. model = load_best_model(model, pt_path, device)
  261. model.eval()
  262. if isinstance(img, str):
  263. img = Image.open(img).convert("RGB")
  264. transform = transforms.ToTensor()
  265. img_tensor = transform(img) # [3, 512, 512]
  266. # 将图像调整为512x512大小
  267. t_start = time.time()
  268. im = img_tensor.permute(1, 2, 0) # [512, 512, 3]
  269. im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  270. img_ = torch.tensor(im_resized).permute(2, 0, 1)
  271. t_end = time.time()
  272. print(f'switch img used:{t_end - t_start}')
  273. with torch.no_grad():
  274. predictions = model([img_.to(device)])
  275. # print(predictions)
  276. t_start = time.time()
  277. pred = box_line_(img_, predictions) # 线框匹配
  278. t_end = time.time()
  279. print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
  280. if type == 0:
  281. show_all(img_, pred, threshold, save_path=True, show=True)
  282. elif type == 1:
  283. show_box_or_line(img_, predictions, threshold, save_path=True, show_line=True) # 参数确定画什么
  284. elif type == 2:
  285. show_box_or_line(img_, predictions, threshold, save_path=True, show_box=True) # 参数确定画什么
  286. elif type == 3:
  287. show_predict(img_, pred, threshold, t_start)
  288. if __name__ == '__main__':
  289. t_start = time.time()
  290. print(f'start to predict:{t_start}')
  291. model = linenet_resnet50_fpn().to(device)
  292. pt_path = r'D:\python\PycharmProjects\20250214\weight\best.pth'
  293. img_path = r'C:\Users\m2337\Desktop\p\20250226142919.png'
  294. # predict(pt_path, model, img_path)
  295. predict(pt_path, model, img_path, type=2, threshold=0.5, save_path=None, show=False)
  296. t_end = time.time()
  297. print(f'predict used:{t_end - t_start}')