ソースを参照

重构wirenetrcnn2

RenLiqiang 5 ヶ月 前
コミット
f3e14bfc0f

+ 0 - 0
models/wirenet2/__init__.py


+ 82 - 0
models/wirenet2/kepointrcnn.py

@@ -0,0 +1,82 @@
+import math
+import os
+import sys
+from datetime import datetime
+from typing import Mapping, Any
+import cv2
+import numpy as np
+import torch
+import torchvision
+from torch import nn
+from torchvision.io import read_image
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
+from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
+from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
+from torchvision.utils import draw_bounding_boxes
+
+from models.config.config_tool import read_yaml
+from models.keypoint.trainer import train_cfg
+
+from tools import utils
+os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
+
+
+
+
+
+class KeypointRCNNModel(nn.Module):
+
+    def __init__(self, num_classes=2,num_keypoints=2, transforms=None):
+        super(KeypointRCNNModel, self).__init__()
+        default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
+        self.__model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=None,num_classes=num_classes,
+                                                                              num_keypoints=num_keypoints,
+                                                                              progress=False)
+        if transforms is None:
+            self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
+        # if num_classes != 0:
+        #     self.set_num_classes(num_classes)
+            # self.__num_classes=0
+
+        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    def forward(self, inputs):
+        outputs = self.__model(inputs)
+        return outputs
+
+    def train(self, cfg):
+        parameters = read_yaml(cfg)
+        num_classes = parameters['num_classes']
+        num_keypoints = parameters['num_keypoints']
+        # print(f'num_classes:{num_classes}')
+        # self.set_num_classes(num_classes)
+        self.num_keypoints = num_keypoints
+        train_cfg(self.__model, cfg)
+
+    # def set_num_classes(self, num_classes):
+    #     in_features = self.__model.roi_heads.box_predictor.cls_score.in_features
+    #     self.__model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)
+    #
+    #     # in_features_mask = self.__model.roi_heads.mask_predictor.conv5_mask.in_channels
+    #     in_channels = self.__model.roi_heads.keypoint_predictor.
+    #     hidden_layer = 256
+    #     self.__model.roi_heads.mask_predictor = KeypointRCNNPredictor(in_channels, hidden_layer,
+    #                                                               num_classes=num_classes)
+    #     self.__model.roi_heads.keypoint_predictor=KeypointRCNNPredictor(in_channels, num_keypoints=num_classes)
+
+    def load_weight(self, pt_path):
+        state_dict = torch.load(pt_path)
+        self.__model.load_state_dict(state_dict)
+
+    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+        self.__model.load_state_dict(state_dict)
+        # return super().load_state_dict(state_dict, strict)
+
+
+if __name__ == '__main__':
+    # ins_model = MaskRCNNModel(num_classes=5)
+    keypoint_model = KeypointRCNNModel(num_keypoints=2)
+    # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
+    # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
+    keypoint_model.train(cfg='train.yaml')

+ 203 - 0
models/wirenet2/keypoint_dataset.py

@@ -0,0 +1,203 @@
+from torch.utils.data.dataset import T_co
+
+from models.base.base_dataset import BaseDataset
+
+import glob
+import json
+import math
+import os
+import random
+import cv2
+import PIL
+
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
+
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
+
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+
+def validate_keypoints(keypoints, image_width, image_height):
+    for kp in keypoints:
+        x, y, v = kp
+        if not (0 <= x < image_width and 0 <= y < image_height):
+            raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
+
+
+class KeypointDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
+
+        # print(f'img:{img}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'shape:{shape}')
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # 字典
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"]),
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # 真实存在线条的邻接矩阵
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        if self.target_type == 'polygon':
+            labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        elif self.target_type == 'pixel':
+            labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
+        target = {}
+
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+
+        target["labels"] = torch.stack(labels)
+        # print(f'labels:{target["labels"]}')
+        # target["boxes"] = line_boxes(target)
+        target["boxes"], keypoints = line_boxes(target)
+        # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
+
+        # keypoints= wire_labels["junc_coords"]
+        a = torch.full((keypoints.shape[0],), 2).unsqueeze(1)
+        keypoints = torch.cat((keypoints, a), dim=1)
+        target["keypoints"] = keypoints.to(torch.float32).view(-1,2,3)
+        # print(f'boxes:{target["boxes"].shape}')
+        # 在 __getitem__ 方法中调用此函数
+        validate_keypoints(keypoints, shape[0], shape[1])
+        # print(f'keypoints:{target["keypoints"].shape}')
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        def draw_vecl(lines, sline, juncs, junts, fn=None):
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            imshow(io.imread(img_path))
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
+
+
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            img = PIL.Image.open(img_path).convert('RGB')
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
+
+            plt.show()
+            if fn != None:
+                plt.savefig(fn)
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
+
+
+    def show_img(self, img_path):
+        pass
+
+
+
+if __name__ == '__main__':
+    path=r"I:\wirenet_dateset"
+    dataset= KeypointDataset(dataset_path=path, dataset_type='train')
+    dataset.show(0)

