pridict.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import cv2
  3. import torch
  4. import json
  5. import numpy as np
  6. from PIL import Image
  7. import matplotlib.pyplot as plt
  8. from torchvision import transforms
  9. from models.line_detect.line_net import linenet_resnet50_fpn
  10. from models.wirenet.postprocess import postprocess
  11. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  12. def load_best_model(model, save_path, device):
  13. if os.path.exists(save_path):
  14. checkpoint = torch.load(save_path, map_location=device)
  15. model.load_state_dict(checkpoint['model_state_dict'])
  16. else:
  17. print(f"No saved model found at {save_path}")
  18. return model
  19. def show_line_and_box(img_tensor, pred, img_name="", save_path=None):
  20. im = img_tensor.permute(1, 2, 0).cpu().numpy()
  21. boxes = pred[0]['boxes'].cpu().numpy()
  22. lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  23. scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  24. # 后处理线
  25. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  26. lines, line_scores = postprocess(lines, scores, diag * 0.01, 0, False)
  27. fig, ax = plt.subplots(figsize=(10, 10))
  28. ax.imshow(im)
  29. for box in boxes:
  30. x0, y0, x1, y1 = box
  31. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0,
  32. fill=False, edgecolor='lime', linewidth=1))
  33. for (a, b) in lines:
  34. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  35. ax.scatter([a[1], b[1]], [a[0], b[0]], c='red', s=3)
  36. if save_path:
  37. fig.savefig(save_path, bbox_inches='tight')
  38. plt.close(fig)
  39. else:
  40. plt.show()
  41. # 返回线段信息
  42. line_list = []
  43. for (a, b) in lines:
  44. line_list.append({
  45. "start": [round(float(a[1]), 2), round(float(a[0]), 2)], # [x, y]
  46. "end": [round(float(b[1]), 2), round(float(b[0]), 2)]
  47. })
  48. return line_list
  49. def predict_and_save(img_path, model, output_path):
  50. image = Image.open(img_path).convert("RGB")
  51. transform = transforms.ToTensor()
  52. img_tensor = transform(image)
  53. im = img_tensor.permute(1, 2, 0)
  54. # Resize to 512x512
  55. if im.shape != (512, 512, 3):
  56. im = cv2.resize(im.cpu().numpy(), (512, 512), interpolation=cv2.INTER_LINEAR)
  57. img_tensor = torch.tensor(im).permute(2, 0, 1)
  58. img_tensor = img_tensor.to(device)
  59. with torch.no_grad():
  60. predictions = model([img_tensor])
  61. img_name = os.path.basename(img_path)
  62. lines = show_line_and_box(img_tensor, predictions, img_name=img_name, save_path=output_path)
  63. return {
  64. "image": img_name,
  65. "lines": lines
  66. }
  67. def process_folder(input_dir, output_dir, model, pt_path):
  68. model = load_best_model(model, pt_path, device)
  69. model.eval()
  70. os.makedirs(output_dir, exist_ok=True)
  71. supported_exts = ('.jpg', '.jpeg', '.png')
  72. img_list = [f for f in os.listdir(input_dir) if f.lower().endswith(supported_exts)]
  73. all_results = []
  74. for img_name in img_list:
  75. img_path = os.path.join(input_dir, img_name)
  76. save_path = os.path.join(output_dir, os.path.splitext(img_name)[0] + ".jpg")
  77. print(f"Processing {img_path}...")
  78. result = predict_and_save(img_path, model, save_path)
  79. all_results.append(result)
  80. # 保存为 JSON 文件
  81. json_path = os.path.join(output_dir, "predictions.json")
  82. with open(json_path, 'w', encoding='utf-8') as f:
  83. json.dump(all_results, f, indent=2, ensure_ascii=False)
  84. print(f"处理完成,结果保存在: {json_path}")
  85. def main():
  86. model = linenet_resnet50_fpn().to(device)
  87. input_dir = r"G:\python_ws_g\data\pcd2color_result\color_jpg"
  88. parent_dir = os.path.dirname(input_dir)
  89. output_dir = os.path.join(parent_dir, "a_predict_restnet50")
  90. pt_path = r"G:\python_ws_g\code\mulitivision汇总\转tiff\MultiVisionModels\weight\best_val.pth"
  91. process_folder(input_dir, output_dir, model, pt_path)
  92. if __name__ == '__main__':
  93. main()