| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import os
- import cv2
- import torch
- import json
- import numpy as np
- from PIL import Image
- import matplotlib.pyplot as plt
- from torchvision import transforms
- from models.line_detect.line_net import linenet_resnet50_fpn
- from models.wirenet.postprocess import postprocess
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- def load_best_model(model, save_path, device):
- if os.path.exists(save_path):
- checkpoint = torch.load(save_path, map_location=device)
- model.load_state_dict(checkpoint['model_state_dict'])
- else:
- print(f"No saved model found at {save_path}")
- return model
- def show_line_and_box(img_tensor, pred, img_name="", save_path=None):
- im = img_tensor.permute(1, 2, 0).cpu().numpy()
- boxes = pred[0]['boxes'].cpu().numpy()
- lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
- scores = pred[-1]['wires']['score'].cpu().numpy()[0]
- # 后处理线
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- lines, line_scores = postprocess(lines, scores, diag * 0.01, 0, False)
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(im)
- for box in boxes:
- x0, y0, x1, y1 = box
- ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0,
- fill=False, edgecolor='lime', linewidth=1))
- for (a, b) in lines:
- ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
- ax.scatter([a[1], b[1]], [a[0], b[0]], c='red', s=3)
- if save_path:
- fig.savefig(save_path, bbox_inches='tight')
- plt.close(fig)
- else:
- plt.show()
- # 返回线段信息
- line_list = []
- for (a, b) in lines:
- line_list.append({
- "start": [round(float(a[1]), 2), round(float(a[0]), 2)], # [x, y]
- "end": [round(float(b[1]), 2), round(float(b[0]), 2)]
- })
- return line_list
- def predict_and_save(img_path, model, output_path):
- image = Image.open(img_path).convert("RGB")
- transform = transforms.ToTensor()
- img_tensor = transform(image)
- im = img_tensor.permute(1, 2, 0)
- # Resize to 512x512
- if im.shape != (512, 512, 3):
- im = cv2.resize(im.cpu().numpy(), (512, 512), interpolation=cv2.INTER_LINEAR)
- img_tensor = torch.tensor(im).permute(2, 0, 1)
- img_tensor = img_tensor.to(device)
- with torch.no_grad():
- predictions = model([img_tensor])
- img_name = os.path.basename(img_path)
- lines = show_line_and_box(img_tensor, predictions, img_name=img_name, save_path=output_path)
- return {
- "image": img_name,
- "lines": lines
- }
- def process_folder(input_dir, output_dir, model, pt_path):
- model = load_best_model(model, pt_path, device)
- model.eval()
- os.makedirs(output_dir, exist_ok=True)
- supported_exts = ('.jpg', '.jpeg', '.png')
- img_list = [f for f in os.listdir(input_dir) if f.lower().endswith(supported_exts)]
- all_results = []
- for img_name in img_list:
- img_path = os.path.join(input_dir, img_name)
- save_path = os.path.join(output_dir, os.path.splitext(img_name)[0] + ".jpg")
- print(f"Processing {img_path}...")
- result = predict_and_save(img_path, model, save_path)
- all_results.append(result)
- # 保存为 JSON 文件
- json_path = os.path.join(output_dir, "predictions.json")
- with open(json_path, 'w', encoding='utf-8') as f:
- json.dump(all_results, f, indent=2, ensure_ascii=False)
- print(f"处理完成,结果保存在: {json_path}")
- def main():
- model = linenet_resnet50_fpn().to(device)
- input_dir = r"G:\python_ws_g\data\pcd2color_result\color_jpg"
- parent_dir = os.path.dirname(input_dir)
- output_dir = os.path.join(parent_dir, "a_predict_restnet50")
- pt_path = r"G:\python_ws_g\code\mulitivision汇总\转tiff\MultiVisionModels\weight\best_val.pth"
- process_folder(input_dir, output_dir, model, pt_path)
- if __name__ == '__main__':
- main()
|