浏览代码

predict2_box内无线段时,选box内点组成线段最长的 两个点组成的线段返回

xue50 8 月之前
父节点
当前提交
21568f220b
共有 1 个文件被更改,包括 9 次插入5 次删除
  1. 9 5
      models/line_detect/predict2.py

+ 9 - 5
models/line_detect/predict2.py

@@ -188,7 +188,7 @@ def load_best_model(model, save_path, device):
     return model
 
 
-def box_line_(imgs, pred, length=False):    # 默认置信度
+def box_line_(imgs, pred, length=False):  # 默认置信度
     im = imgs.permute(1, 2, 0).cpu().numpy()
     line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
     line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
@@ -221,13 +221,14 @@ def box_line_(imgs, pred, length=False):    # 默认置信度
         pred[idx]['line_score'] = processed_s_list
     return pred
 
+
 # box内无线段时,选box内点组成线段最长的 两个点组成的线段返回
-def box_line1(imgs, pred, length=False):    # 默认置信度
+def box_line1(imgs, pred):  # 默认置信度
     im = imgs.permute(1, 2, 0).cpu().numpy()
     line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
     line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
 
-    points=pred[-1]['wires']['juncs'].cpu().numpy()[0]/ 128 * 512
+    points = pred[-1]['wires']['juncs'].cpu().numpy()[0] / 128 * 512
 
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
     line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
@@ -280,6 +281,7 @@ def box_line1(imgs, pred, length=False):    # 默认置信度
         pred[idx]['line_score'] = processed_s_list
     return pred
 
+
 def show_box(imgs, pred, t_start):
     col = [
         '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
@@ -378,6 +380,7 @@ def show_predict(imgs, pred, t_start):
 
     plt.show()
 
+
 def show_line(imgs, pred, t_start):
     im = imgs.permute(1, 2, 0)
     line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
@@ -408,6 +411,7 @@ def show_line(imgs, pred, t_start):
 
     plt.show()
 
+
 def predict(pt_path, model, img):
     model = load_best_model(model, pt_path, device)
 
@@ -436,12 +440,12 @@ def predict(pt_path, model, img):
     # show_line_optimized(img_, predictions, t_start)   # 只画线
     show_line(img_, predictions, t_end1)
     t_end2 = time.time()
-    show_box(img_, predictions, t_end2)   # 只画kuang
+    show_box(img_, predictions, t_end2)  # 只画kuang
     # show_box_or_line(img_, predictions, show_line=True, show_box=True)   # 参数确定画什么
     # show_box_and_line(img_, predictions, show_line=True, show_box=True)  # 一起画 1x2 2张图
 
     t_start = time.time()
-    pred = box_line_(img_, predictions)
+    pred = box_line1(img_, predictions)
     t_end = time.time()
     print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')