Parcourir la source

修改line_iou_loss

lstrlq il y a 5 mois
Parent
commit
f378b1a7a9
1 fichiers modifiés avec 7 ajouts et 13 suppressions
  1. 7 13
      models/line_detect/loi_heads.py

+ 7 - 13
models/line_detect/loi_heads.py

@@ -595,20 +595,14 @@ def line_iou_loss(x, boxes, gt_lines, matched_idx,img_size):
             continue
 
         # 匈牙利匹配,避免顺序错位
-        cost_matrix = torch.zeros((len(pred_lines), len(gt_line_points)))
+        # cost_matrix = torch.zeros((len(pred_lines), len(gt_line_points)))
         for i, pline in enumerate(pred_lines):
             for j, gline in enumerate(gt_line_points):
                 box1 = line_to_box(pline,img_size)
                 box2 = line_to_box(gline,img_size)
-                cost_matrix[i, j] = 1.0 - box_iou(box1, box2)
 
-        row_ind, col_ind = linear_sum_assignment(cost_matrix.numpy())
-
-        for r, c in zip(row_ind, col_ind):
-            box1 = line_to_box(pred_lines[r],img_size)
-            box2 = line_to_box(gt_line_points[c],img_size)
-            iou = box_iou(box1, box2)
-            losses.append(1.0 - iou)
+                iou = box_iou(box1, box2)
+                losses.append(1.0 - iou)
 
     total_loss = torch.mean(torch.stack(losses)) if losses else None
     return total_loss
@@ -1213,10 +1207,10 @@ class RoIHeads(nn.Module):
                 rcnn_loss_line = lines_point_pair_loss(
                     line_logits, line_proposals, gt_lines, pos_matched_idxs
                 )
-                iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs,img_size)
+                # iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs,img_size)
 
                 loss_line = {"loss_line": rcnn_loss_line}
-                loss_line_iou = {'loss_line_iou': iou_loss}
+                # loss_line_iou = {'loss_line_iou': iou_loss}
 
             else:
                 if targets is not None:
@@ -1226,8 +1220,8 @@ class RoIHeads(nn.Module):
                     )
                     loss_line = {"loss_line": rcnn_loss_lines}
 
-                    iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs,img_size)
-                    loss_line_iou={'loss_line_iou':iou_loss}
+                    # iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs,img_size)
+                    # loss_line_iou={'loss_line_iou':iou_loss}
 
 
                 else: