Sfoglia il codice sorgente

修改line_iou_loss支持并行计算

lstrlq 5 mesi fa
parent
commit
0f19a393c5
1 ha cambiato i file con 66 aggiunte e 43 eliminazioni
  1. 66 43
      models/line_detect/loi_heads.py

+ 66 - 43
models/line_detect/loi_heads.py

@@ -544,44 +544,64 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     line_loss=F.cross_entropy(line_logits,gs_heatmaps)
 
     return line_loss
-def line_to_box(line,img_size):
-    p1 = line[:, :2][0]
-    p2 = line[:, :2][1]
 
-    x_coords = torch.tensor([p1[0], p2[0]])
-    y_coords = torch.tensor([p1[1], p2[1]])
-
-    x_min = x_coords.min().clamp(min=0)
-    y_min = y_coords.min().clamp(min=0)
-    x_max = x_coords.max().clamp(min=0)
-    y_max = y_coords.max().clamp(min=0)
-
-    x_min = (x_min - 1).clamp(min=0)
-    y_min = (y_min - 1).clamp(min=0)
-    x_max = (x_max + 1).clamp(max=img_size)
-    y_max = (y_max + 1).clamp(max=img_size)
-
-    return torch.stack([x_min, y_min, x_max, y_max])
 
+def lines_to_boxes(lines, img_size=511):
+    """
+    输入:
+        lines: Tensor of shape (N, 2, 2),表示 N 条线段,每个线段有两个端点 (x, y)
+        img_size: int,图像尺寸,用于 clamp 边界
 
-def box_iou(box1, box2):
-    # box: [x1, y1, x2, y2]
-    lt = torch.max(box1[:2], box2[:2])
-    rb = torch.min(box1[2:], box2[2:])
+    输出:
+        boxes: Tensor of shape (N, 4),表示 N 个包围盒 [x_min, y_min, x_max, y_max]
+    """
+    # 提取所有线段的两个端点
+    p1 = lines[:, 0]  # (N, 2)
+    p2 = lines[:, 1]  # (N, 2)
+
+    # 每条线段的 x 和 y 坐标
+    x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1)  # (N, 2)
+    y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1)  # (N, 2)
+
+    # 计算包围盒边界
+    x_min = x_coords.min(dim=1).values
+    y_min = y_coords.min(dim=1).values
+    x_max = x_coords.max(dim=1).values
+    y_max = y_coords.max(dim=1).values
+
+    # 扩展边界并限制在图像范围内
+    x_min = (x_min - 1).clamp(min=0, max=img_size)
+    y_min = (y_min - 1).clamp(min=0, max=img_size)
+    x_max = (x_max + 1).clamp(min=0, max=img_size)
+    y_max = (y_max + 1).clamp(min=0, max=img_size)
+
+    # 合成包围盒
+    boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1)  # (N, 4)
+    return boxes
+
+def box_iou_pairwise(box1, box2):
+    """
+    输入:
+        box1: shape (N, 4)
+        box2: shape (M, 4)
+    输出:
+        ious: shape (min(N, M), ), 只计算 i = j 的配对
+    """
+    N = min(len(box1), len(box2))
+    lt = torch.max(box1[:N, :2], box2[:N, :2])  # 左上角
+    rb = torch.min(box1[:N, 2:], box2[:N, 2:])  # 右下角
 
-    wh = (rb - lt).clamp(min=0)
-    inter_area = wh[0] * wh[1]
+    wh = (rb - lt).clamp(min=0)  # 宽高
+    inter_area = wh[:, 0] * wh[:, 1]  # 交集面积
 
-    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
-    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
+    area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
+    area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
 
     union_area = area1 + area2 - inter_area
-    iou = inter_area / (union_area + 1e-6)
-
-    return iou
-
+    ious = inter_area / (union_area + 1e-6)
 
-def line_iou_loss(x, boxes, gt_lines, matched_idx,img_size):
+    return ious
+def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511):
     losses = []
     boxes_per_image = [box.size(0) for box in boxes]
     x2 = x.split(boxes_per_image, dim=0)
@@ -594,17 +614,20 @@ def line_iou_loss(x, boxes, gt_lines, matched_idx,img_size):
         if len(pred_lines) == 0 or len(gt_line_points) == 0:
             continue
 
-        # 匈牙利匹配,避免顺序错位
-        # cost_matrix = torch.zeros((len(pred_lines), len(gt_line_points)))
-        for i, pline in enumerate(pred_lines):
-            for j, gline in enumerate(gt_line_points):
-                box1 = line_to_box(pline,img_size)
-                box2 = line_to_box(gline,img_size)
+        # ==== 使用新的批量版 lines_to_boxes ====
+        pred_boxes = lines_to_boxes(pred_lines, img_size)
+        gt_boxes = lines_to_boxes(gt_line_points, img_size)
+
+        # ==== 成对 IoU 计算 ====
+        ious = box_iou_pairwise(pred_boxes, gt_boxes)
+
+        loss = 1.0 - ious
+        losses.append(loss)
 
-                iou = box_iou(box1, box2)
-                losses.append(1.0 - iou)
+    if not losses:
+        return None
 
-    total_loss = torch.mean(torch.stack(losses)) if losses else None
+    total_loss = torch.mean(torch.cat(losses))
     return total_loss
 
 def line_inference(x, boxes):
@@ -1207,10 +1230,10 @@ class RoIHeads(nn.Module):
                 rcnn_loss_line = lines_point_pair_loss(
                     line_logits, line_proposals, gt_lines, pos_matched_idxs
                 )
-                # iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs,img_size)
+                iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs,img_size)
 
                 loss_line = {"loss_line": rcnn_loss_line}
-                # loss_line_iou = {'loss_line_iou': iou_loss}
+                loss_line_iou = {'loss_line_iou': iou_loss}
 
             else:
                 if targets is not None:
@@ -1220,8 +1243,8 @@ class RoIHeads(nn.Module):
                     )
                     loss_line = {"loss_line": rcnn_loss_lines}
 
-                    # iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs,img_size)
-                    # loss_line_iou={'loss_line_iou':iou_loss}
+                    iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs,img_size)
+                    loss_line_iou={'loss_line_iou':iou_loss}
 
 
                 else: