Browse Source

修复single point 映射ROI失败bug

RenLiqiang 5 months ago
parent
commit
fcf18376b5
2 changed files with 14 additions and 3 deletions
  1. 9 0
      libs/vision_libs/models/detection/transform.py
  2. 5 3
      models/line_detect/loi_heads.py

+ 9 - 0
libs/vision_libs/models/detection/transform.py

@@ -206,6 +206,11 @@ class GeneralizedRCNNTransform(nn.Module):
             lines = target["lines"]
             lines = resize_keypoints(lines, (h, w), image.shape[-2:])
             target["lines"] = lines
+
+        if "points" in target:
+            points = target["points"]
+            points = resize_keypoints(points, (h, w), image.shape[-2:])
+            target["points"] = points
         return image, target
 
     # _onnx_batch_images() is an implementation of
@@ -284,6 +289,10 @@ class GeneralizedRCNNTransform(nn.Module):
                 keypoints = pred["lines"]
                 keypoints = resize_keypoints(keypoints, im_s, o_im_s)
                 result[i]["lines"] = keypoints
+            if "points" in pred:
+                keypoints = pred["points"]
+                keypoints = resize_keypoints(keypoints, im_s, o_im_s)
+                result[i]["points"] = keypoints
         return result
 
     def __repr__(self) -> str:

+ 5 - 3
models/line_detect/loi_heads.py

@@ -201,16 +201,17 @@ def single_point_to_heatmap(keypoints, rois, heatmap_size):
     y = keypoints[..., 1].unsqueeze(1)
 
 
-    gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=1.0)
+    gs = generate_gaussian_heatmaps(x, y,num_points=1, heatmap_size=heatmap_size, sigma=2.0)
     # show_heatmap(gs[0],'target')
     all_roi_heatmap = []
     for roi, heatmap in zip(rois, gs):
+        # show_heatmap(heatmap, 'target')
         # print(f'heatmap:{heatmap.shape}')
         heatmap = heatmap.unsqueeze(0)
         x1, y1, x2, y2 = map(int, roi)
         roi_heatmap = torch.zeros_like(heatmap)
         roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
-        # show_heatmap(roi_heatmap,'roi_heatmap')
+        # show_heatmap(roi_heatmap[0],'roi_heatmap')
         all_roi_heatmap.append(roi_heatmap)
 
     all_roi_heatmap = torch.cat(all_roi_heatmap)
@@ -327,7 +328,7 @@ def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, dev
     assert xs.shape == ys.shape, "x and y must have the same shape"
     print(f'xs:{xs.shape}')
     N = xs.shape[0]
-    print(f'N:{N}')
+    print(f'N:{N},num_points:{num_points}')
 
     # 创建网格
     grid_y, grid_x = torch.meshgrid(
@@ -345,6 +346,7 @@ def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, dev
         for j in range(num_points):
             mu_x1 = xs[i, j].clamp(0, heatmap_size - 1).item()
             mu_y1 = ys[i, j].clamp(0, heatmap_size - 1).item()
+            # print(f'mu_x1,mu_y1:{mu_x1},{mu_y1}')
 
             # 计算距离平方
             dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2