12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- def map_heatmap_keypoints_to_original_image(heatmap, rois, downsample_ratio=4, joff=None):
- """
- 将热力图中的关键点映射回原始图像的位置。
- 参数:
- heatmap (torch.Tensor): 热力图,形状为 [H, W]
- rois (list of tuples): 每个ROI的坐标列表 [(x_min, y_min, x_max, y_max), ...]
- downsample_ratio (int): 下采样比例,默认为4
- joff (torch.Tensor, optional): 偏移图,形状为 [2, H, W]
- 返回:
- list of tuples: 每个ROI对应的关键点在原始图像中的坐标 [(x, y), ...]
- """
- keypoints_in_original_image = []
- for i, (x_min, y_min, x_max, y_max) in enumerate(rois):
- roi_width = x_max - x_min
- roi_height = y_max - y_min
- # 获取热力图中的关键点位置
- heatmap_roi = heatmap[i] if len(heatmap.shape) == 4 else heatmap
- y_prime, x_prime = torch.where(heatmap_roi == torch.max(heatmap_roi))
- if len(y_prime) > 0 and len(x_prime) > 0:
- y_prime, x_prime = y_prime.item(), x_prime.item()
- # 如果有偏移图,则应用偏移修正
- if joff is not None:
- offset_x = joff[0, y_prime, x_prime].item()
- offset_y = joff[1, y_prime, x_prime].item()
- x_prime += offset_x
- y_prime += offset_y
- # 计算ROI内的相对坐标
- relative_x = x_prime / 128 * roi_width
- relative_y = y_prime / 128 * roi_height
- # 映射回原始图像坐标
- final_x = relative_x + x_min
- final_y = relative_y + y_min
- keypoints_in_original_image.append((final_x.item(), final_y.item()))
- else:
- keypoints_in_original_image.append(None) # 如果没有找到关键点
- return keypoints_in_original_image
|