predict2.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. # import time
  2. #
  3. # from models.line_detect.postprocess import show_predict
  4. # import os
  5. #
  6. # import torch
  7. # from PIL import Image
  8. # import matplotlib.pyplot as plt
  9. # import matplotlib as mpl
  10. # import numpy as np
  11. # from models.line_detect.line_net import linenet_resnet50_fpn
  12. # from torchvision import transforms
  13. # from rtree import index
  14. # # from models.wirenet.postprocess import postprocess
  15. # from models.wirenet.postprocess import postprocess
  16. #
  17. # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  18. #
  19. #
  20. # def load_best_model(model, save_path, device):
  21. # if os.path.exists(save_path):
  22. # checkpoint = torch.load(save_path, map_location=device)
  23. # model.load_state_dict(checkpoint['model_state_dict'])
  24. # # if optimizer is not None:
  25. # # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  26. # epoch = checkpoint['epoch']
  27. # loss = checkpoint['loss']
  28. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  29. # else:
  30. # print(f"No saved model found at {save_path}")
  31. # return model
  32. #
  33. #
  34. # def box_line_optimized(pred):
  35. # # 创建R-tree索引
  36. # idx = index.Index()
  37. #
  38. # # 将所有线段添加到R-tree中
  39. # lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
  40. # scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
  41. #
  42. # # 提取并处理所有线段
  43. # for idx_line in range(lines.shape[1]): # 遍历2500条线段
  44. # line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512 # 转换为numpy数组并调整比例
  45. # x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
  46. # y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
  47. # x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
  48. # y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
  49. # idx.insert(idx_line, (x_min, y_min, x_max, y_max))
  50. #
  51. # for idx_box, box_ in enumerate(pred[0:-1]):
  52. # box = box_['boxes'].cpu().numpy() # 确保将张量转换为numpy数组
  53. # line_ = []
  54. # score_ = []
  55. #
  56. # for i in box:
  57. # score_max = 0.0
  58. # tmp = [[0.0, 0.0], [0.0, 0.0]]
  59. #
  60. # # 获取与当前box可能相交的所有线段
  61. # possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
  62. #
  63. # for j in possible_matches:
  64. # line_j = lines[0, j].cpu().numpy() / 128 * 512
  65. # if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and # 注意这里交换了x和y
  66. # line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  67. # line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
  68. # line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
  69. #
  70. # if scores[j] > score_max:
  71. # tmp = line_j
  72. # score_max = scores[j]
  73. #
  74. # line_.append(tmp)
  75. # score_.append(score_max)
  76. #
  77. # processed_list = torch.tensor(line_)
  78. # pred[idx_box]['line'] = processed_list
  79. #
  80. # processed_s_list = torch.tensor(score_)
  81. # pred[idx_box]['line_score'] = processed_s_list
  82. #
  83. # return pred
  84. #
  85. # # def box_line_(pred):
  86. # # for idx, box_ in enumerate(pred[0:-1]):
  87. # # box = box_['boxes'] # 是一个tensor
  88. # # line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
  89. # # score = pred[-1]['wires']['score'][idx]
  90. # # line_ = []
  91. # # score_ = []
  92. # #
  93. # # for i in box:
  94. # # score_max = 0.0
  95. # # tmp = [[0.0, 0.0], [0.0, 0.0]]
  96. # #
  97. # # for j in range(len(line)):
  98. # # if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  99. # # line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  100. # # line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  101. # # line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  102. # #
  103. # # if score[j] > score_max:
  104. # # tmp = line[j]
  105. # # score_max = score[j]
  106. # # line_.append(tmp)
  107. # # score_.append(score_max)
  108. # # processed_list = torch.tensor(line_)
  109. # # pred[idx]['line'] = processed_list
  110. # #
  111. # # processed_s_list = torch.tensor(score_)
  112. # # pred[idx]['line_score'] = processed_s_list
  113. # # return pred
  114. #
  115. #
  116. # def predict(pt_path, model, img):
  117. # model = load_best_model(model, pt_path, device)
  118. #
  119. # model.eval()
  120. #
  121. # if isinstance(img, str):
  122. # img = Image.open(img).convert("RGB")
  123. #
  124. # transform = transforms.ToTensor()
  125. # img_tensor = transform(img)
  126. #
  127. # with torch.no_grad():
  128. # t_start = time.time()
  129. # predictions = model([img_tensor.to(device)])
  130. # t_end=time.time()
  131. # print(f'predict used:{t_end-t_start}')
  132. # # print(f'predictions:{predictions}')
  133. # boxes=predictions[0]['boxes'].shape
  134. # lines=predictions[-1]['wires']['lines'].shape
  135. # lines_scores=predictions[-1]['wires']['score'].shape
  136. # print(f'predictions boxes:{boxes},lines:{lines},lines_scores:{lines_scores}')
  137. # t_start=time.time()
  138. # pred = box_line_optimized(predictions)
  139. # t_end=time.time()
  140. # print(f'matched boxes and lines used:{t_end - t_start}')
  141. # # print(f'pred:{pred[0]}')
  142. # show_predict(img_tensor, pred, t_start)
  143. #
  144. #
  145. # if __name__ == '__main__':
  146. # t_start = time.time()
  147. # print(f'start to predict:{t_start}')
  148. # model = linenet_resnet50_fpn().to(device)
  149. # pt_path = r"F:\BaiduNetdiskDownload\resnet50_best_e8.pth"
  150. # img_path = r"I:\datasets\wirenet_1000\images\val\00035148_0.png"
  151. # predict(pt_path, model, img_path)
  152. # t_end = time.time()
  153. # # print(f'predict used:{t_end - t_start}')
  154. import time
  155. import cv2
  156. import skimage
  157. import os
  158. import torch
  159. from PIL import Image
  160. import matplotlib.pyplot as plt
  161. import matplotlib as mpl
  162. import numpy as np
  163. from models.line_detect.line_net import linenet_resnet50_fpn
  164. from torchvision import transforms
  165. # from models.wirenet.postprocess import postprocess
  166. from models.wirenet.postprocess import postprocess
  167. from rtree import index
  168. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  169. def load_best_model(model, save_path, device):
  170. if os.path.exists(save_path):
  171. checkpoint = torch.load(save_path, map_location=device)
  172. model.load_state_dict(checkpoint['model_state_dict'])
  173. # if optimizer is not None:
  174. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  175. # epoch = checkpoint['epoch']
  176. # loss = checkpoint['loss']
  177. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  178. else:
  179. print(f"No saved model found at {save_path}")
  180. return model
  181. def box_line_(imgs, pred, length=False): # 默认置信度
  182. im = imgs.permute(1, 2, 0).cpu().numpy()
  183. line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  184. line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  185. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  186. line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  187. for idx, box_ in enumerate(pred[0:-1]):
  188. box = box_['boxes'] # 是一个tensor
  189. line_ = []
  190. score_ = []
  191. for i in box:
  192. score_max = 0.0
  193. tmp = [[0.0, 0.0], [0.0, 0.0]]
  194. for j in range(len(line)):
  195. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  196. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  197. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  198. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  199. if score[j] > score_max:
  200. tmp = line[j]
  201. score_max = score[j]
  202. line_.append(tmp)
  203. score_.append(score_max)
  204. processed_list = torch.tensor(line_)
  205. pred[idx]['line'] = processed_list
  206. processed_s_list = torch.tensor(score_)
  207. pred[idx]['line_score'] = processed_s_list
  208. return pred
  209. # box内无线段时,选box内点组成线段最长的 两个点组成的线段返回
  210. def box_line1(imgs, pred): # 默认置信度
  211. im = imgs.permute(1, 2, 0).cpu().numpy()
  212. line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  213. line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  214. points = pred[-1]['wires']['juncs'].cpu().numpy()[0] / 128 * 512
  215. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  216. line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  217. for idx, box_ in enumerate(pred[0:-1]):
  218. box = box_['boxes'] # 是一个tensor
  219. line_ = []
  220. score_ = []
  221. for i in box:
  222. score_max = 0.0
  223. tmp = [[0.0, 0.0], [0.0, 0.0]]
  224. for j in range(len(line)):
  225. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  226. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  227. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  228. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  229. if score[j] > score_max:
  230. tmp = line[j]
  231. score_max = score[j]
  232. # # 如果 box 内无线段,则通过点坐标找最长线段
  233. # if score_max == 0.0: # 说明 box 内无线段
  234. # box_points = [
  235. # [x, y] for x, y in points
  236. # if i[0] <= y <= i[2] and i[1] <= x <= i[3]
  237. # ]
  238. #
  239. # if len(box_points) >= 2: # 至少需要两个点才能组成线段
  240. # max_distance = 0.0
  241. # longest_segment = [[0.0, 0.0], [0.0, 0.0]]
  242. #
  243. # # 找出 box 内点组成的最长线段
  244. # for p1 in box_points:
  245. # for p2 in box_points:
  246. # if p1 != p2:
  247. # distance = np.linalg.norm(np.array(p1) - np.array(p2))
  248. # if distance > max_distance:
  249. # max_distance = distance
  250. # longest_segment = [p1, p2]
  251. #
  252. # tmp = longest_segment
  253. # score_max = 0.0 # 默认分数为 0.0
  254. line_.append(tmp)
  255. score_.append(score_max)
  256. processed_list = torch.tensor(line_)
  257. pred[idx]['line'] = processed_list
  258. processed_s_list = torch.tensor(score_)
  259. pred[idx]['line_score'] = processed_s_list
  260. return pred
  261. def show_box(imgs, pred, t_start):
  262. col = [
  263. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  264. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  265. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  266. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  267. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  268. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  269. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  270. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  271. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  272. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  273. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  274. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  275. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  276. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  277. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  278. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  279. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  280. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  281. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  282. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  283. ]
  284. # print(len(col))
  285. im = imgs.permute(1, 2, 0)
  286. boxes = pred[0]['boxes'].cpu().numpy()
  287. box_scores = pred[0]['scores'].cpu().numpy()
  288. # 可视化预测结
  289. fig, ax = plt.subplots(figsize=(10, 10))
  290. ax.imshow(np.array(im))
  291. for idx, box in enumerate(boxes):
  292. # if box_scores[idx] < 0.7:
  293. # continue
  294. x0, y0, x1, y1 = box
  295. ax.add_patch(
  296. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  297. t_end = time.time()
  298. print(f'show_box used:{t_end - t_start}')
  299. plt.show()
  300. def show_predict(imgs, pred, t_start):
  301. col = [
  302. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  303. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  304. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  305. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  306. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  307. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  308. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  309. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  310. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  311. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  312. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  313. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  314. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  315. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  316. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  317. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  318. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  319. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  320. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  321. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  322. ]
  323. im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3]
  324. boxes = pred[0]['boxes'].cpu().numpy()
  325. box_scores = pred[0]['scores'].cpu().numpy()
  326. lines = pred[0]['line'].cpu().numpy()
  327. line_scores = pred[0]['line_score'].cpu().numpy()
  328. # 可视化预测结
  329. fig, ax = plt.subplots(figsize=(10, 10))
  330. ax.imshow(np.array(im))
  331. idx = 0
  332. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  333. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  334. x0, y0, x1, y1 = box
  335. # 框中无线的跳过
  336. if np.array_equal(line, tmp):
  337. continue
  338. a, b = line
  339. if box_score >= 0 or line_score >= 0:
  340. ax.add_patch(
  341. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  342. ax.scatter(a[1], a[0], c='#871F78', s=10)
  343. ax.scatter(b[1], b[0], c='#871F78', s=10)
  344. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  345. idx = idx + 1
  346. t_end = time.time()
  347. print(f'predict used:{t_end - t_start}')
  348. plt.show()
  349. def show_line(imgs, pred, t_start):
  350. im = imgs.permute(1, 2, 0)
  351. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  352. # print(pred[-1]['wires']['score'])
  353. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  354. t1 = time.time()
  355. print(f't1:{t1 - t_start}')
  356. for i in range(1, len(lines)):
  357. if (lines[i] == lines[0]).all():
  358. lines = lines[:i]
  359. scores = scores[:i]
  360. break
  361. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  362. line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
  363. print(f'lines num:{len(line)}')
  364. t2 = time.time()
  365. print(f't1:{t2 - t1}')
  366. # 可视化预测结
  367. fig, ax = plt.subplots(figsize=(10, 10))
  368. ax.imshow(np.array(im))
  369. for idx, (a, b) in enumerate(line):
  370. # if line_score[idx] < 0.7:
  371. # continue
  372. ax.scatter(a[1], a[0], c='#871F78', s=2)
  373. ax.scatter(b[1], b[0], c='#871F78', s=2)
  374. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  375. t_end = time.time()
  376. print(f'show_line used:{t_end - t_start}')
  377. plt.show()
  378. def predict(pt_path, model, img):
  379. model = load_best_model(model, pt_path, device)
  380. model.eval()
  381. if isinstance(img, str):
  382. img = Image.open(img).convert("RGB")
  383. transform = transforms.ToTensor()
  384. img_tensor = transform(img) # [3, 512, 512]
  385. # 将图像调整为512x512大小
  386. t_start = time.time()
  387. # im = img_tensor.permute(1, 2, 0) # [512, 512, 3]
  388. # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  389. # img_ = torch.tensor(im_resized).permute(2, 0, 1)
  390. im = img_tensor.permute(1, 2, 0) # [H, W, 3]
  391. if im.shape != (512, 512, 3):
  392. im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
  393. img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512]
  394. t_end = time.time()
  395. print(f'switch img used:{t_end - t_start}')
  396. with torch.no_grad():
  397. predictions = model([img_.to(device)])
  398. print(predictions)
  399. t_end1 = time.time()
  400. print(f'model test used:{t_end1 - t_end}')
  401. # show_line_optimized(img_, predictions, t_start) # 只画线
  402. show_line(img_, predictions, t_end1)
  403. t_end2 = time.time()
  404. show_box(img_, predictions, t_end2) # 只画kuang
  405. # show_box_or_line(img_, predictions, show_line=True, show_box=True) # 参数确定画什么
  406. # show_box_and_line(img_, predictions, show_line=True, show_box=True) # 一起画 1x2 2张图
  407. # t_start = time.time()
  408. # pred = box_line1(img_, predictions)
  409. # t_end = time.time()
  410. # print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
  411. #
  412. # show_predict(img_, pred, t_start)
  413. if __name__ == '__main__':
  414. t_start = time.time()
  415. print(f'start to predict:{t_start}')
  416. model = linenet_resnet50_fpn().to(device)
  417. # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练24轮结果.pth"
  418. # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
  419. pt_path = r"\\192.168.50.222\share\lm\weight\20250425_112601\weights\best.pth"
  420. # pt_path = r"C:\Users\m2337\Downloads\best_e20.pth"
  421. img_path = r"C:\Users\m2337\Desktop\p\140502.png"
  422. predict(pt_path, model, img_path)
  423. t_end = time.time()
  424. print(f'predict used:{t_end - t_start}')