import math
import warnings
from fractions import Fraction
from typing import Dict, List, Optional, Tuple, Union

import torch

from ..extension import _load_library


try:
    _load_library("video_reader")
    _HAS_VIDEO_OPT = True
except (ImportError, OSError):
    _HAS_VIDEO_OPT = False

default_timebase = Fraction(0, 1)


# simple class for torch scripting
# the complex Fraction class from fractions module is not scriptable
class Timebase:
    __annotations__ = {"numerator": int, "denominator": int}
    __slots__ = ["numerator", "denominator"]

    def __init__(
        self,
        numerator: int,
        denominator: int,
    ) -> None:
        self.numerator = numerator
        self.denominator = denominator


class VideoMetaData:
    __annotations__ = {
        "has_video": bool,
        "video_timebase": Timebase,
        "video_duration": float,
        "video_fps": float,
        "has_audio": bool,
        "audio_timebase": Timebase,
        "audio_duration": float,
        "audio_sample_rate": float,
    }
    __slots__ = [
        "has_video",
        "video_timebase",
        "video_duration",
        "video_fps",
        "has_audio",
        "audio_timebase",
        "audio_duration",
        "audio_sample_rate",
    ]

    def __init__(self) -> None:
        self.has_video = False
        self.video_timebase = Timebase(0, 1)
        self.video_duration = 0.0
        self.video_fps = 0.0
        self.has_audio = False
        self.audio_timebase = Timebase(0, 1)
        self.audio_duration = 0.0
        self.audio_sample_rate = 0.0


def _validate_pts(pts_range: Tuple[int, int]) -> None:

    if pts_range[0] > pts_range[1] > 0:
        raise ValueError(
            f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}"
        )


def _fill_info(
    vtimebase: torch.Tensor,
    vfps: torch.Tensor,
    vduration: torch.Tensor,
    atimebase: torch.Tensor,
    asample_rate: torch.Tensor,
    aduration: torch.Tensor,
) -> VideoMetaData:
    """
    Build update VideoMetaData struct with info about the video
    """
    meta = VideoMetaData()
    if vtimebase.numel() > 0:
        meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item()))
        timebase = vtimebase[0].item() / float(vtimebase[1].item())
        if vduration.numel() > 0:
            meta.has_video = True
            meta.video_duration = float(vduration.item()) * timebase
    if vfps.numel() > 0:
        meta.video_fps = float(vfps.item())
    if atimebase.numel() > 0:
        meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item()))
        timebase = atimebase[0].item() / float(atimebase[1].item())
        if aduration.numel() > 0:
            meta.has_audio = True
            meta.audio_duration = float(aduration.item()) * timebase
    if asample_rate.numel() > 0:
        meta.audio_sample_rate = float(asample_rate.item())

    return meta


def _align_audio_frames(
    aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int]
) -> torch.Tensor:
    start, end = aframe_pts[0], aframe_pts[-1]
    num_samples = aframes.size(0)
    step_per_aframe = float(end - start + 1) / float(num_samples)
    s_idx = 0
    e_idx = num_samples
    if start < audio_pts_range[0]:
        s_idx = int((audio_pts_range[0] - start) / step_per_aframe)
    if audio_pts_range[1] != -1 and end > audio_pts_range[1]:
        e_idx = int((audio_pts_range[1] - end) / step_per_aframe)
    return aframes[s_idx:e_idx, :]


