import numpy as np
import cv2


def draw_line_heatmap(image_shape, pt1, pt2, sigma=1):
    """
    根据给定的两个端点生成线段的热度图。

    参数:
    - image_shape: (height, width) 输出热度图的形状
    - pt1: (x1, y1) 线段的第一个端点
    - pt2: (x2, y2) 线段的第二个端点
    - sigma: 高斯核的标准差,用于控制热度扩散的程度

    返回:
    - heatmap: 生成的热度图
    """
    # 创建空白热度图
    heatmap = np.zeros(image_shape, dtype=np.float32)

    # 绘制线段
    cv2.line(heatmap, pt1, pt2, color=1, thickness=1)

    # 应用高斯模糊以生成热度效果
    if sigma > 0:
        heatmap = cv2.GaussianBlur(heatmap, (0, 0), sigmaX=sigma, sigmaY=sigma)

    # 归一化热度图
    heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6)

    return heatmap


# 测试函数
if __name__ == "__main__":
    # 定义图像尺寸和线段端点
    image_shape = (256, 256)  # 图像的高度和宽度
    pt1 = (50, 50)  # 第一个端点
    pt2 = (200, 200)  # 第二个端点
    sigma = 2  # 控制热度扩散程度

    # 生成热度图
    heatmap = draw_line_heatmap(image_shape, pt1, pt2, sigma)

    # 显示结果
    import matplotlib.pyplot as plt

    plt.imshow(heatmap, cmap='hot', interpolation='nearest')
    plt.colorbar()
    plt.show()