predict.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. # import time
  2. # import torch
  3. # from PIL import Image
  4. # from torchvision import transforms
  5. # from skimage.transform import resize
  6. import time
  7. import cv2
  8. import skimage
  9. import os
  10. import torch
  11. from PIL import Image
  12. import matplotlib.pyplot as plt
  13. import matplotlib as mpl
  14. import numpy as np
  15. # from models.line_detect.line_net import linenet_resnet50_fpn
  16. from torchvision import transforms
  17. from models.wirenet.postprocess import postprocess
  18. from rtree import index
  19. from datetime import datetime
  20. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  21. def box_line_(imgs, pred): # 默认置信度
  22. im = imgs.permute(1, 2, 0).cpu().numpy()
  23. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  24. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  25. # print(f'111:{len(lines)}')
  26. for i in range(1, len(lines)):
  27. if (lines[i] == lines[0]).all():
  28. lines = lines[:i]
  29. scores = scores[:i]
  30. break
  31. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  32. line, score = postprocess(lines, scores, diag * 0.01, 0, False)
  33. # print(f'333:{len(lines)}')
  34. for idx, box_ in enumerate(pred[0:-2]):
  35. box = box_['boxes'] # 是一个tensor
  36. line_ = []
  37. score_ = []
  38. for i in box:
  39. score_max = 0.0
  40. tmp = [[0.0, 0.0], [0.0, 0.0]]
  41. for j in range(len(line)):
  42. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  43. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  44. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  45. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  46. if score[j] > score_max:
  47. tmp = line[j]
  48. score_max = score[j]
  49. line_.append(tmp)
  50. score_.append(score_max)
  51. processed_list = torch.tensor(np.array(line_))
  52. pred[idx]['line'] = processed_list
  53. processed_s_list = torch.tensor(score_)
  54. pred[idx]['line_score'] = processed_s_list
  55. return pred
  56. def set_thresholds(threshold):
  57. if isinstance(threshold, list):
  58. if len(threshold) != 2:
  59. raise ValueError("Threshold list must contain exactly two elements.")
  60. a, b = threshold
  61. elif isinstance(threshold, (int, float)):
  62. a = b = threshold
  63. else:
  64. raise TypeError("Threshold must be either a list of two numbers or a single number.")
  65. return a, b
  66. def color():
  67. return [
  68. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  69. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  70. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  71. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  72. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  73. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  74. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  75. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  76. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  77. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  78. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  79. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  80. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  81. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  82. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  83. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  84. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  85. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  86. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  87. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  88. ]
  89. def show_all(imgs, pred, threshold, save_path):
  90. col = color()
  91. box_th, line_th = set_thresholds(threshold)
  92. im = imgs.permute(1, 2, 0)
  93. boxes = pred[0]['boxes'].cpu().numpy()
  94. box_scores = pred[0]['scores'].cpu().numpy()
  95. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  96. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  97. for i in range(1, len(lines)):
  98. if (lines[i] == lines[0]).all():
  99. lines = lines[:i]
  100. scores = scores[:i]
  101. break
  102. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  103. line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
  104. fig, axs = plt.subplots(1, 3, figsize=(10, 10))
  105. axs[0].imshow(np.array(im))
  106. for idx, box in enumerate(boxes):
  107. if box_scores[idx] < box_th:
  108. continue
  109. x0, y0, x1, y1 = box
  110. axs[0].add_patch(
  111. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  112. axs[0].set_title('Boxes')
  113. axs[1].imshow(np.array(im))
  114. for idx, (a, b) in enumerate(line):
  115. if line_score[idx] < line_th:
  116. continue
  117. axs[1].scatter(a[1], a[0], c='#871F78', s=2)
  118. axs[1].scatter(b[1], b[0], c='#871F78', s=2)
  119. axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  120. axs[1].set_title('Lines')
  121. axs[2].imshow(np.array(im))
  122. lines = pred[0]['line'].cpu().numpy()
  123. line_scores = pred[0]['line_score'].cpu().numpy()
  124. idx = 0
  125. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  126. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  127. x0, y0, x1, y1 = box
  128. # 框中无线的跳过
  129. if np.array_equal(line, tmp):
  130. continue
  131. a, b = line
  132. if box_score >= 0.7 or line_score >= 0.9:
  133. axs[2].add_patch(
  134. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  135. axs[2].scatter(a[1], a[0], c='#871F78', s=10)
  136. axs[2].scatter(b[1], b[0], c='#871F78', s=10)
  137. axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  138. idx = idx + 1
  139. axs[2].set_title('Boxes and Lines')
  140. if save_path:
  141. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png')
  142. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  143. plt.savefig(save_path)
  144. print(f"Saved result image to {save_path}")
  145. # if show:
  146. # 调整子图之间的距离,防止标题和标签重叠
  147. plt.tight_layout()
  148. plt.show()
  149. def show_box_or_line(imgs, pred, threshold, save_path=None, show_line=False, show_box=False):
  150. col = color()
  151. box_th, line_th = set_thresholds(threshold)
  152. im = imgs.permute(1, 2, 0)
  153. # 可视化预测结
  154. fig, ax = plt.subplots(figsize=(10, 10))
  155. ax.imshow(np.array(im))
  156. if show_box:
  157. boxes = pred[0]['boxes'].cpu().numpy()
  158. box_scores = pred[0]['scores'].cpu().numpy()
  159. for idx, box in enumerate(boxes):
  160. if box_scores[idx] < box_th:
  161. continue
  162. x0, y0, x1, y1 = box
  163. ax.add_patch(
  164. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  165. if save_path:
  166. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png')
  167. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  168. plt.savefig(save_path)
  169. print(f"Saved result image to {save_path}")
  170. if show_line:
  171. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  172. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  173. for i in range(1, len(lines)):
  174. if (lines[i] == lines[0]).all():
  175. lines = lines[:i]
  176. scores = scores[:i]
  177. break
  178. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  179. line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
  180. for idx, (a, b) in enumerate(line):
  181. if line_score[idx] < line_th:
  182. continue
  183. ax.scatter(a[1], a[0], c='#871F78', s=2)
  184. ax.scatter(b[1], b[0], c='#871F78', s=2)
  185. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  186. if save_path:
  187. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png')
  188. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  189. plt.savefig(save_path)
  190. print(f"Saved result image to {save_path}")
  191. plt.show()
  192. def show_predict(imgs, pred, threshold, t_start):
  193. col = color()
  194. box_th, line_th = set_thresholds(threshold)
  195. im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3]
  196. boxes = pred[0]['boxes'].cpu().numpy()
  197. box_scores = pred[0]['scores'].cpu().numpy()
  198. lines = pred[0]['line'].cpu().numpy()
  199. scores = pred[0]['line_score'].cpu().numpy()
  200. for i in range(1, len(lines)):
  201. if (lines[i] == lines[0]).all():
  202. lines = lines[:i]
  203. scores = scores[:i]
  204. break
  205. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  206. line1, line_score1 = postprocess(lines, scores, diag * 0.01, 0, False)
  207. # 可视化预测结
  208. fig, ax = plt.subplots(figsize=(10, 10))
  209. ax.imshow(np.array(im))
  210. idx = 0
  211. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  212. for box, line, box_score, line_score in zip(boxes, line1, box_scores, line_score1):
  213. x0, y0, x1, y1 = box
  214. # 框中无线的跳过
  215. if np.array_equal(line, tmp):
  216. continue
  217. a, b = line
  218. if box_score >= box_th or line_score >= line_th:
  219. ax.add_patch(
  220. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  221. ax.scatter(a[1], a[0], c='#871F78', s=10)
  222. ax.scatter(b[1], b[0], c='#871F78', s=10)
  223. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  224. idx = idx + 1
  225. t_end = time.time()
  226. print(f'predict used:{t_end - t_start}')
  227. plt.show()
  228. class Predict:
  229. def __init__(self, pt_path, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
  230. """
  231. 初始化预测器。
  232. 参数:
  233. pt_path: 模型权重文件路径。
  234. model: 模型定义(未加载权重)。
  235. img: 输入图像(路径或 PIL 图像对象)。
  236. type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。
  237. threshold: 阈值,用于过滤预测结果。
  238. save_path: 保存结果的路径(可选)。
  239. show: 是否显示结果。
  240. device: 运行设备(默认 'cuda')。
  241. """
  242. self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
  243. self.model = model
  244. self.pt_path = pt_path
  245. self.img = self.load_image(img)
  246. self.type = type
  247. self.threshold = threshold
  248. self.save_path = save_path
  249. self.show_line = show_line
  250. self.show_box = show_box
  251. def load_best_model(self, model, save_path, device):
  252. if os.path.exists(save_path):
  253. checkpoint = torch.load(save_path, map_location=device)
  254. model.load_state_dict(checkpoint['model_state_dict'])
  255. # if optimizer is not None:
  256. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  257. # epoch = checkpoint['epoch']
  258. # loss = checkpoint['loss']
  259. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  260. else:
  261. print(f"No saved model found at {save_path}")
  262. return model
  263. def load_image(self, img):
  264. """加载图像"""
  265. if isinstance(img, str):
  266. img = Image.open(img).convert("RGB")
  267. return img
  268. def preprocess_image(self, img):
  269. """预处理图像"""
  270. transform = transforms.ToTensor()
  271. img_tensor = transform(img) # [3, H, W]
  272. # 调整大小为 512x512
  273. t_start = time.time()
  274. im = img_tensor.permute(1, 2, 0) # [H, W, 3]
  275. # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  276. if im.shape != (512, 512, 3):
  277. im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
  278. img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512]
  279. t_end = time.time()
  280. print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
  281. return img_
  282. def predict(self):
  283. """执行预测"""
  284. model = self.load_best_model(self.model, self.pt_path, device)
  285. model.eval()
  286. # 预处理图像
  287. img_ = self.preprocess_image(self.img)
  288. # 模型推理
  289. with torch.no_grad():
  290. predictions = model([img_.to(self.device)])
  291. print("Model predictions completed.")
  292. # 后处理
  293. t_start = time.time()
  294. pred = box_line_(img_, predictions) # 线框匹配
  295. t_end = time.time()
  296. print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
  297. # 根据类型显示或保存结果
  298. if self.type == 0:
  299. show_all(img_, pred, self.threshold, save_path=self.save_path)
  300. elif self.type == 1:
  301. show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True)
  302. elif self.type == 2:
  303. show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True)
  304. elif self.type == 3:
  305. show_predict(img_, pred, self.threshold, t_start)
  306. def run(self):
  307. """运行预测流程"""
  308. self.predict()
  309. class Predict1:
  310. def __init__(self, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
  311. """
  312. 初始化预测器。
  313. 参数:
  314. pt_path: 模型权重文件路径。
  315. model: 模型定义(未加载权重)。
  316. img: 输入图像(路径或 PIL 图像对象)。
  317. type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。
  318. threshold: 阈值,用于过滤预测结果。
  319. save_path: 保存结果的路径(可选)。
  320. show: 是否显示结果。
  321. device: 运行设备(默认 'cuda')。
  322. """
  323. self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
  324. self.model = model
  325. self.img = self.load_image(img)
  326. self.type = type
  327. self.threshold = threshold
  328. self.save_path = save_path
  329. self.show_line = show_line
  330. self.show_box = show_box
  331. def load_best_model(self, model, save_path, device):
  332. if os.path.exists(save_path):
  333. checkpoint = torch.load(save_path, map_location=device)
  334. model.load_state_dict(checkpoint['model_state_dict'])
  335. # if optimizer is not None:
  336. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  337. # epoch = checkpoint['epoch']
  338. # loss = checkpoint['loss']
  339. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  340. else:
  341. print(f"No saved model found at {save_path}")
  342. return model
  343. def load_image(self, img):
  344. """加载图像"""
  345. if isinstance(img, str):
  346. img = Image.open(img).convert("RGB")
  347. return img
  348. def preprocess_image(self, img):
  349. """预处理图像"""
  350. transform = transforms.ToTensor()
  351. img_tensor = transform(img) # [3, H, W]
  352. # 调整大小为 512x512
  353. t_start = time.time()
  354. im = img_tensor.permute(1, 2, 0) # [H, W, 3]
  355. # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  356. if im.shape != (512, 512, 3):
  357. im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
  358. img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512]
  359. t_end = time.time()
  360. print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
  361. return img_
  362. def predict(self):
  363. """执行预测"""
  364. # model = self.load_best_model(self.model, self.pt_path, device)
  365. model = self.model
  366. model.eval()
  367. # 预处理图像
  368. img_ = self.preprocess_image(self.img)
  369. # 模型推理
  370. with torch.no_grad():
  371. predictions = model([img_.to(self.device)])
  372. print("Model predictions completed.")
  373. # 根据类型显示或保存结果
  374. if self.type == 0:
  375. # 后处理
  376. t_start = time.time()
  377. pred = box_line_(img_, predictions) # 线框匹配
  378. t_end = time.time()
  379. print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
  380. show_all(img_, pred, self.threshold, save_path=self.save_path)
  381. elif self.type == 1:
  382. show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True)
  383. elif self.type == 2:
  384. show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True)
  385. elif self.type == 3:
  386. # 后处理
  387. t_start = time.time()
  388. pred = box_line_(img_, predictions) # 线框匹配
  389. t_end = time.time()
  390. print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
  391. show_predict(img_, pred, self.threshold, t_start)
  392. def run(self):
  393. """运行预测流程"""
  394. self.predict()