123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- import time
- import skimage
- from models.line_detect.postprocess import show_predict
- import os
- import torch
- from PIL import Image
- import matplotlib.pyplot as plt
- import numpy as np
- from models.line_detect.line_net import linenet_resnet50_fpn
- from torchvision import transforms
- from rtree import index
- import multiprocessing as mp
- # 设置多进程启动方式为 'spawn'
- mp.set_start_method('spawn', force=True)
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- def load_best_model(model, save_path, device):
- if os.path.exists(save_path):
- checkpoint = torch.load(save_path, map_location=device)
- model.load_state_dict(checkpoint['model_state_dict'])
- epoch = checkpoint['epoch']
- loss = checkpoint['loss']
- print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
- else:
- print(f"No saved model found at {save_path}")
- return model
- def process_box(box, lines, scores):
- valid_lines = [] # 存储有效的线段
- valid_scores = [] # 存储有效的分数
- for i in box:
- best_line = None
- max_length = 0.0
- # 遍历所有线段
- for j in range(lines.shape[1]):
- line_j = lines[0, j].cpu().numpy() / 128 * 512
- # 检查线段是否完全在box内
- if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
- line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
- line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
- line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
- # 计算线段长度
- length = np.linalg.norm(line_j[0] - line_j[1])
- if length > max_length:
- best_line = line_j
- max_length = length
- # 如果找到有效的线段,则添加到结果中
- if best_line is not None:
- valid_lines.append(best_line)
- valid_scores.append(max_length) # 使用线段长度作为分数
- return valid_lines, valid_scores
- def box_line_optimized_parallel(pred):
- # 提取所有线段和分数
- lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
- scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
- # 初始化存储结果的列表
- filtered_pred = []
- # 使用多进程并行处理每个box
- boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # 所有box
- num_processes = min(mp.cpu_count(), len(boxes)) # 使用可用的核心数
- with mp.Pool(processes=num_processes) as pool:
- results = pool.starmap(
- process_box,
- [(box, lines, scores) for box in boxes]
- )
- # 更新预测结果
- for idx_box, (valid_lines, valid_scores) in enumerate(results):
- if valid_lines:
- pred[idx_box]['line'] = torch.tensor(valid_lines)
- pred[idx_box]['line_score'] = torch.tensor(valid_scores)
- filtered_pred.append(pred[idx_box])
- return filtered_pred
- def predict(pt_path, model, img):
- model = load_best_model(model, pt_path, device)
- model.eval()
- if isinstance(img, str):
- img = Image.open(img).convert("RGB")
- transform = transforms.ToTensor()
- img_tensor = transform(img)
- im = img_tensor.permute(1, 2, 0) # [512, 512, 3]
- im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
- img_tensor = torch.tensor(im_resized).permute(2, 0, 1)
- with torch.no_grad():
- t_start = time.time()
- predictions = model([img_tensor.to(device)])
- t_end = time.time()
- print(f'Prediction used: {t_end - t_start:.4f} seconds')
- boxes = predictions[0]['boxes'].shape
- lines = predictions[-1]['wires']['lines'].shape
- lines_scores = predictions[-1]['wires']['score'].shape
- print(f'Predictions - boxes: {boxes}, lines: {lines}, lines_scores: {lines_scores}')
- t_start = time.time()
- pred = box_line_optimized_parallel(predictions)
- t_end = time.time()
- print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
- # 检查 pred 是否为空
- if not pred:
- print("No valid predictions found. Skipping visualization.")
- return
- # 只绘制有效的线段
- show_predict(img_tensor, pred, t_start)
- if __name__ == '__main__':
- t_start = time.time()
- print(f'Start to predict: {t_start}')
- model = linenet_resnet50_fpn().to(device)
- pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
- # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png' # 工件图
- # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png' # wireframe图
- img_path = r'C:\Users\m2337\Desktop\9.jpg'
- predict(pt_path, model, img_path)
- t_end = time.time()
- print(f'Total prediction time: {t_end - t_start:.4f} seconds')
|