|
|
@@ -595,20 +595,14 @@ def line_iou_loss(x, boxes, gt_lines, matched_idx,img_size):
|
|
|
continue
|
|
|
|
|
|
# 匈牙利匹配,避免顺序错位
|
|
|
- cost_matrix = torch.zeros((len(pred_lines), len(gt_line_points)))
|
|
|
+ # cost_matrix = torch.zeros((len(pred_lines), len(gt_line_points)))
|
|
|
for i, pline in enumerate(pred_lines):
|
|
|
for j, gline in enumerate(gt_line_points):
|
|
|
box1 = line_to_box(pline,img_size)
|
|
|
box2 = line_to_box(gline,img_size)
|
|
|
- cost_matrix[i, j] = 1.0 - box_iou(box1, box2)
|
|
|
|
|
|
- row_ind, col_ind = linear_sum_assignment(cost_matrix.numpy())
|
|
|
-
|
|
|
- for r, c in zip(row_ind, col_ind):
|
|
|
- box1 = line_to_box(pred_lines[r],img_size)
|
|
|
- box2 = line_to_box(gt_line_points[c],img_size)
|
|
|
- iou = box_iou(box1, box2)
|
|
|
- losses.append(1.0 - iou)
|
|
|
+ iou = box_iou(box1, box2)
|
|
|
+ losses.append(1.0 - iou)
|
|
|
|
|
|
total_loss = torch.mean(torch.stack(losses)) if losses else None
|
|
|
return total_loss
|
|
|
@@ -1213,10 +1207,10 @@ class RoIHeads(nn.Module):
|
|
|
rcnn_loss_line = lines_point_pair_loss(
|
|
|
line_logits, line_proposals, gt_lines, pos_matched_idxs
|
|
|
)
|
|
|
- iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs,img_size)
|
|
|
+ # iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs,img_size)
|
|
|
|
|
|
loss_line = {"loss_line": rcnn_loss_line}
|
|
|
- loss_line_iou = {'loss_line_iou': iou_loss}
|
|
|
+ # loss_line_iou = {'loss_line_iou': iou_loss}
|
|
|
|
|
|
else:
|
|
|
if targets is not None:
|
|
|
@@ -1226,8 +1220,8 @@ class RoIHeads(nn.Module):
|
|
|
)
|
|
|
loss_line = {"loss_line": rcnn_loss_lines}
|
|
|
|
|
|
- iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs,img_size)
|
|
|
- loss_line_iou={'loss_line_iou':iou_loss}
|
|
|
+ # iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs,img_size)
|
|
|
+ # loss_line_iou={'loss_line_iou':iou_loss}
|
|
|
|
|
|
|
|
|
else:
|