predict.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import os
  2. import torch
  3. from PIL import Image
  4. import matplotlib.pyplot as plt
  5. import matplotlib as mpl
  6. import numpy as np
  7. from models.line_detect.line_net import linenet_resnet50_fpn
  8. from torchvision import transforms
  9. from models.wirenet.postprocess import postprocess
  10. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  11. def load_best_model(model, save_path, device):
  12. if os.path.exists(save_path):
  13. checkpoint = torch.load(save_path, map_location=device)
  14. model.load_state_dict(checkpoint['model_state_dict'])
  15. # if optimizer is not None:
  16. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  17. epoch = checkpoint['epoch']
  18. loss = checkpoint['loss']
  19. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  20. else:
  21. print(f"No saved model found at {save_path}")
  22. return model
  23. cmap = plt.get_cmap("jet")
  24. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  25. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  26. sm.set_array([])
  27. def c(x):
  28. return sm.to_rgba(x)
  29. def imshow(im):
  30. plt.close()
  31. plt.tight_layout()
  32. plt.imshow(im)
  33. plt.colorbar(sm, fraction=0.046)
  34. plt.xlim([0, im.shape[0]])
  35. plt.ylim([im.shape[0], 0])
  36. def show_line(img, pred):
  37. im = img.permute(1, 2, 0)
  38. # 创建图形和坐标轴
  39. fig, ax = plt.subplots(figsize=(10, 10))
  40. # 绘制原始图像
  41. ax.imshow(np.array(im))
  42. # 绘制边界框
  43. boxes = pred[0]['boxes'].cpu().numpy()
  44. boxes_scores = pred[0]['scores'].cpu().numpy()
  45. # for box in boxes:
  46. # x0, y0, x1, y1 = box
  47. # rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1)
  48. # ax.add_patch(rect) # 将矩形添加到 Axes 对象上
  49. for b, s in zip(boxes, boxes_scores):
  50. # print(f'box:{b}, box_score:{s}')
  51. if s < 0.7:
  52. continue
  53. x0, y0, x1, y1 = b
  54. rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1)
  55. ax.add_patch(rect) # 将矩形添加到 Axes 对象上
  56. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  57. H = pred[-1]['wires']
  58. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  59. scores = H["score"][0].cpu().numpy()
  60. for i in range(1, len(lines)):
  61. if (lines[i] == lines[0]).all():
  62. lines = lines[:i]
  63. scores = scores[:i]
  64. break
  65. # 后处理线条以去除重叠的线条
  66. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  67. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  68. # 根据分数绘制线条
  69. for i, t in enumerate([0.9]):
  70. for (a, b), s in zip(nlines, nscores):
  71. if s < t:
  72. continue
  73. ax.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) # 在 Axes 上绘制线条
  74. ax.scatter(a[1], a[0], **PLTOPTS) # 在 Axes 上绘制散点
  75. ax.scatter(b[1], b[0], **PLTOPTS) # 在 Axes 上绘制散点
  76. # 隐藏坐标轴
  77. ax.set_axis_off()
  78. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  79. plt.margins(0, 0)
  80. ax.xaxis.set_major_locator(plt.NullLocator())
  81. ax.yaxis.set_major_locator(plt.NullLocator())
  82. # 显示图像
  83. plt.show()
  84. def predict(pt_path, model, img):
  85. model = load_best_model(model, pt_path, device)
  86. model.eval()
  87. if isinstance(img, str):
  88. img = Image.open(img).convert("RGB")
  89. transform = transforms.ToTensor()
  90. img_tensor = transform(img)
  91. with torch.no_grad():
  92. predictions = model([img_tensor])
  93. print(predictions[0])
  94. show_line(img_tensor, predictions)
  95. if __name__ == '__main__':
  96. model = linenet_resnet50_fpn()
  97. pt_path = r'D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\linenet_wts\resnet50_best_e8.pth'
  98. # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png' # 工件图
  99. img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png' # wireframe图
  100. predict(pt_path, model, img_path)