predict2.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. # import time
  2. #
  3. # from models.line_detect.postprocess import show_predict
  4. # import os
  5. #
  6. # import torch
  7. # from PIL import Image
  8. # import matplotlib.pyplot as plt
  9. # import matplotlib as mpl
  10. # import numpy as np
  11. # from models.line_detect.line_net import linenet_resnet50_fpn
  12. # from torchvision import transforms
  13. # from rtree import index
  14. # # from models.wirenet.postprocess import postprocess
  15. # from models.wirenet.postprocess import postprocess
  16. #
  17. # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  18. #
  19. #
  20. # def load_best_model(model, save_path, device):
  21. # if os.path.exists(save_path):
  22. # checkpoint = torch.load(save_path, map_location=device)
  23. # model.load_state_dict(checkpoint['model_state_dict'])
  24. # # if optimizer is not None:
  25. # # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  26. # epoch = checkpoint['epoch']
  27. # loss = checkpoint['loss']
  28. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  29. # else:
  30. # print(f"No saved model found at {save_path}")
  31. # return model
  32. #
  33. #
  34. # def box_line_optimized(pred):
  35. # # 创建R-tree索引
  36. # idx = index.Index()
  37. #
  38. # # 将所有线段添加到R-tree中
  39. # lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
  40. # scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
  41. #
  42. # # 提取并处理所有线段
  43. # for idx_line in range(lines.shape[1]): # 遍历2500条线段
  44. # line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512 # 转换为numpy数组并调整比例
  45. # x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
  46. # y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
  47. # x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
  48. # y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
  49. # idx.insert(idx_line, (x_min, y_min, x_max, y_max))
  50. #
  51. # for idx_box, box_ in enumerate(pred[0:-1]):
  52. # box = box_['boxes'].cpu().numpy() # 确保将张量转换为numpy数组
  53. # line_ = []
  54. # score_ = []
  55. #
  56. # for i in box:
  57. # score_max = 0.0
  58. # tmp = [[0.0, 0.0], [0.0, 0.0]]
  59. #
  60. # # 获取与当前box可能相交的所有线段
  61. # possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
  62. #
  63. # for j in possible_matches:
  64. # line_j = lines[0, j].cpu().numpy() / 128 * 512
  65. # if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and # 注意这里交换了x和y
  66. # line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  67. # line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
  68. # line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
  69. #
  70. # if scores[j] > score_max:
  71. # tmp = line_j
  72. # score_max = scores[j]
  73. #
  74. # line_.append(tmp)
  75. # score_.append(score_max)
  76. #
  77. # processed_list = torch.tensor(line_)
  78. # pred[idx_box]['line'] = processed_list
  79. #
  80. # processed_s_list = torch.tensor(score_)
  81. # pred[idx_box]['line_score'] = processed_s_list
  82. #
  83. # return pred
  84. #
  85. # # def box_line_(pred):
  86. # # for idx, box_ in enumerate(pred[0:-1]):
  87. # # box = box_['boxes'] # 是一个tensor
  88. # # line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
  89. # # score = pred[-1]['wires']['score'][idx]
  90. # # line_ = []
  91. # # score_ = []
  92. # #
  93. # # for i in box:
  94. # # score_max = 0.0
  95. # # tmp = [[0.0, 0.0], [0.0, 0.0]]
  96. # #
  97. # # for j in range(len(line)):
  98. # # if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  99. # # line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  100. # # line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  101. # # line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  102. # #
  103. # # if score[j] > score_max:
  104. # # tmp = line[j]
  105. # # score_max = score[j]
  106. # # line_.append(tmp)
  107. # # score_.append(score_max)
  108. # # processed_list = torch.tensor(line_)
  109. # # pred[idx]['line'] = processed_list
  110. # #
  111. # # processed_s_list = torch.tensor(score_)
  112. # # pred[idx]['line_score'] = processed_s_list
  113. # # return pred
  114. #
  115. #
  116. # def predict(pt_path, model, img):
  117. # model = load_best_model(model, pt_path, device)
  118. #
  119. # model.eval()
  120. #
  121. # if isinstance(img, str):
  122. # img = Image.open(img).convert("RGB")
  123. #
  124. # transform = transforms.ToTensor()
  125. # img_tensor = transform(img)
  126. #
  127. # with torch.no_grad():
  128. # t_start = time.time()
  129. # predictions = model([img_tensor.to(device)])
  130. # t_end=time.time()
  131. # print(f'predict used:{t_end-t_start}')
  132. # # print(f'predictions:{predictions}')
  133. # boxes=predictions[0]['boxes'].shape
  134. # lines=predictions[-1]['wires']['lines'].shape
  135. # lines_scores=predictions[-1]['wires']['score'].shape
  136. # print(f'predictions boxes:{boxes},lines:{lines},lines_scores:{lines_scores}')
  137. # t_start=time.time()
  138. # pred = box_line_optimized(predictions)
  139. # t_end=time.time()
  140. # print(f'matched boxes and lines used:{t_end - t_start}')
  141. # # print(f'pred:{pred[0]}')
  142. # show_predict(img_tensor, pred, t_start)
  143. #
  144. #
  145. # if __name__ == '__main__':
  146. # t_start = time.time()
  147. # print(f'start to predict:{t_start}')
  148. # model = linenet_resnet50_fpn().to(device)
  149. # pt_path = r"F:\BaiduNetdiskDownload\resnet50_best_e8.pth"
  150. # img_path = r"I:\datasets\wirenet_1000\images\val\00035148_0.png"
  151. # predict(pt_path, model, img_path)
  152. # t_end = time.time()
  153. # # print(f'predict used:{t_end - t_start}')
  154. import time
  155. import cv2
  156. import skimage
  157. import os
  158. import torch
  159. from PIL import Image
  160. import matplotlib.pyplot as plt
  161. import matplotlib as mpl
  162. import numpy as np
  163. from models.line_detect.line_net import linenet_resnet50_fpn
  164. from torchvision import transforms
  165. # from models.wirenet.postprocess import postprocess
  166. from models.wirenet.postprocess import postprocess
  167. from rtree import index
  168. import imageio
  169. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  170. def load_best_model(model, save_path, device):
  171. if os.path.exists(save_path):
  172. checkpoint = torch.load(save_path, map_location=device)
  173. model.load_state_dict(checkpoint['model_state_dict'])
  174. # if optimizer is not None:
  175. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  176. # epoch = checkpoint['epoch']
  177. # loss = checkpoint['loss']
  178. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  179. else:
  180. print(f"No saved model found at {save_path}")
  181. return model
  182. def box_line_(imgs, pred, length=False): # 默认置信度
  183. im = imgs.permute(1, 2, 0).cpu().numpy()
  184. line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  185. line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  186. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  187. line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  188. for idx, box_ in enumerate(pred[0:-1]):
  189. box = box_['boxes'] # 是一个tensor
  190. line_ = []
  191. score_ = []
  192. for i in box:
  193. score_max = 0.0
  194. tmp = [[0.0, 0.0], [0.0, 0.0]]
  195. for j in range(len(line)):
  196. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  197. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  198. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  199. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  200. if score[j] > score_max:
  201. tmp = line[j]
  202. score_max = score[j]
  203. line_.append(tmp)
  204. score_.append(score_max)
  205. processed_list = torch.tensor(line_)
  206. pred[idx]['line'] = processed_list
  207. processed_s_list = torch.tensor(score_)
  208. pred[idx]['line_score'] = processed_s_list
  209. return pred
  210. # box内无线段时,选box内点组成线段最长的 两个点组成的线段返回
  211. def box_line1(imgs, pred): # 默认置信度
  212. im = imgs.permute(1, 2, 0).cpu().numpy()
  213. line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  214. line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  215. points = pred[-1]['wires']['juncs'].cpu().numpy()[0] / 128 * 512
  216. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  217. line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  218. for idx, box_ in enumerate(pred[0:-1]):
  219. box = box_['boxes'] # 是一个tensor
  220. line_ = []
  221. score_ = []
  222. for i in box:
  223. score_max = 0.0
  224. tmp = [[0.0, 0.0], [0.0, 0.0]]
  225. for j in range(len(line)):
  226. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  227. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  228. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  229. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  230. if score[j] > score_max:
  231. tmp = line[j]
  232. score_max = score[j]
  233. # # 如果 box 内无线段,则通过点坐标找最长线段
  234. # if score_max == 0.0: # 说明 box 内无线段
  235. # box_points = [
  236. # [x, y] for x, y in points
  237. # if i[0] <= y <= i[2] and i[1] <= x <= i[3]
  238. # ]
  239. #
  240. # if len(box_points) >= 2: # 至少需要两个点才能组成线段
  241. # max_distance = 0.0
  242. # longest_segment = [[0.0, 0.0], [0.0, 0.0]]
  243. #
  244. # # 找出 box 内点组成的最长线段
  245. # for p1 in box_points:
  246. # for p2 in box_points:
  247. # if p1 != p2:
  248. # distance = np.linalg.norm(np.array(p1) - np.array(p2))
  249. # if distance > max_distance:
  250. # max_distance = distance
  251. # longest_segment = [p1, p2]
  252. #
  253. # tmp = longest_segment
  254. # score_max = 0.0 # 默认分数为 0.0
  255. line_.append(tmp)
  256. score_.append(score_max)
  257. processed_list = torch.tensor(line_)
  258. pred[idx]['line'] = processed_list
  259. processed_s_list = torch.tensor(score_)
  260. pred[idx]['line_score'] = processed_s_list
  261. return pred
  262. def show_box(imgs, pred, t_start):
  263. col = [
  264. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  265. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  266. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  267. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  268. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  269. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  270. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  271. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  272. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  273. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  274. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  275. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  276. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  277. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  278. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  279. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  280. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  281. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  282. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  283. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  284. ]
  285. # print(len(col))
  286. im = imgs.permute(1, 2, 0)
  287. boxes = pred[0]['boxes'].cpu().numpy()
  288. box_scores = pred[0]['scores'].cpu().numpy()
  289. # 可视化预测结
  290. fig, ax = plt.subplots(figsize=(10, 10))
  291. ax.imshow(np.array(im))
  292. for idx, box in enumerate(boxes):
  293. # if box_scores[idx] < 0.7:
  294. # continue
  295. x0, y0, x1, y1 = box
  296. ax.add_patch(
  297. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  298. t_end = time.time()
  299. print(f'show_box used:{t_end - t_start}')
  300. plt.show()
  301. def show_predict(imgs, pred, t_start):
  302. col = [
  303. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  304. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  305. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  306. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  307. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  308. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  309. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  310. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  311. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  312. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  313. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  314. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  315. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  316. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  317. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  318. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  319. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  320. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  321. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  322. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  323. ]
  324. im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3]
  325. boxes = pred[0]['boxes'].cpu().numpy()
  326. box_scores = pred[0]['scores'].cpu().numpy()
  327. lines = pred[0]['line'].cpu().numpy()
  328. line_scores = pred[0]['line_score'].cpu().numpy()
  329. # 可视化预测结
  330. fig, ax = plt.subplots(figsize=(10, 10))
  331. ax.imshow(np.array(im))
  332. idx = 0
  333. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  334. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  335. x0, y0, x1, y1 = box
  336. # 框中无线的跳过
  337. if np.array_equal(line, tmp):
  338. continue
  339. a, b = line
  340. if box_score >= 0 or line_score >= 0:
  341. ax.add_patch(
  342. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  343. ax.scatter(a[1], a[0], c='#871F78', s=10)
  344. ax.scatter(b[1], b[0], c='#871F78', s=10)
  345. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  346. idx = idx + 1
  347. t_end = time.time()
  348. print(f'predict used:{t_end - t_start}')
  349. plt.show()
  350. def show_line(imgs, pred, t_start):
  351. im = imgs.permute(1, 2, 0)
  352. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  353. # print(pred[-1]['wires']['score'])
  354. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  355. t1 = time.time()
  356. print(f't1:{t1 - t_start}')
  357. for i in range(1, len(lines)):
  358. if (lines[i] == lines[0]).all():
  359. lines = lines[:i]
  360. scores = scores[:i]
  361. break
  362. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  363. line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
  364. print(f'lines num:{len(line)}')
  365. t2 = time.time()
  366. print(f't1:{t2 - t1}')
  367. # 可视化预测结
  368. fig, ax = plt.subplots(figsize=(10, 10))
  369. ax.imshow(np.array(im))
  370. for idx, (a, b) in enumerate(line):
  371. # if line_score[idx] < 0.7:
  372. # continue
  373. ax.scatter(a[1], a[0], c='#871F78', s=2)
  374. ax.scatter(b[1], b[0], c='#871F78', s=2)
  375. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  376. t_end = time.time()
  377. print(f'show_line used:{t_end - t_start}')
  378. plt.show()
  379. def predict(pt_path, model, img):
  380. model = load_best_model(model, pt_path, device)
  381. model.eval()
  382. # if isinstance(img, str):
  383. # img = Image.open(img).convert("RGB")
  384. print(imageio.v3.imread(img_path).shape)
  385. img = imageio.v3.imread(img_path).reshape(2114, 1332, 1)
  386. img_3channel = np.zeros((2114, 1332, 3), dtype=img.dtype)
  387. img_3channel[:, :, 2] = img[:, :, 0]
  388. img = torch.from_numpy(img_3channel).permute(2, 0, 1)
  389. img_tensor = img
  390. # transform = transforms.ToTensor()
  391. # img_tensor = transform(img) # [3, 512, 512]
  392. # 将图像调整为512x512大小
  393. t_start = time.time()
  394. # im = img_tensor.permute(1, 2, 0) # [512, 512, 3]
  395. # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  396. # img_ = torch.tensor(im_resized).permute(2, 0, 1)
  397. im = img_tensor.permute(1, 2, 0) # [H, W, 3]
  398. if im.shape != (512, 512, 3):
  399. im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_NEAREST)
  400. img_ = torch.tensor(im).permute(2, 0, 1) # [3, 512, 512]
  401. t_end = time.time()
  402. print(f'switch img used:{t_end - t_start}')
  403. with torch.no_grad():
  404. predictions = model([img_.to(device)])
  405. print(predictions)
  406. t_end1 = time.time()
  407. print(f'model test used:{t_end1 - t_end}')
  408. # show_line_optimized(img_, predictions, t_start) # 只画线
  409. show_line(img_, predictions, t_end1)
  410. t_end2 = time.time()
  411. show_box(img_, predictions, t_end2) # 只画kuang
  412. # show_box_or_line(img_, predictions, show_line=True, show_box=True) # 参数确定画什么
  413. # show_box_and_line(img_, predictions, show_line=True, show_box=True) # 一起画 1x2 2张图
  414. # t_start = time.time()
  415. # pred = box_line1(img_, predictions)
  416. # t_end = time.time()
  417. # print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
  418. #
  419. # show_predict(img_, pred, t_start)
  420. from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, get_line_net_efficientnetv2
  421. if __name__ == '__main__':
  422. t_start = time.time()
  423. print(f'start to predict:{t_start}')
  424. # model = linenet_resnet50_fpn().to(device)
  425. model = get_line_net_efficientnetv2(2, pretrained_backbone=True).to(device)
  426. # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练24轮结果.pth"
  427. # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
  428. pt_path = r"\\192.168.50.222\share\lm\weight\20250510_155941\weights\best.pth"
  429. # pt_path = r"C:\Users\m2337\Downloads\best_e20.pth"
  430. img_path = r"D:\python\PycharmProjects\20250214\cloud\新建文件夹\depth_map.tiff"
  431. predict(pt_path, model, img_path)
  432. t_end = time.time()
  433. print(f'predict used:{t_end - t_start}')