import time

import numpy as np
import torch
from matplotlib import pyplot as plt
from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
from torchvision.io import decode_image, read_image
import torchvision.transforms.functional as F
from torchvision.utils import draw_keypoints, draw_bounding_boxes

from models.keypoint.kepointrcnn import KeypointRCNNModel


def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])


# img_path=r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg"
# img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png"
img_path = r"I:\datasets\wirenet_1000\images\train\00031644_0.png"
img_int = read_image(img_path)

# person_int = decode_image(r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg")

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT

transforms = weights.transforms()

print(f'transforms:{transforms}')
img = transforms(img_int)

person_float = transforms(img)

model = KeypointRCNNModel(num_keypoints=2)

print(f'start to load pretraine weight!')
model.load_weight('./train_results/20241226_171710/weights/best.pt')
print(f'loaded weight !!!')

# model.to(device)
model.eval()
# model = keypointrcnn_resnet50_fpn(weights=None, progress=False)
# model = model.eval()
t1 = time.time()
# img = torch.ones((3, 3, 512, 512))

print(f't1:{t1}')
outputs = model([img])
t2 = time.time()
print(f'time:{t2 - t1}')
# print(f'outputs:{outputs}')

kpts = outputs[0]['keypoints']
scores = outputs[0]['scores']
boxes= outputs[0]['boxes']
print(f'kpts:{kpts}')
print(f'scores:{scores}')

detect_threshold = 0.001
idx = torch.where(scores > detect_threshold)
keypoints = kpts[idx]

# print(f'keypoints:{keypoints}')


res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
res_box=draw_bounding_boxes(img_int,boxes)
show(res_box)
plt.show()