Browse Source

修改获取lines points 的逻辑

RenLiqiang 5 months ago
parent
commit
9cabf6b25b
2 changed files with 68 additions and 1 deletions
  1. 67 1
      models/line_detect/roi_heads.py
  2. 1 0
      models/line_net/line_predictor.py

+ 67 - 1
models/line_detect/roi_heads.py

@@ -429,6 +429,72 @@ def heatmaps_to_keypoints(maps, rois):
         end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
         end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
 
 
     return xy_preds.permute(0, 2, 1), end_scores
     return xy_preds.permute(0, 2, 1), end_scores
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask
+
+def heatmaps_to_lines(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    xy_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
+
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        print(f'roi_map:{roi_map.shape}')
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        flatten_map=non_maximum_suppression(roi_map).reshape(1, -1)
+        score, index = torch.topk(flatten_map, k=2)
+
+        print(f'index:{index}')
+
+        # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        pos=index
+
+        x_int = pos % w
+
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+
+
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(1, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
 def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
 def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
     # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
     N, K, H, W = line_logits.shape
     N, K, H, W = line_logits.shape
@@ -487,7 +553,7 @@ def line_inference(x, boxes):
     x2 = x.split(boxes_per_image, dim=0)
     x2 = x.split(boxes_per_image, dim=0)
 
 
     for xx, bb in zip(x2, boxes):
     for xx, bb in zip(x2, boxes):
-        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_prob, scores = heatmaps_to_lines(xx, bb)
         kp_probs.append(kp_prob)
         kp_probs.append(kp_prob)
         kp_scores.append(scores)
         kp_scores.append(scores)
 
 

+ 1 - 0
models/line_net/line_predictor.py

@@ -265,6 +265,7 @@ class LineRCNNPredictor(nn.Module):
 
 
             n_type = jmap.shape[0]
             n_type = jmap.shape[0]
             jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
             jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
+
             joff = joff.reshape(n_type, 2, -1)
             joff = joff.reshape(n_type, 2, -1)
             max_K = self.n_dyn_junc // n_type
             max_K = self.n_dyn_junc // n_type
             N = len(junc)
             N = len(junc)