Pārlūkot izejas kodu

版本3 box和line绘图完成

xue50 3 mēneši atpakaļ
vecāks
revīzija
8d2e1ead8f
1 mainītis faili ar 136 papildinājumiem un 0 dzēšanām
  1. 136 0
      models/line_detect/predict_zjf.py

+ 136 - 0
models/line_detect/predict_zjf.py

@@ -0,0 +1,136 @@
+import time
+import skimage
+from models.line_detect.postprocess import show_predict
+import os
+import torch
+from PIL import Image
+import matplotlib.pyplot as plt
+import numpy as np
+from models.line_detect.line_net import linenet_resnet50_fpn
+from torchvision import transforms
+from rtree import index
+import multiprocessing as mp
+
+# 设置多进程启动方式为 'spawn'
+mp.set_start_method('spawn', force=True)
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+def load_best_model(model, save_path, device):
+    if os.path.exists(save_path):
+        checkpoint = torch.load(save_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        epoch = checkpoint['epoch']
+        loss = checkpoint['loss']
+        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+    else:
+        print(f"No saved model found at {save_path}")
+    return model
+
+def process_box(box, lines, scores):
+    valid_lines = []  # 存储有效的线段
+    valid_scores = []  # 存储有效的分数
+
+    for i in box:
+        best_line = None
+        max_length = 0.0
+
+        # 遍历所有线段
+        for j in range(lines.shape[1]):
+            line_j = lines[0, j].cpu().numpy() / 128 * 512
+            # 检查线段是否完全在box内
+            if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
+                    line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+                    line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+                    line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+                # 计算线段长度
+                length = np.linalg.norm(line_j[0] - line_j[1])
+                if length > max_length:
+                    best_line = line_j
+                    max_length = length
+
+        # 如果找到有效的线段,则添加到结果中
+        if best_line is not None:
+            valid_lines.append(best_line)
+            valid_scores.append(max_length)  # 使用线段长度作为分数
+
+    return valid_lines, valid_scores
+
+
+def box_line_optimized_parallel(pred):
+    # 提取所有线段和分数
+    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
+    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]
+
+    # 初始化存储结果的列表
+    filtered_pred = []
+
+    # 使用多进程并行处理每个box
+    boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]]  # 所有box
+    num_processes = min(mp.cpu_count(), len(boxes))  # 使用可用的核心数
+
+    with mp.Pool(processes=num_processes) as pool:
+        results = pool.starmap(
+            process_box,
+            [(box, lines, scores) for box in boxes]
+        )
+
+    # 更新预测结果
+    for idx_box, (valid_lines, valid_scores) in enumerate(results):
+        if valid_lines:
+            pred[idx_box]['line'] = torch.tensor(valid_lines)
+            pred[idx_box]['line_score'] = torch.tensor(valid_scores)
+            filtered_pred.append(pred[idx_box])
+
+    return filtered_pred
+
+
+def predict(pt_path, model, img):
+    model = load_best_model(model, pt_path, device)
+    model.eval()
+
+    if isinstance(img, str):
+        img = Image.open(img).convert("RGB")
+
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)
+    im = img_tensor.permute(1, 2, 0)  # [512, 512, 3]
+    im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
+    img_tensor = torch.tensor(im_resized).permute(2, 0, 1)
+
+    with torch.no_grad():
+        t_start = time.time()
+        predictions = model([img_tensor.to(device)])
+        t_end = time.time()
+        print(f'Prediction used: {t_end - t_start:.4f} seconds')
+
+        boxes = predictions[0]['boxes'].shape
+        lines = predictions[-1]['wires']['lines'].shape
+        lines_scores = predictions[-1]['wires']['score'].shape
+        print(f'Predictions - boxes: {boxes}, lines: {lines}, lines_scores: {lines_scores}')
+
+    t_start = time.time()
+    pred = box_line_optimized_parallel(predictions)
+    t_end = time.time()
+    print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
+
+    # 检查 pred 是否为空
+    if not pred:
+        print("No valid predictions found. Skipping visualization.")
+        return
+
+    # 只绘制有效的线段
+    show_predict(img_tensor, pred, t_start)
+
+
+if __name__ == '__main__':
+    t_start = time.time()
+    print(f'Start to predict: {t_start}')
+    model = linenet_resnet50_fpn().to(device)
+    pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
+    # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
+    # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
+    img_path = r'C:\Users\m2337\Desktop\9.jpg'
+    predict(pt_path, model, img_path)
+    t_end = time.time()
+    print(f'Total prediction time: {t_end - t_start:.4f} seconds')