|
|
@@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
import torchvision
|
|
|
+from scipy.optimize import linear_sum_assignment
|
|
|
from torch import nn, Tensor
|
|
|
from libs.vision_libs.ops import boxes as box_ops, roi_align
|
|
|
|
|
|
@@ -543,84 +544,73 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
|
|
|
line_loss=F.cross_entropy(line_logits,gs_heatmaps)
|
|
|
|
|
|
return line_loss
|
|
|
-def is_collinear(p1, p2, q1, q2, eps=1e-6):
|
|
|
- v1 = p2 - p1
|
|
|
- v2 = q1 - p1
|
|
|
- cross_z = v1[0] * v2[1] - v1[1] * v2[0]
|
|
|
- return abs(cross_z) < eps
|
|
|
+def line_to_box(line):
|
|
|
+ p1 = line[:, :2][0]
|
|
|
+ p2 = line[:, :2][1]
|
|
|
|
|
|
+ x_coords = torch.tensor([p1[0], p2[0]])
|
|
|
+ y_coords = torch.tensor([p1[1], p2[1]])
|
|
|
|
|
|
-def segment_intersection_length(line1, line2):
|
|
|
- p1, p2 = line1
|
|
|
- q1, q2 = line2
|
|
|
+ x_min = x_coords.min().clamp(min=0)
|
|
|
+ y_min = y_coords.min().clamp(min=0)
|
|
|
+ x_max = x_coords.max().clamp(min=0)
|
|
|
+ y_max = y_coords.max().clamp(min=0)
|
|
|
|
|
|
- if not is_collinear(p1, p2, q1, q2):
|
|
|
- return 0.0
|
|
|
+ x_min -= 1
|
|
|
+ y_min -= 1
|
|
|
+ x_max += 1
|
|
|
+ y_max += 1
|
|
|
|
|
|
- dir_vec = p2 - p1
|
|
|
- if torch.norm(dir_vec) == 0:
|
|
|
- return 0.0
|
|
|
+ return torch.stack([x_min, y_min, x_max, y_max])
|
|
|
|
|
|
- def project(point):
|
|
|
- return torch.dot(point - p1, dir_vec)
|
|
|
|
|
|
- t_p1 = 0.0
|
|
|
- t_p2 = 1.0
|
|
|
- t_q1 = project(q1)
|
|
|
- t_q2 = project(q2)
|
|
|
+def box_iou(box1, box2):
|
|
|
+ # box: [x1, y1, x2, y2]
|
|
|
+ lt = torch.max(box1[:2], box2[:2])
|
|
|
+ rb = torch.min(box1[2:], box2[2:])
|
|
|
|
|
|
- t_min = max(t_p1, min(t_q1, t_q2))
|
|
|
- t_max = min(t_p2, max(t_q1, t_q2))
|
|
|
+ wh = (rb - lt).clamp(min=0)
|
|
|
+ inter_area = wh[0] * wh[1]
|
|
|
|
|
|
- if t_min >= t_max:
|
|
|
- return 0.0
|
|
|
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
|
|
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
|
|
|
|
|
- length = torch.norm(dir_vec) * (t_max - t_min)
|
|
|
- return length.item()
|
|
|
+ union_area = area1 + area2 - inter_area
|
|
|
+ iou = inter_area / (union_area + 1e-6)
|
|
|
|
|
|
+ return iou
|
|
|
|
|
|
-def line_iou(pred_line, target_line):
|
|
|
- pred_line_coords = pred_line[:, :2]
|
|
|
- target_line_coords = target_line[:, :2]
|
|
|
|
|
|
- l1_len = torch.norm(pred_line_coords[1] - pred_line_coords[0])
|
|
|
- l2_len = torch.norm(target_line_coords[1] - target_line_coords[0])
|
|
|
-
|
|
|
- inter_len = segment_intersection_length(pred_line_coords, target_line_coords)
|
|
|
- union_len = l1_len + l2_len - inter_len
|
|
|
-
|
|
|
- if union_len <= 0:
|
|
|
- return 0.0
|
|
|
-
|
|
|
- return inter_len / union_len
|
|
|
-
|
|
|
-
|
|
|
-def line_iou_loss(x, boxes,gt_lines,matched_idx):
|
|
|
- # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
|
|
|
- points_probs = []
|
|
|
- points_scores = []
|
|
|
+def line_iou_loss(x, boxes, gt_lines, matched_idx):
|
|
|
losses = []
|
|
|
-
|
|
|
boxes_per_image = [box.size(0) for box in boxes]
|
|
|
x2 = x.split(boxes_per_image, dim=0)
|
|
|
|
|
|
- for xx, bb,gt_line,mid in zip(x2, boxes,gt_lines,matched_idx):
|
|
|
+ for xx, bb, gt_line, mid in zip(x2, boxes, gt_lines, matched_idx):
|
|
|
p_prob, scores = heatmaps_to_lines(xx, bb)
|
|
|
- points_probs.append(p_prob)
|
|
|
- points_scores.append(scores)
|
|
|
- gt_line_points=gt_line[mid]
|
|
|
- print(f'gt_line_points:{gt_line_points.shape}')
|
|
|
- # 匹配预测线段和真实线段(例如匈牙利匹配)
|
|
|
- # 这里假设一对一匹配
|
|
|
- pred_lines = p_prob # shape: (num_pred_lines, 2, 2)
|
|
|
- print(f'pred_lines:{pred_lines.shape}')
|
|
|
+ pred_lines = p_prob
|
|
|
+ gt_line_points = gt_line[mid]
|
|
|
|
|
|
- for j in range(min(len(pred_lines), len(gt_line_points))):
|
|
|
- iou = line_iou(pred_lines[j], gt_line_points[j])
|
|
|
- losses.append(1.0 - iou) # 损失为 1 - IoU
|
|
|
+ if len(pred_lines) == 0 or len(gt_line_points) == 0:
|
|
|
+ continue
|
|
|
|
|
|
- total_loss = torch.mean(torch.stack(losses)) if losses else None
|
|
|
+ # 匈牙利匹配,避免顺序错位
|
|
|
+ cost_matrix = torch.zeros((len(pred_lines), len(gt_line_points)))
|
|
|
+ for i, pline in enumerate(pred_lines):
|
|
|
+ for j, gline in enumerate(gt_line_points):
|
|
|
+ box1 = line_to_box(pline)
|
|
|
+ box2 = line_to_box(gline)
|
|
|
+ cost_matrix[i, j] = 1.0 - box_iou(box1, box2)
|
|
|
|
|
|
+ row_ind, col_ind = linear_sum_assignment(cost_matrix.numpy())
|
|
|
+
|
|
|
+ for r, c in zip(row_ind, col_ind):
|
|
|
+ box1 = line_to_box(pred_lines[r])
|
|
|
+ box2 = line_to_box(gt_line_points[c])
|
|
|
+ iou = box_iou(box1, box2)
|
|
|
+ losses.append(1.0 - iou)
|
|
|
+
|
|
|
+ total_loss = torch.mean(torch.stack(losses)) if losses else None
|
|
|
return total_loss
|
|
|
|
|
|
def line_inference(x, boxes):
|