predict.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import os
  2. import skimage
  3. import torch
  4. from PIL import Image
  5. import matplotlib.pyplot as plt
  6. import matplotlib as mpl
  7. import numpy as np
  8. from models.line_detect.line_net import linenet_resnet50_fpn
  9. from torchvision import transforms
  10. from models.wirenet.postprocess import postprocess
  11. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  12. def load_best_model(model, save_path, device):
  13. if os.path.exists(save_path):
  14. checkpoint = torch.load(save_path, map_location=device)
  15. model.load_state_dict(checkpoint['model_state_dict'])
  16. # if optimizer is not None:
  17. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  18. epoch = checkpoint['epoch']
  19. loss = checkpoint['loss']
  20. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  21. else:
  22. print(f"No saved model found at {save_path}")
  23. return model
  24. cmap = plt.get_cmap("jet")
  25. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  26. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  27. sm.set_array([])
  28. def c(x):
  29. return sm.to_rgba(x)
  30. def imshow(im):
  31. plt.close()
  32. plt.tight_layout()
  33. plt.imshow(im)
  34. plt.colorbar(sm, fraction=0.046)
  35. plt.xlim([0, im.shape[0]])
  36. plt.ylim([im.shape[0], 0])
  37. def show_line(img, pred):
  38. im = img.permute(1, 2, 0)
  39. # 创建图形和坐标轴
  40. fig, ax = plt.subplots(figsize=(10, 10))
  41. # 绘制原始图像
  42. ax.imshow(np.array(im))
  43. # 绘制边界框
  44. boxes = pred[0]['boxes'].cpu().numpy()
  45. boxes_scores = pred[0]['scores'].cpu().numpy()
  46. for b, s in zip(boxes, boxes_scores):
  47. # print(f'box:{b}, box_score:{s}')
  48. if s < 0.7:
  49. continue
  50. x0, y0, x1, y1 = b
  51. rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1)
  52. ax.add_patch(rect) # 将矩形添加到 Axes 对象上
  53. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  54. H = pred[-1]['wires']
  55. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  56. scores = H["score"][0].cpu().numpy()
  57. for i in range(1, len(lines)):
  58. if (lines[i] == lines[0]).all():
  59. lines = lines[:i]
  60. scores = scores[:i]
  61. break
  62. # 后处理线条以去除重叠的线条
  63. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  64. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  65. # 根据分数绘制线条
  66. for i, t in enumerate([0.9]):
  67. for (a, b), s in zip(nlines, nscores):
  68. if s < t:
  69. continue
  70. ax.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) # 在 Axes 上绘制线条
  71. ax.scatter(a[1], a[0], **PLTOPTS) # 在 Axes 上绘制散点
  72. ax.scatter(b[1], b[0], **PLTOPTS) # 在 Axes 上绘制散点
  73. # 隐藏坐标轴
  74. ax.set_axis_off()
  75. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  76. plt.margins(0, 0)
  77. ax.xaxis.set_major_locator(plt.NullLocator())
  78. ax.yaxis.set_major_locator(plt.NullLocator())
  79. # 显示图像
  80. plt.show()
  81. def predict(pt_path, model, img):
  82. model = load_best_model(model, pt_path, device)
  83. model.eval()
  84. if isinstance(img, str):
  85. img = Image.open(img).convert("RGB")
  86. transform = transforms.ToTensor()
  87. img_tensor = transform(img)
  88. # 将图像调整为512x512大小
  89. im = img_tensor.permute(1, 2, 0) # [512, 512, 3]
  90. im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
  91. img_tensor = torch.tensor(im_resized).permute(2, 0, 1)
  92. with torch.no_grad():
  93. predictions = model([img_tensor])
  94. # print(predictions[0])
  95. show_line(img_tensor, predictions)
  96. if __name__ == '__main__':
  97. model = linenet_resnet50_fpn()
  98. pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
  99. # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png' # 工件图
  100. # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png' # wireframe图
  101. img_path = r'C:\Users\m2337\Desktop\p\24.jpg'
  102. predict(pt_path, model, img_path)