|
@@ -544,7 +544,7 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
|
|
|
line_loss=F.cross_entropy(line_logits,gs_heatmaps)
|
|
line_loss=F.cross_entropy(line_logits,gs_heatmaps)
|
|
|
|
|
|
|
|
return line_loss
|
|
return line_loss
|
|
|
-def line_to_box(line):
|
|
|
|
|
|
|
+def line_to_box(line,img_size):
|
|
|
p1 = line[:, :2][0]
|
|
p1 = line[:, :2][0]
|
|
|
p2 = line[:, :2][1]
|
|
p2 = line[:, :2][1]
|
|
|
|
|
|
|
@@ -556,10 +556,10 @@ def line_to_box(line):
|
|
|
x_max = x_coords.max().clamp(min=0)
|
|
x_max = x_coords.max().clamp(min=0)
|
|
|
y_max = y_coords.max().clamp(min=0)
|
|
y_max = y_coords.max().clamp(min=0)
|
|
|
|
|
|
|
|
- x_min -= 1
|
|
|
|
|
- y_min -= 1
|
|
|
|
|
- x_max += 1
|
|
|
|
|
- y_max += 1
|
|
|
|
|
|
|
+ x_min = (x_min - 1).clamp(min=0)
|
|
|
|
|
+ y_min = (y_min - 1).clamp(min=0)
|
|
|
|
|
+ x_max = (x_max + 1).clamp(max=img_size)
|
|
|
|
|
+ y_max = (y_max + 1).clamp(max=img_size)
|
|
|
|
|
|
|
|
return torch.stack([x_min, y_min, x_max, y_max])
|
|
return torch.stack([x_min, y_min, x_max, y_max])
|
|
|
|
|
|
|
@@ -581,7 +581,7 @@ def box_iou(box1, box2):
|
|
|
return iou
|
|
return iou
|
|
|
|
|
|
|
|
|
|
|
|
|
-def line_iou_loss(x, boxes, gt_lines, matched_idx):
|
|
|
|
|
|
|
+def line_iou_loss(x, boxes, gt_lines, matched_idx,img_size):
|
|
|
losses = []
|
|
losses = []
|
|
|
boxes_per_image = [box.size(0) for box in boxes]
|
|
boxes_per_image = [box.size(0) for box in boxes]
|
|
|
x2 = x.split(boxes_per_image, dim=0)
|
|
x2 = x.split(boxes_per_image, dim=0)
|
|
@@ -598,15 +598,15 @@ def line_iou_loss(x, boxes, gt_lines, matched_idx):
|
|
|
cost_matrix = torch.zeros((len(pred_lines), len(gt_line_points)))
|
|
cost_matrix = torch.zeros((len(pred_lines), len(gt_line_points)))
|
|
|
for i, pline in enumerate(pred_lines):
|
|
for i, pline in enumerate(pred_lines):
|
|
|
for j, gline in enumerate(gt_line_points):
|
|
for j, gline in enumerate(gt_line_points):
|
|
|
- box1 = line_to_box(pline)
|
|
|
|
|
- box2 = line_to_box(gline)
|
|
|
|
|
|
|
+ box1 = line_to_box(pline,img_size)
|
|
|
|
|
+ box2 = line_to_box(gline,img_size)
|
|
|
cost_matrix[i, j] = 1.0 - box_iou(box1, box2)
|
|
cost_matrix[i, j] = 1.0 - box_iou(box1, box2)
|
|
|
|
|
|
|
|
row_ind, col_ind = linear_sum_assignment(cost_matrix.numpy())
|
|
row_ind, col_ind = linear_sum_assignment(cost_matrix.numpy())
|
|
|
|
|
|
|
|
for r, c in zip(row_ind, col_ind):
|
|
for r, c in zip(row_ind, col_ind):
|
|
|
- box1 = line_to_box(pred_lines[r])
|
|
|
|
|
- box2 = line_to_box(gt_line_points[c])
|
|
|
|
|
|
|
+ box1 = line_to_box(pred_lines[r],img_size)
|
|
|
|
|
+ box2 = line_to_box(gt_line_points[c],img_size)
|
|
|
iou = box_iou(box1, box2)
|
|
iou = box_iou(box1, box2)
|
|
|
losses.append(1.0 - iou)
|
|
losses.append(1.0 - iou)
|
|
|
|
|
|
|
@@ -1204,6 +1204,7 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
|
|
loss_line = {}
|
|
loss_line = {}
|
|
|
loss_line_iou={}
|
|
loss_line_iou={}
|
|
|
|
|
+ img_size=512
|
|
|
if self.training:
|
|
if self.training:
|
|
|
if targets is None or pos_matched_idxs is None:
|
|
if targets is None or pos_matched_idxs is None:
|
|
|
raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
|
|
raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
|
|
@@ -1212,7 +1213,7 @@ class RoIHeads(nn.Module):
|
|
|
rcnn_loss_line = lines_point_pair_loss(
|
|
rcnn_loss_line = lines_point_pair_loss(
|
|
|
line_logits, line_proposals, gt_lines, pos_matched_idxs
|
|
line_logits, line_proposals, gt_lines, pos_matched_idxs
|
|
|
)
|
|
)
|
|
|
- iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs)
|
|
|
|
|
|
|
+ iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs,img_size)
|
|
|
|
|
|
|
|
loss_line = {"loss_line": rcnn_loss_line}
|
|
loss_line = {"loss_line": rcnn_loss_line}
|
|
|
loss_line_iou = {'loss_line_iou': iou_loss}
|
|
loss_line_iou = {'loss_line_iou': iou_loss}
|
|
@@ -1225,7 +1226,7 @@ class RoIHeads(nn.Module):
|
|
|
)
|
|
)
|
|
|
loss_line = {"loss_line": rcnn_loss_lines}
|
|
loss_line = {"loss_line": rcnn_loss_lines}
|
|
|
|
|
|
|
|
- iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs)
|
|
|
|
|
|
|
+ iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs,img_size)
|
|
|
loss_line_iou={'loss_line_iou':iou_loss}
|
|
loss_line_iou={'loss_line_iou':iou_loss}
|
|
|
|
|
|
|
|
|
|
|