Browse Source

debug circle mask

admin 1 month ago
parent
commit
2e6990f2fd
2 changed files with 23 additions and 28 deletions
  1. 15 22
      models/line_detect/heads/head_losses.py
  2. 8 6
      models/line_detect/loi_heads.py

+ 15 - 22
models/line_detect/heads/head_losses.py

@@ -781,7 +781,7 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
 
     return line_loss
 
-def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
+def compute_mask_loss(feature_logits, proposals, gt_, pos_matched_idxs):
     print(f'compute_arc_loss:{feature_logits.shape}')
     N, K, H, W = feature_logits.shape
 
@@ -823,7 +823,7 @@ def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
         print(f'gt_kp_in_image:{gt_kp_in_image.shape}')
         if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
 
-            gs_heatmaps_per_img = arc_points_to_heatmap(kp, proposals_per_image, discretization_size)
+            gs_heatmaps_per_img = align_masks(kp, proposals_per_image, discretization_size)
             gs_heatmaps.append(gs_heatmaps_per_img)
 
     if len(gs_heatmaps)>0:
@@ -831,7 +831,7 @@ def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
         print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}')
 
         line_logits = feature_logits.squeeze(1)
-        print(f'single_point_logits:{line_logits.shape}')
+        print(f'mask shape:{line_logits.shape}')
 
         # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
 
@@ -846,7 +846,7 @@ def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
     return line_loss
 
 
-def arc_points_to_heatmap(keypoints, rois, heatmap_size):
+def align_masks(keypoints, rois, heatmap_size):
     print(f'rois:{rois.shape}')
     print(f'heatmap_size:{heatmap_size}')
 
@@ -863,24 +863,17 @@ def arc_points_to_heatmap(keypoints, rois, heatmap_size):
 
     num_points=x.shape[2]
     print(f'num_points:{num_points}')
-    gs = generate_mask_gaussian_heatmaps(x, y, num_points=num_points, heatmap_size=heatmap_size, sigma=10)
-    print(f'gs max :{gs.max()},gs.shape:{gs.shape}')
-    # show_heatmap(gs[0],'target')
-    all_roi_heatmap = []
-    for roi, heatmap in zip(rois, gs):
-        show_heatmap(heatmap, 'target')
-        print(f'heatmap.shape:{heatmap.shape}')
-        heatmap = heatmap.unsqueeze(0)
-        x1, y1, x2, y2 = map(int, roi)
-        roi_heatmap = torch.zeros_like(heatmap)
-        roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
-        # show_heatmap(roi_heatmap[0],'roi_heatmap')
-        all_roi_heatmap.append(roi_heatmap)
-
-    all_roi_heatmap = torch.cat(all_roi_heatmap)
-    print(f'all_roi_heatmap:{all_roi_heatmap.shape}')
-
-    return all_roi_heatmap
+    mask_4d = keypoints.unsqueeze(1).float()
+    resized_mask = F.interpolate(
+        mask_4d,
+    size = (heatmap_size, heatmap_size),
+    mode = 'bilinear',
+    align_corners = False
+    ).squeeze(1)  # [B,heatmap_size,heatmap_size]
+    # plt.imshow(resized_mask[0].cpu())
+    # plt.show()
+    print(f'resized_mask:{resized_mask.shape}')
+    return resized_mask
 
 def compute_point_loss(line_logits, proposals, gt_points, point_matched_idxs):
     # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor

+ 8 - 6
models/line_detect/loi_heads.py

@@ -13,7 +13,7 @@ import libs.vision_libs.models.detection._utils as det_utils
 from collections import OrderedDict
 
 from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
-    lines_point_pair_loss, features_align, line_inference, compute_arc_loss, arc_inference, compute_circle_loss, \
+    lines_point_pair_loss, features_align, line_inference, compute_mask_loss, arc_inference, compute_circle_loss, \
     circle_inference
 
 
@@ -135,6 +135,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
 
 
 
+
 def keypoints_to_heatmap(keypoints, rois, heatmap_size):
     # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
     offset_x = rois[:, 0]
@@ -1253,7 +1254,7 @@ class RoIHeads(nn.Module):
                     # if gt_arcs_tensor.shape[0] > 0:
                     #     print(f'start to compute point_loss')
                     if len(gt_arcs) > 0 and feature_logits is not None:
-                        loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
+                        loss_arc = compute_mask_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
 
                     if loss_arc is None:
                         print(f'loss_arc is None111')
@@ -1281,7 +1282,7 @@ class RoIHeads(nn.Module):
 
                         if len(gt_arcs) > 0 and feature_logits is not None:
                             print(f'start to compute arc_loss')
-                            loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
+                            loss_arc = compute_mask_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
 
 
                         if loss_arc is None:
@@ -1377,6 +1378,7 @@ class RoIHeads(nn.Module):
                         raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
 
                     gt_circles = [t["circle_masks"] for t in targets if "circle_masks" in t]
+                    gt_labels = [t["labels"] for t in targets]
 
                     print(f'gt_circle:{gt_circles[0].shape}')
                     h, w = targets[0]["img_size"]
@@ -1390,7 +1392,7 @@ class RoIHeads(nn.Module):
                     if gt_circles_tensor.shape[0] > 0:
                         print(f'start to compute circle_loss')
 
-                        loss_circle = compute_arc_loss(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
+                        loss_circle = compute_mask_loss(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
 
                         # loss_circle_extra=compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
 
@@ -1410,7 +1412,7 @@ class RoIHeads(nn.Module):
                         h, w = targets[0]["img_size"]
                         img_size = h
                         gt_circles = [t["circle_masks"] for t in targets if "circle_masks" in t]
-
+                        gt_labels = [t["labels"] for t in targets]
                         gt_circles_tensor = torch.zeros(0, 0)
                         if len(gt_circles) > 0:
                             gt_circles_tensor = torch.cat(gt_circles)
@@ -1419,7 +1421,7 @@ class RoIHeads(nn.Module):
                         if gt_circles_tensor.shape[0] > 0:
                             print(f'start to compute circle_loss')
 
-                            loss_circle = maskrcnn_loss(feature_logits, circle_proposals, gt_circles,
+                            loss_circle = compute_mask_loss(feature_logits, circle_proposals, gt_circles,
                                                             circle_pos_matched_idxs)
 
                             # loss_circle_extra = compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles,circle_pos_matched_idxs)