import os import cv2 import torch import json import numpy as np from PIL import Image import matplotlib.pyplot as plt from torchvision import transforms from models.line_detect.line_net import linenet_resnet50_fpn from models.wirenet.postprocess import postprocess device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_best_model(model, save_path, device): if os.path.exists(save_path): checkpoint = torch.load(save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) else: print(f"No saved model found at {save_path}") return model def show_line_and_box(img_tensor, pred, img_name="", save_path=None): im = img_tensor.permute(1, 2, 0).cpu().numpy() boxes = pred[0]['boxes'].cpu().numpy() lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 scores = pred[-1]['wires']['score'].cpu().numpy()[0] # 后处理线 diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 lines, line_scores = postprocess(lines, scores, diag * 0.01, 0, False) fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for box in boxes: x0, y0, x1, y1 = box ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='lime', linewidth=1)) for (a, b) in lines: ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) ax.scatter([a[1], b[1]], [a[0], b[0]], c='red', s=3) if save_path: fig.savefig(save_path, bbox_inches='tight') plt.close(fig) else: plt.show() # 返回线段信息 line_list = [] for (a, b) in lines: line_list.append({ "start": [round(float(a[1]), 2), round(float(a[0]), 2)], # [x, y] "end": [round(float(b[1]), 2), round(float(b[0]), 2)] }) return line_list def predict_and_save(img_path, model, output_path): image = Image.open(img_path).convert("RGB") transform = transforms.ToTensor() img_tensor = transform(image) im = img_tensor.permute(1, 2, 0) # Resize to 512x512 if im.shape != (512, 512, 3): im = cv2.resize(im.cpu().numpy(), (512, 512), interpolation=cv2.INTER_LINEAR) img_tensor = torch.tensor(im).permute(2, 0, 1) img_tensor = img_tensor.to(device) with torch.no_grad(): predictions = model([img_tensor]) img_name = os.path.basename(img_path) lines = show_line_and_box(img_tensor, predictions, img_name=img_name, save_path=output_path) return { "image": img_name, "lines": lines } def process_folder(input_dir, output_dir, model, pt_path): model = load_best_model(model, pt_path, device) model.eval() os.makedirs(output_dir, exist_ok=True) supported_exts = ('.jpg', '.jpeg', '.png') img_list = [f for f in os.listdir(input_dir) if f.lower().endswith(supported_exts)] all_results = [] for img_name in img_list: img_path = os.path.join(input_dir, img_name) save_path = os.path.join(output_dir, os.path.splitext(img_name)[0] + ".jpg") print(f"Processing {img_path}...") result = predict_and_save(img_path, model, save_path) all_results.append(result) # 保存为 JSON 文件 json_path = os.path.join(output_dir, "predictions.json") with open(json_path, 'w', encoding='utf-8') as f: json.dump(all_results, f, indent=2, ensure_ascii=False) print(f"处理完成,结果保存在: {json_path}") def main(): model = linenet_resnet50_fpn().to(device) input_dir = r"G:\python_ws_g\data\pcd2color_result\color_jpg" parent_dir = os.path.dirname(input_dir) output_dir = os.path.join(parent_dir, "a_predict_restnet50") pt_path = r"G:\python_ws_g\code\mulitivision汇总\转tiff\MultiVisionModels\weight\best_val.pth" process_folder(input_dir, output_dir, model, pt_path) if __name__ == '__main__': main()