import os
from typing import Optional, Any

import numpy as np
import torch
from tensorboardX import SummaryWriter
from torch import nn
import torch.nn.functional as F
# from torchinfo import summary
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.detection._utils import overwrite_eps
from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNHeads, KeypointRCNNPredictor, \
    KeypointRCNN_ResNet50_FPN_Weights
from torchvision.ops import MultiScaleRoIAlign
from torchvision.ops import misc as misc_nn_ops
# from visdom import Visdom

from models.config import config_tool
from models.config.config_tool import read_yaml
from models.ins.trainer import get_transform
from models.wirenet.head import RoIHeads
from models.wirenet.wirepoint_dataset import WirePointDataset
from tools import utils


FEATURE_DIM = 8


def non_maximum_suppression(a):
    ap = F.max_pool2d(a, 3, stride=1, padding=1)
    mask = (a == ap).float().clamp(min=0.0)
    return a * mask


class Bottleneck1D(nn.Module):
    def __init__(self, inplanes, outplanes):
        super(Bottleneck1D, self).__init__()

        planes = outplanes // 2
        self.op = nn.Sequential(
            nn.BatchNorm1d(inplanes),
            nn.ReLU(inplace=True),
            nn.Conv1d(inplanes, planes, kernel_size=1),
            nn.BatchNorm1d(planes),
            nn.ReLU(inplace=True),
            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
            nn.BatchNorm1d(planes),
            nn.ReLU(inplace=True),
            nn.Conv1d(planes, outplanes, kernel_size=1),
        )

    def forward(self, x):
        return x + self.op(x)


class WirepointRCNN(FasterRCNN):
    def __init__(
            self,
            backbone,
            num_classes=None,
            # transform parameters
            min_size=None,
            max_size=1333,
            image_mean=None,
            image_std=None,
            # RPN parameters
            rpn_anchor_generator=None,
            rpn_head=None,
            rpn_pre_nms_top_n_train=2000,
            rpn_pre_nms_top_n_test=1000,
            rpn_post_nms_top_n_train=2000,
            rpn_post_nms_top_n_test=1000,
            rpn_nms_thresh=0.7,
            rpn_fg_iou_thresh=0.7,
            rpn_bg_iou_thresh=0.3,
            rpn_batch_size_per_image=256,
            rpn_positive_fraction=0.5,
            rpn_score_thresh=0.0,
            # Box parameters
            box_roi_pool=None,
            box_head=None,
            box_predictor=None,
            box_score_thresh=0.05,
            box_nms_thresh=0.5,
            box_detections_per_img=100,
            box_fg_iou_thresh=0.5,
            box_bg_iou_thresh=0.5,
            box_batch_size_per_image=512,
            box_positive_fraction=0.25,
            bbox_reg_weights=None,
            # keypoint parameters
            keypoint_roi_pool=None,
            keypoint_head=None,
            keypoint_predictor=None,
            num_keypoints=None,
            wirepoint_roi_pool=None,
            wirepoint_head=None,
            wirepoint_predictor=None,
            **kwargs,
    ):
        if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
            raise TypeError(
                "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
            )
        if min_size is None:
            min_size = (640, 672, 704, 736, 768, 800)

        if num_keypoints is not None:
            if keypoint_predictor is not None:
                raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
        else:
            num_keypoints = 17

        out_channels = backbone.out_channels

        if wirepoint_roi_pool is None:
            wirepoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=128,
                                                    sampling_ratio=2,)

        if wirepoint_head is None:
            keypoint_layers = tuple(512 for _ in range(8))
            print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}')
            wirepoint_head = WirepointHead(out_channels, keypoint_layers)

        if wirepoint_predictor is None:
            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
            wirepoint_predictor = WirepointPredictor()

        super().__init__(
            backbone,
            num_classes,
            # transform parameters
            min_size,
            max_size,
            image_mean,
            image_std,
            # RPN-specific parameters
            rpn_anchor_generator,
            rpn_head,
            rpn_pre_nms_top_n_train,
            rpn_pre_nms_top_n_test,
            rpn_post_nms_top_n_train,
            rpn_post_nms_top_n_test,
            rpn_nms_thresh,
            rpn_fg_iou_thresh,
            rpn_bg_iou_thresh,
            rpn_batch_size_per_image,
            rpn_positive_fraction,
            rpn_score_thresh,
            # Box parameters
            box_roi_pool,
            box_head,
            box_predictor,
            box_score_thresh,
            box_nms_thresh,
            box_detections_per_img,
            box_fg_iou_thresh,
            box_bg_iou_thresh,
            box_batch_size_per_image,
            box_positive_fraction,
            bbox_reg_weights,
            **kwargs,
        )

        if box_roi_pool is None:
            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)

        if box_head is None:
            resolution = box_roi_pool.output_size[0]
            representation_size = 1024
            box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)

        if box_predictor is None:
            representation_size = 1024
            box_predictor = FastRCNNPredictor(representation_size, num_classes)

        roi_heads = RoIHeads(
            # Box
            box_roi_pool,
            box_head,
            box_predictor,
            box_fg_iou_thresh,
            box_bg_iou_thresh,
            box_batch_size_per_image,
            box_positive_fraction,
            bbox_reg_weights,
            box_score_thresh,
            box_nms_thresh,
            box_detections_per_img,
            # wirepoint_roi_pool=wirepoint_roi_pool,
            # wirepoint_head=wirepoint_head,
            # wirepoint_predictor=wirepoint_predictor,
        )
        self.roi_heads = roi_heads

        self.roi_heads.wirepoint_roi_pool = wirepoint_roi_pool
        self.roi_heads.wirepoint_head = wirepoint_head
        self.roi_heads.wirepoint_predictor = wirepoint_predictor


