Quellcode durchsuchen

优化line_iou_loss

RenLiqiang vor 5 Monaten
Ursprung
Commit
3922df8740
2 geänderte Dateien mit 50 neuen und 59 gelöschten Zeilen
  1. 2 1
      models/line_detect/line_dataset.py
  2. 48 58
      models/line_detect/loi_heads.py

+ 2 - 1
models/line_detect/line_dataset.py

@@ -33,7 +33,7 @@ def validate_keypoints(keypoints, image_width, image_height):
 
 
 class LineDataset(BaseDataset):
-    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
         super().__init__(dataset_path)
 
         self.data_path = dataset_path
@@ -44,6 +44,7 @@ class LineDataset(BaseDataset):
         self.imgs = os.listdir(self.img_path)
         self.lbls = os.listdir(self.lbl_path)
         self.target_type = target_type
+        self.img_type=img_type
         # self.default_transform = DefaultTransform()
 
     def __getitem__(self, index) -> T_co:

+ 48 - 58
models/line_detect/loi_heads.py

@@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
 import torch
 import torch.nn.functional as F
 import torchvision
+from scipy.optimize import linear_sum_assignment
 from torch import nn, Tensor
 from  libs.vision_libs.ops import boxes as box_ops, roi_align
 
@@ -543,84 +544,73 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     line_loss=F.cross_entropy(line_logits,gs_heatmaps)
 
     return line_loss
-def is_collinear(p1, p2, q1, q2, eps=1e-6):
-    v1 = p2 - p1
-    v2 = q1 - p1
-    cross_z = v1[0] * v2[1] - v1[1] * v2[0]
-    return abs(cross_z) < eps
+def line_to_box(line):
+    p1 = line[:, :2][0]
+    p2 = line[:, :2][1]
 
+    x_coords = torch.tensor([p1[0], p2[0]])
+    y_coords = torch.tensor([p1[1], p2[1]])
 
-def segment_intersection_length(line1, line2):
-    p1, p2 = line1
-    q1, q2 = line2
+    x_min = x_coords.min().clamp(min=0)
+    y_min = y_coords.min().clamp(min=0)
+    x_max = x_coords.max().clamp(min=0)
+    y_max = y_coords.max().clamp(min=0)
 
-    if not is_collinear(p1, p2, q1, q2):
-        return 0.0
+    x_min -= 1
+    y_min -= 1
+    x_max += 1
+    y_max += 1
 
-    dir_vec = p2 - p1
-    if torch.norm(dir_vec) == 0:
-        return 0.0
+    return torch.stack([x_min, y_min, x_max, y_max])
 
-    def project(point):
-        return torch.dot(point - p1, dir_vec)
 
-    t_p1 = 0.0
-    t_p2 = 1.0
-    t_q1 = project(q1)
-    t_q2 = project(q2)
+def box_iou(box1, box2):
+    # box: [x1, y1, x2, y2]
+    lt = torch.max(box1[:2], box2[:2])
+    rb = torch.min(box1[2:], box2[2:])
 
-    t_min = max(t_p1, min(t_q1, t_q2))
-    t_max = min(t_p2, max(t_q1, t_q2))
+    wh = (rb - lt).clamp(min=0)
+    inter_area = wh[0] * wh[1]
 
-    if t_min >= t_max:
-        return 0.0
+    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
+    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
 
-    length = torch.norm(dir_vec) * (t_max - t_min)
-    return length.item()
+    union_area = area1 + area2 - inter_area
+    iou = inter_area / (union_area + 1e-6)
 
+    return iou
 
-def line_iou(pred_line, target_line):
-    pred_line_coords = pred_line[:, :2]
-    target_line_coords = target_line[:, :2]
 
-    l1_len = torch.norm(pred_line_coords[1] - pred_line_coords[0])
-    l2_len = torch.norm(target_line_coords[1] - target_line_coords[0])
-
-    inter_len = segment_intersection_length(pred_line_coords, target_line_coords)
-    union_len = l1_len + l2_len - inter_len
-
-    if union_len <= 0:
-        return 0.0
-
-    return inter_len / union_len
-
-
-def line_iou_loss(x, boxes,gt_lines,matched_idx):
-    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
-    points_probs = []
-    points_scores = []
+def line_iou_loss(x, boxes, gt_lines, matched_idx):
     losses = []
-
     boxes_per_image = [box.size(0) for box in boxes]
     x2 = x.split(boxes_per_image, dim=0)
 
-    for xx, bb,gt_line,mid in zip(x2, boxes,gt_lines,matched_idx):
+    for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
         p_prob, scores = heatmaps_to_lines(xx, bb)
-        points_probs.append(p_prob)
-        points_scores.append(scores)
-        gt_line_points=gt_line[mid]
-        print(f'gt_line_points:{gt_line_points.shape}')
-        # 匹配预测线段和真实线段(例如匈牙利匹配)
-        # 这里假设一对一匹配
-        pred_lines = p_prob  # shape: (num_pred_lines, 2, 2)
-        print(f'pred_lines:{pred_lines.shape}')
+        pred_lines = p_prob
+        gt_line_points = gt_line[mid]
 
-        for j in range(min(len(pred_lines), len(gt_line_points))):
-            iou = line_iou(pred_lines[j], gt_line_points[j])
-            losses.append(1.0 - iou)  # 损失为 1 - IoU
+        if len(pred_lines) == 0 or len(gt_line_points) == 0:
+            continue
 
-    total_loss = torch.mean(torch.stack(losses)) if losses else None
+        # 匈牙利匹配,避免顺序错位
+        cost_matrix = torch.zeros((len(pred_lines), len(gt_line_points)))
+        for i, pline in enumerate(pred_lines):
+            for j, gline in enumerate(gt_line_points):
+                box1 = line_to_box(pline)
+                box2 = line_to_box(gline)
+                cost_matrix[i, j] = 1.0 - box_iou(box1, box2)
 
+        row_ind, col_ind = linear_sum_assignment(cost_matrix.numpy())
+
+        for r, c in zip(row_ind, col_ind):
+            box1 = line_to_box(pred_lines[r])
+            box2 = line_to_box(gt_line_points[c])
+            iou = box_iou(box1, box2)
+            losses.append(1.0 - iou)
+
+    total_loss = torch.mean(torch.stack(losses)) if losses else None
     return total_loss
 
 def line_inference(x, boxes):