def map_heatmap_keypoints_to_original_image(heatmap, rois, downsample_ratio=4, joff=None): """ 将热力图中的关键点映射回原始图像的位置。 参数: heatmap (torch.Tensor): 热力图,形状为 [H, W] rois (list of tuples): 每个ROI的坐标列表 [(x_min, y_min, x_max, y_max), ...] downsample_ratio (int): 下采样比例,默认为4 joff (torch.Tensor, optional): 偏移图,形状为 [2, H, W] 返回: list of tuples: 每个ROI对应的关键点在原始图像中的坐标 [(x, y), ...] """ keypoints_in_original_image = [] for i, (x_min, y_min, x_max, y_max) in enumerate(rois): roi_width = x_max - x_min roi_height = y_max - y_min # 获取热力图中的关键点位置 heatmap_roi = heatmap[i] if len(heatmap.shape) == 4 else heatmap y_prime, x_prime = torch.where(heatmap_roi == torch.max(heatmap_roi)) if len(y_prime) > 0 and len(x_prime) > 0: y_prime, x_prime = y_prime.item(), x_prime.item() # 如果有偏移图,则应用偏移修正 if joff is not None: offset_x = joff[0, y_prime, x_prime].item() offset_y = joff[1, y_prime, x_prime].item() x_prime += offset_x y_prime += offset_y # 计算ROI内的相对坐标 relative_x = x_prime / 128 * roi_width relative_y = y_prime / 128 * roi_height # 映射回原始图像坐标 final_x = relative_x + x_min final_y = relative_y + y_min keypoints_in_original_image.append((final_x.item(), final_y.item())) else: keypoints_in_original_image.append(None) # 如果没有找到关键点 return keypoints_in_original_image