import time
import skimage
from models.line_detect.postprocess import show_predict
import os
import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from models.line_detect.line_net import linenet_resnet50_fpn
from torchvision import transforms
from rtree import index
import multiprocessing as mp

# 设置多进程启动方式为 'spawn'
mp.set_start_method('spawn', force=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device:{device}')
def load_best_model(model, save_path, device):
    if os.path.exists(save_path):
        checkpoint = torch.load(save_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
    else:
        print(f"No saved model found at {save_path}")
    return model

def process_box(box, lines, scores):
    valid_lines = []  # 存储有效的线段
    valid_scores = []  # 存储有效的分数

    for i in box:
        best_line = None
        max_length = 0.0

        # 遍历所有线段
        for j in range(lines.shape[1]):
            line_j = lines[0, j].cpu().numpy() / 128 * 512
            # 检查线段是否完全在box内
            if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
                    line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
                    line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
                    line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
                # 计算线段长度
                length = np.linalg.norm(line_j[0] - line_j[1])
                if length > max_length:
                    best_line = line_j
                    max_length = length

        # 如果找到有效的线段,则添加到结果中
        if best_line is not None:
            valid_lines.append(best_line)
            valid_scores.append(max_length)  # 使用线段长度作为分数

    return valid_lines, valid_scores


def box_line_optimized_parallel(pred):
    # 提取所有线段和分数
    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]

    # 初始化存储结果的列表
    filtered_pred = []

    # 使用多进程并行处理每个box
    boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]]  # 所有box
    num_processes = min(mp.cpu_count(), len(boxes))  # 使用可用的核心数

    with mp.Pool(processes=num_processes) as pool:
        results = pool.starmap(
            process_box,
            [(box, lines, scores) for box in boxes]
        )

    # 更新预测结果
    for idx_box, (valid_lines, valid_scores) in enumerate(results):
        if valid_lines:
            pred[idx_box]['line'] = torch.tensor(valid_lines)
            pred[idx_box]['line_score'] = torch.tensor(valid_scores)
            filtered_pred.append(pred[idx_box])

    return filtered_pred


def predict(pt_path, model, img):
    model = load_best_model(model, pt_path, device)
    model.eval()

    if isinstance(img, str):
        img = Image.open(img).convert("RGB")

    transform = transforms.ToTensor()
    img_tensor = transform(img)
    im = img_tensor.permute(1, 2, 0)  # [512, 512, 3]
    im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
    img_tensor = torch.tensor(im_resized).permute(2, 0, 1)

    with torch.no_grad():
        t_start = time.time()
        predictions = model([img_tensor.to(device)])
        t_end = time.time()
        print(f'Prediction used: {t_end - t_start:.4f} seconds')

        boxes = predictions[0]['boxes'].shape
        lines = predictions[-1]['wires']['lines'].shape
        lines_scores = predictions[-1]['wires']['score'].shape
        print(f'Predictions - boxes: {boxes}, lines: {lines}, lines_scores: {lines_scores}')

    t_start = time.time()
    pred = box_line_optimized_parallel(predictions)
    t_end = time.time()
    print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')

    # 检查 pred 是否为空
    if not pred:
        print("No valid predictions found. Skipping visualization.")
        return

    # 只绘制有效的线段
    show_predict(img_tensor, pred, t_start)


if __name__ == '__main__':
    t_start = time.time()
    print(f'Start to predict: {t_start}')
    model = linenet_resnet50_fpn().to(device)
    pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
    img_path = r'C:\Users\m2337\Desktop\p\49.jpg'
    predict(pt_path, model, img_path)
    t_end = time.time()
    print(f'Total prediction time: {t_end - t_start:.4f} seconds')