import math
import os
import sys
from datetime import datetime

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights

from models.wirenet.postprocess import postprocess_keypoint
from torchvision.utils import draw_bounding_boxes
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
from tools.coco_utils import get_coco_api_from_dataset
from tools.coco_eval import CocoEvaluator
import time

from models.config.config_tool import read_yaml
from models.ins.maskrcnn_dataset import MaskRCNNDataset
from models.keypoint.keypoint_dataset import KeypointDataset
from tools import utils, presets


def log_losses_to_tensorboard(writer, result, step):
    writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
    writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
    writer.add_scalar('Loss/keypoint', result['loss_keypoint'].item(), step)
    writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
    writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)


def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, scaler=None):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
    header = f"Epoch: [{epoch}]"

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1.0 / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
        )
    total_train_loss=0
    for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):

        global_step = epoch * len(data_loader) + batch_idx
        # print(f'images:{images}')
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            loss_dict = model(images, targets)
            # print(f'loss_dict:{loss_dict}')

            losses = sum(loss for loss in loss_dict.values())

            total_train_loss += losses.item()
            log_losses_to_tensorboard(writer, loss_dict, global_step)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            print(f"Loss is {loss_value}, stopping training")
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(losses).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            losses.backward()
            optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    return metric_logger, total_train_loss


cmap = plt.get_cmap("jet")
norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])


def c(x):
    return sm.to_rgba(x)


def show_line(img, pred, epoch, writer):
    im = img.permute(1, 2, 0)   # [512, 512, 3]
    writer.add_image("ori", im, epoch, dataformats="HWC")

    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred["boxes"],
                                      colors="yellow", width=1)

    # plt.imshow(boxed_image.permute(1, 2, 0).detach().cpu().numpy())
    # plt.show()

    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")

    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
    lines = pred["keypoints"].detach().cpu().numpy()
    scores = pred["keypoints_scores"].detach().cpu().numpy()
    # for i in range(1, len(lines)):
    #     if (lines[i] == lines[0]).all():
    #         lines = lines[:i]
    #         scores = scores[:i]
    #         break

    # postprocess lines to remove overlapped lines
    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
    nlines, nscores = postprocess_keypoint(lines[:, :, :2], scores, diag * 0.01, 0, False)
    # print(f'nscores:{nscores}')

    for i, t in enumerate([0.5]):
        plt.gca().set_axis_off()
        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
        plt.margins(0, 0)
        for (a, b), s in zip(nlines, nscores):
            if s < t:
                continue
            # plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
            # plt.scatter(a[1], a[0], **PLTOPTS)
            # plt.scatter(b[1], b[0], **PLTOPTS)
            plt.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=2, zorder=s)
            plt.scatter(a[0], a[1], **PLTOPTS)
            plt.scatter(b[0], b[1], **PLTOPTS)
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.imshow(im.cpu())
        plt.tight_layout()
        fig = plt.gcf()
        fig.canvas.draw()
        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
            fig.canvas.get_width_height()[::-1] + (3,))
        plt.close()
        img2 = transforms.ToTensor()(image_from_plot)

        writer.add_image("output", img2, epoch)



def _get_iou_types(model):
    model_without_ddp = model
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model_without_ddp = model.module
    iou_types = ["bbox"]
    if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
        iou_types.append("segm")
    if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
        iou_types.append("keypoints")
    return iou_types


def evaluate(model, data_loader, epoch, writer, device):
    n_threads = torch.get_num_threads()
    # FIXME remove this and make paste_masks_in_image run on the GPU
    torch.set_num_threads(1)
    cpu_device = torch.device("cpu")
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = "Test:"

    coco = get_coco_api_from_dataset(data_loader.dataset)
    iou_types = _get_iou_types(model)
    coco_evaluator = CocoEvaluator(coco, iou_types)

    print(f'start to evaluate!!!')
    for batch_idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, 10, header)):
        images = list(img.to(device) for img in images)

        model_time = time.time()
        outputs = model(images)
        # print(f'outputs:{outputs}')

        if batch_idx == 0:
            show_line(images[0], outputs[0], epoch, writer)

        # print(f'outputs:{outputs}')
        # print(f'outputs[0]:{outputs[0]}')


    #     outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
    #     model_time = time.time() - model_time
    #
    #     res = {target["image_id"]: output for target, output in zip(targets, outputs)}
    #     evaluator_time = time.time()
    #     coco_evaluator.update(res)
    #     evaluator_time = time.time() - evaluator_time
    #     metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
    #
    # # gather the stats from all processes
    # metric_logger.synchronize_between_processes()
    # print("Averaged stats:", metric_logger)
    # coco_evaluator.synchronize_between_processes()
    #
    # # accumulate predictions from all images
    # coco_evaluator.accumulate()
    # coco_evaluator.summarize()
    # torch.set_num_threads(n_threads)
    # return coco_evaluator


def train_cfg(model, cfg):
    parameters = read_yaml(cfg)
    print(f'train parameters:{parameters}')
    train(model, **parameters)


