|
@@ -544,44 +544,64 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
|
|
|
line_loss=F.cross_entropy(line_logits,gs_heatmaps)
|
|
line_loss=F.cross_entropy(line_logits,gs_heatmaps)
|
|
|
|
|
|
|
|
return line_loss
|
|
return line_loss
|
|
|
-def line_to_box(line,img_size):
|
|
|
|
|
- p1 = line[:, :2][0]
|
|
|
|
|
- p2 = line[:, :2][1]
|
|
|
|
|
|
|
|
|
|
- x_coords = torch.tensor([p1[0], p2[0]])
|
|
|
|
|
- y_coords = torch.tensor([p1[1], p2[1]])
|
|
|
|
|
-
|
|
|
|
|
- x_min = x_coords.min().clamp(min=0)
|
|
|
|
|
- y_min = y_coords.min().clamp(min=0)
|
|
|
|
|
- x_max = x_coords.max().clamp(min=0)
|
|
|
|
|
- y_max = y_coords.max().clamp(min=0)
|
|
|
|
|
-
|
|
|
|
|
- x_min = (x_min - 1).clamp(min=0)
|
|
|
|
|
- y_min = (y_min - 1).clamp(min=0)
|
|
|
|
|
- x_max = (x_max + 1).clamp(max=img_size)
|
|
|
|
|
- y_max = (y_max + 1).clamp(max=img_size)
|
|
|
|
|
-
|
|
|
|
|
- return torch.stack([x_min, y_min, x_max, y_max])
|
|
|
|
|
|
|
|
|
|
|
|
+def lines_to_boxes(lines, img_size=511):
|
|
|
|
|
+ """
|
|
|
|
|
+ è¾å
¥:
|
|
|
|
|
+ lines: Tensor of shape (N, 2, 2)ï¼è¡¨ç¤º N æ¡çº¿æ®µï¼æ¯ä¸ªçº¿æ®µæä¸¤ä¸ªç«¯ç¹ (x, y)
|
|
|
|
|
+ img_size: intï¼å¾å尺寸ï¼ç¨äº clamp è¾¹ç
|
|
|
|
|
|
|
|
-def box_iou(box1, box2):
|
|
|
|
|
- # box: [x1, y1, x2, y2]
|
|
|
|
|
- lt = torch.max(box1[:2], box2[:2])
|
|
|
|
|
- rb = torch.min(box1[2:], box2[2:])
|
|
|
|
|
|
|
+ è¾åº:
|
|
|
|
|
+ boxes: Tensor of shape (N, 4)ï¼è¡¨ç¤º N 个å
å´ç [x_min, y_min, x_max, y_max]
|
|
|
|
|
+ """
|
|
|
|
|
+ # æåææçº¿æ®µç两个端ç¹
|
|
|
|
|
+ p1 = lines[:, 0] # (N, 2)
|
|
|
|
|
+ p2 = lines[:, 1] # (N, 2)
|
|
|
|
|
+
|
|
|
|
|
+ # æ¯æ¡çº¿æ®µç x å y åæ
|
|
|
|
|
+ x_coords = torch.stack([p1[:, 0], p2[:, 0]], dim=1) # (N, 2)
|
|
|
|
|
+ y_coords = torch.stack([p1[:, 1], p2[:, 1]], dim=1) # (N, 2)
|
|
|
|
|
+
|
|
|
|
|
+ # 计ç®å
å´çè¾¹ç
|
|
|
|
|
+ x_min = x_coords.min(dim=1).values
|
|
|
|
|
+ y_min = y_coords.min(dim=1).values
|
|
|
|
|
+ x_max = x_coords.max(dim=1).values
|
|
|
|
|
+ y_max = y_coords.max(dim=1).values
|
|
|
|
|
+
|
|
|
|
|
+ # æ©å±è¾¹çå¹¶éå¶å¨å¾åèå´å
|
|
|
|
|
+ x_min = (x_min - 1).clamp(min=0, max=img_size)
|
|
|
|
|
+ y_min = (y_min - 1).clamp(min=0, max=img_size)
|
|
|
|
|
+ x_max = (x_max + 1).clamp(min=0, max=img_size)
|
|
|
|
|
+ y_max = (y_max + 1).clamp(min=0, max=img_size)
|
|
|
|
|
+
|
|
|
|
|
+ # åæå
å´ç
|
|
|
|
|
+ boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1) # (N, 4)
|
|
|
|
|
+ return boxes
|
|
|
|
|
+
|
|
|
|
|
+def box_iou_pairwise(box1, box2):
|
|
|
|
|
+ """
|
|
|
|
|
+ è¾å
¥ï¼
|
|
|
|
|
+ box1: shape (N, 4)
|
|
|
|
|
+ box2: shape (M, 4)
|
|
|
|
|
+ è¾åºï¼
|
|
|
|
|
+ ious: shape (min(N, M), ), åªè®¡ç® i = j çé
对
|
|
|
|
|
+ """
|
|
|
|
|
+ N = min(len(box1), len(box2))
|
|
|
|
|
+ lt = torch.max(box1[:N, :2], box2[:N, :2]) # å·¦ä¸è§
|
|
|
|
|
+ rb = torch.min(box1[:N, 2:], box2[:N, 2:]) # å³ä¸è§
|
|
|
|
|
|
|
|
- wh = (rb - lt).clamp(min=0)
|
|
|
|
|
- inter_area = wh[0] * wh[1]
|
|
|
|
|
|
|
+ wh = (rb - lt).clamp(min=0) # 宽é«
|
|
|
|
|
+ inter_area = wh[:, 0] * wh[:, 1] # 交éé¢ç§¯
|
|
|
|
|
|
|
|
- area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
|
|
|
|
- area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
|
|
|
|
|
|
+ area1 = (box1[:N, 2] - box1[:N, 0]) * (box1[:N, 3] - box1[:N, 1])
|
|
|
|
|
+ area2 = (box2[:N, 2] - box2[:N, 0]) * (box2[:N, 3] - box2[:N, 1])
|
|
|
|
|
|
|
|
union_area = area1 + area2 - inter_area
|
|
union_area = area1 + area2 - inter_area
|
|
|
- iou = inter_area / (union_area + 1e-6)
|
|
|
|
|
-
|
|
|
|
|
- return iou
|
|
|
|
|
-
|
|
|
|
|
|
|
+ ious = inter_area / (union_area + 1e-6)
|
|
|
|
|
|
|
|
-def line_iou_loss(x, boxes, gt_lines, matched_idx,img_size):
|
|
|
|
|
|
|
+ return ious
|
|
|
|
|
+def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511):
|
|
|
losses = []
|
|
losses = []
|
|
|
boxes_per_image = [box.size(0) for box in boxes]
|
|
boxes_per_image = [box.size(0) for box in boxes]
|
|
|
x2 = x.split(boxes_per_image, dim=0)
|
|
x2 = x.split(boxes_per_image, dim=0)
|
|
@@ -594,17 +614,20 @@ def line_iou_loss(x, boxes, gt_lines, matched_idx,img_size):
|
|
|
if len(pred_lines) == 0 or len(gt_line_points) == 0:
|
|
if len(pred_lines) == 0 or len(gt_line_points) == 0:
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
- # 匈牙利匹配,避免顺序错位
|
|
|
|
|
- # cost_matrix = torch.zeros((len(pred_lines), len(gt_line_points)))
|
|
|
|
|
- for i, pline in enumerate(pred_lines):
|
|
|
|
|
- for j, gline in enumerate(gt_line_points):
|
|
|
|
|
- box1 = line_to_box(pline,img_size)
|
|
|
|
|
- box2 = line_to_box(gline,img_size)
|
|
|
|
|
|
|
+ # ==== ä½¿ç¨æ°çæ¹éç lines_to_boxes ====
|
|
|
|
|
+ pred_boxes = lines_to_boxes(pred_lines, img_size)
|
|
|
|
|
+ gt_boxes = lines_to_boxes(gt_line_points, img_size)
|
|
|
|
|
+
|
|
|
|
|
+ # ==== æå¯¹ IoU è®¡ç® ====
|
|
|
|
|
+ ious = box_iou_pairwise(pred_boxes, gt_boxes)
|
|
|
|
|
+
|
|
|
|
|
+ loss = 1.0 - ious
|
|
|
|
|
+ losses.append(loss)
|
|
|
|
|
|
|
|
- iou = box_iou(box1, box2)
|
|
|
|
|
- losses.append(1.0 - iou)
|
|
|
|
|
|
|
+ if not losses:
|
|
|
|
|
+ return None
|
|
|
|
|
|
|
|
- total_loss = torch.mean(torch.stack(losses)) if losses else None
|
|
|
|
|
|
|
+ total_loss = torch.mean(torch.cat(losses))
|
|
|
return total_loss
|
|
return total_loss
|
|
|
|
|
|
|
|
def line_inference(x, boxes):
|
|
def line_inference(x, boxes):
|
|
@@ -1207,10 +1230,10 @@ class RoIHeads(nn.Module):
|
|
|
rcnn_loss_line = lines_point_pair_loss(
|
|
rcnn_loss_line = lines_point_pair_loss(
|
|
|
line_logits, line_proposals, gt_lines, pos_matched_idxs
|
|
line_logits, line_proposals, gt_lines, pos_matched_idxs
|
|
|
)
|
|
)
|
|
|
- # iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs,img_size)
|
|
|
|
|
|
|
+ iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs,img_size)
|
|
|
|
|
|
|
|
loss_line = {"loss_line": rcnn_loss_line}
|
|
loss_line = {"loss_line": rcnn_loss_line}
|
|
|
- # loss_line_iou = {'loss_line_iou': iou_loss}
|
|
|
|
|
|
|
+ loss_line_iou = {'loss_line_iou': iou_loss}
|
|
|
|
|
|
|
|
else:
|
|
else:
|
|
|
if targets is not None:
|
|
if targets is not None:
|
|
@@ -1220,8 +1243,8 @@ class RoIHeads(nn.Module):
|
|
|
)
|
|
)
|
|
|
loss_line = {"loss_line": rcnn_loss_lines}
|
|
loss_line = {"loss_line": rcnn_loss_lines}
|
|
|
|
|
|
|
|
- # iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs,img_size)
|
|
|
|
|
- # loss_line_iou={'loss_line_iou':iou_loss}
|
|
|
|
|
|
|
+ iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs,img_size)
|
|
|
|
|
+ loss_line_iou={'loss_line_iou':iou_loss}
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
else:
|