from typing import Any, Optional

import torch
from torch import nn
from torchvision.ops import MultiScaleRoIAlign

from libs.vision_libs.ops import misc as misc_nn_ops
from libs.vision_libs.transforms._presets import ObjectDetection
from .roi_heads import RoIHeads
from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
from libs.vision_libs.models.detection._utils import overwrite_eps
from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor

from models.config.config_tool import read_yaml
import numpy as np
import torch.nn.functional as F

FEATURE_DIM = 8

def non_maximum_suppression(a):
    ap = F.max_pool2d(a, 3, stride=1, padding=1)
    mask = (a == ap).float().clamp(min=0.0)
    return a * mask

class Bottleneck1D(nn.Module):
    def __init__(self, inplanes, outplanes):
        super(Bottleneck1D, self).__init__()

        planes = outplanes // 2
        self.op = nn.Sequential(
            nn.BatchNorm1d(inplanes),
            nn.ReLU(inplace=True),
            nn.Conv1d(inplanes, planes, kernel_size=1),
            nn.BatchNorm1d(planes),
            nn.ReLU(inplace=True),
            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
            nn.BatchNorm1d(planes),
            nn.ReLU(inplace=True),
            nn.Conv1d(planes, outplanes, kernel_size=1),
        )

    def forward(self, x):
        return x + self.op(x)

class LineRCNNPredictor(nn.Module):
    def __init__(self,n_pts0 = 32,
                 n_pts1 = 8,
                 n_stc_posl =300,
                 dim_loi = 128,
                 use_conv = 0,
                 dim_fc = 1024,
                 n_out_line = 2500,
                 n_out_junc =250,
                 n_dyn_junc = 300,
                 eval_junc_thres = 0.008,
                 n_dyn_posl =300,
                 n_dyn_negl =80,
                 n_dyn_othr = 600,
                 use_cood = 0,
                 use_slop = 0,
                 n_stc_negl = 40,
                 head_size = [[2], [1], [2]] ,
                 **kwargs):
        super().__init__()
        # self.backbone = backbone
        # self.cfg = read_yaml(cfg)
        # self.cfg = read_yaml(r'./config/wireframe.yaml')

        # print(f'linePredictor cfg:{cfg}')
        #
        # self.cfg = cfg
        # self.n_pts0 = self.cfg['n_pts0']
        # self.n_pts1 = self.cfg['n_pts1']
        # self.n_stc_posl = self.cfg['n_stc_posl']
        # self.dim_loi = self.cfg['dim_loi']
        # self.use_conv = self.cfg['use_conv']
        # self.dim_fc = self.cfg['dim_fc']
        # self.n_out_line = self.cfg['n_out_line']
        # self.n_out_junc = self.cfg['n_out_junc']
        # self.loss_weight = self.cfg['loss_weight']
        # self.n_dyn_junc = self.cfg['n_dyn_junc']
        # self.eval_junc_thres = self.cfg['eval_junc_thres']
        # self.n_dyn_posl = self.cfg['n_dyn_posl']
        # self.n_dyn_negl = self.cfg['n_dyn_negl']
        # self.n_dyn_othr = self.cfg['n_dyn_othr']
        # self.use_cood = self.cfg['use_cood']
        # self.use_slop = self.cfg['use_slop']
        # self.n_stc_negl = self.cfg['n_stc_negl']
        # self.head_size = self.cfg['head_size']


        self.n_pts0 = n_pts0
        self.n_pts1 = n_pts1
        self.n_stc_posl =n_stc_posl
        self.dim_loi = dim_loi
        self.use_conv = use_conv
        self.dim_fc = dim_fc
        self.n_out_line = n_out_line
        self.n_out_junc =n_out_junc
        # self.loss_weight =
        self.n_dyn_junc = n_dyn_junc
        self.eval_junc_thres = eval_junc_thres
        self.n_dyn_posl =n_dyn_posl
        self.n_dyn_negl = n_dyn_negl
        self.n_dyn_othr = n_dyn_othr
        self.use_cood = use_cood
        self.use_slop = use_slop
        self.n_stc_negl = n_stc_negl
        self.head_size = head_size

        self.num_class = sum(sum(self.head_size, []))
        self.head_off = np.cumsum([sum(h) for h in self.head_size])

        lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
        self.register_buffer("lambda_", lambda_)
        self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0

        self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
        scale_factor = self.n_pts0 // self.n_pts1
        if self.use_conv:
            self.pooling = nn.Sequential(
                nn.MaxPool1d(scale_factor, scale_factor),
                Bottleneck1D(self.dim_loi, self.dim_loi),
            )
            self.fc2 = nn.Sequential(
                nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
            )
        else:
            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
            self.fc2 = nn.Sequential(
                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
                nn.ReLU(inplace=True),
                nn.Linear(self.dim_fc, self.dim_fc),
                nn.ReLU(inplace=True),
                nn.Linear(self.dim_fc, 1),
            )
        self.loss = nn.BCEWithLogitsLoss(reduction="none")

    def forward(self, inputs, features, targets=None):

        # outputs, features = input
        # for out in outputs:
        #     print(f'out:{out.shape}')
        # outputs=merge_features(outputs,100)
        batch, channel, row, col = inputs.shape
        # print(f'outputs:{inputs.shape}')
        # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')

        if targets is not None:
            self.training = True
            # print(f'target:{targets}')
            wires_targets = [t["wires"] for t in targets]
            # print(f'wires_target:{wires_targets}')
            # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
            junc_maps = [d["junc_map"] for d in wires_targets]
            junc_offsets = [d["junc_offset"] for d in wires_targets]
            line_maps = [d["line_map"] for d in wires_targets]

            junc_map_tensor = torch.stack(junc_maps, dim=0)
            junc_offset_tensor = torch.stack(junc_offsets, dim=0)
            line_map_tensor = torch.stack(line_maps, dim=0)

            wires_meta = {
                "junc_map": junc_map_tensor,
                "junc_offset": junc_offset_tensor,
                # "line_map": line_map_tensor,
            }
        else:
            self.training = False
            t = {
                "junc_coords": torch.zeros(1, 2),
                "jtyp": torch.zeros(1, dtype=torch.uint8),
                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
                "junc_map": torch.zeros([1, 1, 128, 128]),
                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
            }
            wires_targets = [t for b in range(inputs.size(0))]

            wires_meta = {
                "junc_map": torch.zeros([1, 1, 128, 128]),
                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
            }

        T = wires_meta.copy()
        n_jtyp = T["junc_map"].shape[1]
        offset = self.head_off
        result = {}
        for stack, output in enumerate([inputs]):
            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
            # print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
            lmap = output[offset[0]: offset[1]].squeeze(0)
            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)

            if stack == 0:
                result["preds"] = {
                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
                    "lmap": lmap.sigmoid(),
                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
                }
                # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
                # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
                # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")

        h = result["preds"]
        # print(f'features shape:{features.shape}')
        x = self.fc1(features)

        # print(f'x:{x.shape}')

        n_batch, n_channel, row, col = x.shape

        # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')

        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []

        for i, meta in enumerate(wires_targets):
            p, label, feat, jc = self.sample_lines(
                meta, h["jmap"][i], h["joff"][i],
            )
            # print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
            ys.append(label)
            if self.training and self.do_static_sampling:
                p = torch.cat([p, meta["lpre"]])
                feat = torch.cat([feat, meta["lpre_feat"]])
                ys.append(meta["lpre_label"])
                del jc
            else:
                jcs.append(jc)
                ps.append(p)
            fs.append(feat)

            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
            px0 = px.floor().clamp(min=0, max=127)
            py0 = py.floor().clamp(min=0, max=127)
            px1 = (px0 + 1).clamp(min=0, max=127)
            py1 = (py0 + 1).clamp(min=0, max=127)
            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()

            # xp: [N_LINE, N_CHANNEL, N_POINT]
            xp = (
                (
                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
                )
                .reshape(n_channel, -1, self.n_pts0)
                .permute(1, 0, 2)
            )
            xp = self.pooling(xp)
            # print(f'xp.shape:{xp.shape}')
            xs.append(xp)
            idx.append(idx[-1] + xp.shape[0])
            # print(f'idx__:{idx}')

        x, y = torch.cat(xs), torch.cat(ys)
        f = torch.cat(fs)
        x = x.reshape(-1, self.n_pts1 * self.dim_loi)

        # print("Weight dtype:", self.fc2.weight.dtype)
        x = torch.cat([x, f], 1)
        # print("Input dtype:", x.dtype)
        x = x.to(dtype=torch.float32)
        # print("Input dtype1:", x.dtype)
        x = self.fc2(x).flatten()

        # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
        return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc

        # if mode != "training":
        # self.inference(x, idx, jcs, n_batch, ps)

        # return result

    def sample_lines(self, meta, jmap, joff):
        device = jmap.device
        with torch.no_grad():
            junc = meta["junc_coords"].to(device)  # [N, 2]
            jtyp = meta["jtyp"].to(device)  # [N]
            Lpos = meta["line_pos_idx"].to(device)
            Lneg = meta["line_neg_idx"].to(device)

            n_type = jmap.shape[0]
            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
            joff = joff.reshape(n_type, 2, -1)
            max_K = self.n_dyn_junc // n_type
            N = len(junc)
            # if mode != "training":
            if not self.training:
                K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
            else:
                K = min(int(N * 2 + 2), max_K)
            if K < 2:
                K = 2
            device = jmap.device

            # index: [N_TYPE, K]
            score, index = torch.topk(jmap, k=K)
            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5

            # xy: [N_TYPE, K, 2]
            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
            xy_ = xy[..., None, :]
            del x, y, index

            # dist: [N_TYPE, K, N]
            dist = torch.sum((xy_ - junc) ** 2, -1)
            cost, match = torch.min(dist, -1)

            # xy: [N_TYPE * K, 2]
            # match: [N_TYPE, K]
            for t in range(n_type):
                match[t, jtyp[match[t]] != t] = N
            match[cost > 1.5 * 1.5] = N
            match = match.flatten()

            _ = torch.arange(n_type * K, device=device)
            u, v = torch.meshgrid(_, _)
            u, v = u.flatten(), v.flatten()
            up, vp = match[u], match[v]
            label = Lpos[up, vp]

            # if mode == "training":
            if self.training:
                c = torch.zeros_like(label, dtype=torch.bool)

                # sample positive lines
                cdx = label.nonzero().flatten()
                if len(cdx) > self.n_dyn_posl:
                    # print("too many positive lines")
                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
                    cdx = cdx[perm]
                c[cdx] = 1

                # sample negative lines
                cdx = Lneg[up, vp].nonzero().flatten()
                if len(cdx) > self.n_dyn_negl:
                    # print("too many negative lines")
                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
                    cdx = cdx[perm]
                c[cdx] = 1

                # sample other (unmatched) lines
                cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
                c[cdx] = 1
            else:
                c = (u < v).flatten()

            # sample lines
            u, v, label = u[c], v[c], label[c]
            xy = xy.reshape(n_type * K, 2)
            xyu, xyv = xy[u], xy[v]

            u2v = xyu - xyv
            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
            feat = torch.cat(
                [
                    xyu / 128 * self.use_cood,
                    xyv / 128 * self.use_cood,
                    u2v * self.use_slop,
                    (u[:, None] > K).float(),
                    (v[:, None] > K).float(),
                ],
                1,
            )
            line = torch.cat([xyu[:, None], xyv[:, None]], 1)

            xy = xy.reshape(n_type, K, 2)
            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
            return line, label.float(), feat, jcs



_COMMON_META = {
    "categories": _COCO_PERSON_CATEGORIES,
    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
    "min_size": (1, 1),
}