import math
import os
import sys
from collections import OrderedDict
from datetime import datetime
from typing import Mapping
import cv2
import numpy as np
import torch
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
from torch import nn
from torch.nn.modules.module import T
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights, resnet18, ResNet18_Weights
from torchvision.models._utils import _ovewrite_value_param
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.backbone_utils import _validate_trainable_layers, _resnet_fpn_extractor
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor, KeypointRCNN, \
    KeypointRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.ops.feature_pyramid_network import LastLevelMaxPool
from torchvision.utils import draw_bounding_boxes
from torchvision.ops import misc as misc_nn_ops, FeaturePyramidNetwork
from typing import Optional, Any
from models.config.config_tool import read_yaml
from models.keypoint.trainer import train_cfg
from models.wirenet._utils import overwrite_eps
# from timm import create_model
from  torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from tools import utils
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

class KeypointRCNNModel(nn.Module):

    def __init__(self, num_classes=2,num_keypoints=2, transforms=None):
        super(KeypointRCNNModel, self).__init__()

        ####mobile net
       # backbone = torchvision.models.mobilenet_v2(weights=None).features
       # backbone.out_channels = 1280
       # anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),aspect_ratios = ((0.5, 1.0, 2.0),))
       # roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],output_size = 7,sampling_ratio = 2)
       #keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],output_size = 14,sampling_ratio = 2)
       # self.__model= KeypointRCNN(backbone, num_classes=2, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler,keypoint_roi_pool=keypoint_roi_pooler)
        ####

        # 加载 EfficientNet 模型并移除分类头
        # backbone = create_model('tf_efficientnet_b0', pretrained=True, features_only=True)
        # backbone_out_channels =backbone.feature_info.channels()  # 获取所有阶段的通道数
        #
        #
        # # 构建 FPN
        # fpn = FeaturePyramidNetwork(
        #     in_channels_list=backbone_out_channels,
        #     out_channels=256,
        #     extra_blocks=LastLevelMaxPool()
        # )
        #
        # # 将 EfficientNet 和 FPN 组合成一个新的 backbone
        # self.body = nn.Sequential(
        #     backbone,
        #     fpn
        # )
        default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT

        self.__model = keypointrcnn_resnet18_fpn(weights=None,num_classes=num_classes,
                                                                              num_keypoints=num_keypoints,
                                                                              progress=False)
        # self.__model.backbone.body = nn.Sequential(OrderedDict([
        #     ('body', self.body),
        #     ('fpn', fpn)
        # ]))

        if transforms is None:
            self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
        # if num_classes != 0:
        #     self.set_num_classes(num_classes)
            # self.__num_classes=0

        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    def forward(self, inputs):
        outputs = self.__model(inputs)
        return outputs

    def train(self, cfg):
        parameters = read_yaml(cfg)
        num_classes = parameters['num_classes']
        num_keypoints = parameters['num_keypoints']
        # print(f'num_classes:{num_classes}')
        # self.set_num_classes(num_classes)
        self.num_keypoints = num_keypoints
        train_cfg(self.__model, cfg)

    # def set_num_classes(self, num_classes):
    #     in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
    #     self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
    #
    #     # in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
    #     in_channels = self.__model.roi_heads.keypoint_predictor.
    #     hidden_layer = 256
    #     self.__model.roi_heads.mask_predictor = KeypointRCNNPredictor(in_channels, hidden_layer,
    #                                                               num_classes=num_classes)
    #     self.__model.roi_heads.keypoint_predictor=KeypointRCNNPredictor(in_channels, num_keypoints=num_classes)

    def load_weight(self, pt_path):
        state_dict = torch.load(pt_path)
        self.__model.load_state_dict(state_dict)

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
        self.__model.load_state_dict(state_dict)
        # return super().load_state_dict(state_dict, strict)

    def eval(self: T) -> T:
        self.__model.eval()
        # return super().eval()
    def predict(self, img, show=True, save=False, save_path=None):
        """
         对输入图像进行关键点检测预测。

         参数:
             img (str or PIL.Image): 输入图像的路径或 PIL.Image 对象。
             show (bool): 是否显示预测结果,默认为 True。
             save (bool): 是否保存预测结果,默认为 False。

         返回:
             dict: 包含预测结果的字典。
         """
        if isinstance(img, str):
            img = Image.open(img).convert("RGB")

        self.__model.eval()

        # 预处理图像
        img_tensor = self.transforms(img)
        with torch.no_grad():
            predictions = self.__model([img_tensor])

        print(f'predictions:{predictions}')

        # 后处理预测结果
        boxes = predictions[0]['boxes'].cpu().numpy()
        keypoints = predictions[0]['keypoints'].cpu().numpy()

        # 可视化预测结果
        if show or save:
            fig, ax = plt.subplots(figsize=(10, 10))
            ax.imshow(np.array(img))

            for box in boxes:
                x0, y0, x1, y1 = box
                ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1))

            for (a, b) in keypoints:
                ax.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=1)
                ax.scatter(a[0], a[1], c='red', s=2)
                ax.scatter(b[0], b[1], c='red', s=2)

            if show:
                plt.show()

            if save:
                fig.savefig(save_path)
                print(f"Prediction saved to {save_path}")
            plt.close(fig)

def keypointrcnn_resnet18_fpn(
        *,
        weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
        progress: bool = True,
        num_classes: Optional[int] = None,
        num_keypoints: Optional[int] = None,
        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
        trainable_backbone_layers: Optional[int] = None,
        **kwargs: Any,
) -> KeypointRCNN:
    """
    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.

    .. betastatus:: detection module

    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.

    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
    image, and should be in ``0-1`` range. Different images can have different sizes.

    The behavior of the model changes depending on if it is in training or evaluation mode.

    During training, the model expects both the input tensors and targets (list of dictionary),
    containing:

        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.

    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
    losses for both the RPN and the R-CNN, and the keypoint loss.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
    follows, where ``N`` is the number of detected instances:

        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
        - labels (``Int64Tensor[N]``): the predicted labels for each instance
        - scores (``Tensor[N]``): the scores or each instance
        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.

    For more details on the output, you may refer to :ref:`instance_seg_output`.

    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.

    Example::

        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)
        >>>
        >>> # optionally, if you want to export the model to ONNX:
        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)

    Args:
        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
            below for more details, and possible values. By default, no
            pre-trained weights are used.
        progress (bool): If True, displays a progress bar of the download to stderr
        num_classes (int, optional): number of output classes of the model (including the background)
        num_keypoints (int, optional): number of keypoints
        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
            pretrained weights for the backbone.
        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
            passed (the default) this value is set to 3.

    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
        :members:
    """
    weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
    weights_backbone = ResNet50_Weights.verify(weights_backbone)
    # if weights_backbone is None:

    weights_backbone = ResNet18_Weights.IMAGENET1K_V1

    if weights is not None:
        # weights_backbone = None
        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
    else:
        if num_classes is None:
            num_classes = 2
        if num_keypoints is None:
            num_keypoints = 17

    is_trained = weights is not None or weights_backbone is not None
    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d

    backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)

    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
    model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
        if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
            overwrite_eps(model, 0.0)

    return model
if __name__ == '__main__':
    # ins_model = MaskRCNNModel(num_classes=5)
    keypoint_model = KeypointRCNNModel(num_keypoints=2)
    wts_path='./train_results/20241227_231659/weights/best.pt'


    # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
    # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
    keypoint_model.train(cfg='train.yaml')

    # keypoint_model.load_weight(wts_path)
    # img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-10-02-15_SaveImage.png"
    # keypoint_model.predict(img_path)