class WirepointHead(nn.Module):
    def __init__(self, input_channels, num_class):
        super(WirepointHead, self).__init__()
        self.head_size = [[2], [1], [2]]
        m = int(input_channels / 4)
        heads = []
        # print(f'M.head_size:{M.head_size}')
        # for output_channels in sum(M.head_size, []):
        for output_channels in sum(self.head_size, []):
            heads.append(
                nn.Sequential(
                    nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(m, output_channels, kernel_size=1),
                )
            )
        self.heads = nn.ModuleList(heads)

    def forward(self, x):
        # for idx, head in enumerate(self.heads):
        #     print(f'{idx},multitask head:{head(x).shape},input x:{x.shape}')

        outputs = torch.cat([head(x) for head in self.heads], dim=1)

        features = x
        return outputs, features


class WirepointPredictor(nn.Module):

    def __init__(self):
        super().__init__()
        # self.backbone = backbone
        # self.cfg = read_yaml(cfg)
        self.cfg = read_yaml('wirenet.yaml')
        self.n_pts0 = self.cfg['model']['n_pts0']
        self.n_pts1 = self.cfg['model']['n_pts1']
        self.n_stc_posl = self.cfg['model']['n_stc_posl']
        self.dim_loi = self.cfg['model']['dim_loi']
        self.use_conv = self.cfg['model']['use_conv']
        self.dim_fc = self.cfg['model']['dim_fc']
        self.n_out_line = self.cfg['model']['n_out_line']
        self.n_out_junc = self.cfg['model']['n_out_junc']
        self.loss_weight = self.cfg['model']['loss_weight']
        self.n_dyn_junc = self.cfg['model']['n_dyn_junc']
        self.eval_junc_thres = self.cfg['model']['eval_junc_thres']
        self.n_dyn_posl = self.cfg['model']['n_dyn_posl']
        self.n_dyn_negl = self.cfg['model']['n_dyn_negl']
        self.n_dyn_othr = self.cfg['model']['n_dyn_othr']
        self.use_cood = self.cfg['model']['use_cood']
        self.use_slop = self.cfg['model']['use_slop']
        self.n_stc_negl = self.cfg['model']['n_stc_negl']
        self.head_size = self.cfg['model']['head_size']
        self.num_class = sum(sum(self.head_size, []))
        self.head_off = np.cumsum([sum(h) for h in self.head_size])

        lambda_ = torch.linspace(0, 1, self.n_pts0)[:, None]
        self.register_buffer("lambda_", lambda_)
        self.do_static_sampling = self.n_stc_posl + self.n_stc_negl > 0

        self.fc1 = nn.Conv2d(256, self.dim_loi, 1)
        scale_factor = self.n_pts0 // self.n_pts1
        if self.use_conv:
            self.pooling = nn.Sequential(
                nn.MaxPool1d(scale_factor, scale_factor),
                Bottleneck1D(self.dim_loi, self.dim_loi),
            )
            self.fc2 = nn.Sequential(
                nn.ReLU(inplace=True), nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, 1)
            )
        else:
            self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
            self.fc2 = nn.Sequential(
                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
                nn.ReLU(inplace=True),
                nn.Linear(self.dim_fc, self.dim_fc),
                nn.ReLU(inplace=True),
                nn.Linear(self.dim_fc, 1),
            )
        self.loss = nn.BCEWithLogitsLoss(reduction="none")

    def forward(self, inputs,features, targets=None):

        # outputs, features = input
        # for out in outputs:
        #     print(f'out:{out.shape}')
        # outputs=merge_features(outputs,100)
        batch, channel, row, col = inputs.shape
        print(f'outputs:{inputs.shape}')
        print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')

        if targets is not None:
            self.training = True
            # print(f'target:{targets}')
            wires_targets = [t["wires"] for t in targets]
            # print(f'wires_target:{wires_targets}')
            # 提取所有 'junc_map', 'junc_offset', 'line_map' 的张量
            junc_maps = [d["junc_map"] for d in wires_targets]
            junc_offsets = [d["junc_offset"] for d in wires_targets]
            line_maps = [d["line_map"] for d in wires_targets]

            junc_map_tensor = torch.stack(junc_maps, dim=0)
            junc_offset_tensor = torch.stack(junc_offsets, dim=0)
            line_map_tensor = torch.stack(line_maps, dim=0)

            wires_meta = {
                "junc_map": junc_map_tensor,
                "junc_offset": junc_offset_tensor,
                # "line_map": line_map_tensor,
            }
        else:
            self.training = False
            t = {
                    "junc_coords": torch.zeros(1, 2),
                    "jtyp": torch.zeros(1, dtype=torch.uint8),
                    "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8),
                    "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8),
                    "junc_map": torch.zeros([1, 1, 128, 128]),
                    "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
                }
            wires_targets=[t for b in range(inputs.size(0))]

            wires_meta={
                "junc_map": torch.zeros([1, 1, 128, 128]),
                "junc_offset": torch.zeros([1, 1, 2, 128, 128]),
            }


        T = wires_meta.copy()
        n_jtyp = T["junc_map"].shape[1]
        offset = self.head_off
        result={}
        for stack, output in enumerate([inputs]):
            output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
            print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
            jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
            lmap = output[offset[0]: offset[1]].squeeze(0)
            joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)

            if stack == 0:
                result["preds"] = {
                    "jmap": jmap.permute(2, 0, 1, 3, 4).softmax(2)[:, :, 1],
                    "lmap": lmap.sigmoid(),
                    "joff": joff.permute(2, 0, 1, 3, 4).sigmoid() - 0.5,
                }
                # visualize_feature_map(jmap[0, 0], title=f"jmap - Stack {stack}")
                # visualize_feature_map(lmap, title=f"lmap - Stack {stack}")
                # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")

        h = result["preds"]
        print(f'features shape:{features.shape}')
        x = self.fc1(features)

        print(f'x:{x.shape}')

        n_batch, n_channel, row, col = x.shape

        print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')

        xs, ys, fs, ps, idx, jcs = [], [], [], [], [0], []

        for i, meta in enumerate(wires_targets):
            p, label, feat, jc = self.sample_lines(
                meta, h["jmap"][i], h["joff"][i],
            )
            print(f"p.shape:{p.shape},label:{label.shape},feat:{feat.shape},jc:{len(jc)}")
            ys.append(label)
            if self.training and self.do_static_sampling:
                p = torch.cat([p, meta["lpre"]])
                feat = torch.cat([feat, meta["lpre_feat"]])
                ys.append(meta["lpre_label"])
                del jc
            else:
                jcs.append(jc)
                ps.append(p)
            fs.append(feat)

            p = p[:, 0:1, :] * self.lambda_ + p[:, 1:2, :] * (1 - self.lambda_) - 0.5
            p = p.reshape(-1, 2)  # [N_LINE x N_POINT, 2_XY]
            px, py = p[:, 0].contiguous(), p[:, 1].contiguous()
            px0 = px.floor().clamp(min=0, max=127)
            py0 = py.floor().clamp(min=0, max=127)
            px1 = (px0 + 1).clamp(min=0, max=127)
            py1 = (py0 + 1).clamp(min=0, max=127)
            px0l, py0l, px1l, py1l = px0.long(), py0.long(), px1.long(), py1.long()

            # xp: [N_LINE, N_CHANNEL, N_POINT]
            xp = (
                (
                        x[i, :, px0l, py0l] * (px1 - px) * (py1 - py)
                        + x[i, :, px1l, py0l] * (px - px0) * (py1 - py)
                        + x[i, :, px0l, py1l] * (px1 - px) * (py - py0)
                        + x[i, :, px1l, py1l] * (px - px0) * (py - py0)
                )
                .reshape(n_channel, -1, self.n_pts0)
                .permute(1, 0, 2)
            )
            xp = self.pooling(xp)
            print(f'xp.shape:{xp.shape}')
            xs.append(xp)
            idx.append(idx[-1] + xp.shape[0])
            print(f'idx__:{idx}')

        x, y = torch.cat(xs), torch.cat(ys)
        f = torch.cat(fs)
        x = x.reshape(-1, self.n_pts1 * self.dim_loi)

        # print("Weight dtype:", self.fc2.weight.dtype)
        x = torch.cat([x, f], 1)
        print("Input dtype:", x.dtype)
        x = x.to(dtype=torch.float32)
        print("Input dtype1:", x.dtype)
        x = self.fc2(x).flatten()

        # return  x,idx,jcs,n_batch,ps,self.n_out_line,self.n_out_junc
        return x, y, idx, jcs, n_batch, ps, self.n_out_line, self.n_out_junc


        # if mode != "training":
        # self.inference(x, idx, jcs, n_batch, ps)

        # return result


    ####deprecated
    # def inference(self,input, idx, jcs, n_batch, ps):
    #     if not self.training:
    #         p = torch.cat(ps)
    #         s = torch.sigmoid(input)
    #         b = s > 0.5
    #         lines = []
    #         score = []
    #         print(f"n_batch:{n_batch}")
    #         for i in range(n_batch):
    #             print(f"idx:{idx}")
    #             p0 = p[idx[i]: idx[i + 1]]
    #             s0 = s[idx[i]: idx[i + 1]]
    #             mask = b[idx[i]: idx[i + 1]]
    #             p0 = p0[mask]
    #             s0 = s0[mask]
    #             if len(p0) == 0:
    #                 lines.append(torch.zeros([1, self.n_out_line, 2, 2], device=p.device))
    #                 score.append(torch.zeros([1, self.n_out_line], device=p.device))
    #             else:
    #                 arg = torch.argsort(s0, descending=True)
    #                 p0, s0 = p0[arg], s0[arg]
    #                 lines.append(p0[None, torch.arange(self.n_out_line) % len(p0)])
    #                 score.append(s0[None, torch.arange(self.n_out_line) % len(s0)])
    #             for j in range(len(jcs[i])):
    #                 if len(jcs[i][j]) == 0:
    #                     jcs[i][j] = torch.zeros([self.n_out_junc, 2], device=p.device)
    #                 jcs[i][j] = jcs[i][j][
    #                     None, torch.arange(self.n_out_junc) % len(jcs[i][j])
    #                 ]
    #         result["preds"]["lines"] = torch.cat(lines)
    #         result["preds"]["score"] = torch.cat(score)
    #         result["preds"]["juncs"] = torch.cat([jcs[i][0] for i in range(n_batch)])
    #
    #         if len(jcs[i]) > 1:
    #             result["preds"]["junts"] = torch.cat(
    #                 [jcs[i][1] for i in range(n_batch)]
    #             )
    #     if self.training:
    #         del result["preds"]

    def sample_lines(self, meta, jmap, joff):
        with torch.no_grad():
            junc = meta["junc_coords"]  # [N, 2]
            jtyp = meta["jtyp"]  # [N]
            Lpos = meta["line_pos_idx"]
            Lneg = meta["line_neg_idx"]

            n_type = jmap.shape[0]
            jmap = non_maximum_suppression(jmap).reshape(n_type, -1)
            joff = joff.reshape(n_type, 2, -1)
            max_K = self.n_dyn_junc // n_type
            N = len(junc)
            # if mode != "training":
            if not self.training:
                K = min(int((jmap > self.eval_junc_thres).float().sum().item()), max_K)
            else:
                K = min(int(N * 2 + 2), max_K)
            if K < 2:
                K = 2
            device = jmap.device

            # index: [N_TYPE, K]
            score, index = torch.topk(jmap, k=K)
            y = (index // 128).float() + torch.gather(joff[:, 0], 1, index) + 0.5
            x = (index % 128).float() + torch.gather(joff[:, 1], 1, index) + 0.5

            # xy: [N_TYPE, K, 2]
            xy = torch.cat([y[..., None], x[..., None]], dim=-1)
            xy_ = xy[..., None, :]
            del x, y, index

            # dist: [N_TYPE, K, N]
            dist = torch.sum((xy_ - junc) ** 2, -1)
            cost, match = torch.min(dist, -1)

            # xy: [N_TYPE * K, 2]
            # match: [N_TYPE, K]
            for t in range(n_type):
                match[t, jtyp[match[t]] != t] = N
            match[cost > 1.5 * 1.5] = N
            match = match.flatten()

            _ = torch.arange(n_type * K, device=device)
            u, v = torch.meshgrid(_, _)
            u, v = u.flatten(), v.flatten()
            up, vp = match[u], match[v]
            label = Lpos[up, vp]

            # if mode == "training":
            if self.training:
                c = torch.zeros_like(label, dtype=torch.bool)

                # sample positive lines
                cdx = label.nonzero().flatten()
                if len(cdx) > self.n_dyn_posl:
                    # print("too many positive lines")
                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_posl]
                    cdx = cdx[perm]
                c[cdx] = 1

                # sample negative lines
                cdx = Lneg[up, vp].nonzero().flatten()
                if len(cdx) > self.n_dyn_negl:
                    # print("too many negative lines")
                    perm = torch.randperm(len(cdx), device=device)[: self.n_dyn_negl]
                    cdx = cdx[perm]
                c[cdx] = 1

                # sample other (unmatched) lines
                cdx = torch.randint(len(c), (self.n_dyn_othr,), device=device)
                c[cdx] = 1
            else:
                c = (u < v).flatten()

            # sample lines
            u, v, label = u[c], v[c], label[c]
            xy = xy.reshape(n_type * K, 2)
            xyu, xyv = xy[u], xy[v]

            u2v = xyu - xyv
            u2v /= torch.sqrt((u2v ** 2).sum(-1, keepdim=True)).clamp(min=1e-6)
            feat = torch.cat(
                [
                    xyu / 128 * self.use_cood,
                    xyv / 128 * self.use_cood,
                    u2v * self.use_slop,
                    (u[:, None] > K).float(),
                    (v[:, None] > K).float(),
                ],
                1,
            )
            line = torch.cat([xyu[:, None], xyv[:, None]], 1)

            xy = xy.reshape(n_type, K, 2)
            jcs = [xy[i, score[i] > 0.03] for i in range(n_type)]
            return line, label.float(), feat, jcs



