# 2025/2/9
import os

import numpy as np
import torch

from models.config.config_tool import read_yaml
from models.line_detect.dataset_LD import WirePointDataset
from tools import utils

from torch.utils.tensorboard import SummaryWriter
import matplotlib as mpl

from models.line_detect.line_net import linenet_resnet50_fpn
from torchvision.utils import draw_bounding_boxes
from models.wirenet.postprocess import postprocess
from torchvision import transforms

from PIL import Image

from models.line_detect.postprocess import box_line_, show_
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def _loss(losses):
    total_loss = 0
    for i in losses.keys():
        if i != "loss_wirepoint":
            total_loss += losses[i]
        else:
            loss_labels = losses[i]["losses"]
    loss_labels_k = list(loss_labels[0].keys())
    for j, name in enumerate(loss_labels_k):
        loss = loss_labels[0][name].mean()
        total_loss += loss

    return total_loss


cmap = plt.get_cmap("jet")
norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])


def c(x):
    return sm.to_rgba(x)


def imshow(im):
    plt.close()
    plt.tight_layout()
    plt.imshow(im)
    plt.colorbar(sm, fraction=0.046)
    plt.xlim([0, im.shape[0]])
    plt.ylim([im.shape[0], 0])


def show_line(img, pred, epoch, writer):
    im = img.permute(1, 2, 0)
    writer.add_image("ori", im, epoch, dataformats="HWC")

    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
                                      colors="yellow", width=1)
    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")

    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
    # print(f'pred[1]:{pred[1]}')
    H = pred[-1]['wires']
    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
    scores = H["score"][0].cpu().numpy()
    for i in range(1, len(lines)):
        if (lines[i] == lines[0]).all():
            lines = lines[:i]
            scores = scores[:i]
            break

    # postprocess lines to remove overlapped lines
    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)

    for i, t in enumerate([0.85]):
        plt.gca().set_axis_off()
        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
        plt.margins(0, 0)
        for (a, b), s in zip(nlines, nscores):
            if s < t:
                continue
            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
            plt.scatter(a[1], a[0], **PLTOPTS)
            plt.scatter(b[1], b[0], **PLTOPTS)
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.imshow(im)
        plt.tight_layout()
        fig = plt.gcf()
        fig.canvas.draw()
        width, height = fig.get_size_inches() * fig.get_dpi()  # 获取图像尺寸
        tmp_img = fig.canvas.tostring_argb()
        tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8)
        tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4)

        img_rgb = tmp_img_np[:, :, 1:]  # 提取RGB部分,忽略Alpha通道

        # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
        #     fig.canvas.get_width_height()[::-1] + (3,))
        # plt.close()

        img2 = transforms.ToTensor()(img_rgb)

        writer.add_image("z-output", img2, epoch)


def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    if current_loss < best_loss:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'loss': current_loss
        }

        if optimizer is not None:
            checkpoint['optimizer_state_dict'] = optimizer.state_dict()

        torch.save(checkpoint, save_path)
        print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")

        return current_loss

    return best_loss


def save_latest_model(model, save_path, epoch, optimizer=None):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
    }

    if optimizer is not None:
        checkpoint['optimizer_state_dict'] = optimizer.state_dict()

    torch.save(checkpoint, save_path)


