|
@@ -0,0 +1,136 @@
|
|
|
|
+import time
|
|
|
|
+import skimage
|
|
|
|
+from models.line_detect.postprocess import show_predict
|
|
|
|
+import os
|
|
|
|
+import torch
|
|
|
|
+from PIL import Image
|
|
|
|
+import matplotlib.pyplot as plt
|
|
|
|
+import numpy as np
|
|
|
|
+from models.line_detect.line_net import linenet_resnet50_fpn
|
|
|
|
+from torchvision import transforms
|
|
|
|
+from rtree import index
|
|
|
|
+import multiprocessing as mp
|
|
|
|
+
|
|
|
|
+# 设置多进程启动方式为 'spawn'
|
|
|
|
+mp.set_start_method('spawn', force=True)
|
|
|
|
+
|
|
|
|
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
+
|
|
|
|
+def load_best_model(model, save_path, device):
|
|
|
|
+ if os.path.exists(save_path):
|
|
|
|
+ checkpoint = torch.load(save_path, map_location=device)
|
|
|
|
+ model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
+ epoch = checkpoint['epoch']
|
|
|
|
+ loss = checkpoint['loss']
|
|
|
|
+ print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
|
|
|
|
+ else:
|
|
|
|
+ print(f"No saved model found at {save_path}")
|
|
|
|
+ return model
|
|
|
|
+
|
|
|
|
+def process_box(box, lines, scores):
|
|
|
|
+ valid_lines = [] # 存储有效的线段
|
|
|
|
+ valid_scores = [] # 存储有效的分数
|
|
|
|
+
|
|
|
|
+ for i in box:
|
|
|
|
+ best_line = None
|
|
|
|
+ max_length = 0.0
|
|
|
|
+
|
|
|
|
+ # 遍历所有线段
|
|
|
|
+ for j in range(lines.shape[1]):
|
|
|
|
+ line_j = lines[0, j].cpu().numpy() / 128 * 512
|
|
|
|
+ # 检查线段是否完全在box内
|
|
|
|
+ if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
|
|
|
|
+ line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
|
|
|
|
+ line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
|
|
|
|
+ line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
|
|
|
|
+ # 计算线段长度
|
|
|
|
+ length = np.linalg.norm(line_j[0] - line_j[1])
|
|
|
|
+ if length > max_length:
|
|
|
|
+ best_line = line_j
|
|
|
|
+ max_length = length
|
|
|
|
+
|
|
|
|
+ # 如果找到有效的线段,则添加到结果中
|
|
|
|
+ if best_line is not None:
|
|
|
|
+ valid_lines.append(best_line)
|
|
|
|
+ valid_scores.append(max_length) # 使用线段长度作为分数
|
|
|
|
+
|
|
|
|
+ return valid_lines, valid_scores
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def box_line_optimized_parallel(pred):
|
|
|
|
+ # 提取所有线段和分数
|
|
|
|
+ lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
|
|
|
|
+ scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
|
|
|
|
+
|
|
|
|
+ # 初始化存储结果的列表
|
|
|
|
+ filtered_pred = []
|
|
|
|
+
|
|
|
|
+ # 使用多进程并行处理每个box
|
|
|
|
+ boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # 所有box
|
|
|
|
+ num_processes = min(mp.cpu_count(), len(boxes)) # 使用可用的核心数
|
|
|
|
+
|
|
|
|
+ with mp.Pool(processes=num_processes) as pool:
|
|
|
|
+ results = pool.starmap(
|
|
|
|
+ process_box,
|
|
|
|
+ [(box, lines, scores) for box in boxes]
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # 更新预测结果
|
|
|
|
+ for idx_box, (valid_lines, valid_scores) in enumerate(results):
|
|
|
|
+ if valid_lines:
|
|
|
|
+ pred[idx_box]['line'] = torch.tensor(valid_lines)
|
|
|
|
+ pred[idx_box]['line_score'] = torch.tensor(valid_scores)
|
|
|
|
+ filtered_pred.append(pred[idx_box])
|
|
|
|
+
|
|
|
|
+ return filtered_pred
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def predict(pt_path, model, img):
|
|
|
|
+ model = load_best_model(model, pt_path, device)
|
|
|
|
+ model.eval()
|
|
|
|
+
|
|
|
|
+ if isinstance(img, str):
|
|
|
|
+ img = Image.open(img).convert("RGB")
|
|
|
|
+
|
|
|
|
+ transform = transforms.ToTensor()
|
|
|
|
+ img_tensor = transform(img)
|
|
|
|
+ im = img_tensor.permute(1, 2, 0) # [512, 512, 3]
|
|
|
|
+ im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512)) # (512, 512, 3)
|
|
|
|
+ img_tensor = torch.tensor(im_resized).permute(2, 0, 1)
|
|
|
|
+
|
|
|
|
+ with torch.no_grad():
|
|
|
|
+ t_start = time.time()
|
|
|
|
+ predictions = model([img_tensor.to(device)])
|
|
|
|
+ t_end = time.time()
|
|
|
|
+ print(f'Prediction used: {t_end - t_start:.4f} seconds')
|
|
|
|
+
|
|
|
|
+ boxes = predictions[0]['boxes'].shape
|
|
|
|
+ lines = predictions[-1]['wires']['lines'].shape
|
|
|
|
+ lines_scores = predictions[-1]['wires']['score'].shape
|
|
|
|
+ print(f'Predictions - boxes: {boxes}, lines: {lines}, lines_scores: {lines_scores}')
|
|
|
|
+
|
|
|
|
+ t_start = time.time()
|
|
|
|
+ pred = box_line_optimized_parallel(predictions)
|
|
|
|
+ t_end = time.time()
|
|
|
|
+ print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
|
|
|
|
+
|
|
|
|
+ # 检查 pred 是否为空
|
|
|
|
+ if not pred:
|
|
|
|
+ print("No valid predictions found. Skipping visualization.")
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ # 只绘制有效的线段
|
|
|
|
+ show_predict(img_tensor, pred, t_start)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
+ t_start = time.time()
|
|
|
|
+ print(f'Start to predict: {t_start}')
|
|
|
|
+ model = linenet_resnet50_fpn().to(device)
|
|
|
|
+ pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
|
|
|
|
+ # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png' # 工件图
|
|
|
|
+ # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png' # wireframe图
|
|
|
|
+ img_path = r'C:\Users\m2337\Desktop\9.jpg'
|
|
|
|
+ predict(pt_path, model, img_path)
|
|
|
|
+ t_end = time.time()
|
|
|
|
+ print(f'Total prediction time: {t_end - t_start:.4f} seconds')
|