+ 65 - 0
models/wirenet2/test.py

@@ -0,0 +1,65 @@
+import time
+
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
+from torchvision.io import decode_image, read_image
+import torchvision.transforms.functional as F
+from torchvision.utils import draw_keypoints
+def show(imgs):
+    if not isinstance(imgs, list):
+        imgs = [imgs]
+    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
+    for i, img in enumerate(imgs):
+        img = img.detach()
+        img = F.to_pil_image(img)
+        axs[0, i].imshow(np.asarray(img))
+        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
+
+img_path=r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg"
+# img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-09-58-42_SaveImage.png"
+img_int = read_image(img_path)
+
+
+# person_int = decode_image(r"F:\DevTools\datasets\coco2017\val2017\000000000785.jpg")
+
+weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
+transforms = weights.transforms()
+print(f'transforms:{transforms}')
+img = transforms(img_int)
+
+person_float = transforms(img)
+
+model = keypointrcnn_resnet50_fpn(weights=None, progress=False)
+model = model.eval()
+t1=time.time()
+# img = torch.ones((3, 3, 512, 512))
+
+
+outputs = model([img])
+t2=time.time()
+print(f'time:{t2-t1}')
+# print(f'outputs:{outputs}')
+
+kpts = outputs[0]['keypoints']
+scores = outputs[0]['scores']
+
+print(f'kpts:{kpts}')
+print(f'scores:{scores}')
+
+detect_threshold = 0.75
+idx = torch.where(scores > detect_threshold)
+keypoints = kpts[idx]
+
+# print(f'keypoints:{keypoints}')
+
+
+
+res = draw_keypoints(img_int, keypoints, colors="blue", radius=3)
+show(res)
+plt.show()
+
+
+
+

+ 32 - 0
models/wirenet2/train.yaml

@@ -0,0 +1,32 @@
+
+
+dataset_path: I:/wirenet_dateset
+
+#train parameters
+num_classes: 2
+num_keypoints: 2
+opt: 'adamw'
+batch_size: 2
+epochs: 10
+lr: 0.005
+momentum: 0.9
+weight_decay: 0.0001
+lr_step_size: 3
+lr_gamma: 0.1
+num_workers: 4
+print_freq: 10
+target_type: pixel
+enable_logs: True
+augmentation: False
+checkpoint: None
+
+
+## Classes
+#names:
+#  0: fire
+#  1: dust
+#  2: move_machine
+#  3: open_machine
+#  4: close_machine
+
+

+ 212 - 0
models/wirenet2/trainer.py

