predict_zjf(1).py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import time
  2. import skimage
  3. from models.line_detect.postprocess import show_predict
  4. import os
  5. import torch
  6. from PIL import Image
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. from models.line_detect.line_net import linenet_resnet50_fpn
  10. from torchvision import transforms
  11. from rtree import index
  12. import multiprocessing as mp
  13. # 设置多进程启动方式为 'spawn'
  14. mp.set_start_method('spawn', force=True)
  15. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  16. print(f'device:{device}')
  17. def load_best_model(model, save_path, device):
  18. if os.path.exists(save_path):
  19. checkpoint = torch.load(save_path, map_location=device)
  20. model.load_state_dict(checkpoint['model_state_dict'])
  21. epoch = checkpoint['epoch']
  22. loss = checkpoint['loss']
  23. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  24. else:
  25. print(f"No saved model found at {save_path}")
  26. return model
  27. def process_box(box, lines, scores):
  28. valid_lines = [] # 存储有效的线段
  29. valid_scores = [] # 存储有效的分数
  30. for i in box:
  31. best_line = None
  32. max_length = 0.0
  33. # 遍历所有线段
  34. for j in range(lines.shape[1]):
  35. line_j = lines[0, j].cpu().numpy() / 128 * 512
  36. # 检查线段是否完全在box内
  37. if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
  38. line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  39. line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
  40. line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
  41. # 计算线段长度
  42. length = np.linalg.norm(line_j[0] - line_j[1])
  43. if length > max_length:
  44. best_line = line_j
  45. max_length = length
  46. # 如果找到有效的线段,则添加到结果中
  47. if best_line is not None:
  48. valid_lines.append(best_line)
  49. valid_scores.append(max_length) # 使用线段长度作为分数
  50. return valid_lines, valid_scores
  51. def box_line_optimized_parallel(pred):
  52. # 提取所有线段和分数
  53. lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
  54. scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
  55. # 初始化存储结果的列表
  56. filtered_pred = []
  57. # 使用多进程并行处理每个box
  58. boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # 所有box
  59. num_processes = min(mp.cpu_count(), len(boxes)) # 使用可用的核心数
  60. with mp.Pool(processes=num_processes) as pool:
  61. results = pool.starmap(
  62. process_box,
  63. [(box, lines, scores) for box in boxes]
  64. )
  65. # 更新预测结果
  66. for idx_box, (valid_lines, valid_scores) in enumerate(results):
  67. if valid_lines:
  68. pred[idx_box]['line'] = torch.tensor(valid_lines)
  69. pred[idx_box]['line_score'] = torch.tensor(valid_scores)
  70. filtered_pred.append(pred[idx_box])
  71. return filtered_pred
  72. def predict(pt_path, model, img):
  73. model = load_best_model(model, pt_path, device)
  74. model.eval()
  75. if isinstance(img, str):
  76. img = Image.open(img).convert("RGB")
  77. transform = transforms.ToTensor()
  78. img_tensor = transform(img)
  79. im = img_tensor.permute(1, 2, 0) # [512, 512, 3]
  80. im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  81. img_tensor = torch.tensor(im_resized).permute(2, 0, 1)
  82. with torch.no_grad():
  83. t_start = time.time()
  84. predictions = model([img_tensor.to(device)])
  85. t_end = time.time()
  86. print(f'Prediction used: {t_end - t_start:.4f} seconds')
  87. boxes = predictions[0]['boxes'].shape
  88. lines = predictions[-1]['wires']['lines'].shape
  89. lines_scores = predictions[-1]['wires']['score'].shape
  90. print(f'Predictions - boxes: {boxes}, lines: {lines}, lines_scores: {lines_scores}')
  91. t_start = time.time()
  92. pred = box_line_optimized_parallel(predictions)
  93. t_end = time.time()
  94. print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
  95. # 检查 pred 是否为空
  96. if not pred:
  97. print("No valid predictions found. Skipping visualization.")
  98. return
  99. # 只绘制有效的线段
  100. show_predict(img_tensor, pred, t_start)
  101. if __name__ == '__main__':
  102. t_start = time.time()
  103. print(f'Start to predict: {t_start}')
  104. model = linenet_resnet50_fpn().to(device)
  105. pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
  106. img_path = r'C:\Users\m2337\Desktop\p\49.jpg'
  107. predict(pt_path, model, img_path)
  108. t_end = time.time()
  109. print(f'Total prediction time: {t_end - t_start:.4f} seconds')