def wirepointrcnn_resnet50_fpn(
        *,
        weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
        progress: bool = True,
        num_classes: Optional[int] = None,
        num_keypoints: Optional[int] = None,
        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
        trainable_backbone_layers: Optional[int] = None,
        **kwargs: Any,
) -> WirepointRCNN:
    weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
    weights_backbone = ResNet50_Weights.verify(weights_backbone)

    is_trained = weights is not None or weights_backbone is not None
    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)

    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d

    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
    model = WirepointRCNN(backbone, num_classes=5, **kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
        if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
            overwrite_eps(model, 0.0)

    return model


if __name__ == '__main__':
    cfg = 'wirenet.yaml'
    cfg = read_yaml(cfg)
    print(f'cfg:{cfg}')
    print(cfg['model']['n_dyn_negl'])
    # net = WirepointPredictor()

    dataset = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
    train_sampler = torch.utils.data.RandomSampler(dataset)
    # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=4, drop_last=True)
    train_collate_fn = utils.collate_fn_wirepoint
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_sampler=train_batch_sampler, num_workers=10, collate_fn=train_collate_fn
    )
    model = wirepointrcnn_resnet50_fpn()

    imgs, targets = next(iter(data_loader))

    model.train()
    pred = model(imgs, targets)
    print(f'pred:{pred}')
    # result, losses = model(imgs, targets)
    # print(f'result:{result}')
    # print(f'pred:{losses}')
