|
|
@@ -10,7 +10,7 @@ import matplotlib as mpl
|
|
|
import numpy as np
|
|
|
from models.line_detect.line_net import linenet_resnet50_fpn
|
|
|
from torchvision import transforms
|
|
|
-
|
|
|
+from rtree import index
|
|
|
# from models.wirenet.postprocess import postprocess
|
|
|
from models.wirenet.postprocess import postprocess
|
|
|
|
|
|
@@ -30,11 +30,26 @@ def load_best_model(model, save_path, device):
|
|
|
print(f"No saved model found at {save_path}")
|
|
|
return model
|
|
|
|
|
|
-def box_line_(pred):
|
|
|
- for idx, box_ in enumerate(pred[0:-1]):
|
|
|
- box = box_['boxes'] # 是一个tensor
|
|
|
- line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
|
|
|
- score = pred[-1]['wires']['score'][idx]
|
|
|
+
|
|
|
+def box_line_optimized(pred):
|
|
|
+ # 创建R-tree索引
|
|
|
+ idx = index.Index()
|
|
|
+
|
|
|
+ # 将所有线段添加到R-tree中
|
|
|
+ lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
|
|
|
+ scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
|
|
|
+
|
|
|
+ # 提取并处理所有线段
|
|
|
+ for idx_line in range(lines.shape[1]): # 遍历2500条线段
|
|
|
+ line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512 # 转换为numpy数组并调整比例
|
|
|
+ x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
|
|
|
+ y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
|
|
|
+ x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
|
|
|
+ y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
|
|
|
+ idx.insert(idx_line, (x_min, y_min, x_max, y_max))
|
|
|
+
|
|
|
+ for idx_box, box_ in enumerate(pred[0:-1]):
|
|
|
+ box = box_['boxes'].cpu().numpy() # 确保将张量转换为numpy数组
|
|
|
line_ = []
|
|
|
score_ = []
|
|
|
|
|
|
@@ -42,24 +57,61 @@ def box_line_(pred):
|
|
|
score_max = 0.0
|
|
|
tmp = [[0.0, 0.0], [0.0, 0.0]]
|
|
|
|
|
|
- for j in range(len(line)):
|
|
|
- if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
|
|
|
- line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
|
|
|
- line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
|
|
|
- line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
|
|
|
+ # 获取与当前box可能相交的所有线段
|
|
|
+ possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
|
|
|
+
|
|
|
+ for j in possible_matches:
|
|
|
+ line_j = lines[0, j].cpu().numpy() / 128 * 512 # 调整比例
|
|
|
+ if (line_j[0][0] >= i[0] and line_j[1][0] >= i[0] and
|
|
|
+ line_j[0][0] <= i[2] and line_j[1][0] <= i[2] and
|
|
|
+ line_j[0][1] >= i[1] and line_j[1][1] >= i[1] and
|
|
|
+ line_j[0][1] <= i[3] and line_j[1][1] <= i[3]):
|
|
|
+
|
|
|
+ if scores[j] > score_max:
|
|
|
+ tmp = line_j
|
|
|
+ score_max = scores[j]
|
|
|
|
|
|
- if score[j] > score_max:
|
|
|
- tmp = line[j]
|
|
|
- score_max = score[j]
|
|
|
line_.append(tmp)
|
|
|
score_.append(score_max)
|
|
|
+
|
|
|
processed_list = torch.tensor(line_)
|
|
|
- pred[idx]['line'] = processed_list
|
|
|
+ pred[idx_box]['line'] = processed_list
|
|
|
|
|
|
processed_s_list = torch.tensor(score_)
|
|
|
- pred[idx]['line_score'] = processed_s_list
|
|
|
+ pred[idx_box]['line_score'] = processed_s_list
|
|
|
+
|
|
|
return pred
|
|
|
|
|
|
+# def box_line_(pred):
|
|
|
+# for idx, box_ in enumerate(pred[0:-1]):
|
|
|
+# box = box_['boxes'] # 是一个tensor
|
|
|
+# line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
|
|
|
+# score = pred[-1]['wires']['score'][idx]
|
|
|
+# line_ = []
|
|
|
+# score_ = []
|
|
|
+#
|
|
|
+# for i in box:
|
|
|
+# score_max = 0.0
|
|
|
+# tmp = [[0.0, 0.0], [0.0, 0.0]]
|
|
|
+#
|
|
|
+# for j in range(len(line)):
|
|
|
+# if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
|
|
|
+# line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
|
|
|
+# line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
|
|
|
+# line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
|
|
|
+#
|
|
|
+# if score[j] > score_max:
|
|
|
+# tmp = line[j]
|
|
|
+# score_max = score[j]
|
|
|
+# line_.append(tmp)
|
|
|
+# score_.append(score_max)
|
|
|
+# processed_list = torch.tensor(line_)
|
|
|
+# pred[idx]['line'] = processed_list
|
|
|
+#
|
|
|
+# processed_s_list = torch.tensor(score_)
|
|
|
+# pred[idx]['line_score'] = processed_s_list
|
|
|
+# return pred
|
|
|
+
|
|
|
|
|
|
def predict(pt_path, model, img):
|
|
|
model = load_best_model(model, pt_path, device)
|
|
|
@@ -73,10 +125,19 @@ def predict(pt_path, model, img):
|
|
|
img_tensor = transform(img)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
+ t_start = time.time()
|
|
|
predictions = model([img_tensor.to(device)])
|
|
|
- # print(predictions)
|
|
|
-
|
|
|
- pred = box_line_(predictions)
|
|
|
+ t_end=time.time()
|
|
|
+ print(f'predict used:{t_end-t_start}')
|
|
|
+ # print(f'predictions:{predictions}')
|
|
|
+ boxes=predictions[0]['boxes'].shape
|
|
|
+ lines=predictions[-1]['wires']['lines'].shape
|
|
|
+ lines_scores=predictions[-1]['wires']['score'].shape
|
|
|
+ print(f'predictions boxes:{boxes},lines:{lines},lines_scores:{lines_scores}')
|
|
|
+ t_start=time.time()
|
|
|
+ pred = box_line_optimized(predictions)
|
|
|
+ t_end=time.time()
|
|
|
+ print(f'matched boxes and lines used:{t_end - t_start}')
|
|
|
# print(f'pred:{pred[0]}')
|
|
|
show_predict(img_tensor, pred, t_start)
|
|
|
|
|
|
@@ -85,9 +146,8 @@ if __name__ == '__main__':
|
|
|
t_start = time.time()
|
|
|
print(f'start to predict:{t_start}')
|
|
|
model = linenet_resnet50_fpn().to(device)
|
|
|
- pt_path = r'D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\linenet_wts\resnet50_best_e8.pth'
|
|
|
- img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png' # 工件图
|
|
|
- # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png' # wireframe图
|
|
|
+ pt_path = r"F:\BaiduNetdiskDownload\resnet50_best_e8.pth"
|
|
|
+ img_path = r"I:\datasets\wirenet_1000\images\val\00037040_0.png"
|
|
|
predict(pt_path, model, img_path)
|
|
|
t_end = time.time()
|
|
|
- print(f'predict used:{t_end - t_start}')
|
|
|
+ # print(f'predict used:{t_end - t_start}')
|