predict2.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import time
  2. from models.line_detect.postprocess import show_predict
  3. import os
  4. import torch
  5. from PIL import Image
  6. import matplotlib.pyplot as plt
  7. import matplotlib as mpl
  8. import numpy as np
  9. from models.line_detect.line_net import linenet_resnet50_fpn
  10. from torchvision import transforms
  11. # from models.wirenet.postprocess import postprocess
  12. from models.wirenet.postprocess import postprocess
  13. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  14. def load_best_model(model, save_path, device):
  15. if os.path.exists(save_path):
  16. checkpoint = torch.load(save_path, map_location=device)
  17. model.load_state_dict(checkpoint['model_state_dict'])
  18. # if optimizer is not None:
  19. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  20. epoch = checkpoint['epoch']
  21. loss = checkpoint['loss']
  22. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  23. else:
  24. print(f"No saved model found at {save_path}")
  25. return model
  26. def box_line_(pred):
  27. for idx, box_ in enumerate(pred[0:-1]):
  28. box = box_['boxes'] # 是一个tensor
  29. line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
  30. score = pred[-1]['wires']['score'][idx]
  31. line_ = []
  32. score_ = []
  33. for i in box:
  34. score_max = 0.0
  35. tmp = [[0.0, 0.0], [0.0, 0.0]]
  36. for j in range(len(line)):
  37. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  38. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  39. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  40. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  41. if score[j] > score_max:
  42. tmp = line[j]
  43. score_max = score[j]
  44. line_.append(tmp)
  45. score_.append(score_max)
  46. processed_list = torch.tensor(line_)
  47. pred[idx]['line'] = processed_list
  48. processed_s_list = torch.tensor(score_)
  49. pred[idx]['line_score'] = processed_s_list
  50. return pred
  51. def predict(pt_path, model, img):
  52. model = load_best_model(model, pt_path, device)
  53. model.eval()
  54. if isinstance(img, str):
  55. img = Image.open(img).convert("RGB")
  56. transform = transforms.ToTensor()
  57. img_tensor = transform(img)
  58. with torch.no_grad():
  59. predictions = model([img_tensor.to(device)])
  60. # print(predictions)
  61. pred = box_line_(predictions)
  62. # print(f'pred:{pred[0]}')
  63. show_predict(img_tensor, pred, t_start)
  64. if __name__ == '__main__':
  65. t_start = time.time()
  66. print(f'start to predict:{t_start}')
  67. model = linenet_resnet50_fpn().to(device)
  68. pt_path = r'D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\linenet_wts\resnet50_best_e8.pth'
  69. img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png' # 工件图
  70. # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png' # wireframe图
  71. predict(pt_path, model, img_path)
  72. t_end = time.time()
  73. print(f'predict used:{t_end - t_start}')