predict.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # import time
  2. # import torch
  3. # from PIL import Image
  4. # from torchvision import transforms
  5. # from skimage.transform import resize
  6. import time
  7. import cv2
  8. import skimage
  9. import os
  10. import torch
  11. from PIL import Image
  12. import matplotlib.pyplot as plt
  13. import numpy as np
  14. from torchvision import transforms
  15. from models.wirenet.postprocess import postprocess
  16. from rtree import index
  17. from datetime import datetime
  18. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  19. def box_line_(imgs, pred): # 默认置信度
  20. im = imgs.permute(1, 2, 0).cpu().numpy()
  21. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * np.array([2000, 2000])
  22. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  23. # print(f'111:{len(lines)}')
  24. for i in range(1, len(lines)):
  25. if (lines[i] == lines[0]).all():
  26. lines = lines[:i]
  27. scores = scores[:i]
  28. break
  29. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  30. line, score = postprocess(lines, scores, diag * 0.01, 0, False)
  31. # print(f'333:{len(lines)}')
  32. for idx, box_ in enumerate(pred[0:-2]):
  33. box = box_['boxes'] # 是一个tensor
  34. line_ = []
  35. score_ = []
  36. for i in box:
  37. score_max = 0.0
  38. tmp = [[0.0, 0.0], [0.0, 0.0]]
  39. for j in range(len(line)):
  40. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  41. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  42. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  43. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  44. if score[j] > score_max:
  45. tmp = line[j]
  46. score_max = score[j]
  47. line_.append(tmp)
  48. score_.append(score_max)
  49. processed_list = torch.tensor(np.array(line_))
  50. pred[idx]['line'] = processed_list
  51. processed_s_list = torch.tensor(score_)
  52. pred[idx]['line_score'] = processed_s_list
  53. return pred
  54. def set_thresholds(threshold):
  55. if isinstance(threshold, list):
  56. if len(threshold) != 2:
  57. raise ValueError("Threshold list must contain exactly two elements.")
  58. a, b = threshold
  59. elif isinstance(threshold, (int, float)):
  60. a = b = threshold
  61. else:
  62. raise TypeError("Threshold must be either a list of two numbers or a single number.")
  63. return a, b
  64. def color():
  65. return [
  66. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  67. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  68. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  69. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  70. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  71. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  72. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  73. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  74. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  75. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  76. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  77. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  78. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  79. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  80. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  81. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  82. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  83. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  84. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  85. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  86. ]
  87. def show_all(imgs, pred, threshold, save_path):
  88. col = color()
  89. box_th, line_th = set_thresholds(threshold)
  90. im = imgs.permute(1, 2, 0)
  91. boxes = pred[0]['boxes'].cpu().numpy()
  92. box_scores = pred[0]['scores'].cpu().numpy()
  93. # lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  94. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * np.array([2000, 2000])
  95. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  96. for i in range(1, len(lines)):
  97. if (lines[i] == lines[0]).all():
  98. lines = lines[:i]
  99. scores = scores[:i]
  100. break
  101. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  102. line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
  103. fig, axs = plt.subplots(1, 3, figsize=(10, 10))
  104. axs[0].imshow(np.array(im))
  105. for idx, box in enumerate(boxes):
  106. if box_scores[idx] < box_th:
  107. continue
  108. x0, y0, x1, y1 = box
  109. axs[0].add_patch(
  110. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  111. axs[0].set_title('Boxes')
  112. axs[1].imshow(np.array(im))
  113. for idx, (a, b) in enumerate(line):
  114. if line_score[idx] < line_th:
  115. continue
  116. axs[1].scatter(a[1], a[0], c='#871F78', s=2)
  117. axs[1].scatter(b[1], b[0], c='#871F78', s=2)
  118. axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  119. axs[1].set_title('Lines')
  120. axs[2].imshow(np.array(im))
  121. lines = pred[0]['line'].cpu().numpy()
  122. line_scores = pred[0]['line_score'].cpu().numpy()
  123. idx = 0
  124. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  125. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  126. x0, y0, x1, y1 = box
  127. # 框中无线的跳过
  128. if np.array_equal(line, tmp):
  129. continue
  130. a, b = line
  131. if box_score >= 0.7 or line_score >= 0.9:
  132. axs[2].add_patch(
  133. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  134. axs[2].scatter(a[1], a[0], c='#871F78', s=10)
  135. axs[2].scatter(b[1], b[0], c='#871F78', s=10)
  136. axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  137. idx = idx + 1
  138. axs[2].set_title('Boxes and Lines')
  139. if save_path:
  140. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png')
  141. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  142. plt.savefig(save_path)
  143. print(f"Saved result image to {save_path}")
  144. # if show:
  145. # 调整子图之间的距离,防止标题和标签重叠
  146. plt.tight_layout()
  147. plt.show()
  148. def show_box_or_line(imgs, pred, threshold, save_path=None, show_line=False, show_box=False):
  149. col = color()
  150. box_th, line_th = set_thresholds(threshold)
  151. im = imgs.permute(1, 2, 0)
  152. # 可视化预测结
  153. fig, ax = plt.subplots(figsize=(10, 10))
  154. ax.imshow(np.array(im))
  155. if show_box:
  156. boxes = pred[0]['boxes'].cpu().numpy()
  157. box_scores = pred[0]['scores'].cpu().numpy()
  158. for idx, box in enumerate(boxes):
  159. if box_scores[idx] < box_th:
  160. continue
  161. x0, y0, x1, y1 = box
  162. ax.add_patch(
  163. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  164. if save_path:
  165. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png')
  166. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  167. plt.savefig(save_path)
  168. print(f"Saved result image to {save_path}")
  169. if show_line:
  170. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  171. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  172. for i in range(1, len(lines)):
  173. if (lines[i] == lines[0]).all():
  174. lines = lines[:i]
  175. scores = scores[:i]
  176. break
  177. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  178. line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
  179. for idx, (a, b) in enumerate(line):
  180. if line_score[idx] < line_th:
  181. continue
  182. ax.scatter(a[1], a[0], c='#871F78', s=2)
  183. ax.scatter(b[1], b[0], c='#871F78', s=2)
  184. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  185. if save_path:
  186. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png')
  187. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  188. plt.savefig(save_path)
  189. print(f"Saved result image to {save_path}")
  190. plt.show()
  191. def show_predict(imgs, pred, threshold, t_start):
  192. col = color()
  193. box_th, line_th = set_thresholds(threshold)
  194. im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3]
  195. boxes = pred[0]['boxes'].cpu().numpy()
  196. box_scores = pred[0]['scores'].cpu().numpy()
  197. lines = pred[0]['line'].cpu().numpy()
  198. scores = pred[0]['line_score'].cpu().numpy()
  199. for i in range(1, len(lines)):
  200. if (lines[i] == lines[0]).all():
  201. lines = lines[:i]
  202. scores = scores[:i]
  203. break
  204. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  205. line1, line_score1 = postprocess(lines, scores, diag * 0.01, 0, False)
  206. # 可视化预测结
  207. fig, ax = plt.subplots(figsize=(10, 10))
  208. ax.imshow(np.array(im))
  209. idx = 0
  210. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  211. for box, line, box_score, line_score in zip(boxes, line1, box_scores, line_score1):
  212. x0, y0, x1, y1 = box
  213. # 框中无线的跳过
  214. if np.array_equal(line, tmp):
  215. continue
  216. a, b = line
  217. if box_score >= box_th or line_score >= line_th:
  218. ax.add_patch(
  219. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  220. ax.scatter(a[1], a[0], c='#871F78', s=10)
  221. ax.scatter(b[1], b[0], c='#871F78', s=10)
  222. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  223. idx = idx + 1
  224. t_end = time.time()
  225. print(f'predict used:{t_end - t_start}')
  226. plt.show()
  227. class Predict:
  228. def __init__(self, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
  229. """
  230. 初始化预测器。
  231. 参数:
  232. pt_path: 模型权重文件路径。
  233. model: 模型定义(未加载权重)。
  234. img: 输入图像(路径或 PIL 图像对象)。
  235. type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。
  236. threshold: 阈值,用于过滤预测结果。
  237. save_path: 保存结果的路径(可选)。
  238. show: 是否显示结果。
  239. device: 运行设备(默认 'cuda')。
  240. """
  241. # self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
  242. self.model = model
  243. self.device = next(model.parameters()).device
  244. # self.pt_path = pt_path
  245. self.img = self.load_image(img)
  246. self.type = type
  247. self.threshold = threshold
  248. self.save_path = save_path
  249. self.show_line = show_line
  250. self.show_box = show_box
  251. def load_best_model(self, model, save_path, device):
  252. if os.path.exists(save_path):
  253. checkpoint = torch.load(save_path, map_location=device)
  254. model.load_state_dict(checkpoint['model_state_dict'])
  255. # if optimizer is not None:
  256. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  257. # epoch = checkpoint['epoch']
  258. # loss = checkpoint['loss']
  259. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  260. else:
  261. print(f"No saved model found at {save_path}")
  262. return model
  263. def load_image(self, img):
  264. """加载图像"""
  265. if isinstance(img, str):
  266. img = Image.open(img).convert("RGB")
  267. return img
  268. def preprocess_image(self, img):
  269. """预处理图像"""
  270. transform = transforms.ToTensor()
  271. img_tensor = transform(img) # [3, H, W]
  272. # 调整大小为 512x512
  273. t_start = time.time()
  274. im = img_tensor.permute(1, 2, 0) # [H, W, 3]
  275. # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  276. if im.shape != (512, 512, 3):
  277. im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
  278. img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512]
  279. t_end = time.time()
  280. print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
  281. return img_
  282. def predict(self):
  283. """执行预测"""
  284. # model = self.load_best_model(self.model, self.pt_path, device)
  285. #
  286. # model.eval()
  287. # 预处理图像
  288. img_ = self.preprocess_image(self.img)
  289. # 模型推理
  290. with torch.no_grad():
  291. predictions =self.model([img_.to(self.device)])
  292. print("Model predictions completed.")
  293. # 后处理
  294. t_start = time.time()
  295. pred = box_line_(img_, predictions) # 线框匹配
  296. t_end = time.time()
  297. print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
  298. # 根据类型显示或保存结果
  299. if self.type == 0:
  300. show_all(img_, pred, self.threshold, save_path=self.save_path)
  301. elif self.type == 1:
  302. show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True)
  303. elif self.type == 2:
  304. show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True)
  305. elif self.type == 3:
  306. show_predict(img_, pred, self.threshold, t_start)
  307. def run(self):
  308. """运行预测流程"""
  309. self.predict()
  310. class Predict1:
  311. def __init__(self, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
  312. """
  313. 初始化预测器。
  314. 参数:
  315. pt_path: 模型权重文件路径。
  316. model: 模型定义(未加载权重)。
  317. img: 输入图像(路径或 PIL 图像对象)。
  318. type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。
  319. threshold: 阈值,用于过滤预测结果。
  320. save_path: 保存结果的路径(可选)。
  321. show: 是否显示结果。
  322. device: 运行设备(默认 'cuda')。
  323. """
  324. self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
  325. self.model = model
  326. self.img = self.load_image(img)
  327. self.type = type
  328. self.threshold = threshold
  329. self.save_path = save_path
  330. self.show_line = show_line
  331. self.show_box = show_box
  332. def load_best_model(self, model, save_path, device):
  333. if os.path.exists(save_path):
  334. checkpoint = torch.load(save_path, map_location=device)
  335. model.load_state_dict(checkpoint['model_state_dict'])
  336. # if optimizer is not None:
  337. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  338. # epoch = checkpoint['epoch']
  339. # loss = checkpoint['loss']
  340. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  341. else:
  342. print(f"No saved model found at {save_path}")
  343. return model
  344. def load_image(self, img):
  345. """加载图像"""
  346. if isinstance(img, str):
  347. img = Image.open(img).convert("RGB")
  348. return img
  349. def preprocess_image(self, img):
  350. """预处理图像"""
  351. transform = transforms.ToTensor()
  352. img_tensor = transform(img) # [3, H, W]
  353. # 调整大小为 512x512
  354. t_start = time.time()
  355. im = img_tensor.permute(1, 2, 0) # [H, W, 3]
  356. # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  357. if im.shape != (512, 512, 3):
  358. im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
  359. img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512]
  360. t_end = time.time()
  361. print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
  362. return img_
  363. def predict(self):
  364. """执行预测"""
  365. # model = self.load_best_model(self.model, self.pt_path, device)
  366. model = self.model
  367. model.eval()
  368. # 预处理图像
  369. img_ = self.preprocess_image(self.img)
  370. # 模型推理
  371. with torch.no_grad():
  372. predictions = model([img_.to(self.device)])
  373. print("Model predictions completed.")
  374. # 根据类型显示或保存结果
  375. if self.type == 0:
  376. # 后处理
  377. t_start = time.time()
  378. pred = box_line_(img_, predictions) # 线框匹配
  379. t_end = time.time()
  380. print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
  381. show_all(img_, pred, self.threshold, save_path=self.save_path)
  382. elif self.type == 1:
  383. show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True)
  384. elif self.type == 2:
  385. show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True)
  386. elif self.type == 3:
  387. # 后处理
  388. t_start = time.time()
  389. pred = box_line_(img_, predictions) # 线框匹配
  390. t_end = time.time()
  391. print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
  392. show_predict(img_, pred, self.threshold, t_start)
  393. def run(self):
  394. """运行预测流程"""
  395. self.predict()