@@ -0,0 +1,212 @@
+import math
+import os
+import sys
+from datetime import datetime
+
+import torch
+import torchvision
+from torch.utils.tensorboard import SummaryWriter
+from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
+
+from models.config.config_tool import read_yaml
+from models.ins.maskrcnn_dataset import MaskRCNNDataset
+from models.keypoint.keypoint_dataset import KeypointDataset
+from tools import utils, presets
+def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
+    model.train()
+    metric_logger = utils.MetricLogger(delimiter="  ")
+    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
+    header = f"Epoch: [{epoch}]"
+
+    lr_scheduler = None
+    if epoch == 0:
+        warmup_factor = 1.0 / 1000
+        warmup_iters = min(1000, len(data_loader) - 1)
+
+        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
+            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
+        )
+
+    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
+        # print(f'images:{images}')
+        images = list(image.to(device) for image in images)
+        targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
+        with torch.cuda.amp.autocast(enabled=scaler is not None):
+            loss_dict = model(images, targets)
+            # print(f'loss_dict:{loss_dict}')
+            losses = sum(loss for loss in loss_dict.values())
+
+        # reduce losses over all GPUs for logging purposes
+        loss_dict_reduced = utils.reduce_dict(loss_dict)
+        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+        loss_value = losses_reduced.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training")
+            print(loss_dict_reduced)
+            sys.exit(1)
+
+        optimizer.zero_grad()
+        if scaler is not None:
+            scaler.scale(losses).backward()
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            losses.backward()
+            optimizer.step()
+
+        if lr_scheduler is not None:
+            lr_scheduler.step()
+
+        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+    return metric_logger
+
+def train_cfg(model, cfg):
+    parameters = read_yaml(cfg)
+    print(f'train parameters:{parameters}')
+    train(model, **parameters)
+
+def train(model, **kwargs):
+    # 默认参数
+    default_params = {
+        'dataset_path': '/path/to/dataset',
+        'num_classes': 2,
+        'num_keypoints':2,
+        'opt': 'adamw',
+        'batch_size': 2,
+        'epochs': 10,
+        'lr': 0.005,
+        'momentum': 0.9,
+        'weight_decay': 1e-4,
+        'lr_step_size': 3,
+        'lr_gamma': 0.1,
+        'num_workers': 4,
+        'print_freq': 10,
+        'target_type': 'polygon',
+        'enable_logs': True,
+        'augmentation': False,
+        'checkpoint':None
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    dataset_path = default_params['dataset_path']
+    num_classes = default_params['num_classes']
+    batch_size = default_params['batch_size']
+    epochs = default_params['epochs']
+    lr = default_params['lr']
+    momentum = default_params['momentum']
+    weight_decay = default_params['weight_decay']
+    lr_step_size = default_params['lr_step_size']
+    lr_gamma = default_params['lr_gamma']
+    num_workers = default_params['num_workers']
+    print_freq = default_params['print_freq']
+    target_type = default_params['target_type']
+    augmentation = default_params['augmentation']
+    # 设置设备
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+    wts_path = os.path.join(train_result_ptath, 'weights')
+    tb_path = os.path.join(train_result_ptath, 'logs')
+    writer = SummaryWriter(tb_path)
+
+    transforms = None
+    # default_transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
+    if augmentation:
+        transforms = get_transform(is_train=True)
+        print(f'transforms:{transforms}')
+    if not os.path.exists('train_results'):
+        os.mkdir('train_results')
+
+    model.to(device)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
+
+    dataset = KeypointDataset(dataset_path=dataset_path,
+                              transforms=transforms, dataset_type='train', target_type=target_type)
+    dataset_test = KeypointDataset(dataset_path=dataset_path, transforms=None,
+                                   dataset_type='val')
+
+    train_sampler = torch.utils.data.RandomSampler(dataset)
+    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
+    train_collate_fn = utils.collate_fn
+    data_loader = torch.utils.data.DataLoader(
+        dataset, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
+    )
+    # data_loader_test = torch.utils.data.DataLoader(
+    #     dataset_test, batch_size=1, sampler=test_sampler, num_workers=num_workers, collate_fn=utils.collate_fn
+    # )
+
+    img_results_path = os.path.join(train_result_ptath, 'img_results')
+    if os.path.exists(train_result_ptath):
+        pass
+    #     os.remove(train_result_ptath)
+    else:
+        os.mkdir(train_result_ptath)
+
+    if os.path.exists(train_result_ptath):
+        os.mkdir(wts_path)
+        os.mkdir(img_results_path)
+
+    for epoch in range(epochs):
+        metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, None)
+        losses = metric_logger.meters['loss'].global_avg
+        print(f'epoch {epoch}:loss:{losses}')
+        if os.path.exists(f'{wts_path}/last.pt'):
+            os.remove(f'{wts_path}/last.pt')
+        torch.save(model.state_dict(), f'{wts_path}/last.pt')
+        write_metric_logs(epoch, metric_logger, writer)
+        if epoch == 0:
+            best_loss = losses;
+        if best_loss >= losses:
+            best_loss = losses
+            if os.path.exists(f'{wts_path}/best.pt'):
+                os.remove(f'{wts_path}/best.pt')
+            torch.save(model.state_dict(), f'{wts_path}/best.pt')
+
+def get_transform(is_train, **kwargs):
+    default_params = {
+        'augmentation': 'multiscale',
+        'backend': 'tensor',
+        'use_v2': False,
+
+    }
+    # 更新默认参数
+    for key, value in kwargs.items():
+        if key in default_params:
+            default_params[key] = value
+        else:
+            raise ValueError(f"Unknown argument: {key}")
+
+    # 解析参数
+    augmentation = default_params['augmentation']
+    backend = default_params['backend']
+    use_v2 = default_params['use_v2']
+    if is_train:
+        return presets.DetectionPresetTrain(
+            data_augmentation=augmentation, backend=backend, use_v2=use_v2
+        )
+    # elif weights and test_only:
+    #     weights = torchvision.models.get_weight(args.weights)
+    #     trans = weights.transforms()
+    #     return lambda img, target: (trans(img), target)
+    else:
+        return presets.DetectionPresetEval(backend=backend, use_v2=use_v2)
+
+
+def write_metric_logs(epoch, metric_logger, writer):
+    writer.add_scalar(f'loss_classifier:', metric_logger.meters['loss_classifier'].global_avg, epoch)
+    writer.add_scalar(f'loss_box_reg:', metric_logger.meters['loss_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'loss_mask:', metric_logger.meters['loss_mask'].global_avg, epoch)
+    writer.add_scalar(f'loss_objectness:', metric_logger.meters['loss_objectness'].global_avg, epoch)
+    writer.add_scalar(f'loss_rpn_box_reg:', metric_logger.meters['loss_rpn_box_reg'].global_avg, epoch)
+    writer.add_scalar(f'train loss:', metric_logger.meters['loss'].global_avg, epoch)

+ 499 - 0
models/wirenet2/wirenet_rcnn.py

@@ -0,0 +1,499 @@
+import os
+from typing import Optional, Any
+
+import cv2
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+# from torchinfo import summary
+from torchvision.io import read_image
+from torchvision.models import resnet50, ResNet50_Weights, WeightsEnum, Weights
+from torchvision.models._api import register_model
+from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param
+from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights
+from torchvision.models.detection._utils import overwrite_eps
+from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from  torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
+from torchvision.ops import MultiScaleRoIAlign
+from torchvision.ops import misc as misc_nn_ops
+
+__all__ = [
+    "WirenetRCNN",
+    "WirenetRCNN_ResNet50_FPN_Weights",
+    "wirenetrcnn_resnet50_fpn",
+]
+
+from torchvision.transforms._presets import ObjectDetection
+
+
+class WirenetRCNN(FasterRCNN):
+    """
+    Implements Keypoint R-CNN.
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+        - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
+          format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
+
+    The model returns a Dict[Tensor] during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each image
+        - scores (Tensor[N]): the scores or each prediction
+        - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
+
+    Args:
+        backbone (nn.Module): the network used to compute the features for the model.
+            It should contain an out_channels attribute, which indicates the number of output
+            channels that each feature map has (and it should be the same for all feature maps).
+            The backbone should return a single Tensor or and OrderedDict[Tensor].
+        num_classes (int): number of output classes of the model (including the background).
+            If box_predictor is specified, num_classes should be None.
+        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+        image_mean (Tuple[float, float, float]): mean values used for input normalization.
+            They are generally the mean values of the dataset on which the backbone has been trained
+            on
+        image_std (Tuple[float, float, float]): std values used for input normalization.
+            They are generally the std values of the dataset on which the backbone has been trained on
+        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+            maps.
+        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+            considered as positive during training of the RPN.
+        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+            considered as negative during training of the RPN.
+        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+            for computing the loss
+        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+            of the RPN
+        rpn_score_thresh (float): during inference, only return proposals with a classification score
+            greater than rpn_score_thresh
+        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+            the locations indicated by the bounding boxes
+        box_head (nn.Module): module that takes the cropped feature maps as input
+        box_predictor (nn.Module): module that takes the output of box_head and returns the
+            classification logits and box regression deltas.
+        box_score_thresh (float): during inference, only return proposals with a classification score
+            greater than box_score_thresh
+        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+        box_detections_per_img (int): maximum number of detections per image, for all classes.
+        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+            considered as positive during training of the classification head
+        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+            considered as negative during training of the classification head
+        box_batch_size_per_image (int): number of proposals that are sampled during training of the
+            classification head
+        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+            of the classification head
+        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+            bounding boxes
+        wirenet_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+             the locations indicated by the bounding boxes, which will be used for the keypoint head.
+        wirenet_head (nn.Module): module that takes the cropped feature maps as input
+        wirenet_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
+            heatmap logits
+
+    Example::
+
+        >>> import torch
+        >>> import torchvision
+        >>> from torchvision.models.detection import KeypointRCNN
+        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
+        >>>
+        >>> # load a pre-trained model for classification and return
+        >>> # only the features
+        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+        >>> # KeypointRCNN needs to know the number of
+        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+        >>> # so we need to add it here
+        >>> backbone.out_channels = 1280
+        >>>
+        >>> # let's make the RPN generate 5 x 3 anchors per spatial
+        >>> # location, with 5 different sizes and 3 different aspect
+        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+        >>> # map could potentially have different sizes and
+        >>> # aspect ratios
+        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
+        >>>
+        >>> # let's define what are the feature maps that we will
+        >>> # use to perform the region of interest cropping, as well as
+        >>> # the size of the crop after rescaling.
+        >>> # if your backbone returns a Tensor, featmap_names is expected to
+        >>> # be ['0']. More generally, the backbone should return an
+        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+        >>> # feature maps to use.
+        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                 output_size=7,
+        >>>                                                 sampling_ratio=2)
+        >>>
+        >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                          output_size=14,
+        >>>                                                          sampling_ratio=2)
+        >>> # put the pieces together inside a KeypointRCNN model
+        >>> model = KeypointRCNN(backbone,
+        >>>                      num_classes=2,
+        >>>                      rpn_anchor_generator=anchor_generator,
+        >>>                      box_roi_pool=roi_pooler,
+        >>>                      keypoint_roi_pool=keypoint_roi_pooler)
+        >>> model.eval()
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+    """
+
+    def __init__(
+            self,
+            backbone,
+            num_classes=None,
+            # transform parameters
+            min_size=None,
+            max_size=1333,
+            image_mean=None,
+            image_std=None,
+            # RPN parameters
+            rpn_anchor_generator=None,
+            rpn_head=None,
+            rpn_pre_nms_top_n_train=2000,
+            rpn_pre_nms_top_n_test=1000,
+            rpn_post_nms_top_n_train=2000,
+            rpn_post_nms_top_n_test=1000,
+            rpn_nms_thresh=0.7,
+            rpn_fg_iou_thresh=0.7,
+            rpn_bg_iou_thresh=0.3,
+            rpn_batch_size_per_image=256,
+            rpn_positive_fraction=0.5,
+            rpn_score_thresh=0.0,
+            # Box parameters
+            box_roi_pool=None,
+            box_head=None,
+            box_predictor=None,
+            box_score_thresh=0.05,
+            box_nms_thresh=0.5,
+            box_detections_per_img=100,
+            box_fg_iou_thresh=0.5,
+            box_bg_iou_thresh=0.5,
+            box_batch_size_per_image=512,
+            box_positive_fraction=0.25,
+            bbox_reg_weights=None,
+            # keypoint parameters
+            wirenet_roi_pool=None,
+            wirenet_head=None,
+            wirenet_predictor=None,
+            num_keypoints=None,
+            **kwargs,
+    ):
+
+        if not isinstance(wirenet_roi_pool, (MultiScaleRoIAlign, type(None))):
+            raise TypeError(
+                "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
+            )
+        if min_size is None:
+            min_size = (640, 672, 704, 736, 768, 800)
+
+        if num_keypoints is not None:
+            if wirenet_predictor is not None:
+                raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
+        else:
+            num_keypoints = 2
+
+        out_channels = backbone.out_channels
+
+        if wirenet_roi_pool is None:
+            wirenet_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+        if wirenet_head is None:
+            keypoint_layers = tuple(512 for _ in range(8))
+            wirenet_head = WirenetRCNNHeads(out_channels, keypoint_layers)
+
+        if wirenet_predictor is None:
+            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
+            wirenet_predictor = WirenetRCNNPredictor(keypoint_dim_reduced, num_keypoints)
+
+        super().__init__(
+            backbone,
+            num_classes,
+            # transform parameters
+            min_size,
+            max_size,
+            image_mean,
+            image_std,
+            # RPN-specific parameters
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_pre_nms_top_n_train,
+            rpn_pre_nms_top_n_test,
+            rpn_post_nms_top_n_train,
+            rpn_post_nms_top_n_test,
+            rpn_nms_thresh,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_score_thresh,
+            # Box parameters
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            **kwargs,
+        )
+
+        self.roi_heads.keypoint_roi_pool = wirenet_roi_pool
+        self.roi_heads.keypoint_head = wirenet_head
+        self.roi_heads.keypoint_predictor = wirenet_predictor
+
+
+class WirenetRCNNHeads(nn.Module):
+    def __init__(self, in_channels, layers, num_keypoints):
+        d = []
+        next_feature = in_channels
+        for out_channels in layers:
+            d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
+            d.append(nn.ReLU(inplace=True))
+            next_feature = out_channels
+        # super().__init__(*d)
+        self.feature_layers = nn.Sequential(*d)
+        for m in self.feature_layers.children():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+                nn.init.constant_(m.bias, 0)
+
+        input_features = next_feature
+        deconv_kernel = 4
+        self.kps_score_lowres = nn.ConvTranspose2d(
+            input_features,
+            num_keypoints,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+        self.out_channels = num_keypoints
+
+    def forward(self, x):
+        x = self.feature_layers(x)
+        x = self.kps_score_lowres(x)
+        return torch.nn.functional.interpolate(
+            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        )
+
+
+class WirenetRCNNPredictor(nn.Module):
+    def __init__(self, in_channels, num_keypoints):
+        super().__init__()
+        input_features = in_channels
+        deconv_kernel = 4
+        self.kps_score_lowres = nn.ConvTranspose2d(
+            input_features,
+            num_keypoints,
+            deconv_kernel,
+            stride=2,
+            padding=deconv_kernel // 2 - 1,
+        )
+        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+        nn.init.constant_(self.kps_score_lowres.bias, 0)
+        self.up_scale = 2
+        self.out_channels = num_keypoints
+
+    def forward(self, x):
+        x = self.kps_score_lowres(x)
+        return torch.nn.functional.interpolate(
+            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+        )
+
+
+_COMMON_META = {
+    "categories": _COCO_PERSON_CATEGORIES,
+    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
+    "min_size": (1, 1),
+}
+
+
+class WirenetRCNN_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_LEGACY = Weights(
+        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 59137258,
+            "recipe": "https://github.com/pytorch/vision/issues/1606",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 50.6,
+                    "kp_map": 61.1,
+                }
+            },
+            "_ops": 133.924,
+            "_file_size": 226.054,
+            "_docs": """
+                These weights were produced by following a similar training recipe as on the paper but use a checkpoint
+                from an early epoch.
+            """,
+        },
+    )
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 59137258,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 54.6,
+                    "kp_map": 65.0,
+                }
+            },
+            "_ops": 137.42,
+            "_file_size": 226.054,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=(
+            "pretrained",
+            lambda kwargs: WirenetRCNN_ResNet50_FPN_Weights.COCO_LEGACY
+            if kwargs["pretrained"] == "legacy"
+            else WirenetRCNN_ResNet50_FPN_Weights.COCO_V1,
+    ),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def wirenetrcnn_resnet50_fpn(
+        *,
+        weights: Optional[WirenetRCNN_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        num_keypoints: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> WirenetRCNN:
+    """
+    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
+          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detected instances:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each instance
+        - scores (``Tensor[N]``): the scores or each instance
+        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=WirenetRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr
+        num_classes (int, optional): number of output classes of the model (including the background)
+        num_keypoints (int, optional): number of keypoints
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 3.
+
+    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = WirenetRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
+    else:
+        if num_classes is None:
+            num_classes = 2
+        if num_keypoints is None:
+            num_keypoints = 17
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = WirenetRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == WirenetRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model