predict2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import time
  2. import cv2
  3. import skimage
  4. import os
  5. import torch
  6. from PIL import Image
  7. import matplotlib.pyplot as plt
  8. import matplotlib as mpl
  9. import numpy as np
  10. from models.line_net.line_net import linenet_resnet50_fpn, get_line_net_efficientnetv2, get_line_net_convnext_fpn
  11. from torchvision import transforms
  12. # from models.wirenet.postprocess import postprocess
  13. from models.wirenet.postprocess import postprocess
  14. from rtree import index
  15. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  16. def load_best_model(model, save_path, device):
  17. if os.path.exists(save_path):
  18. checkpoint = torch.load(save_path, map_location=device)
  19. model.load_state_dict(checkpoint['model_state_dict'],strict=False)
  20. # if optimizer is not None:
  21. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  22. # epoch = checkpoint['epoch']
  23. # loss = checkpoint['loss']
  24. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  25. else:
  26. print(f"No saved model found at {save_path}")
  27. return model
  28. def box_line_(imgs, pred, length=False): # 默认置信度
  29. im = imgs.permute(1, 2, 0).cpu().numpy()
  30. line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  31. line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  32. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  33. line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  34. for idx, box_ in enumerate(pred[0:-1]):
  35. box = box_['boxes'] # 是一个tensor
  36. line_ = []
  37. score_ = []
  38. for i in box:
  39. score_max = 0.0
  40. tmp = [[0.0, 0.0], [0.0, 0.0]]
  41. for j in range(len(line)):
  42. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  43. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  44. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  45. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  46. if score[j] > score_max:
  47. tmp = line[j]
  48. score_max = score[j]
  49. line_.append(tmp)
  50. score_.append(score_max)
  51. processed_list = torch.tensor(line_)
  52. pred[idx]['line'] = processed_list
  53. processed_s_list = torch.tensor(score_)
  54. pred[idx]['line_score'] = processed_s_list
  55. return pred
  56. # box内无线段时,选box内点组成线段最长的 两个点组成的线段返回
  57. def box_line1(imgs, pred): # 默认置信度
  58. im = imgs.permute(1, 2, 0).cpu().numpy()
  59. line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  60. line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  61. points = pred[-1]['wires']['juncs'].cpu().numpy()[0] / 128 * 512
  62. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  63. line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  64. for idx, box_ in enumerate(pred[0:-1]):
  65. box = box_['boxes'] # 是一个tensor
  66. line_ = []
  67. score_ = []
  68. for i in box:
  69. score_max = 0.0
  70. tmp = [[0.0, 0.0], [0.0, 0.0]]
  71. for j in range(len(line)):
  72. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  73. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  74. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  75. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  76. if score[j] > score_max:
  77. tmp = line[j]
  78. score_max = score[j]
  79. # # 如果 box 内无线段,则通过点坐标找最长线段
  80. # if score_max == 0.0: # 说明 box 内无线段
  81. # box_points = [
  82. # [x, y] for x, y in points
  83. # if i[0] <= y <= i[2] and i[1] <= x <= i[3]
  84. # ]
  85. #
  86. # if len(box_points) >= 2: # 至少需要两个点才能组成线段
  87. # max_distance = 0.0
  88. # longest_segment = [[0.0, 0.0], [0.0, 0.0]]
  89. #
  90. # # 找出 box 内点组成的最长线段
  91. # for p1 in box_points:
  92. # for p2 in box_points:
  93. # if p1 != p2:
  94. # distance = np.linalg.norm(np.array(p1) - np.array(p2))
  95. # if distance > max_distance:
  96. # max_distance = distance
  97. # longest_segment = [p1, p2]
  98. #
  99. # tmp = longest_segment
  100. # score_max = 0.0 # 默认分数为 0.0
  101. line_.append(tmp)
  102. score_.append(score_max)
  103. processed_list = torch.tensor(line_)
  104. pred[idx]['line'] = processed_list
  105. processed_s_list = torch.tensor(score_)
  106. pred[idx]['line_score'] = processed_s_list
  107. return pred
  108. def show_box(imgs, pred, t_start):
  109. col = [
  110. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  111. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  112. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  113. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  114. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  115. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  116. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  117. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  118. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  119. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  120. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  121. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  122. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  123. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  124. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  125. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  126. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  127. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  128. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  129. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  130. ]
  131. # print(len(col))
  132. im = imgs.permute(1, 2, 0)
  133. boxes = pred[0]['boxes'].cpu().numpy()
  134. box_scores = pred[0]['scores'].cpu().numpy()
  135. # 可视化预测结
  136. fig, ax = plt.subplots(figsize=(10, 10))
  137. ax.imshow(np.array(im))
  138. for idx, box in enumerate(boxes):
  139. # if box_scores[idx] < 0.7:
  140. # continue
  141. x0, y0, x1, y1 = box
  142. ax.add_patch(
  143. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  144. t_end = time.time()
  145. print(f'show_box used:{t_end - t_start}')
  146. plt.show()
  147. def show_predict(imgs, pred, t_start):
  148. col = [
  149. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  150. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  151. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  152. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  153. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  154. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  155. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  156. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  157. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  158. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  159. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  160. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  161. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  162. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  163. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  164. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  165. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  166. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  167. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  168. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  169. ]
  170. im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3]
  171. boxes = pred[0]['boxes'].cpu().numpy()
  172. box_scores = pred[0]['scores'].cpu().numpy()
  173. lines = pred[0]['line'].cpu().numpy()
  174. line_scores = pred[0]['line_score'].cpu().numpy()
  175. # 可视化预测结
  176. fig, ax = plt.subplots(figsize=(10, 10))
  177. ax.imshow(np.array(im))
  178. idx = 0
  179. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  180. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  181. x0, y0, x1, y1 = box
  182. # 框中无线的跳过
  183. if np.array_equal(line, tmp):
  184. continue
  185. a, b = line
  186. if box_score >= 0 or line_score >= 0:
  187. ax.add_patch(
  188. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  189. ax.scatter(a[1], a[0], c='#871F78', s=10)
  190. ax.scatter(b[1], b[0], c='#871F78', s=10)
  191. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  192. idx = idx + 1
  193. t_end = time.time()
  194. print(f'predict used:{t_end - t_start}')
  195. plt.show()
  196. def show_line(imgs, pred, t_start):
  197. im = imgs.permute(1, 2, 0)
  198. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  199. # print(pred[-1]['wires']['score'])
  200. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  201. t1 = time.time()
  202. print(f't1:{t1 - t_start}')
  203. for i in range(1, len(lines)):
  204. if (lines[i] == lines[0]).all():
  205. lines = lines[:i]
  206. scores = scores[:i]
  207. break
  208. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  209. line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
  210. print(f'lines num:{len(line)}')
  211. t2 = time.time()
  212. print(f't1:{t2 - t1}')
  213. # 可视化预测结
  214. fig, ax = plt.subplots(figsize=(10, 10))
  215. ax.imshow(np.array(im))
  216. for idx, (a, b) in enumerate(line):
  217. # if line_score[idx] < 0.7:
  218. # continue
  219. ax.scatter(a[1], a[0], c='#871F78', s=2)
  220. ax.scatter(b[1], b[0], c='#871F78', s=2)
  221. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  222. t_end = time.time()
  223. print(f'show_line used:{t_end - t_start}')
  224. plt.show()
  225. def predict(pt_path, model, img):
  226. model = load_best_model(model, pt_path, device)
  227. model.eval()
  228. if isinstance(img, str):
  229. img = Image.open(img).convert("RGB")
  230. transform = transforms.ToTensor()
  231. img_tensor = transform(img) # [3, 512, 512]
  232. # 将图像调整为512x512大小
  233. t_start = time.time()
  234. # im = img_tensor.permute(1, 2, 0) # [512, 512, 3]
  235. # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  236. # img_ = torch.tensor(im_resized).permute(2, 0, 1)
  237. im = img_tensor.permute(1, 2, 0) # [H, W, 3]
  238. if im.shape != (512, 512, 3):
  239. # im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_NEAREST)
  240. im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_NEAREST_EXACT)
  241. img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512]
  242. t_end = time.time()
  243. print(f'switch img used:{t_end - t_start}')
  244. with torch.no_grad():
  245. predictions = model([img_.to(device)])
  246. print(predictions)
  247. t_end1 = time.time()
  248. print(f'model test used:{t_end1 - t_end}')
  249. # show_line_optimized(img_, predictions, t_start) # 只画线
  250. show_line(img_, predictions, t_end1)
  251. t_end2 = time.time()
  252. show_box(img_, predictions, t_end2) # 只画kuang
  253. # show_box_or_line(img_, predictions, show_line=True, show_box=True) # 参数确定画什么
  254. # show_box_and_line(img_, predictions, show_line=True, show_box=True) # 一起画 1x2 2张图
  255. # t_start = time.time()
  256. # pred = box_line1(img_, predictions)
  257. # t_end = time.time()
  258. # print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
  259. #
  260. # show_predict(img_, pred, t_start)
  261. if __name__ == '__main__':
  262. t_start = time.time()
  263. print(f'start to predict:{t_start}')
  264. # model = linenet_resnet50_fpn().to(device)
  265. # model = get_line_net_efficientnetv2(2, pretrained_backbone=True).to(device)
  266. model=get_line_net_convnext_fpn(num_classes=2).to(device)
  267. # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练24轮结果.pth"
  268. # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
  269. pt_path = r"\\192.168.50.222\share\rlq\weights\convnext25051401.pth"
  270. # pt_path = r"C:\Users\m2337\Downloads\best_e20.pth"
  271. img_path = r"\\192.168.50.222\share\zyh\513\a_dataset\images\val\2025-05-13-08-56-03_LaserData_ID019504_color_scale.jpg"
  272. predict(pt_path, model, img_path)
  273. t_end = time.time()
  274. print(f'predict used:{t_end - t_start}')