import time

import skimage

from models.line_detect.postprocess import show_predict, show_box, show_box_or_line, show_box_and_line, \
    show_line_optimized, show_line
import os

import torch
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from models.line_detect.line_net import linenet_resnet50_fpn
from torchvision import transforms

# from models.wirenet.postprocess import postprocess
from models.wirenet.postprocess import postprocess
from rtree import index

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def load_best_model(model, save_path, device):
    if os.path.exists(save_path):
        checkpoint = torch.load(save_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        # if optimizer is not None:
        #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
    else:
        print(f"No saved model found at {save_path}")
    return model


def box_line_(imgs, pred):
    im = imgs.permute(1, 2, 0).cpu().numpy()
    line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
    line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]

    # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
    # line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
    for idx, box_ in enumerate(pred[0:-1]):
        box = box_['boxes']  # 是一个tensor
        line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
        score = pred[-1]['wires']['score'][idx]
        #
        # diag = (512 ** 2 + 512 ** 2) ** 0.5
        # lines, scores = postprocess(line, score, diag * 0.01, 0, False)

        line_ = []
        score_ = []

        for i in box:
            score_max = 0.0
            tmp = [[0.0, 0.0], [0.0, 0.0]]

            for j in range(len(line)):
                if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
                        line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
                        line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
                        line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):

                    if score[j] > score_max:
                        tmp = line[j]
                        score_max = score[j]
            line_.append(tmp)
            score_.append(score_max)
        processed_list = torch.tensor(line_)
        pred[idx]['line'] = processed_list

        processed_s_list = torch.tensor(score_)
        pred[idx]['line_score'] = processed_s_list
    return pred


def box_line_optimized(pred):
    # 创建R-tree索引
    idx = index.Index()

    # 将所有线段添加到R-tree中
    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]

    # 提取并处理所有线段
    for idx_line in range(lines.shape[1]):  # 遍历2500条线段
        line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512  # 转换为numpy数组并调整比例
        x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
        y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
        x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
        y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
        idx.insert(idx_line, (max(0, x_min - 256), max(0, y_min - 256), min(512, x_max + 256), min(512, y_max + 256)))

    for idx_box, box_ in enumerate(pred[0:-1]):
        box = box_['boxes'].cpu().numpy()  # 确保将张量转换为numpy数组
        line_ = []
        score_ = []

        for i in box:
            score_max = 0.0
            tmp = [[0.0, 0.0], [0.0, 0.0]]

            # 获取与当前box可能相交的所有线段
            possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))

            for j in possible_matches:
                line_j = lines[0, j].cpu().numpy() / 128 * 512
                if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and  # 注意这里交换了x和y
                        line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
                        line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
                        line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):

                    if scores[j] > score_max:
                        tmp = line_j
                        score_max = scores[j]

            line_.append(tmp)
            score_.append(score_max)

        processed_list = torch.tensor(line_)
        pred[idx_box]['line'] = processed_list

        processed_s_list = torch.tensor(score_)
        pred[idx_box]['line_score'] = processed_s_list

    return pred


def predict(pt_path, model, img):
    model = load_best_model(model, pt_path, device)

    model.eval()

    if isinstance(img, str):
        img = Image.open(img).convert("RGB")

    transform = transforms.ToTensor()
    img_tensor = transform(img)  # [3, 512, 512]

    # img_ = img_tensor

    # 将图像调整为512x512大小
    t_start = time.time()
    im = img_tensor.permute(1, 2, 0)  # [512, 512, 3]
    im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
    img_ = torch.tensor(im_resized).permute(2, 0, 1)
    t_end = time.time()
    print(f'switch img used:{t_end - t_start}')

    with torch.no_grad():
        predictions = model([img_.to(device)])
        # print(predictions)

    # lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
    # scores = predictions[-1]['wires']['score'][0].cpu().numpy() / 128 * 512
    # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
    # nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
    # print(len(nlines))

    # arr = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
    # unique_subarrays = set()
    #
    # for i in range(arr.shape[0]):
    #     for j in range(arr.shape[1]):
    #         subarray = arr[i, j]
    #         # 确保 subarray 是一个二维数组
    #         if subarray.shape != (2,):
    #             raise ValueError(f"Unexpected shape of subarray at index [{i}, {j}]: {subarray.shape}, expected (2,)")
    #
    #         subarray_tuple = tuple(subarray.tolist())
    #         unique_subarrays.add(subarray_tuple)
    #
    # # 计算唯一子数组的数量
    # num_unique_subarrays = len(unique_subarrays)
    # print(f"共有 {num_unique_subarrays} 个不同的 [2, 2] 子数组")

    # show_line_optimized(img_, predictions, t_start)   # 只画线
    show_line(img_, predictions, t_start)
    # show_box(img_, predictions, t_start)   # 只画kuang
    # show_box_or_line(img_, predictions, show_line=True, show_box=True)   # 参数确定画什么
    # show_box_and_line(img_, predictions, show_line=True, show_box=True)  # 一起画 1x2 2张图

    t_start = time.time()
    # pred = box_line_optimized(predictions)
    pred = box_line_(img_, predictions)
    t_end = time.time()
    print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')

    show_predict(img_, pred, t_start)


if __name__ == '__main__':
    t_start = time.time()
    print(f'start to predict:{t_start}')
    model = linenet_resnet50_fpn().to(device)
    # pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
    # pt_path = r'D:\python\PycharmProjects\linenet_wts\r50fpn_wts_e350\best.pth'
    pt_path = r'D:\python\PycharmProjects\20250214\weight\linenet_wts\r50fpn_wts_e350\best.pth'
    # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-43-13_SaveImage.png'  # 工件图
    # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
    img_path = r'C:\Users\m2337\Desktop\p\2025-01-03-09-34-32_SaveImage_adjust_brightness_contrast.jpg'
    predict(pt_path, model, img_path)
    t_end = time.time()
    print(f'predict used:{t_end - t_start}')