predict2.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import time
  2. from models.line_detect.postprocess import show_predict
  3. import os
  4. import torch
  5. from PIL import Image
  6. import matplotlib.pyplot as plt
  7. import matplotlib as mpl
  8. import numpy as np
  9. from models.line_detect.line_net import linenet_resnet50_fpn
  10. from torchvision import transforms
  11. from rtree import index
  12. # from models.wirenet.postprocess import postprocess
  13. from models.wirenet.postprocess import postprocess
  14. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  15. def load_best_model(model, save_path, device):
  16. if os.path.exists(save_path):
  17. checkpoint = torch.load(save_path, map_location=device)
  18. model.load_state_dict(checkpoint['model_state_dict'])
  19. # if optimizer is not None:
  20. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  21. epoch = checkpoint['epoch']
  22. loss = checkpoint['loss']
  23. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  24. else:
  25. print(f"No saved model found at {save_path}")
  26. return model
  27. def box_line_optimized(pred):
  28. # 创建R-tree索引
  29. idx = index.Index()
  30. # 将所有线段添加到R-tree中
  31. lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
  32. scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
  33. # 提取并处理所有线段
  34. for idx_line in range(lines.shape[1]): # 遍历2500条线段
  35. line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512 # 转换为numpy数组并调整比例
  36. x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
  37. y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
  38. x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
  39. y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
  40. idx.insert(idx_line, (x_min, y_min, x_max, y_max))
  41. for idx_box, box_ in enumerate(pred[0:-1]):
  42. box = box_['boxes'].cpu().numpy() # 确保将张量转换为numpy数组
  43. line_ = []
  44. score_ = []
  45. for i in box:
  46. score_max = 0.0
  47. tmp = [[0.0, 0.0], [0.0, 0.0]]
  48. # 获取与当前box可能相交的所有线段
  49. possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
  50. for j in possible_matches:
  51. line_j = lines[0, j].cpu().numpy() / 128 * 512 # 调整比例
  52. if (line_j[0][0] >= i[0] and line_j[1][0] >= i[0] and
  53. line_j[0][0] <= i[2] and line_j[1][0] <= i[2] and
  54. line_j[0][1] >= i[1] and line_j[1][1] >= i[1] and
  55. line_j[0][1] <= i[3] and line_j[1][1] <= i[3]):
  56. if scores[j] > score_max:
  57. tmp = line_j
  58. score_max = scores[j]
  59. line_.append(tmp)
  60. score_.append(score_max)
  61. processed_list = torch.tensor(line_)
  62. pred[idx_box]['line'] = processed_list
  63. processed_s_list = torch.tensor(score_)
  64. pred[idx_box]['line_score'] = processed_s_list
  65. return pred
  66. # def box_line_(pred):
  67. # for idx, box_ in enumerate(pred[0:-1]):
  68. # box = box_['boxes'] # 是一个tensor
  69. # line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
  70. # score = pred[-1]['wires']['score'][idx]
  71. # line_ = []
  72. # score_ = []
  73. #
  74. # for i in box:
  75. # score_max = 0.0
  76. # tmp = [[0.0, 0.0], [0.0, 0.0]]
  77. #
  78. # for j in range(len(line)):
  79. # if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  80. # line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  81. # line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  82. # line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  83. #
  84. # if score[j] > score_max:
  85. # tmp = line[j]
  86. # score_max = score[j]
  87. # line_.append(tmp)
  88. # score_.append(score_max)
  89. # processed_list = torch.tensor(line_)
  90. # pred[idx]['line'] = processed_list
  91. #
  92. # processed_s_list = torch.tensor(score_)
  93. # pred[idx]['line_score'] = processed_s_list
  94. # return pred
  95. def predict(pt_path, model, img):
  96. model = load_best_model(model, pt_path, device)
  97. model.eval()
  98. if isinstance(img, str):
  99. img = Image.open(img).convert("RGB")
  100. transform = transforms.ToTensor()
  101. img_tensor = transform(img)
  102. with torch.no_grad():
  103. t_start = time.time()
  104. predictions = model([img_tensor.to(device)])
  105. t_end=time.time()
  106. print(f'predict used:{t_end-t_start}')
  107. # print(f'predictions:{predictions}')
  108. boxes=predictions[0]['boxes'].shape
  109. lines=predictions[-1]['wires']['lines'].shape
  110. lines_scores=predictions[-1]['wires']['score'].shape
  111. print(f'predictions boxes:{boxes},lines:{lines},lines_scores:{lines_scores}')
  112. t_start=time.time()
  113. pred = box_line_optimized(predictions)
  114. t_end=time.time()
  115. print(f'matched boxes and lines used:{t_end - t_start}')
  116. # print(f'pred:{pred[0]}')
  117. show_predict(img_tensor, pred, t_start)
  118. if __name__ == '__main__':
  119. t_start = time.time()
  120. print(f'start to predict:{t_start}')
  121. model = linenet_resnet50_fpn().to(device)
  122. pt_path = r"F:\BaiduNetdiskDownload\resnet50_best_e8.pth"
  123. img_path = r"I:\datasets\wirenet_1000\images\val\00037040_0.png"
  124. predict(pt_path, model, img_path)
  125. t_end = time.time()
  126. # print(f'predict used:{t_end - t_start}')