predict.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. # import time
  2. # import torch
  3. # from PIL import Image
  4. # from torchvision import transforms
  5. # from skimage.transform import resize
  6. import time
  7. import cv2
  8. import skimage
  9. import os
  10. import torch
  11. from PIL import Image
  12. import matplotlib.pyplot as plt
  13. import numpy as np
  14. from torchvision import transforms
  15. from models.wirenet.postprocess import postprocess
  16. from rtree import index
  17. from datetime import datetime
  18. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  19. def box_line_(imgs, pred): # 默认置信度
  20. im = imgs.permute(1, 2, 0).cpu().numpy()
  21. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  22. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  23. # print(f'111:{len(lines)}')
  24. for i in range(1, len(lines)):
  25. if (lines[i] == lines[0]).all():
  26. lines = lines[:i]
  27. scores = scores[:i]
  28. break
  29. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  30. line, score = postprocess(lines, scores, diag * 0.01, 0, False)
  31. # print(f'333:{len(lines)}')
  32. for idx, box_ in enumerate(pred[0:-2]):
  33. box = box_['boxes'] # 是一个tensor
  34. line_ = []
  35. score_ = []
  36. for i in box:
  37. score_max = 0.0
  38. tmp = [[0.0, 0.0], [0.0, 0.0]]
  39. for j in range(len(line)):
  40. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  41. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  42. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  43. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  44. if score[j] > score_max:
  45. tmp = line[j]
  46. score_max = score[j]
  47. line_.append(tmp)
  48. score_.append(score_max)
  49. processed_list = torch.tensor(np.array(line_))
  50. pred[idx]['line'] = processed_list
  51. processed_s_list = torch.tensor(score_)
  52. pred[idx]['line_score'] = processed_s_list
  53. return pred
  54. def set_thresholds(threshold):
  55. if isinstance(threshold, list):
  56. if len(threshold) != 2:
  57. raise ValueError("Threshold list must contain exactly two elements.")
  58. a, b = threshold
  59. elif isinstance(threshold, (int, float)):
  60. a = b = threshold
  61. else:
  62. raise TypeError("Threshold must be either a list of two numbers or a single number.")
  63. return a, b
  64. def color():
  65. return [
  66. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  67. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  68. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  69. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  70. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  71. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  72. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  73. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  74. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  75. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  76. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  77. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  78. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  79. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  80. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  81. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  82. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  83. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  84. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  85. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  86. ]
  87. def show_all(imgs, pred, threshold, save_path):
  88. col = color()
  89. box_th, line_th = set_thresholds(threshold)
  90. im = imgs.permute(1, 2, 0)
  91. boxes = pred[0]['boxes'].cpu().numpy()
  92. box_scores = pred[0]['scores'].cpu().numpy()
  93. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  94. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  95. for i in range(1, len(lines)):
  96. if (lines[i] == lines[0]).all():
  97. lines = lines[:i]
  98. scores = scores[:i]
  99. break
  100. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  101. line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
  102. fig, axs = plt.subplots(1, 3, figsize=(10, 10))
  103. axs[0].imshow(np.array(im))
  104. for idx, box in enumerate(boxes):
  105. if box_scores[idx] < box_th:
  106. continue
  107. x0, y0, x1, y1 = box
  108. axs[0].add_patch(
  109. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  110. axs[0].set_title('Boxes')
  111. axs[1].imshow(np.array(im))
  112. for idx, (a, b) in enumerate(line):
  113. if line_score[idx] < line_th:
  114. continue
  115. axs[1].scatter(a[1], a[0], c='#871F78', s=2)
  116. axs[1].scatter(b[1], b[0], c='#871F78', s=2)
  117. axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  118. axs[1].set_title('Lines')
  119. axs[2].imshow(np.array(im))
  120. lines = pred[0]['line'].cpu().numpy()
  121. line_scores = pred[0]['line_score'].cpu().numpy()
  122. idx = 0
  123. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  124. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  125. x0, y0, x1, y1 = box
  126. # 框中无线的跳过
  127. if np.array_equal(line, tmp):
  128. continue
  129. a, b = line
  130. if box_score >= 0.7 or line_score >= 0.9:
  131. axs[2].add_patch(
  132. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  133. axs[2].scatter(a[1], a[0], c='#871F78', s=10)
  134. axs[2].scatter(b[1], b[0], c='#871F78', s=10)
  135. axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  136. idx = idx + 1
  137. axs[2].set_title('Boxes and Lines')
  138. if save_path:
  139. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png')
  140. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  141. plt.savefig(save_path)
  142. print(f"Saved result image to {save_path}")
  143. # if show:
  144. # 调整子图之间的距离,防止标题和标签重叠
  145. plt.tight_layout()
  146. plt.show()
  147. def show_box_or_line(imgs, pred, threshold, save_path=None, show_line=False, show_box=False):
  148. col = color()
  149. box_th, line_th = set_thresholds(threshold)
  150. im = imgs.permute(1, 2, 0)
  151. # 可视化预测结
  152. fig, ax = plt.subplots(figsize=(10, 10))
  153. ax.imshow(np.array(im))
  154. if show_box:
  155. boxes = pred[0]['boxes'].cpu().numpy()
  156. box_scores = pred[0]['scores'].cpu().numpy()
  157. for idx, box in enumerate(boxes):
  158. if box_scores[idx] < box_th:
  159. continue
  160. x0, y0, x1, y1 = box
  161. ax.add_patch(
  162. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  163. if save_path:
  164. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png')
  165. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  166. plt.savefig(save_path)
  167. print(f"Saved result image to {save_path}")
  168. if show_line:
  169. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  170. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  171. for i in range(1, len(lines)):
  172. if (lines[i] == lines[0]).all():
  173. lines = lines[:i]
  174. scores = scores[:i]
  175. break
  176. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  177. line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
  178. for idx, (a, b) in enumerate(line):
  179. if line_score[idx] < line_th:
  180. continue
  181. ax.scatter(a[1], a[0], c='#871F78', s=2)
  182. ax.scatter(b[1], b[0], c='#871F78', s=2)
  183. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  184. if save_path:
  185. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png')
  186. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  187. plt.savefig(save_path)
  188. print(f"Saved result image to {save_path}")
  189. plt.show()
  190. def show_predict(imgs, pred, threshold, t_start):
  191. col = color()
  192. box_th, line_th = set_thresholds(threshold)
  193. im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3]
  194. boxes = pred[0]['boxes'].cpu().numpy()
  195. box_scores = pred[0]['scores'].cpu().numpy()
  196. lines = pred[0]['line'].cpu().numpy()
  197. scores = pred[0]['line_score'].cpu().numpy()
  198. for i in range(1, len(lines)):
  199. if (lines[i] == lines[0]).all():
  200. lines = lines[:i]
  201. scores = scores[:i]
  202. break
  203. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  204. line1, line_score1 = postprocess(lines, scores, diag * 0.01, 0, False)
  205. # 可视化预测结
  206. fig, ax = plt.subplots(figsize=(10, 10))
  207. ax.imshow(np.array(im))
  208. idx = 0
  209. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  210. for box, line, box_score, line_score in zip(boxes, line1, box_scores, line_score1):
  211. x0, y0, x1, y1 = box
  212. # 框中无线的跳过
  213. if np.array_equal(line, tmp):
  214. continue
  215. a, b = line
  216. if box_score >= box_th or line_score >= line_th:
  217. ax.add_patch(
  218. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  219. ax.scatter(a[1], a[0], c='#871F78', s=10)
  220. ax.scatter(b[1], b[0], c='#871F78', s=10)
  221. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  222. idx = idx + 1
  223. t_end = time.time()
  224. print(f'predict used:{t_end - t_start}')
  225. plt.show()
  226. class Predict:
  227. def __init__(self, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
  228. """
  229. 初始化预测器。
  230. 参数:
  231. pt_path: 模型权重文件路径。
  232. model: 模型定义(未加载权重)。
  233. img: 输入图像(路径或 PIL 图像对象)。
  234. type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。
  235. threshold: 阈值,用于过滤预测结果。
  236. save_path: 保存结果的路径(可选)。
  237. show: 是否显示结果。
  238. device: 运行设备(默认 'cuda')。
  239. """
  240. # self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
  241. self.model = model
  242. self.device = next(model.parameters()).device
  243. # self.pt_path = pt_path
  244. self.img = self.load_image(img)
  245. self.type = type
  246. self.threshold = threshold
  247. self.save_path = save_path
  248. self.show_line = show_line
  249. self.show_box = show_box
  250. def load_best_model(self, model, save_path, device):
  251. if os.path.exists(save_path):
  252. checkpoint = torch.load(save_path, map_location=device)
  253. model.load_state_dict(checkpoint['model_state_dict'])
  254. # if optimizer is not None:
  255. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  256. # epoch = checkpoint['epoch']
  257. # loss = checkpoint['loss']
  258. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  259. else:
  260. print(f"No saved model found at {save_path}")
  261. return model
  262. def load_image(self, img):
  263. """加载图像"""
  264. if isinstance(img, str):
  265. img = Image.open(img).convert("RGB")
  266. return img
  267. def preprocess_image(self, img):
  268. """预处理图像"""
  269. transform = transforms.ToTensor()
  270. img_tensor = transform(img) # [3, H, W]
  271. # 调整大小为 512x512
  272. t_start = time.time()
  273. im = img_tensor.permute(1, 2, 0) # [H, W, 3]
  274. # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  275. if im.shape != (512, 512, 3):
  276. im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
  277. img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512]
  278. t_end = time.time()
  279. print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
  280. return img_
  281. def predict(self):
  282. """执行预测"""
  283. # model = self.load_best_model(self.model, self.pt_path, device)
  284. #
  285. # model.eval()
  286. # 预处理图像
  287. img_ = self.preprocess_image(self.img)
  288. # 模型推理
  289. with torch.no_grad():
  290. predictions =self.model([img_.to(self.device)])
  291. print("Model predictions completed.")
  292. # 后处理
  293. t_start = time.time()
  294. pred = box_line_(img_, predictions) # 线框匹配
  295. t_end = time.time()
  296. print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
  297. # 根据类型显示或保存结果
  298. if self.type == 0:
  299. show_all(img_, pred, self.threshold, save_path=self.save_path)
  300. elif self.type == 1:
  301. show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True)
  302. elif self.type == 2:
  303. show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True)
  304. elif self.type == 3:
  305. show_predict(img_, pred, self.threshold, t_start)
  306. def run(self):
  307. """运行预测流程"""
  308. self.predict()
  309. class Predict1:
  310. def __init__(self, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
  311. """
  312. 初始化预测器。
  313. 参数:
  314. pt_path: 模型权重文件路径。
  315. model: 模型定义(未加载权重)。
  316. img: 输入图像(路径或 PIL 图像对象)。
  317. type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。
  318. threshold: 阈值,用于过滤预测结果。
  319. save_path: 保存结果的路径(可选)。
  320. show: 是否显示结果。
  321. device: 运行设备(默认 'cuda')。
  322. """
  323. self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
  324. self.model = model
  325. self.img = self.load_image(img)
  326. self.type = type
  327. self.threshold = threshold
  328. self.save_path = save_path
  329. self.show_line = show_line
  330. self.show_box = show_box
  331. def load_best_model(self, model, save_path, device):
  332. if os.path.exists(save_path):
  333. checkpoint = torch.load(save_path, map_location=device)
  334. model.load_state_dict(checkpoint['model_state_dict'])
  335. # if optimizer is not None:
  336. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  337. # epoch = checkpoint['epoch']
  338. # loss = checkpoint['loss']
  339. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  340. else:
  341. print(f"No saved model found at {save_path}")
  342. return model
  343. def load_image(self, img):
  344. """加载图像"""
  345. if isinstance(img, str):
  346. img = Image.open(img).convert("RGB")
  347. return img
  348. def preprocess_image(self, img):
  349. """预处理图像"""
  350. transform = transforms.ToTensor()
  351. img_tensor = transform(img) # [3, H, W]
  352. # 调整大小为 512x512
  353. t_start = time.time()
  354. im = img_tensor.permute(1, 2, 0) # [H, W, 3]
  355. # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  356. if im.shape != (512, 512, 3):
  357. im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
  358. img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512]
  359. t_end = time.time()
  360. print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
  361. return img_
  362. def predict(self):
  363. """执行预测"""
  364. # model = self.load_best_model(self.model, self.pt_path, device)
  365. model = self.model
  366. model.eval()
  367. # 预处理图像
  368. img_ = self.preprocess_image(self.img)
  369. # 模型推理
  370. with torch.no_grad():
  371. predictions = model([img_.to(self.device)])
  372. print("Model predictions completed.")
  373. # 根据类型显示或保存结果
  374. if self.type == 0:
  375. # 后处理
  376. t_start = time.time()
  377. pred = box_line_(img_, predictions) # 线框匹配
  378. t_end = time.time()
  379. print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
  380. show_all(img_, pred, self.threshold, save_path=self.save_path)
  381. elif self.type == 1:
  382. show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True)
  383. elif self.type == 2:
  384. show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True)
  385. elif self.type == 3:
  386. # 后处理
  387. t_start = time.time()
  388. pred = box_line_(img_, predictions) # 线框匹配
  389. t_end = time.time()
  390. print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
  391. show_predict(img_, pred, self.threshold, t_start)
  392. def run(self):
  393. """运行预测流程"""
  394. self.predict()