def _read_video_from_file(
    filename: str,
    seek_frame_margin: float = 0.25,
    read_video_stream: bool = True,
    video_width: int = 0,
    video_height: int = 0,
    video_min_dimension: int = 0,
    video_max_dimension: int = 0,
    video_pts_range: Tuple[int, int] = (0, -1),
    video_timebase: Fraction = default_timebase,
    read_audio_stream: bool = True,
    audio_samples: int = 0,
    audio_channels: int = 0,
    audio_pts_range: Tuple[int, int] = (0, -1),
    audio_timebase: Fraction = default_timebase,
) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
    """
    Reads a video from a file, returning both the video frames and the audio frames

    Args:
    filename (str): path to the video file
    seek_frame_margin (double, optional): seeking frame in the stream is imprecise. Thus,
        when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
    read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
    video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
        the size of decoded frames:

            - When video_width = 0, video_height = 0, video_min_dimension = 0,
                and video_max_dimension = 0, keep the original frame resolution
            - When video_width = 0, video_height = 0, video_min_dimension != 0,
                and video_max_dimension = 0, keep the aspect ratio and resize the
                frame so that shorter edge size is video_min_dimension
            - When video_width = 0, video_height = 0, video_min_dimension = 0,
                and video_max_dimension != 0, keep the aspect ratio and resize
                the frame so that longer edge size is video_max_dimension
            - When video_width = 0, video_height = 0, video_min_dimension != 0,
                and video_max_dimension != 0, resize the frame so that shorter
                edge size is video_min_dimension, and longer edge size is
                video_max_dimension. The aspect ratio may not be preserved
            - When video_width = 0, video_height != 0, video_min_dimension = 0,
                and video_max_dimension = 0, keep the aspect ratio and resize
                the frame so that frame video_height is $video_height
            - When video_width != 0, video_height == 0, video_min_dimension = 0,
                and video_max_dimension = 0, keep the aspect ratio and resize
                the frame so that frame video_width is $video_width
            - When video_width != 0, video_height != 0, video_min_dimension = 0,
                and video_max_dimension = 0, resize the frame so that frame
                video_width and  video_height are set to $video_width and
                $video_height, respectively
    video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
    video_timebase (Fraction, optional): a Fraction rational number which denotes timebase in video stream
    read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
    audio_samples (int, optional): audio sampling rate
    audio_channels (int optional): audio channels
    audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
    audio_timebase (Fraction, optional): a Fraction rational number which denotes time base in audio stream

    Returns
        vframes (Tensor[T, H, W, C]): the `T` video frames
        aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
            `K` is the number of audio_channels
        info (Dict): metadata for the video and audio. Can contain the fields video_fps (float)
            and audio_fps (int)
    """
    _validate_pts(video_pts_range)
    _validate_pts(audio_pts_range)

    result = torch.ops.video_reader.read_video_from_file(
        filename,
        seek_frame_margin,
        0,  # getPtsOnly
        read_video_stream,
        video_width,
        video_height,
        video_min_dimension,
        video_max_dimension,
        video_pts_range[0],
        video_pts_range[1],
        video_timebase.numerator,
        video_timebase.denominator,
        read_audio_stream,
        audio_samples,
        audio_channels,
        audio_pts_range[0],
        audio_pts_range[1],
        audio_timebase.numerator,
        audio_timebase.denominator,
    )
    vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
    if aframes.numel() > 0:
        # when audio stream is found
        aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
    return vframes, aframes, info


def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]:
    """
    Decode all video- and audio frames in the video. Only pts
    (presentation timestamp) is returned. The actual frame pixel data is not
    copied. Thus, it is much faster than read_video(...)
    """
    result = torch.ops.video_reader.read_video_from_file(
        filename,
        0,  # seek_frame_margin
        1,  # getPtsOnly
        1,  # read_video_stream
        0,  # video_width
        0,  # video_height
        0,  # video_min_dimension
        0,  # video_max_dimension
        0,  # video_start_pts
        -1,  # video_end_pts
        0,  # video_timebase_num
        1,  # video_timebase_den
        1,  # read_audio_stream
        0,  # audio_samples
        0,  # audio_channels
        0,  # audio_start_pts
        -1,  # audio_end_pts
        0,  # audio_timebase_num
        1,  # audio_timebase_den
    )
    _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)

    vframe_pts = vframe_pts.numpy().tolist()
    aframe_pts = aframe_pts.numpy().tolist()
    return vframe_pts, aframe_pts, info


def _probe_video_from_file(filename: str) -> VideoMetaData:
    """
    Probe a video file and return VideoMetaData with info about the video
    """
    result = torch.ops.video_reader.probe_video_from_file(filename)
    vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
    return info


