import os import torch from PIL import Image import matplotlib.pyplot as plt import matplotlib as mpl import numpy as np from models.line_detect.line_net import linenet_resnet50_fpn from torchvision import transforms from models.wirenet.postprocess import postprocess device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_best_model(model, save_path, device): if os.path.exists(save_path): checkpoint = torch.load(save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) # if optimizer is not None: # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}") else: print(f"No saved model found at {save_path}") return model cmap = plt.get_cmap("jet") norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) def c(x): return sm.to_rgba(x) def imshow(im): plt.close() plt.tight_layout() plt.imshow(im) plt.colorbar(sm, fraction=0.046) plt.xlim([0, im.shape[0]]) plt.ylim([im.shape[0], 0]) def show_line(img, pred): im = img.permute(1, 2, 0) # 创建图形和坐标轴 fig, ax = plt.subplots(figsize=(10, 10)) # 绘制原始图像 ax.imshow(np.array(im)) # 绘制边界框 boxes = pred[0]['boxes'].cpu().numpy() boxes_scores = pred[0]['scores'].cpu().numpy() # for box in boxes: # x0, y0, x1, y1 = box # rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1) # ax.add_patch(rect) # 将矩形添加到 Axes 对象上 for b, s in zip(boxes, boxes_scores): # print(f'box:{b}, box_score:{s}') if s < 0.7: continue x0, y0, x1, y1 = b rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1) ax.add_patch(rect) # 将矩形添加到 Axes 对象上 PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} H = pred[-1]['wires'] lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2] scores = H["score"][0].cpu().numpy() for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break # 后处理线条以去除重叠的线条 diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) # 根据分数绘制线条 for i, t in enumerate([0.9]): for (a, b), s in zip(nlines, nscores): if s < t: continue ax.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) # 在 Axes 上绘制线条 ax.scatter(a[1], a[0], **PLTOPTS) # 在 Axes 上绘制散点 ax.scatter(b[1], b[0], **PLTOPTS) # 在 Axes 上绘制散点 # 隐藏坐标轴 ax.set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) ax.xaxis.set_major_locator(plt.NullLocator()) ax.yaxis.set_major_locator(plt.NullLocator()) # 显示图像 plt.show() def predict(pt_path, model, img): model = load_best_model(model, pt_path, device) model.eval() if isinstance(img, str): img = Image.open(img).convert("RGB") transform = transforms.ToTensor() img_tensor = transform(img) with torch.no_grad(): predictions = model([img_tensor]) print(predictions[0]) show_line(img_tensor, predictions) if __name__ == '__main__': model = linenet_resnet50_fpn() pt_path = r'D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\linenet_wts\resnet50_best_e8.pth' # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png' # 工件图 img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png' # wireframe图 predict(pt_path, model, img_path)