import os
import time
from datetime import datetime

import torch
from torch.utils.tensorboard import SummaryWriter

from models.base.base_model import BaseModel
from models.base.base_trainer import BaseTrainer
from models.config.config_tool import read_yaml
from models.line_detect.dataset_LD import WirePointDataset
from models.line_detect.postprocess import box_line_, show_
from utils.log_util import show_line, save_last_model, save_best_model
from tools import utils


def _loss(losses):
    total_loss = 0
    for i in losses.keys():
        if i != "loss_wirepoint":
            total_loss += losses[i]
        else:
            loss_labels = losses[i]["losses"]
    loss_labels_k = list(loss_labels[0].keys())
    for j, name in enumerate(loss_labels_k):
        loss = loss_labels[0][name].mean()
        total_loss += loss

    return total_loss
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def move_to_device(data, device):
    if isinstance(data, (list, tuple)):
        return type(data)(move_to_device(item, device) for item in data)
    elif isinstance(data, dict):
        return {key: move_to_device(value, device) for key, value in data.items()}
    elif isinstance(data, torch.Tensor):
        return data.to(device)
    else:
        return data  # 对于非张量类型的数据不做任何改变

class Trainer(BaseTrainer):
    def __init__(self, model=None,
                 dataset=None,
                 device='cuda',
                 **kwargs):

        super().__init__(model,dataset,device,**kwargs)

    def move_to_device(self, data, device):
        if isinstance(data, (list, tuple)):
            return type(data)(self.move_to_device(item, device) for item in data)
        elif isinstance(data, dict):
            return {key: self.move_to_device(value, device) for key, value in data.items()}
        elif isinstance(data, torch.Tensor):
            return data.to(device)
        else:
            return data  # 对于非张量类型的数据不做任何改变

    def load_best_model(self,model, optimizer, save_path, device):
        if os.path.exists(save_path):
            checkpoint = torch.load(save_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            if optimizer is not None:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            epoch = checkpoint['epoch']
            loss = checkpoint['loss']
            print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
        else:
            print(f"No saved model found at {save_path}")
        return model, optimizer

    def writer_loss(self, writer, losses, epoch):
        try:
            for key, value in losses.items():
                if key == 'loss_wirepoint':
                    for subdict in losses['loss_wirepoint']['losses']:
                        for subkey, subvalue in subdict.items():
                            writer.add_scalar(f'loss/{subkey}',
                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue,
                                              epoch)
                elif isinstance(value, torch.Tensor):
                    writer.add_scalar(f'loss/{key}', value.item(), epoch)
        except Exception as e:
            print(f"TensorBoard logging error: {e}")

    def train_cfg(self, model:BaseModel, cfg):
        # cfg = r'./config/wireframe.yaml'
        cfg = read_yaml(cfg)
        print(f'cfg:{cfg}')
        # print(cfg['n_dyn_negl'])
        self.train(model, **cfg)

    def train(self, model, **kwargs):
        dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
        train_sampler = torch.utils.data.RandomSampler(dataset_train)
        # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=64, drop_last=True)
        train_collate_fn = utils.collate_fn_wirepoint
        data_loader_train = torch.utils.data.DataLoader(
            dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
        )

        dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
        val_sampler = torch.utils.data.RandomSampler(dataset_val)
        # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=64, drop_last=True)
        val_collate_fn = utils.collate_fn_wirepoint
        data_loader_val = torch.utils.data.DataLoader(
            dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
        )

        train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
        wts_path = os.path.join(train_result_ptath, 'weights')
        tb_path = os.path.join(train_result_ptath, 'logs')
        writer = SummaryWriter(tb_path)

        optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
        # writer = SummaryWriter(kwargs['io']['logdir'])
        model.to(device)



        # # 加载权重
        # save_path = 'logs/pth/best_model.pth'
        # model, optimizer = self.load_best_model(model, optimizer, save_path, device)

        # logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
        # os.makedirs(logdir_with_pth, exist_ok=True)  # 创建目录(如果不存在)
        last_model_path = os.path.join(wts_path, 'last.pth')
        best_model_path = os.path.join(wts_path, 'best.pth')
        global_step = 0

        for epoch in range(kwargs['optim']['max_epoch']):
            print(f"epoch:{epoch}")
            total_train_loss = 0.0

            model.train()

            for imgs, targets in data_loader_train:
                imgs = move_to_device(imgs, device)
                targets=move_to_device(targets,device)
                # print(f'imgs:{len(imgs)}')
                # print(f'targets:{len(targets)}')
                losses = model(imgs, targets)
                loss = _loss(losses)
                total_train_loss += loss.item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                self.writer_loss(writer, losses, global_step)
                global_step+=1


            avg_train_loss = total_train_loss / len(data_loader_train)
            if epoch == 0:
                best_loss = avg_train_loss;

            writer.add_scalar('loss/train', avg_train_loss, epoch)


            if os.path.exists(f'{wts_path}/last.pt'):
                os.remove(f'{wts_path}/last.pt')
            # torch.save(model.state_dict(), f'{wts_path}/last.pt')
            save_last_model(model,last_model_path,epoch,optimizer)
            best_loss = save_best_model(model,best_model_path,epoch,avg_train_loss,best_loss,optimizer)

            model.eval()
            with torch.no_grad():
                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
                    t_start = time.time()
                    print(f'start to predict:{t_start}')
                    pred = model(self.move_to_device(imgs, self.device))
                    t_end = time.time()
                    print(f'predict used:{t_end - t_start}')
                    if batch_idx == 0:
                        show_line(imgs[0], pred, epoch, writer)
                    break