predict2.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. # import time
  2. #
  3. # from models.line_detect.postprocess import show_predict
  4. # import os
  5. #
  6. # import torch
  7. # from PIL import Image
  8. # import matplotlib.pyplot as plt
  9. # import matplotlib as mpl
  10. # import numpy as np
  11. # from models.line_detect.line_net import linenet_resnet50_fpn
  12. # from torchvision import transforms
  13. # from rtree import index
  14. # # from models.wirenet.postprocess import postprocess
  15. # from models.wirenet.postprocess import postprocess
  16. #
  17. # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  18. #
  19. #
  20. # def load_best_model(model, save_path, device):
  21. # if os.path.exists(save_path):
  22. # checkpoint = torch.load(save_path, map_location=device)
  23. # model.load_state_dict(checkpoint['model_state_dict'])
  24. # # if optimizer is not None:
  25. # # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  26. # epoch = checkpoint['epoch']
  27. # loss = checkpoint['loss']
  28. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  29. # else:
  30. # print(f"No saved model found at {save_path}")
  31. # return model
  32. #
  33. #
  34. # def box_line_optimized(pred):
  35. # # 创建R-tree索引
  36. # idx = index.Index()
  37. #
  38. # # 将所有线段添加到R-tree中
  39. # lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
  40. # scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
  41. #
  42. # # 提取并处理所有线段
  43. # for idx_line in range(lines.shape[1]): # 遍历2500条线段
  44. # line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512 # 转换为numpy数组并调整比例
  45. # x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
  46. # y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
  47. # x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
  48. # y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
  49. # idx.insert(idx_line, (x_min, y_min, x_max, y_max))
  50. #
  51. # for idx_box, box_ in enumerate(pred[0:-1]):
  52. # box = box_['boxes'].cpu().numpy() # 确保将张量转换为numpy数组
  53. # line_ = []
  54. # score_ = []
  55. #
  56. # for i in box:
  57. # score_max = 0.0
  58. # tmp = [[0.0, 0.0], [0.0, 0.0]]
  59. #
  60. # # 获取与当前box可能相交的所有线段
  61. # possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
  62. #
  63. # for j in possible_matches:
  64. # line_j = lines[0, j].cpu().numpy() / 128 * 512
  65. # if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and # 注意这里交换了x和y
  66. # line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  67. # line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
  68. # line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
  69. #
  70. # if scores[j] > score_max:
  71. # tmp = line_j
  72. # score_max = scores[j]
  73. #
  74. # line_.append(tmp)
  75. # score_.append(score_max)
  76. #
  77. # processed_list = torch.tensor(line_)
  78. # pred[idx_box]['line'] = processed_list
  79. #
  80. # processed_s_list = torch.tensor(score_)
  81. # pred[idx_box]['line_score'] = processed_s_list
  82. #
  83. # return pred
  84. #
  85. # # def box_line_(pred):
  86. # # for idx, box_ in enumerate(pred[0:-1]):
  87. # # box = box_['boxes'] # 是一个tensor
  88. # # line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
  89. # # score = pred[-1]['wires']['score'][idx]
  90. # # line_ = []
  91. # # score_ = []
  92. # #
  93. # # for i in box:
  94. # # score_max = 0.0
  95. # # tmp = [[0.0, 0.0], [0.0, 0.0]]
  96. # #
  97. # # for j in range(len(line)):
  98. # # if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  99. # # line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  100. # # line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  101. # # line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  102. # #
  103. # # if score[j] > score_max:
  104. # # tmp = line[j]
  105. # # score_max = score[j]
  106. # # line_.append(tmp)
  107. # # score_.append(score_max)
  108. # # processed_list = torch.tensor(line_)
  109. # # pred[idx]['line'] = processed_list
  110. # #
  111. # # processed_s_list = torch.tensor(score_)
  112. # # pred[idx]['line_score'] = processed_s_list
  113. # # return pred
  114. #
  115. #
  116. # def predict(pt_path, model, img):
  117. # model = load_best_model(model, pt_path, device)
  118. #
  119. # model.eval()
  120. #
  121. # if isinstance(img, str):
  122. # img = Image.open(img).convert("RGB")
  123. #
  124. # transform = transforms.ToTensor()
  125. # img_tensor = transform(img)
  126. #
  127. # with torch.no_grad():
  128. # t_start = time.time()
  129. # predictions = model([img_tensor.to(device)])
  130. # t_end=time.time()
  131. # print(f'predict used:{t_end-t_start}')
  132. # # print(f'predictions:{predictions}')
  133. # boxes=predictions[0]['boxes'].shape
  134. # lines=predictions[-1]['wires']['lines'].shape
  135. # lines_scores=predictions[-1]['wires']['score'].shape
  136. # print(f'predictions boxes:{boxes},lines:{lines},lines_scores:{lines_scores}')
  137. # t_start=time.time()
  138. # pred = box_line_optimized(predictions)
  139. # t_end=time.time()
  140. # print(f'matched boxes and lines used:{t_end - t_start}')
  141. # # print(f'pred:{pred[0]}')
  142. # show_predict(img_tensor, pred, t_start)
  143. #
  144. #
  145. # if __name__ == '__main__':
  146. # t_start = time.time()
  147. # print(f'start to predict:{t_start}')
  148. # model = linenet_resnet50_fpn().to(device)
  149. # pt_path = r"F:\BaiduNetdiskDownload\resnet50_best_e8.pth"
  150. # img_path = r"I:\datasets\wirenet_1000\images\val\00035148_0.png"
  151. # predict(pt_path, model, img_path)
  152. # t_end = time.time()
  153. # # print(f'predict used:{t_end - t_start}')
  154. import time
  155. import skimage
  156. import os
  157. import torch
  158. from PIL import Image
  159. import matplotlib.pyplot as plt
  160. import matplotlib as mpl
  161. import numpy as np
  162. from models.line_detect.line_net import linenet_resnet50_fpn
  163. from torchvision import transforms
  164. # from models.wirenet.postprocess import postprocess
  165. from models.wirenet.postprocess import postprocess
  166. from rtree import index
  167. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  168. def load_best_model(model, save_path, device):
  169. if os.path.exists(save_path):
  170. checkpoint = torch.load(save_path, map_location=device)
  171. model.load_state_dict(checkpoint['model_state_dict'])
  172. # if optimizer is not None:
  173. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  174. # epoch = checkpoint['epoch']
  175. # loss = checkpoint['loss']
  176. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  177. else:
  178. print(f"No saved model found at {save_path}")
  179. return model
  180. def box_line_(imgs, pred, length=False): # 默认置信度
  181. im = imgs.permute(1, 2, 0).cpu().numpy()
  182. line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  183. line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  184. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  185. line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  186. for idx, box_ in enumerate(pred[0:-1]):
  187. box = box_['boxes'] # 是一个tensor
  188. line_ = []
  189. score_ = []
  190. for i in box:
  191. score_max = 0.0
  192. tmp = [[0.0, 0.0], [0.0, 0.0]]
  193. for j in range(len(line)):
  194. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  195. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  196. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  197. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  198. if score[j] > score_max:
  199. tmp = line[j]
  200. score_max = score[j]
  201. line_.append(tmp)
  202. score_.append(score_max)
  203. processed_list = torch.tensor(line_)
  204. pred[idx]['line'] = processed_list
  205. processed_s_list = torch.tensor(score_)
  206. pred[idx]['line_score'] = processed_s_list
  207. return pred
  208. # box内无线段时,选box内点组成线段最长的 两个点组成的线段返回
  209. def box_line1(imgs, pred, length=False): # 默认置信度
  210. im = imgs.permute(1, 2, 0).cpu().numpy()
  211. line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  212. line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  213. points=pred[-1]['wires']['juncs'].cpu().numpy()[0]/ 128 * 512
  214. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  215. line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  216. for idx, box_ in enumerate(pred[0:-1]):
  217. box = box_['boxes'] # 是一个tensor
  218. line_ = []
  219. score_ = []
  220. for i in box:
  221. score_max = 0.0
  222. tmp = [[0.0, 0.0], [0.0, 0.0]]
  223. for j in range(len(line)):
  224. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  225. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  226. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  227. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  228. if score[j] > score_max:
  229. tmp = line[j]
  230. score_max = score[j]
  231. # 如果 box 内无线段,则通过点坐标找最长线段
  232. if score_max == 0.0: # 说明 box 内无线段
  233. box_points = [
  234. [x, y] for x, y in points
  235. if i[0] <= y <= i[2] and i[1] <= x <= i[3]
  236. ]
  237. if len(box_points) >= 2: # 至少需要两个点才能组成线段
  238. max_distance = 0.0
  239. longest_segment = [[0.0, 0.0], [0.0, 0.0]]
  240. # 找出 box 内点组成的最长线段
  241. for p1 in box_points:
  242. for p2 in box_points:
  243. if p1 != p2:
  244. distance = np.linalg.norm(np.array(p1) - np.array(p2))
  245. if distance > max_distance:
  246. max_distance = distance
  247. longest_segment = [p1, p2]
  248. tmp = longest_segment
  249. score_max = 0.0 # 默认分数为 0.0
  250. line_.append(tmp)
  251. score_.append(score_max)
  252. processed_list = torch.tensor(line_)
  253. pred[idx]['line'] = processed_list
  254. processed_s_list = torch.tensor(score_)
  255. pred[idx]['line_score'] = processed_s_list
  256. return pred
  257. def show_box(imgs, pred, t_start):
  258. col = [
  259. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  260. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  261. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  262. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  263. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  264. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  265. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  266. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  267. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  268. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  269. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  270. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  271. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  272. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  273. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  274. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  275. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  276. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  277. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  278. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  279. ]
  280. # print(len(col))
  281. im = imgs.permute(1, 2, 0)
  282. boxes = pred[0]['boxes'].cpu().numpy()
  283. box_scores = pred[0]['scores'].cpu().numpy()
  284. # 可视化预测结
  285. fig, ax = plt.subplots(figsize=(10, 10))
  286. ax.imshow(np.array(im))
  287. for idx, box in enumerate(boxes):
  288. # if box_scores[idx] < 0.7:
  289. # continue
  290. x0, y0, x1, y1 = box
  291. ax.add_patch(
  292. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  293. t_end = time.time()
  294. print(f'show_box used:{t_end - t_start}')
  295. plt.show()
  296. def show_predict(imgs, pred, t_start):
  297. col = [
  298. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  299. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  300. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  301. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  302. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  303. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  304. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  305. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  306. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  307. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  308. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  309. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  310. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  311. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  312. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  313. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  314. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  315. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  316. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  317. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  318. ]
  319. im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3]
  320. boxes = pred[0]['boxes'].cpu().numpy()
  321. box_scores = pred[0]['scores'].cpu().numpy()
  322. lines = pred[0]['line'].cpu().numpy()
  323. line_scores = pred[0]['line_score'].cpu().numpy()
  324. # 可视化预测结
  325. fig, ax = plt.subplots(figsize=(10, 10))
  326. ax.imshow(np.array(im))
  327. idx = 0
  328. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  329. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  330. x0, y0, x1, y1 = box
  331. # 框中无线的跳过
  332. if np.array_equal(line, tmp):
  333. continue
  334. a, b = line
  335. if box_score >= 0 or line_score >= 0:
  336. ax.add_patch(
  337. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  338. ax.scatter(a[1], a[0], c='#871F78', s=10)
  339. ax.scatter(b[1], b[0], c='#871F78', s=10)
  340. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  341. idx = idx + 1
  342. t_end = time.time()
  343. print(f'predict used:{t_end - t_start}')
  344. plt.show()
  345. def show_line(imgs, pred, t_start):
  346. im = imgs.permute(1, 2, 0)
  347. line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  348. # print(pred[-1]['wires']['score'])
  349. line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
  350. t1 = time.time()
  351. print(f't1:{t1 - t_start}')
  352. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  353. line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
  354. print(f'lines num:{len(line)}')
  355. t2 = time.time()
  356. print(f't1:{t2 - t1}')
  357. # 可视化预测结
  358. fig, ax = plt.subplots(figsize=(10, 10))
  359. ax.imshow(np.array(im))
  360. for idx, (a, b) in enumerate(line):
  361. # if line_score[idx] < 0.7:
  362. # continue
  363. ax.scatter(a[1], a[0], c='#871F78', s=2)
  364. ax.scatter(b[1], b[0], c='#871F78', s=2)
  365. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  366. t_end = time.time()
  367. print(f'show_line used:{t_end - t_start}')
  368. plt.show()
  369. def predict(pt_path, model, img):
  370. model = load_best_model(model, pt_path, device)
  371. model.eval()
  372. if isinstance(img, str):
  373. img = Image.open(img).convert("RGB")
  374. transform = transforms.ToTensor()
  375. img_tensor = transform(img) # [3, 512, 512]
  376. # 将图像调整为512x512大小
  377. t_start = time.time()
  378. im = img_tensor.permute(1, 2, 0) # [512, 512, 3]
  379. im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  380. img_ = torch.tensor(im_resized).permute(2, 0, 1)
  381. t_end = time.time()
  382. print(f'switch img used:{t_end - t_start}')
  383. with torch.no_grad():
  384. predictions = model([img_.to(device)])
  385. print(predictions)
  386. t_end1 = time.time()
  387. print(f'model test used:{t_end1 - t_end}')
  388. # show_line_optimized(img_, predictions, t_start) # 只画线
  389. show_line(img_, predictions, t_end1)
  390. t_end2 = time.time()
  391. show_box(img_, predictions, t_end2) # 只画kuang
  392. # show_box_or_line(img_, predictions, show_line=True, show_box=True) # 参数确定画什么
  393. # show_box_and_line(img_, predictions, show_line=True, show_box=True) # 一起画 1x2 2张图
  394. t_start = time.time()
  395. pred = box_line_(img_, predictions)
  396. t_end = time.time()
  397. print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
  398. show_predict(img_, pred, t_start)
  399. if __name__ == '__main__':
  400. t_start = time.time()
  401. print(f'start to predict:{t_start}')
  402. model = linenet_resnet50_fpn().to(device)
  403. # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练24轮结果.pth"
  404. # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
  405. pt_path = r"C:\Users\m2337\Downloads\best_e150.pth"
  406. # pt_path = r"C:\Users\m2337\Downloads\best_e20.pth"
  407. img_path = r"C:\Users\m2337\Desktop\p\新建文件夹\2025-03-25-16-10-00_SaveLeftImage.png"
  408. predict(pt_path, model, img_path)
  409. t_end = time.time()
  410. print(f'predict used:{t_end - t_start}')