123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- import os
- import skimage
- import torch
- from PIL import Image
- import matplotlib.pyplot as plt
- import matplotlib as mpl
- import numpy as np
- from models.line_detect.line_net import linenet_resnet50_fpn
- from torchvision import transforms
- from models.wirenet.postprocess import postprocess
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- def load_best_model(model, save_path, device):
- if os.path.exists(save_path):
- checkpoint = torch.load(save_path, map_location=device)
- model.load_state_dict(checkpoint['model_state_dict'])
- # if optimizer is not None:
- # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- epoch = checkpoint['epoch']
- loss = checkpoint['loss']
- print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
- else:
- print(f"No saved model found at {save_path}")
- return model
- cmap = plt.get_cmap("jet")
- norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
- sm.set_array([])
- def c(x):
- return sm.to_rgba(x)
- def imshow(im):
- plt.close()
- plt.tight_layout()
- plt.imshow(im)
- plt.colorbar(sm, fraction=0.046)
- plt.xlim([0, im.shape[0]])
- plt.ylim([im.shape[0], 0])
- def show_line(img, pred):
- im = img.permute(1, 2, 0)
- # 创建图形和坐标轴
- fig, ax = plt.subplots(figsize=(10, 10))
- # 绘制原始图像
- ax.imshow(np.array(im))
- # 绘制边界框
- boxes = pred[0]['boxes'].cpu().numpy()
- boxes_scores = pred[0]['scores'].cpu().numpy()
- for b, s in zip(boxes, boxes_scores):
- # print(f'box:{b}, box_score:{s}')
- if s < 0.7:
- continue
- x0, y0, x1, y1 = b
- rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1)
- ax.add_patch(rect) # 将矩形添加到 Axes 对象上
- PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
- H = pred[-1]['wires']
- lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
- scores = H["score"][0].cpu().numpy()
- for i in range(1, len(lines)):
- if (lines[i] == lines[0]).all():
- lines = lines[:i]
- scores = scores[:i]
- break
- # 后处理线条以去除重叠的线条
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
- # 根据分数绘制线条
- for i, t in enumerate([0.9]):
- for (a, b), s in zip(nlines, nscores):
- if s < t:
- continue
- ax.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) # 在 Axes 上绘制线条
- ax.scatter(a[1], a[0], **PLTOPTS) # 在 Axes 上绘制散点
- ax.scatter(b[1], b[0], **PLTOPTS) # 在 Axes 上绘制散点
- # 隐藏坐标轴
- ax.set_axis_off()
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
- plt.margins(0, 0)
- ax.xaxis.set_major_locator(plt.NullLocator())
- ax.yaxis.set_major_locator(plt.NullLocator())
- # 显示图像
- plt.show()
- def predict(pt_path, model, img):
- model = load_best_model(model, pt_path, device)
- model.eval()
- if isinstance(img, str):
- img = Image.open(img).convert("RGB")
- transform = transforms.ToTensor()
- img_tensor = transform(img)
- # 将图像调整为512x512大小
- im = img_tensor.permute(1, 2, 0) # [512, 512, 3]
- im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
- img_tensor = torch.tensor(im_resized).permute(2, 0, 1)
- with torch.no_grad():
- predictions = model([img_tensor])
- # print(predictions[0])
- show_line(img_tensor, predictions)
- if __name__ == '__main__':
- model = linenet_resnet50_fpn()
- pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
- # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png' # 工件图
- # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png' # wireframe图
- img_path = r'C:\Users\m2337\Desktop\p\24.jpg'
- predict(pt_path, model, img_path)
|