浏览代码

修改合并特征图代码

RenLiqiang 5 月之前
父节点
当前提交
79435ef46d
共有 2 个文件被更改,包括 65 次插入3 次删除
  1. 46 0
      models/wirenet/TestPointMap.py
  2. 19 3
      models/wirenet/head.py

+ 46 - 0
models/wirenet/TestPointMap.py

@@ -0,0 +1,46 @@
+def map_heatmap_keypoints_to_original_image(heatmap, rois, downsample_ratio=4, joff=None):
+    """
+    将热力图中的关键点映射回原始图像的位置。
+
+    参数:
+    heatmap (torch.Tensor): 热力图,形状为 [H, W]
+    rois (list of tuples): 每个ROI的坐标列表 [(x_min, y_min, x_max, y_max), ...]
+    downsample_ratio (int): 下采样比例,默认为4
+    joff (torch.Tensor, optional): 偏移图,形状为 [2, H, W]
+
+    返回:
+    list of tuples: 每个ROI对应的关键点在原始图像中的坐标 [(x, y), ...]
+    """
+    keypoints_in_original_image = []
+
+    for i, (x_min, y_min, x_max, y_max) in enumerate(rois):
+        roi_width = x_max - x_min
+        roi_height = y_max - y_min
+
+        # 获取热力图中的关键点位置
+        heatmap_roi = heatmap[i] if len(heatmap.shape) == 4 else heatmap
+        y_prime, x_prime = torch.where(heatmap_roi == torch.max(heatmap_roi))
+
+        if len(y_prime) > 0 and len(x_prime) > 0:
+            y_prime, x_prime = y_prime.item(), x_prime.item()
+
+            # 如果有偏移图,则应用偏移修正
+            if joff is not None:
+                offset_x = joff[0, y_prime, x_prime].item()
+                offset_y = joff[1, y_prime, x_prime].item()
+                x_prime += offset_x
+                y_prime += offset_y
+
+            # 计算ROI内的相对坐标
+            relative_x = x_prime / 128 * roi_width
+            relative_y = y_prime / 128 * roi_height
+
+            # 映射回原始图像坐标
+            final_x = relative_x + x_min
+            final_y = relative_y + y_min
+
+            keypoints_in_original_image.append((final_x.item(), final_y.item()))
+        else:
+            keypoints_in_original_image.append(None)  # 如果没有找到关键点
+
+    return keypoints_in_original_image

+ 19 - 3
models/wirenet/head.py

@@ -1128,7 +1128,12 @@ class RoIHeads(nn.Module):
             # print(f'keypoint_features from roi_pool:{wirepoint_features.shape}')
             outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
 
+
+
             outputs = merge_features(outputs, wirepoint_proposals)
+
+
+
             wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
 
             print(f'outpust:{outputs.shape}')
@@ -1204,6 +1209,7 @@ class RoIHeads(nn.Module):
 
 def merge_features(features, proposals):
     print(f'features:{features.shape}')
+    print(f'proposals:{len(proposals)}')
     def diagnose_input(features, proposals):
         """诊断输入数据"""
         print("Input Diagnostics:")
@@ -1225,11 +1231,19 @@ def merge_features(features, proposals):
                 f"Proposals count ({proposals_count}) must match features batch size ({features_size})"
             )
 
-    def safe_max_reduction(features_per_img):
+    def safe_max_reduction(features_per_img,proposals):
+
+        print(f'proposal:{proposals.shape},features_per_img:{features_per_img.shape}')
         """安全的最大值压缩"""
         if features_per_img.numel() == 0:
             return torch.zeros_like(features_per_img).unsqueeze(0)
 
+        for feature_map,roi in zip(features_per_img,proposals):
+            print(f'feature_map:{feature_map.shape},roi:{roi}')
+            roi_off_x=roi[0]
+            roi_off_y=roi[1]
+
+
         try:
             # 沿着第0维求最大值,保持维度
             max_features, _ = torch.max(features_per_img, dim=0, keepdim=True)
@@ -1257,8 +1271,10 @@ def merge_features(features, proposals):
 
         # 每张图像特征压缩
         features_imgs = []
-        for features_per_img in split_features:
-            compressed_features = safe_max_reduction(features_per_img)
+
+        print(f'split_features:{len(split_features)}')
+        for features_per_img,proposal in zip(split_features,proposals):
+            compressed_features = safe_max_reduction(features_per_img,proposal)
             features_imgs.append(compressed_features)
 
         # 合并特征