def load_best_model(model, optimizer, save_path, device):
    if os.path.exists(save_path):
        checkpoint = torch.load(save_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
    else:
        print(f"No saved model found at {save_path}")
    return model, optimizer


def predict(self, img, show_boxes=True, show_keypoint=True, show_line=True, save=False, save_path=None):
    self.load_weight('weights/best.pt')
    self.__model.eval()

    if isinstance(img, str):
        img = Image.open(img).convert("RGB")

    # 预处理图像
    img_tensor = self.transforms(img)
    with torch.no_grad():
        predictions = self.__model([img_tensor])

    # 后处理预测结果
    boxes = predictions[0]['boxes'].cpu().numpy()
    keypoints = predictions[0]['keypoints'].cpu().numpy()

    # 可视化预测结果
    if show_boxes or show_keypoint or show_line or save:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.imshow(np.array(img))

        if show_boxes:
            for box in boxes:
                x0, y0, x1, y1 = box
                ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1))

        for (a, b) in keypoints:
            if show_keypoint:
                ax.scatter(a[0], a[1], c='c', s=2)
                ax.scatter(b[0], b[1], c='c', s=2)
            if show_line:
                ax.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=1)

        if show_boxes or show_keypoint or show_line:
            plt.show()

        if save:
            fig.savefig(save_path)
            print(f"Prediction saved to {save_path}")
        plt.close(fig)


if __name__ == '__main__':
    cfg = r'./config/wireframe.yaml'
    cfg = read_yaml(cfg)
    print(f'cfg:{cfg}')
    print(cfg['model']['n_dyn_negl'])
    # net = WirepointPredictor()

    dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
    train_sampler = torch.utils.data.RandomSampler(dataset_train)
    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
    train_collate_fn = utils.collate_fn_wirepoint
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
    )

    dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
    val_sampler = torch.utils.data.RandomSampler(dataset_val)
    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
    val_collate_fn = utils.collate_fn_wirepoint
    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
    )

    model = linenet_resnet50_fpn().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
    writer = SummaryWriter(cfg['io']['logdir'])

    # 加载权重
    save_path = 'logs/pth/best_model.pth'
    model, optimizer = load_best_model(model, optimizer, save_path, device)

    logdir_with_pth = os.path.join(cfg['io']['logdir'], 'pth')
    os.makedirs(logdir_with_pth, exist_ok=True)  # 创建目录(如果不存在)
    latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth')
    best_model_path = os.path.join(logdir_with_pth, 'best_model.pth')
    global_step = 0


    def move_to_device(data, device):
        if isinstance(data, (list, tuple)):
            return type(data)(move_to_device(item, device) for item in data)
        elif isinstance(data, dict):
            return {key: move_to_device(value, device) for key, value in data.items()}
        elif isinstance(data, torch.Tensor):
            return data.to(device)
        else:
            return data  # 对于非张量类型的数据不做任何改变


    def writer_loss(writer, losses, epoch):
        try:
            for key, value in losses.items():
                if key == 'loss_wirepoint':
                    for subdict in losses['loss_wirepoint']['losses']:
                        for subkey, subvalue in subdict.items():
                            writer.add_scalar(f'loss/{subkey}',
                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue,
                                              epoch)
                elif isinstance(value, torch.Tensor):
                    writer.add_scalar(f'loss/{key}', value.item(), epoch)
        except Exception as e:
            print(f"TensorBoard logging error: {e}")


    for epoch in range(cfg['optim']['max_epoch']):
        print(f"epoch:{epoch}")
        model.train()
        total_train_loss = 0.0

        for imgs, targets in data_loader_train:
            losses = model(move_to_device(imgs, device), move_to_device(targets, device))
            # print(losses)
            loss = _loss(losses)
            total_train_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            writer_loss(writer, losses, epoch)

        model.eval()
        with torch.no_grad():
            for batch_idx, (imgs, targets) in enumerate(data_loader_val):
                pred = model(move_to_device(imgs, device))

                # pred_ = box_line_(pred)  # 将box与line对应
                # show_(imgs, pred_, epoch, writer)

                if batch_idx == 0:
                    show_line(imgs[0], pred, epoch, writer)

                break

        avg_train_loss = total_train_loss / len(data_loader_train)
        writer.add_scalar('loss/train', avg_train_loss, epoch)
        best_loss = 10000
        save_latest_model(
            model,
            latest_model_path,
            epoch,
            optimizer
        )
        best_loss = save_best_model(
            model,
            best_model_path,
            epoch,
            avg_train_loss,
            best_loss,
            optimizer
        )