def train(model, **kwargs):
    # 默认参数
    default_params = {
        'dataset_path': '/path/to/dataset',
        'num_classes': 2,
        'num_keypoints': 2,
        'opt': 'adamw',
        'batch_size': 2,
        'epochs': 10,
        'lr': 0.005,
        'momentum': 0.9,
        'weight_decay': 1e-4,
        'lr_step_size': 3,
        'lr_gamma': 0.1,
        'num_workers': 4,
        'print_freq': 10,
        'target_type': 'polygon',
        'enable_logs': True,
        'augmentation': False,
        'checkpoint': None
    }
    # 更新默认参数
    for key, value in kwargs.items():
        if key in default_params:
            default_params[key] = value
        else:
            raise ValueError(f"Unknown argument: {key}")

    # 解析参数
    dataset_path = default_params['dataset_path']
    num_classes = default_params['num_classes']
    batch_size = default_params['batch_size']
    epochs = default_params['epochs']
    lr = default_params['lr']
    momentum = default_params['momentum']
    weight_decay = default_params['weight_decay']
    lr_step_size = default_params['lr_step_size']
    lr_gamma = default_params['lr_gamma']
    num_workers = default_params['num_workers']
    print_freq = default_params['print_freq']
    target_type = default_params['target_type']
    augmentation = default_params['augmentation']
    # 设置设备
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
    wts_path = os.path.join(train_result_ptath, 'weights')
    tb_path = os.path.join(train_result_ptath, 'logs')
    writer = SummaryWriter(tb_path)

    transforms = None
    # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
    if augmentation:
        transforms = get_transform(is_train=True)
        print(f'transforms:{transforms}')
    if not os.path.exists('train_results'):
        os.mkdir('train_results')

    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    dataset = KeypointDataset(dataset_path=dataset_path,
                              transforms=transforms, dataset_type='train', target_type=target_type)
    dataset_test = KeypointDataset(dataset_path=dataset_path, transforms=None,
                                   dataset_type='val')

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.RandomSampler(dataset_test)
    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
    train_collate_fn = utils.collate_fn
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
    )
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
    )


    img_results_path = os.path.join(train_result_ptath, 'img_results')
    if os.path.exists(train_result_ptath):
        pass
    #     os.remove(train_result_ptath)
    else:
        os.mkdir(train_result_ptath)

    if os.path.exists(train_result_ptath):
        os.mkdir(wts_path)
        os.mkdir(img_results_path)

    for epoch in range(epochs):
        metric_logger, total_train_loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, writer, None)
        losses = metric_logger.meters['loss'].global_avg
        print(f'epoch {epoch}:loss:{losses}')
        if os.path.exists(f'{wts_path}/last.pt'):
            os.remove(f'{wts_path}/last.pt')
        torch.save(model.state_dict(), f'{wts_path}/last.pt')
        # write_metric_logs(epoch, metric_logger, writer)
        if epoch == 0:
            best_loss = losses;
        if best_loss >= losses:
            best_loss = losses
            if os.path.exists(f'{wts_path}/best.pt'):
                os.remove(f'{wts_path}/best.pt')
            torch.save(model.state_dict(), f'{wts_path}/best.pt')

        evaluate(model, data_loader_test, epoch, writer, device=device)
        avg_train_loss = total_train_loss / len(data_loader)

        writer.add_scalar('Loss/train', avg_train_loss, epoch)


def get_transform(is_train, **kwargs):
    default_params = {
        'augmentation': 'multiscale',
        'backend': 'tensor',
        'use_v2': False,

    }
    # 更新默认参数
    for key, value in kwargs.items():
        if key in default_params:
            default_params[key] = value
        else:
            raise ValueError(f"Unknown argument: {key}")

    # 解析参数
    augmentation = default_params['augmentation']
    backend = default_params['backend']
    use_v2 = default_params['use_v2']
    if is_train:
        return presets.DetectionPresetTrain(
            data_augmentation=augmentation, backend=backend, use_v2=use_v2
        )
    # elif weights and test_only:
    #     weights = torchvision.models.get_weight(args.weights)
    #     trans = weights.transforms()
    #     return lambda img, target: (trans(img), target)
    else:
        return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)


def write_metric_logs(epoch, metric_logger, writer):
    writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
    writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
    # writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
    writer.add_scalar('Loss/box_reg', metric_logger.meters['loss_keypoint'].global_avg, epoch)
    writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
    writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)

# def log_losses_to_tensorboard(writer, result, step):
#     writer.add_scalar('Loss/classifier', result['loss_classifier'].item(), step)
#     writer.add_scalar('Loss/box_reg', result['loss_box_reg'].item(), step)
#     writer.add_scalar('Loss/box_reg', result['loss_keypoint'].item(), step)
#     writer.add_scalar('Loss/objectness', result['loss_objectness'].item(), step)
#     writer.add_scalar('Loss/rpn_box_reg', result['loss_rpn_box_reg'].item(), step)