def _read_video_from_memory(
    video_data: torch.Tensor,
    seek_frame_margin: float = 0.25,
    read_video_stream: int = 1,
    video_width: int = 0,
    video_height: int = 0,
    video_min_dimension: int = 0,
    video_max_dimension: int = 0,
    video_pts_range: Tuple[int, int] = (0, -1),
    video_timebase_numerator: int = 0,
    video_timebase_denominator: int = 1,
    read_audio_stream: int = 1,
    audio_samples: int = 0,
    audio_channels: int = 0,
    audio_pts_range: Tuple[int, int] = (0, -1),
    audio_timebase_numerator: int = 0,
    audio_timebase_denominator: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Reads a video from memory, returning both the video frames as the audio frames
    This function is torchscriptable.

    Args:
    video_data (data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes):
        compressed video content stored in either 1) torch.Tensor 2) python bytes
    seek_frame_margin (double, optional): seeking frame in the stream is imprecise.
        Thus, when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
    read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
    video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
        the size of decoded frames:

            - When video_width = 0, video_height = 0, video_min_dimension = 0,
                and video_max_dimension = 0, keep the original frame resolution
            - When video_width = 0, video_height = 0, video_min_dimension != 0,
                and video_max_dimension = 0, keep the aspect ratio and resize the
                frame so that shorter edge size is video_min_dimension
            - When video_width = 0, video_height = 0, video_min_dimension = 0,
                and video_max_dimension != 0, keep the aspect ratio and resize
                the frame so that longer edge size is video_max_dimension
            - When video_width = 0, video_height = 0, video_min_dimension != 0,
                and video_max_dimension != 0, resize the frame so that shorter
                edge size is video_min_dimension, and longer edge size is
                video_max_dimension. The aspect ratio may not be preserved
            - When video_width = 0, video_height != 0, video_min_dimension = 0,
                and video_max_dimension = 0, keep the aspect ratio and resize
                the frame so that frame video_height is $video_height
            - When video_width != 0, video_height == 0, video_min_dimension = 0,
                and video_max_dimension = 0, keep the aspect ratio and resize
                the frame so that frame video_width is $video_width
            - When video_width != 0, video_height != 0, video_min_dimension = 0,
                and video_max_dimension = 0, resize the frame so that frame
                video_width and  video_height are set to $video_width and
                $video_height, respectively
    video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
    video_timebase_numerator / video_timebase_denominator (float, optional): a rational
        number which denotes timebase in video stream
    read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
    audio_samples (int, optional): audio sampling rate
    audio_channels (int optional): audio audio_channels
    audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
    audio_timebase_numerator / audio_timebase_denominator (float, optional):
        a rational number which denotes time base in audio stream

    Returns:
        vframes (Tensor[T, H, W, C]): the `T` video frames
        aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
            `K` is the number of channels
    """

    _validate_pts(video_pts_range)
    _validate_pts(audio_pts_range)

    if not isinstance(video_data, torch.Tensor):
        with warnings.catch_warnings():
            # Ignore the warning because we actually don't modify the buffer in this function
            warnings.filterwarnings("ignore", message="The given buffer is not writable")
            video_data = torch.frombuffer(video_data, dtype=torch.uint8)

    result = torch.ops.video_reader.read_video_from_memory(
        video_data,
        seek_frame_margin,
        0,  # getPtsOnly
        read_video_stream,
        video_width,
        video_height,
        video_min_dimension,
        video_max_dimension,
        video_pts_range[0],
        video_pts_range[1],
        video_timebase_numerator,
        video_timebase_denominator,
        read_audio_stream,
        audio_samples,
        audio_channels,
        audio_pts_range[0],
        audio_pts_range[1],
        audio_timebase_numerator,
        audio_timebase_denominator,
    )

    vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result

    if aframes.numel() > 0:
        # when audio stream is found
        aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)

    return vframes, aframes


def _read_video_timestamps_from_memory(
    video_data: torch.Tensor,
) -> Tuple[List[int], List[int], VideoMetaData]:
    """
    Decode all frames in the video. Only pts (presentation timestamp) is returned.
    The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
    is much faster than read_video(...)
    """
    if not isinstance(video_data, torch.Tensor):
        with warnings.catch_warnings():
            # Ignore the warning because we actually don't modify the buffer in this function
            warnings.filterwarnings("ignore", message="The given buffer is not writable")
            video_data = torch.frombuffer(video_data, dtype=torch.uint8)
    result = torch.ops.video_reader.read_video_from_memory(
        video_data,
        0,  # seek_frame_margin
        1,  # getPtsOnly
        1,  # read_video_stream
        0,  # video_width
        0,  # video_height
        0,  # video_min_dimension
        0,  # video_max_dimension
        0,  # video_start_pts
        -1,  # video_end_pts
        0,  # video_timebase_num
        1,  # video_timebase_den
        1,  # read_audio_stream
        0,  # audio_samples
        0,  # audio_channels
        0,  # audio_start_pts
        -1,  # audio_end_pts
        0,  # audio_timebase_num
        1,  # audio_timebase_den
    )
    _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)

    vframe_pts = vframe_pts.numpy().tolist()
    aframe_pts = aframe_pts.numpy().tolist()
    return vframe_pts, aframe_pts, info


def _probe_video_from_memory(
    video_data: torch.Tensor,
) -> VideoMetaData:
    """
    Probe a video in memory and return VideoMetaData with info about the video
    This function is torchscriptable
    """
    if not isinstance(video_data, torch.Tensor):
        with warnings.catch_warnings():
            # Ignore the warning because we actually don't modify the buffer in this function
            warnings.filterwarnings("ignore", message="The given buffer is not writable")
            video_data = torch.frombuffer(video_data, dtype=torch.uint8)
    result = torch.ops.video_reader.probe_video_from_memory(video_data)
    vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
    return info


def _read_video(
    filename: str,
    start_pts: Union[float, Fraction] = 0,
    end_pts: Optional[Union[float, Fraction]] = None,
    pts_unit: str = "pts",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
    if end_pts is None:
        end_pts = float("inf")

    if pts_unit == "pts":
        warnings.warn(
            "The pts_unit 'pts' gives wrong results and will be removed in a "
            + "follow-up version. Please use pts_unit 'sec'."
        )

    info = _probe_video_from_file(filename)

    has_video = info.has_video
    has_audio = info.has_audio

    def get_pts(time_base):
        start_offset = start_pts
        end_offset = end_pts
        if pts_unit == "sec":
            start_offset = int(math.floor(start_pts * (1 / time_base)))
            if end_offset != float("inf"):
                end_offset = int(math.ceil(end_pts * (1 / time_base)))
        if end_offset == float("inf"):
            end_offset = -1
        return start_offset, end_offset

    video_pts_range = (0, -1)
    video_timebase = default_timebase
    if has_video:
        video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
        video_pts_range = get_pts(video_timebase)

    audio_pts_range = (0, -1)
    audio_timebase = default_timebase
    if has_audio:
        audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
        audio_pts_range = get_pts(audio_timebase)

    vframes, aframes, info = _read_video_from_file(
        filename,
        read_video_stream=True,
        video_pts_range=video_pts_range,
        video_timebase=video_timebase,
        read_audio_stream=True,
        audio_pts_range=audio_pts_range,
        audio_timebase=audio_timebase,
    )
    _info = {}
    if has_video:
        _info["video_fps"] = info.video_fps
    if has_audio:
        _info["audio_fps"] = info.audio_sample_rate

    return vframes, aframes, _info


def _read_video_timestamps(
    filename: str, pts_unit: str = "pts"
) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]:
    if pts_unit == "pts":
        warnings.warn(
            "The pts_unit 'pts' gives wrong results and will be removed in a "
            + "follow-up version. Please use pts_unit 'sec'."
        )

    pts: Union[List[int], List[Fraction]]
    pts, _, info = _read_video_timestamps_from_file(filename)

    if pts_unit == "sec":
        video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
        pts = [x * video_time_base for x in pts]

    video_fps = info.video_fps if info.has_video else None

    return pts, video_fps