'''
########### predict#############

    img_path=r"I:\wirenet_dateset\images\train\00030078_2.png"
    transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
    img = read_image(img_path)
    img = transforms(img)

    img = torch.ones((2, 3, 512, 512))
    # print(f'img shape:{img.shape}')
    model.eval()
    onnx_file_path = "./wirenet.onnx"

    # 导出模型为ONNX格式
    # torch.onnx.export(model, img, onnx_file_path, verbose=True, input_names=['input'],
    #                   output_names=['output'])
    # torch.save(model,'./wirenet.pt')



    # 5. 指定输出的 ONNX 文件名
    # onnx_file_path = "./wirepoint_rcnn.onnx"

    # 准备一个示例输入:Mask R-CNN 需要一个图像列表作为输入,每个图像张量的形状应为 [C, H, W]
    img = [torch.ones((3, 800, 800))]  # 示例输入图像大小为 800x800,3个通道



    # 指定输出的 ONNX 文件名
    # onnx_file_path = "./mask_rcnn.onnx"



    # model_scripted = torch.jit.script(model)
    # torch.onnx.export(model_scripted, input, "model.onnx", verbose=True, input_names=["input"],
    #                   output_names=["output"])
    #
    # print(f"Model has been converted to ONNX and saved to {onnx_file_path}")

    pred=model(img)
    #
    print(f'pred:{pred}')



################################################## end predict



########## traing ###################################
    # imgs, targets = next(iter(data_loader))

    # model.train()
    # pred = model(imgs, targets)

    # class WrapperModule(torch.nn.Module):
    #     def __init__(self, model):
    #         super(WrapperModule, self).__init__()
    #         self.model = model
    #
    #     def forward(self,img, targets):
    #         # 在这里处理复杂的输入结构,将其转换为适合追踪的形式
    #         return self.model(img,targets)

    # torch.save(model.state_dict(),'./wire.pt')
    # 包装原始模型
    # wrapped_model = WrapperModule(model)
    # # model_scripted = torch.jit.trace(wrapped_model,img)
    # writer = SummaryWriter('./')
    # writer.add_graph(wrapped_model, (imgs,targets))
    # writer.close()


    #
    # print(f'pred:{pred}')
########## end traing ###################################
    # for imgs,targets in data_loader:
    #     print(f'imgs:{imgs}')
    #     print(f'targets:{targets}')
'''