Bladeren bron

修改lines_iou_loss为并行计算

RenLiqiang 5 maanden geleden
bovenliggende
commit
866fc733c3
1 gewijzigde bestanden met toevoegingen van 13 en 15 verwijderingen
  1. 13 15
      models/line_detect/loi_heads.py

+ 13 - 15
models/line_detect/loi_heads.py

@@ -548,44 +548,44 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
 
 def lines_to_boxes(lines, img_size=511):
     """
-    输入:
-        lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
-        img_size: int,图像尺寸,用于 clamp 边界
+    输入:
+        lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
+        img_size: int,图像尺寸,用于 clamp 边界
 
-    输出:
-        boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
+    输出:
+        boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
     """
-    # 提取所有线段的两个端点
+    # 提取所有线段的两个端点
     p1 = lines[:, 0]  # (N, 2)
     p2 = lines[:, 1]  # (N, 2)
 
-    # 每条线段的 x 和 y 坐标
+    # 每条线段的 x 和 y 坐标
     x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1)  # (N, 2)
     y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1)  # (N, 2)
 
-    # 计算包围盒边界
+    # 计算包围盒边界
     x_min = x_coords.min(dim=1).values
     y_min = y_coords.min(dim=1).values
     x_max = x_coords.max(dim=1).values
     y_max = y_coords.max(dim=1).values
 
-    # 扩展边界并限制在图像范围内
+    # 扩展边界并限制在图像范围内
     x_min = (x_min - 1).clamp(min=0, max=img_size)
     y_min = (y_min - 1).clamp(min=0, max=img_size)
     x_max = (x_max + 1).clamp(min=0, max=img_size)
     y_max = (y_max + 1).clamp(min=0, max=img_size)
 
-    # 合成包围盒
+    # 合成包围盒
     boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1)  # (N, 4)
     return boxes
 
 def box_iou_pairwise(box1, box2):
     """
-    输入:
+    输入:
         box1: shape (N, 4)
         box2: shape (M, 4)
-    输出:
-        ious: shape (min(N, M), ), 只计算 i = j 的配对
+    输出:
+        ious: shape (min(N, M), ), 只计算 i = j 的配对
     """
     N = min(len(box1), len(box2))
     lt = torch.max(box1[:N, :2], box2[:N, :2])  # 左上角
@@ -614,11 +614,9 @@ def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511):
         if len(pred_lines) == 0 or len(gt_line_points) == 0:
             continue
 
-        # ==== 使用新的批量版 lines_to_boxes ====
         pred_boxes = lines_to_boxes(pred_lines, img_size)
         gt_boxes = lines_to_boxes(gt_line_points, img_size)
 
-        # ==== 成对 IoU 计算 ====
         ious = box_iou_pairwise(pred_boxes, gt_boxes)
 
         loss = 1.0 - ious