TestPointMap.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. def map_heatmap_keypoints_to_original_image(heatmap, rois, downsample_ratio=4, joff=None):
  2. """
  3. 将热力图中的关键点映射回原始图像的位置。
  4. 参数:
  5. heatmap (torch.Tensor): 热力图,形状为 [H, W]
  6. rois (list of tuples): 每个ROI的坐标列表 [(x_min, y_min, x_max, y_max), ...]
  7. downsample_ratio (int): 下采样比例,默认为4
  8. joff (torch.Tensor, optional): 偏移图,形状为 [2, H, W]
  9. 返回:
  10. list of tuples: 每个ROI对应的关键点在原始图像中的坐标 [(x, y), ...]
  11. """
  12. keypoints_in_original_image = []
  13. for i, (x_min, y_min, x_max, y_max) in enumerate(rois):
  14. roi_width = x_max - x_min
  15. roi_height = y_max - y_min
  16. # 获取热力图中的关键点位置
  17. heatmap_roi = heatmap[i] if len(heatmap.shape) == 4 else heatmap
  18. y_prime, x_prime = torch.where(heatmap_roi == torch.max(heatmap_roi))
  19. if len(y_prime) > 0 and len(x_prime) > 0:
  20. y_prime, x_prime = y_prime.item(), x_prime.item()
  21. # 如果有偏移图,则应用偏移修正
  22. if joff is not None:
  23. offset_x = joff[0, y_prime, x_prime].item()
  24. offset_y = joff[1, y_prime, x_prime].item()
  25. x_prime += offset_x
  26. y_prime += offset_y
  27. # 计算ROI内的相对坐标
  28. relative_x = x_prime / 128 * roi_width
  29. relative_y = y_prime / 128 * roi_height
  30. # 映射回原始图像坐标
  31. final_x = relative_x + x_min
  32. final_y = relative_y + y_min
  33. keypoints_in_original_image.append((final_x.item(), final_y.item()))
  34. else:
  35. keypoints_in_original_image.append(None) # 如果没有找到关键点
  36. return keypoints_in_original_image