Bladeren bron

调试single point,解决部分报错,可进行backward

RenLiqiang 5 maanden geleden
bovenliggende
commit
eb2fc524a4
1 gewijzigde bestanden met toevoegingen van 26 en 11 verwijderingen
  1. 26 11
      models/line_detect/loi_heads.py

+ 26 - 11
models/line_detect/loi_heads.py

@@ -197,8 +197,9 @@ def single_point_to_heatmap(keypoints, rois, heatmap_size):
     print(f'keypoints.shape:{keypoints.shape}')
     # batch_size, num_keypoints, _ = keypoints.shape
 
-    x = keypoints[..., 0]
-    y = keypoints[..., 1]
+    x = keypoints[..., 0].unsqueeze(1)
+    y = keypoints[..., 1].unsqueeze(1)
+
 
     gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
     # show_heatmap(gs[0],'target')
@@ -710,7 +711,8 @@ def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
     gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
     print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
 
-    line_logits = line_logits.squeeze(1)
+    line_logits = line_logits[:,0]
+    print(f'single_point_logits:{line_logits.shape}')
 
     line_loss = F.cross_entropy(line_logits, gs_heatmaps)
 
@@ -1487,18 +1489,24 @@ class RoIHeads(nn.Module):
                 print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
                 print(f'gt_points_tensor:{gt_points_tensor.shape}')
                 if gt_lines_tensor.shape[0]>0:
-                    rcnn_loss_line = lines_point_pair_loss(
+                    loss_line = lines_point_pair_loss(
                         line_logits, line_proposals, gt_lines, line_pos_matched_idxs
                     )
-                    iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
+                    loss_line_iou = line_iou_loss(line_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
 
                 if gt_points_tensor.shape[0]>0:
                     model_loss_point = compute_point_loss(
                         line_logits, point_proposals, gt_points, point_pos_matched_idxs
                     )
 
-                loss_line = {"loss_line": rcnn_loss_line}
-                loss_line_iou = {'loss_line_iou': iou_loss}
+                if not loss_line:
+                    loss_line = torch.tensor(0.0, device=line_features.device)
+
+                if not loss_line_iou:
+                    loss_line_iou = torch.tensor(0.0, device=line_features.device)
+
+                loss_line = {"loss_line": loss_line}
+                loss_line_iou = {'loss_line_iou': loss_line_iou}
                 loss_point = {"loss_point": model_loss_point}
 
             else:
@@ -1508,17 +1516,23 @@ class RoIHeads(nn.Module):
                     gt_lines = [t["lines"] for t in targets]
                     gt_points = [t["points"] for t in targets]
 
-                    rcnn_loss_line = lines_point_pair_loss(
+                    loss_line = lines_point_pair_loss(
                         line_logits, line_proposals, gt_lines, line_pos_matched_idxs
                     )
-                    iou_loss = line_iou_loss(line_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
+                    loss_line_iou = line_iou_loss(line_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
 
                     model_loss_point = compute_point_loss(
                         line_logits, point_proposals, gt_points, point_pos_matched_idxs
                     )
 
-                    loss_line = {"loss_line": rcnn_loss_line}
-                    loss_line_iou = {'loss_line_iou': iou_loss}
+                    if not loss_line :
+                        loss_line=torch.tensor(0.0,device=line_features.device)
+
+                    if not loss_line_iou :
+                        loss_line_iou=torch.tensor(0.0,device=line_features.device)
+
+                    loss_line = {"loss_line": loss_line}
+                    loss_line_iou = {'loss_line_iou': loss_line_iou}
                     loss_point={"loss_point":model_loss_point}
 
 
@@ -1537,6 +1551,7 @@ class RoIHeads(nn.Module):
             losses.update(loss_line)
             losses.update(loss_line_iou)
             losses.update(loss_point)
+            print(f'losses:{losses}')
 
         if self.has_mask():
             mask_proposals = [p["boxes"] for p in result]