import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms


def box_line(pred):
    '''
    :param pred: 预测结果
    :return:

    box与line一一对应
{'box': [0.0, 34.23157501220703, 151.70858764648438, 125.10173797607422], 'line': array([[ 1.9720564, 81.73457  ],
[ 1.9933801, 41.730167 ]], dtype=float32)}
    '''
    box_line = [[] for _ in range((len(pred) - 1))]
    for idx, box_ in enumerate(pred[0:-1]):
        box = box_['boxes']  # 是一个tensor
        line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
        score = pred[-1]['wires']['score'][idx]
        for i in box:
            aaa = {}
            aaa['box'] = i.tolist()
            aaa['line'] = []
            score_max = 0.0
            for j in range(len(line)):
                if (line[j][0][0] >= i[0] and line[j][1][0] >= i[0] and line[j][0][0] <= i[2] and
                        line[j][1][0] <= i[2] and line[j][0][1] >= i[1] and line[j][1][1] >= i[1] and
                        line[j][0][1] <= i[3] and line[j][1][1] <= i[3]):
                    if score[j] > score_max:
                        aaa['line'] = line[j]
                        score_max = score[j]
            box_line[idx].append(aaa)


def box_line_(pred):
    '''
    形式同pred
    '''
    for idx, box_ in enumerate(pred[0:-1]):
        box = box_['boxes']  # 是一个tensor
        line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
        score = pred[-1]['wires']['score'][idx]
        line_ = []
        for i in box:
            score_max = 0.0
            tmp = [[0.0, 0.0], [0.0, 0.0]]
            for j in range(len(line)):
                if (line[j][0][0] >= i[0] and line[j][1][0] >= i[0] and line[j][0][0] <= i[2] and
                        line[j][1][0] <= i[2] and line[j][0][1] >= i[1] and line[j][1][1] >= i[1] and
                        line[j][0][1] <= i[3] and line[j][1][1] <= i[3]):
                    if score[j] > score_max:
                        tmp = line[j]
                        score_max = score[j]
            line_.append(tmp)
        processed_list = torch.tensor(line_)
        pred[idx]['line'] = processed_list
    return pred


def show_(imgs, pred, epoch, writer):
    col = [
        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
    ]
    print(len(col))
    im = imgs[0].permute(1, 2, 0)
    boxes = pred[0]['boxes'].cpu().numpy()
    line = pred[0]['line'].cpu().numpy()

    # 可视化预测结
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(np.array(im))

    for idx, box in enumerate(boxes):
        x0, y0, x1, y1 = box
        ax.add_patch(
            plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))

    for idx, (a, b) in enumerate(line):
        ax.scatter(a[0], a[1], c=col[99 - idx], s=2)
        ax.scatter(b[0], b[1], c=col[99 - idx], s=2)
        ax.plot([a[0], b[0]], [a[1], b[1]], c=col[idx], linewidth=1)

    # 将Matplotlib图像转换为Tensor
    fig.canvas.draw()
    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
        fig.canvas.get_width_height()[::-1] + (3,))
    plt.close()
    img2 = transforms.ToTensor()(image_from_plot)

    writer.add_image("all", img2, epoch)