predict_zjf.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import time
  2. import skimage
  3. from models.line_detect.postprocess import show_predict
  4. import os
  5. import torch
  6. from PIL import Image
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. from models.line_detect.line_net import linenet_resnet50_fpn
  10. from torchvision import transforms
  11. from rtree import index
  12. import multiprocessing as mp
  13. # 设置多进程启动方式为 'spawn'
  14. mp.set_start_method('spawn', force=True)
  15. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  16. def load_best_model(model, save_path, device):
  17. if os.path.exists(save_path):
  18. checkpoint = torch.load(save_path, map_location=device)
  19. model.load_state_dict(checkpoint['model_state_dict'])
  20. epoch = checkpoint['epoch']
  21. loss = checkpoint['loss']
  22. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  23. else:
  24. print(f"No saved model found at {save_path}")
  25. return model
  26. def process_box(box, lines, scores):
  27. valid_lines = [] # 存储有效的线段
  28. valid_scores = [] # 存储有效的分数
  29. for i in box:
  30. best_line = None
  31. max_length = 0.0
  32. # 遍历所有线段
  33. for j in range(lines.shape[1]):
  34. line_j = lines[0, j].cpu().numpy() / 128 * 512
  35. # 检查线段是否完全在box内
  36. if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
  37. line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  38. line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
  39. line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
  40. # 计算线段长度
  41. length = np.linalg.norm(line_j[0] - line_j[1])
  42. if length > max_length:
  43. best_line = line_j
  44. max_length = length
  45. # 如果找到有效的线段,则添加到结果中
  46. if best_line is not None:
  47. valid_lines.append(best_line)
  48. valid_scores.append(max_length) # 使用线段长度作为分数
  49. return valid_lines, valid_scores
  50. def box_line_optimized_parallel(pred):
  51. # 提取所有线段和分数
  52. lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
  53. scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
  54. # 初始化存储结果的列表
  55. filtered_pred = []
  56. # 使用多进程并行处理每个box
  57. boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # 所有box
  58. num_processes = min(mp.cpu_count(), len(boxes)) # 使用可用的核心数
  59. with mp.Pool(processes=num_processes) as pool:
  60. results = pool.starmap(
  61. process_box,
  62. [(box, lines, scores) for box in boxes]
  63. )
  64. # 更新预测结果
  65. for idx_box, (valid_lines, valid_scores) in enumerate(results):
  66. if valid_lines:
  67. pred[idx_box]['line'] = torch.tensor(valid_lines)
  68. pred[idx_box]['line_score'] = torch.tensor(valid_scores)
  69. filtered_pred.append(pred[idx_box])
  70. return filtered_pred
  71. def predict(pt_path, model, img):
  72. model = load_best_model(model, pt_path, device)
  73. model.eval()
  74. if isinstance(img, str):
  75. img = Image.open(img).convert("RGB")
  76. transform = transforms.ToTensor()
  77. img_tensor = transform(img)
  78. im = img_tensor.permute(1, 2, 0) # [512, 512, 3]
  79. im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  80. img_tensor = torch.tensor(im_resized).permute(2, 0, 1)
  81. with torch.no_grad():
  82. t_start = time.time()
  83. predictions = model([img_tensor.to(device)])
  84. t_end = time.time()
  85. print(f'Prediction used: {t_end - t_start:.4f} seconds')
  86. boxes = predictions[0]['boxes'].shape
  87. lines = predictions[-1]['wires']['lines'].shape
  88. lines_scores = predictions[-1]['wires']['score'].shape
  89. print(f'Predictions - boxes: {boxes}, lines: {lines}, lines_scores: {lines_scores}')
  90. t_start = time.time()
  91. pred = box_line_optimized_parallel(predictions)
  92. t_end = time.time()
  93. print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
  94. # 检查 pred 是否为空
  95. if not pred:
  96. print("No valid predictions found. Skipping visualization.")
  97. return
  98. # 只绘制有效的线段
  99. show_predict(img_tensor, pred, t_start)
  100. if __name__ == '__main__':
  101. t_start = time.time()
  102. print(f'Start to predict: {t_start}')
  103. model = linenet_resnet50_fpn().to(device)
  104. pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
  105. # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png' # 工件图
  106. # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png' # wireframe图
  107. img_path = r'C:\Users\m2337\Desktop\9.jpg'
  108. predict(pt_path, model, img_path)
  109. t_end = time.time()
  110. print(f'Total prediction time: {t_end - t_start:.4f} seconds')