predict2.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import time
  2. import skimage
  3. from models.line_detect.postprocess import show_predict, show_box, show_box_or_line, show_box_and_line, \
  4. show_line_optimized, show_line
  5. import os
  6. import torch
  7. from PIL import Image
  8. import matplotlib.pyplot as plt
  9. import matplotlib as mpl
  10. import numpy as np
  11. from models.line_detect.line_net import linenet_resnet50_fpn
  12. from torchvision import transforms
  13. # from models.wirenet.postprocess import postprocess
  14. from models.wirenet.postprocess import postprocess
  15. from rtree import index
  16. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  17. def load_best_model(model, save_path, device):
  18. if os.path.exists(save_path):
  19. checkpoint = torch.load(save_path, map_location=device)
  20. model.load_state_dict(checkpoint['model_state_dict'])
  21. # if optimizer is not None:
  22. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  23. epoch = checkpoint['epoch']
  24. loss = checkpoint['loss']
  25. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  26. else:
  27. print(f"No saved model found at {save_path}")
  28. return model
  29. def box_line_(imgs, pred):
  30. im = imgs.permute(1, 2, 0).cpu().numpy()
  31. line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  32. line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  33. # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  34. # line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  35. for idx, box_ in enumerate(pred[0:-1]):
  36. box = box_['boxes'] # 是一个tensor
  37. line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
  38. score = pred[-1]['wires']['score'][idx]
  39. #
  40. # diag = (512 ** 2 + 512 ** 2) ** 0.5
  41. # lines, scores = postprocess(line, score, diag * 0.01, 0, False)
  42. line_ = []
  43. score_ = []
  44. for i in box:
  45. score_max = 0.0
  46. tmp = [[0.0, 0.0], [0.0, 0.0]]
  47. for j in range(len(line)):
  48. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  49. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  50. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  51. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  52. if score[j] > score_max:
  53. tmp = line[j]
  54. score_max = score[j]
  55. line_.append(tmp)
  56. score_.append(score_max)
  57. processed_list = torch.tensor(line_)
  58. pred[idx]['line'] = processed_list
  59. processed_s_list = torch.tensor(score_)
  60. pred[idx]['line_score'] = processed_s_list
  61. return pred
  62. def box_line_optimized(pred):
  63. # 创建R-tree索引
  64. idx = index.Index()
  65. # 将所有线段添加到R-tree中
  66. lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
  67. scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
  68. # 提取并处理所有线段
  69. for idx_line in range(lines.shape[1]): # 遍历2500条线段
  70. line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512 # 转换为numpy数组并调整比例
  71. x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
  72. y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
  73. x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
  74. y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
  75. idx.insert(idx_line, (max(0, x_min - 256), max(0, y_min - 256), min(512, x_max + 256), min(512, y_max + 256)))
  76. for idx_box, box_ in enumerate(pred[0:-1]):
  77. box = box_['boxes'].cpu().numpy() # 确保将张量转换为numpy数组
  78. line_ = []
  79. score_ = []
  80. for i in box:
  81. score_max = 0.0
  82. tmp = [[0.0, 0.0], [0.0, 0.0]]
  83. # 获取与当前box可能相交的所有线段
  84. possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
  85. for j in possible_matches:
  86. line_j = lines[0, j].cpu().numpy() / 128 * 512
  87. if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and # 注意这里交换了x和y
  88. line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  89. line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
  90. line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
  91. if scores[j] > score_max:
  92. tmp = line_j
  93. score_max = scores[j]
  94. line_.append(tmp)
  95. score_.append(score_max)
  96. processed_list = torch.tensor(line_)
  97. pred[idx_box]['line'] = processed_list
  98. processed_s_list = torch.tensor(score_)
  99. pred[idx_box]['line_score'] = processed_s_list
  100. return pred
  101. def predict(pt_path, model, img):
  102. model = load_best_model(model, pt_path, device)
  103. model.eval()
  104. if isinstance(img, str):
  105. img = Image.open(img).convert("RGB")
  106. transform = transforms.ToTensor()
  107. img_tensor = transform(img) # [3, 512, 512]
  108. # img_ = img_tensor
  109. # 将图像调整为512x512大小
  110. t_start = time.time()
  111. im = img_tensor.permute(1, 2, 0) # [512, 512, 3]
  112. im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  113. img_ = torch.tensor(im_resized).permute(2, 0, 1)
  114. t_end = time.time()
  115. print(f'switch img used:{t_end - t_start}')
  116. with torch.no_grad():
  117. predictions = model([img_.to(device)])
  118. # print(predictions)
  119. # lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  120. # scores = predictions[-1]['wires']['score'][0].cpu().numpy() / 128 * 512
  121. # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  122. # nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  123. # print(len(nlines))
  124. # arr = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  125. # unique_subarrays = set()
  126. #
  127. # for i in range(arr.shape[0]):
  128. # for j in range(arr.shape[1]):
  129. # subarray = arr[i, j]
  130. # # 确保 subarray 是一个二维数组
  131. # if subarray.shape != (2,):
  132. # raise ValueError(f"Unexpected shape of subarray at index [{i}, {j}]: {subarray.shape}, expected (2,)")
  133. #
  134. # subarray_tuple = tuple(subarray.tolist())
  135. # unique_subarrays.add(subarray_tuple)
  136. #
  137. # # 计算唯一子数组的数量
  138. # num_unique_subarrays = len(unique_subarrays)
  139. # print(f"共有 {num_unique_subarrays} 个不同的 [2, 2] 子数组")
  140. # show_line_optimized(img_, predictions, t_start) # 只画线
  141. show_line(img_, predictions, t_start)
  142. # show_box(img_, predictions, t_start) # 只画kuang
  143. # show_box_or_line(img_, predictions, show_line=True, show_box=True) # 参数确定画什么
  144. # show_box_and_line(img_, predictions, show_line=True, show_box=True) # 一起画 1x2 2张图
  145. t_start = time.time()
  146. # pred = box_line_optimized(predictions)
  147. pred = box_line_(img_, predictions)
  148. t_end = time.time()
  149. print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
  150. show_predict(img_, pred, t_start)
  151. if __name__ == '__main__':
  152. t_start = time.time()
  153. print(f'start to predict:{t_start}')
  154. model = linenet_resnet50_fpn().to(device)
  155. # pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
  156. # pt_path = r'D:\python\PycharmProjects\linenet_wts\r50fpn_wts_e350\best.pth'
  157. pt_path = r'D:\python\PycharmProjects\20250214\weight\linenet_wts\r50fpn_wts_e350\best.pth'
  158. # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-43-13_SaveImage.png' # 工件图
  159. # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png' # wireframe图
  160. img_path = r'C:\Users\m2337\Desktop\p\2025-01-03-09-34-32_SaveImage_adjust_brightness_contrast.jpg'
  161. predict(pt_path, model, img_path)
  162. t_end = time.time()
  163. print(f'predict used:{t_end - t_start}')