Bläddra i källkod

Merge branch 'keypoint' of https://dev.lstznkj.com/DevLibs/MultiVisionModels into keypoint

lstrlq 5 månader sedan
förälder
incheckning
1abc5b198b
1 ändrade filer med 13 tillägg och 12 borttagningar
  1. 13 12
      models/line_detect/loi_heads.py

+ 13 - 12
models/line_detect/loi_heads.py

@@ -544,7 +544,7 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     line_loss=F.cross_entropy(line_logits,gs_heatmaps)
 
     return line_loss
-def line_to_box(line):
+def line_to_box(line,img_size):
     p1 = line[:, :2][0]
     p2 = line[:, :2][1]
 
@@ -556,10 +556,10 @@ def line_to_box(line):
     x_max = x_coords.max().clamp(min=0)
     y_max = y_coords.max().clamp(min=0)
 
-    x_min -= 1
-    y_min -= 1
-    x_max += 1
-    y_max += 1
+    x_min = (x_min - 1).clamp(min=0)
+    y_min = (y_min - 1).clamp(min=0)
+    x_max = (x_max + 1).clamp(max=img_size)
+    y_max = (y_max + 1).clamp(max=img_size)
 
     return torch.stack([x_min, y_min, x_max, y_max])
 
@@ -581,7 +581,7 @@ def box_iou(box1, box2):
     return iou
 
 
-def line_iou_loss(x, boxes, gt_lines, matched_idx):
+def line_iou_loss(x, boxes, gt_lines, matched_idx,img_size):
     losses = []
     boxes_per_image = [box.size(0) for box in boxes]
     x2 = x.split(boxes_per_image, dim=0)
@@ -598,15 +598,15 @@ def line_iou_loss(x, boxes, gt_lines, matched_idx):
         cost_matrix = torch.zeros((len(pred_lines), len(gt_line_points)))
         for i, pline in enumerate(pred_lines):
             for j, gline in enumerate(gt_line_points):
-                box1 = line_to_box(pline)
-                box2 = line_to_box(gline)
+                box1 = line_to_box(pline,img_size)
+                box2 = line_to_box(gline,img_size)
                 cost_matrix[i, j] = 1.0 - box_iou(box1, box2)
 
         row_ind, col_ind = linear_sum_assignment(cost_matrix.numpy())
 
         for r, c in zip(row_ind, col_ind):
-            box1 = line_to_box(pred_lines[r])
-            box2 = line_to_box(gt_line_points[c])
+            box1 = line_to_box(pred_lines[r],img_size)
+            box2 = line_to_box(gt_line_points[c],img_size)
             iou = box_iou(box1, box2)
             losses.append(1.0 - iou)
 
@@ -1204,6 +1204,7 @@ class RoIHeads(nn.Module):
 
             loss_line = {}
             loss_line_iou={}
+            img_size=512
             if self.training:
                 if targets is None or pos_matched_idxs is None:
                     raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
@@ -1212,7 +1213,7 @@ class RoIHeads(nn.Module):
                 rcnn_loss_line = lines_point_pair_loss(
                     line_logits, line_proposals, gt_lines, pos_matched_idxs
                 )
-                iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs)
+                iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, pos_matched_idxs,img_size)
 
                 loss_line = {"loss_line": rcnn_loss_line}
                 loss_line_iou = {'loss_line_iou': iou_loss}
@@ -1225,7 +1226,7 @@ class RoIHeads(nn.Module):
                     )
                     loss_line = {"loss_line": rcnn_loss_lines}
 
-                    iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs)
+                    iou_loss =line_iou_loss(line_logits, line_proposals,gt_lines,pos_matched_idxs,img_size)
                     loss_line_iou={'loss_line_iou':iou_loss}