Browse Source

添加box_line_optimized(有bug)

RenLiqiang 10 months ago
parent
commit
5a396c22b5
2 changed files with 85 additions and 25 deletions
  1. 2 2
      .gitignore
  2. 83 23
      models/line_detect/predict2.py

+ 2 - 2
.gitignore

@@ -1,5 +1,7 @@
 .idea
 *.pt
+*.pth
+
 *.log
 *.onnx
 runs
@@ -29,5 +31,3 @@ checkpoint
 
 __pycache__
 train_results
-
-models/line_detect/linenet_wts

+ 83 - 23
models/line_detect/predict2.py

@@ -10,7 +10,7 @@ import matplotlib as mpl
 import numpy as np
 from models.line_detect.line_net import linenet_resnet50_fpn
 from torchvision import transforms
-
+from rtree import index
 # from models.wirenet.postprocess import postprocess
 from models.wirenet.postprocess import postprocess
 
@@ -30,11 +30,26 @@ def load_best_model(model, save_path, device):
         print(f"No saved model found at {save_path}")
     return model
 
-def box_line_(pred):
-    for idx, box_ in enumerate(pred[0:-1]):
-        box = box_['boxes']  # 是一个tensor
-        line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
-        score = pred[-1]['wires']['score'][idx]
+
+def box_line_optimized(pred):
+    # 创建R-tree索引
+    idx = index.Index()
+
+    # 将所有线段添加到R-tree中
+    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
+    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]
+
+    # 提取并处理所有线段
+    for idx_line in range(lines.shape[1]):  # 遍历2500条线段
+        line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512  # 转换为numpy数组并调整比例
+        x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
+        y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
+        x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
+        y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
+        idx.insert(idx_line, (x_min, y_min, x_max, y_max))
+
+    for idx_box, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes'].cpu().numpy()  # 确保将张量转换为numpy数组
         line_ = []
         score_ = []
 
@@ -42,24 +57,61 @@ def box_line_(pred):
             score_max = 0.0
             tmp = [[0.0, 0.0], [0.0, 0.0]]
 
-            for j in range(len(line)):
-                if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
-                        line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
-                        line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
-                        line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+            # 获取与当前box可能相交的所有线段
+            possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
+
+            for j in possible_matches:
+                line_j = lines[0, j].cpu().numpy() / 128 * 512  # 调整比例
+                if (line_j[0][0] >= i[0] and line_j[1][0] >= i[0] and
+                        line_j[0][0] <= i[2] and line_j[1][0] <= i[2] and
+                        line_j[0][1] >= i[1] and line_j[1][1] >= i[1] and
+                        line_j[0][1] <= i[3] and line_j[1][1] <= i[3]):
+
+                    if scores[j] > score_max:
+                        tmp = line_j
+                        score_max = scores[j]
 
-                    if score[j] > score_max:
-                        tmp = line[j]
-                        score_max = score[j]
             line_.append(tmp)
             score_.append(score_max)
+
         processed_list = torch.tensor(line_)
-        pred[idx]['line'] = processed_list
+        pred[idx_box]['line'] = processed_list
 
         processed_s_list = torch.tensor(score_)
-        pred[idx]['line_score'] = processed_s_list
+        pred[idx_box]['line_score'] = processed_s_list
+
     return pred
 
+# def box_line_(pred):
+#     for idx, box_ in enumerate(pred[0:-1]):
+#         box = box_['boxes']  # 是一个tensor
+#         line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
+#         score = pred[-1]['wires']['score'][idx]
+#         line_ = []
+#         score_ = []
+#
+#         for i in box:
+#             score_max = 0.0
+#             tmp = [[0.0, 0.0], [0.0, 0.0]]
+#
+#             for j in range(len(line)):
+#                 if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+#                         line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+#                         line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+#                         line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+#
+#                     if score[j] > score_max:
+#                         tmp = line[j]
+#                         score_max = score[j]
+#             line_.append(tmp)
+#             score_.append(score_max)
+#         processed_list = torch.tensor(line_)
+#         pred[idx]['line'] = processed_list
+#
+#         processed_s_list = torch.tensor(score_)
+#         pred[idx]['line_score'] = processed_s_list
+#     return pred
+
 
 def predict(pt_path, model, img):
     model = load_best_model(model, pt_path, device)
@@ -73,10 +125,19 @@ def predict(pt_path, model, img):
     img_tensor = transform(img)
 
     with torch.no_grad():
+        t_start = time.time()
         predictions = model([img_tensor.to(device)])
-        # print(predictions)
-
-    pred = box_line_(predictions)
+        t_end=time.time()
+        print(f'predict used:{t_end-t_start}')
+        # print(f'predictions:{predictions}')
+        boxes=predictions[0]['boxes'].shape
+        lines=predictions[-1]['wires']['lines'].shape
+        lines_scores=predictions[-1]['wires']['score'].shape
+        print(f'predictions boxes:{boxes},lines:{lines},lines_scores:{lines_scores}')
+    t_start=time.time()
+    pred = box_line_optimized(predictions)
+    t_end=time.time()
+    print(f'matched boxes and lines used:{t_end - t_start}')
     # print(f'pred:{pred[0]}')
     show_predict(img_tensor, pred, t_start)
 
@@ -85,9 +146,8 @@ if __name__ == '__main__':
     t_start = time.time()
     print(f'start to predict:{t_start}')
     model = linenet_resnet50_fpn().to(device)
-    pt_path = r'D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\linenet_wts\resnet50_best_e8.pth'
-    img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
-    # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
+    pt_path = r"F:\BaiduNetdiskDownload\resnet50_best_e8.pth"
+    img_path = r"I:\datasets\wirenet_1000\images\val\00037040_0.png"
     predict(pt_path, model, img_path)
     t_end = time.time()
-    print(f'predict used:{t_end - t_start}')
+    # print(f'predict used:{t_end - t_start}')