RenLiqiang 5 месяцев назад
Родитель
Сommit
43a1845375

+ 83 - 3
models/keypoint/kepointrcnn.py

@@ -1,12 +1,15 @@
 import math
 import os
 import sys
+from collections import OrderedDict
 from datetime import datetime
-from typing import Mapping, Any
+from typing import Mapping
 import cv2
 import numpy as np
 import torch
 import torchvision
+from PIL import Image
+from matplotlib import pyplot as plt
 from torch import nn
 from torch.nn.modules.module import T
 from torchvision.io import read_image
@@ -19,12 +22,14 @@ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
 from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor, KeypointRCNN, \
     KeypointRCNN_ResNet50_FPN_Weights
 from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
+from torchvision.ops.feature_pyramid_network import LastLevelMaxPool
 from torchvision.utils import draw_bounding_boxes
-from torchvision.ops import misc as misc_nn_ops
+from torchvision.ops import misc as misc_nn_ops, FeaturePyramidNetwork
 from typing import Optional, Any
 from models.config.config_tool import read_yaml
 from models.keypoint.trainer import train_cfg
 from models.wirenet._utils import overwrite_eps
+# from timm import create_model
 from  torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
 from tools import utils
 os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
@@ -43,12 +48,32 @@ class KeypointRCNNModel(nn.Module):
        # self.__model= KeypointRCNN(backbone, num_classes=2, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler,keypoint_roi_pool=keypoint_roi_pooler)
         ####
 
+        # 加载 EfficientNet 模型并移除分类头
+        # backbone = create_model('tf_efficientnet_b0', pretrained=True, features_only=True)
+        # backbone_out_channels =backbone.feature_info.channels()  # 获取所有阶段的通道数
+        #
+        #
+        # # 构建 FPN
+        # fpn = FeaturePyramidNetwork(
+        #     in_channels_list=backbone_out_channels,
+        #     out_channels=256,
+        #     extra_blocks=LastLevelMaxPool()
+        # )
+        #
+        # # 将 EfficientNet 和 FPN 组合成一个新的 backbone
+        # self.body = nn.Sequential(
+        #     backbone,
+        #     fpn
+        # )
         default_weights = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
 
         self.__model = keypointrcnn_resnet18_fpn(weights=None,num_classes=num_classes,
                                                                               num_keypoints=num_keypoints,
                                                                               progress=False)
-
+        # self.__model.backbone.body = nn.Sequential(OrderedDict([
+        #     ('body', self.body),
+        #     ('fpn', fpn)
+        # ]))
 
         if transforms is None:
             self.transforms = torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
@@ -93,7 +118,55 @@ class KeypointRCNNModel(nn.Module):
     def eval(self: T) -> T:
         self.__model.eval()
         # return super().eval()
+    def predict(self, img, show=True, save=False, save_path=None):
+        """
+         对输入图像进行关键点检测预测。
+
+         参数:
+             img (str or PIL.Image): 输入图像的路径或 PIL.Image 对象。
+             show (bool): 是否显示预测结果,默认为 True。
+             save (bool): 是否保存预测结果,默认为 False。
+
+         返回:
+             dict: 包含预测结果的字典。
+         """
+        if isinstance(img, str):
+            img = Image.open(img).convert("RGB")
+
+        self.__model.eval()
+
+        # 预处理图像
+        img_tensor = self.transforms(img)
+        with torch.no_grad():
+            predictions = self.__model([img_tensor])
+
+        print(f'predictions:{predictions}')
+
+        # 后处理预测结果
+        boxes = predictions[0]['boxes'].cpu().numpy()
+        keypoints = predictions[0]['keypoints'].cpu().numpy()
 
+        # 可视化预测结果
+        if show or save:
+            fig, ax = plt.subplots(figsize=(10, 10))
+            ax.imshow(np.array(img))
+
+            for box in boxes:
+                x0, y0, x1, y1 = box
+                ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1))
+
+            for (a, b) in keypoints:
+                ax.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=1)
+                ax.scatter(a[0], a[1], c='red', s=2)
+                ax.scatter(b[0], b[1], c='red', s=2)
+
+            if show:
+                plt.show()
+
+            if save:
+                fig.savefig(save_path)
+                print(f"Prediction saved to {save_path}")
+            plt.close(fig)
 
 def keypointrcnn_resnet18_fpn(
         *,
@@ -205,6 +278,13 @@ def keypointrcnn_resnet18_fpn(
 if __name__ == '__main__':
     # ins_model = MaskRCNNModel(num_classes=5)
     keypoint_model = KeypointRCNNModel(num_keypoints=2)
+    wts_path='./train_results/20241227_231659/weights/best.pt'
+
+
     # data_path = r'F:\DevTools\datasets\renyaun\1012\spilt'
     # ins_model.train(data_dir=data_path,epochs=5000,target_type='pixel',batch_size=6,num_workers=10,num_classes=5)
     keypoint_model.train(cfg='train.yaml')
+
+    # keypoint_model.load_weight(wts_path)
+    # img_path=r"F:\DevTools\datasets\renyaun\1012\images\2024-09-23-10-02-15_SaveImage.png"
+    # keypoint_model.predict(img_path)

+ 548 - 0
models/wirenet2/_utils.py

@@ -0,0 +1,548 @@
+import math
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+from torchvision.ops import complete_box_iou_loss, distance_box_iou_loss, FrozenBatchNorm2d, generalized_box_iou_loss
+
+
+class BalancedPositiveNegativeSampler:
+    """
+    This class samples batches, ensuring that they contain a fixed proportion of positives
+    """
+
+    def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
+        """
+        Args:
+            batch_size_per_image (int): number of elements to be selected per image
+            positive_fraction (float): percentage of positive elements per batch
+        """
+        self.batch_size_per_image = batch_size_per_image
+        self.positive_fraction = positive_fraction
+
+    def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
+        """
+        Args:
+            matched_idxs: list of tensors containing -1, 0 or positive values.
+                Each tensor corresponds to a specific image.
+                -1 values are ignored, 0 are considered as negatives and > 0 as
+                positives.
+
+        Returns:
+            pos_idx (list[tensor])
+            neg_idx (list[tensor])
+
+        Returns two lists of binary masks for each image.
+        The first list contains the positive elements that were selected,
+        and the second list the negative example.
+        """
+        pos_idx = []
+        neg_idx = []
+        for matched_idxs_per_image in matched_idxs:
+            positive = torch.where(matched_idxs_per_image >= 1)[0]
+            negative = torch.where(matched_idxs_per_image == 0)[0]
+
+            num_pos = int(self.batch_size_per_image * self.positive_fraction)
+            # protect against not enough positive examples
+            num_pos = min(positive.numel(), num_pos)
+            num_neg = self.batch_size_per_image - num_pos
+            # protect against not enough negative examples
+            num_neg = min(negative.numel(), num_neg)
+
+            # randomly select positive and negative examples
+            perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
+            perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
+
+            pos_idx_per_image = positive[perm1]
+            neg_idx_per_image = negative[perm2]
+
+            # create binary mask from indices
+            pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+            neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
+
+            pos_idx_per_image_mask[pos_idx_per_image] = 1
+            neg_idx_per_image_mask[neg_idx_per_image] = 1
+
+            pos_idx.append(pos_idx_per_image_mask)
+            neg_idx.append(neg_idx_per_image_mask)
+
+        return pos_idx, neg_idx
+
+
+@torch.jit._script_if_tracing
+def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
+    """
+    Encode a set of proposals with respect to some
+    reference boxes
+
+    Args:
+        reference_boxes (Tensor): reference boxes
+        proposals (Tensor): boxes to be encoded
+        weights (Tensor[4]): the weights for ``(x, y, w, h)``
+    """
+
+    # perform some unpacking to make it JIT-fusion friendly
+    wx = weights[0]
+    wy = weights[1]
+    ww = weights[2]
+    wh = weights[3]
+
+    proposals_x1 = proposals[:, 0].unsqueeze(1)
+    proposals_y1 = proposals[:, 1].unsqueeze(1)
+    proposals_x2 = proposals[:, 2].unsqueeze(1)
+    proposals_y2 = proposals[:, 3].unsqueeze(1)
+
+    reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
+    reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
+    reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
+    reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
+
+    # implementation starts here
+    ex_widths = proposals_x2 - proposals_x1
+    ex_heights = proposals_y2 - proposals_y1
+    ex_ctr_x = proposals_x1 + 0.5 * ex_widths
+    ex_ctr_y = proposals_y1 + 0.5 * ex_heights
+
+    gt_widths = reference_boxes_x2 - reference_boxes_x1
+    gt_heights = reference_boxes_y2 - reference_boxes_y1
+    gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
+    gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
+
+    targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
+    targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
+    targets_dw = ww * torch.log(gt_widths / ex_widths)
+    targets_dh = wh * torch.log(gt_heights / ex_heights)
+
+    targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
+    return targets
+
+
+class BoxCoder:
+    """
+    This class encodes and decodes a set of bounding boxes into
+    the representation used for training the regressors.
+    """
+
+    def __init__(
+        self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
+    ) -> None:
+        """
+        Args:
+            weights (4-element tuple)
+            bbox_xform_clip (float)
+        """
+        self.weights = weights
+        self.bbox_xform_clip = bbox_xform_clip
+
+    def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
+        boxes_per_image = [len(b) for b in reference_boxes]
+        reference_boxes = torch.cat(reference_boxes, dim=0)
+        proposals = torch.cat(proposals, dim=0)
+        targets = self.encode_single(reference_boxes, proposals)
+        return targets.split(boxes_per_image, 0)
+
+    def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+        """
+        Encode a set of proposals with respect to some
+        reference boxes
+
+        Args:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+        """
+        dtype = reference_boxes.dtype
+        device = reference_boxes.device
+        weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
+        targets = encode_boxes(reference_boxes, proposals, weights)
+
+        return targets
+
+    def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
+        torch._assert(
+            isinstance(boxes, (list, tuple)),
+            "This function expects boxes of type list or tuple.",
+        )
+        torch._assert(
+            isinstance(rel_codes, torch.Tensor),
+            "This function expects rel_codes of type torch.Tensor.",
+        )
+        boxes_per_image = [b.size(0) for b in boxes]
+        concat_boxes = torch.cat(boxes, dim=0)
+        box_sum = 0
+        for val in boxes_per_image:
+            box_sum += val
+        if box_sum > 0:
+            rel_codes = rel_codes.reshape(box_sum, -1)
+        pred_boxes = self.decode_single(rel_codes, concat_boxes)
+        if box_sum > 0:
+            pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
+        return pred_boxes
+
+    def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Args:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+        """
+
+        boxes = boxes.to(rel_codes.dtype)
+
+        widths = boxes[:, 2] - boxes[:, 0]
+        heights = boxes[:, 3] - boxes[:, 1]
+        ctr_x = boxes[:, 0] + 0.5 * widths
+        ctr_y = boxes[:, 1] + 0.5 * heights
+
+        wx, wy, ww, wh = self.weights
+        dx = rel_codes[:, 0::4] / wx
+        dy = rel_codes[:, 1::4] / wy
+        dw = rel_codes[:, 2::4] / ww
+        dh = rel_codes[:, 3::4] / wh
+
+        # Prevent sending too large values into torch.exp()
+        dw = torch.clamp(dw, max=self.bbox_xform_clip)
+        dh = torch.clamp(dh, max=self.bbox_xform_clip)
+
+        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
+        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
+        pred_w = torch.exp(dw) * widths[:, None]
+        pred_h = torch.exp(dh) * heights[:, None]
+
+        # Distance from center to box's corner.
+        c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
+        c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
+
+        pred_boxes1 = pred_ctr_x - c_to_c_w
+        pred_boxes2 = pred_ctr_y - c_to_c_h
+        pred_boxes3 = pred_ctr_x + c_to_c_w
+        pred_boxes4 = pred_ctr_y + c_to_c_h
+        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
+        return pred_boxes
+
+
+class BoxLinearCoder:
+    """
+    The linear box-to-box transform defined in FCOS. The transformation is parameterized
+    by the distance from the center of (square) src box to 4 edges of the target box.
+    """
+
+    def __init__(self, normalize_by_size: bool = True) -> None:
+        """
+        Args:
+            normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
+        """
+        self.normalize_by_size = normalize_by_size
+
+    def encode(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
+        """
+        Encode a set of proposals with respect to some reference boxes
+
+        Args:
+            reference_boxes (Tensor): reference boxes
+            proposals (Tensor): boxes to be encoded
+
+        Returns:
+            Tensor: the encoded relative box offsets that can be used to
+            decode the boxes.
+
+        """
+
+        # get the center of reference_boxes
+        reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
+        reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])
+
+        # get box regression transformation deltas
+        target_l = reference_boxes_ctr_x - proposals[..., 0]
+        target_t = reference_boxes_ctr_y - proposals[..., 1]
+        target_r = proposals[..., 2] - reference_boxes_ctr_x
+        target_b = proposals[..., 3] - reference_boxes_ctr_y
+
+        targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)
+
+        if self.normalize_by_size:
+            reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
+            reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
+            reference_boxes_size = torch.stack(
+                (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
+            )
+            targets = targets / reference_boxes_size
+        return targets
+
+    def decode(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
+
+        """
+        From a set of original boxes and encoded relative box offsets,
+        get the decoded boxes.
+
+        Args:
+            rel_codes (Tensor): encoded boxes
+            boxes (Tensor): reference boxes.
+
+        Returns:
+            Tensor: the predicted boxes with the encoded relative box offsets.
+
+        .. note::
+            This method assumes that ``rel_codes`` and ``boxes`` have same size for 0th dimension. i.e. ``len(rel_codes) == len(boxes)``.
+
+        """
+
+        boxes = boxes.to(dtype=rel_codes.dtype)
+
+        ctr_x = 0.5 * (boxes[..., 0] + boxes[..., 2])
+        ctr_y = 0.5 * (boxes[..., 1] + boxes[..., 3])
+
+        if self.normalize_by_size:
+            boxes_w = boxes[..., 2] - boxes[..., 0]
+            boxes_h = boxes[..., 3] - boxes[..., 1]
+
+            list_box_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=-1)
+            rel_codes = rel_codes * list_box_size
+
+        pred_boxes1 = ctr_x - rel_codes[..., 0]
+        pred_boxes2 = ctr_y - rel_codes[..., 1]
+        pred_boxes3 = ctr_x + rel_codes[..., 2]
+        pred_boxes4 = ctr_y + rel_codes[..., 3]
+
+        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=-1)
+        return pred_boxes
+
+
+class Matcher:
+    """
+    This class assigns to each predicted "element" (e.g., a box) a ground-truth
+    element. Each predicted element will have exactly zero or one matches; each
+    ground-truth element may be assigned to zero or more predicted elements.
+
+    Matching is based on the MxN match_quality_matrix, that characterizes how well
+    each (ground-truth, predicted)-pair match. For example, if the elements are
+    boxes, the matrix may contain box IoU overlap values.
+
+    The matcher returns a tensor of size N containing the index of the ground-truth
+    element m that matches to prediction n. If there is no match, a negative value
+    is returned.
+    """
+
+    BELOW_LOW_THRESHOLD = -1
+    BETWEEN_THRESHOLDS = -2
+
+    __annotations__ = {
+        "BELOW_LOW_THRESHOLD": int,
+        "BETWEEN_THRESHOLDS": int,
+    }
+
+    def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
+        """
+        Args:
+            high_threshold (float): quality values greater than or equal to
+                this value are candidate matches.
+            low_threshold (float): a lower quality threshold used to stratify
+                matches into three levels:
+                1) matches >= high_threshold
+                2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
+                3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
+            allow_low_quality_matches (bool): if True, produce additional matches
+                for predictions that have only low-quality match candidates. See
+                set_low_quality_matches_ for more details.
+        """
+        self.BELOW_LOW_THRESHOLD = -1
+        self.BETWEEN_THRESHOLDS = -2
+        torch._assert(low_threshold <= high_threshold, "low_threshold should be <= high_threshold")
+        self.high_threshold = high_threshold
+        self.low_threshold = low_threshold
+        self.allow_low_quality_matches = allow_low_quality_matches
+
+    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+        """
+        Args:
+            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
+            pairwise quality between M ground-truth elements and N predicted elements.
+
+        Returns:
+            matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
+            [0, M - 1] or a negative value indicating that prediction i could not
+            be matched.
+        """
+        if match_quality_matrix.numel() == 0:
+            # empty targets or proposals not supported during training
+            if match_quality_matrix.shape[0] == 0:
+                raise ValueError("No ground-truth boxes available for one of the images during training")
+            else:
+                raise ValueError("No proposal boxes available for one of the images during training")
+
+        # match_quality_matrix is M (gt) x N (predicted)
+        # Max over gt elements (dim 0) to find best gt candidate for each prediction
+        matched_vals, matches = match_quality_matrix.max(dim=0)
+        if self.allow_low_quality_matches:
+            all_matches = matches.clone()
+        else:
+            all_matches = None  # type: ignore[assignment]
+
+        # Assign candidate matches with low quality to negative (unassigned) values
+        below_low_threshold = matched_vals < self.low_threshold
+        between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
+        matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
+        matches[between_thresholds] = self.BETWEEN_THRESHOLDS
+
+        if self.allow_low_quality_matches:
+            if all_matches is None:
+                torch._assert(False, "all_matches should not be None")
+            else:
+                self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
+
+        return matches
+
+    def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
+        """
+        Produce additional matches for predictions that have only low-quality matches.
+        Specifically, for each ground-truth find the set of predictions that have
+        maximum overlap with it (including ties); for each prediction in that set, if
+        it is unmatched, then match it to the ground-truth with which it has the highest
+        quality value.
+        """
+        # For each gt, find the prediction with which it has the highest quality
+        highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
+        # Find the highest quality match available, even if it is low, including ties
+        gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
+        # Example gt_pred_pairs_of_highest_quality:
+        #   tensor([[    0, 39796],
+        #           [    1, 32055],
+        #           [    1, 32070],
+        #           [    2, 39190],
+        #           [    2, 40255],
+        #           [    3, 40390],
+        #           [    3, 41455],
+        #           [    4, 45470],
+        #           [    5, 45325],
+        #           [    5, 46390]])
+        # Each row is a (gt index, prediction index)
+        # Note how gt items 1, 2, 3, and 5 each have two ties
+
+        pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
+        matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
+
+
+class SSDMatcher(Matcher):
+    def __init__(self, threshold: float) -> None:
+        super().__init__(threshold, threshold, allow_low_quality_matches=False)
+
+    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
+        matches = super().__call__(match_quality_matrix)
+
+        # For each gt, find the prediction with which it has the highest quality
+        _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
+        matches[highest_quality_pred_foreach_gt] = torch.arange(
+            highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
+        )
+
+        return matches
+
+
+def overwrite_eps(model: nn.Module, eps: float) -> None:
+    """
+    This method overwrites the default eps values of all the
+    FrozenBatchNorm2d layers of the model with the provided value.
+    This is necessary to address the BC-breaking change introduced
+    by the bug-fix at pytorch/vision#2933. The overwrite is applied
+    only when the pretrained weights are loaded to maintain compatibility
+    with previous versions.
+
+    Args:
+        model (nn.Module): The model on which we perform the overwrite.
+        eps (float): The new value of eps.
+    """
+    for module in model.modules():
+        if isinstance(module, FrozenBatchNorm2d):
+            module.eps = eps
+
+
+def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
+    """
+    This method retrieves the number of output channels of a specific model.
+
+    Args:
+        model (nn.Module): The model for which we estimate the out_channels.
+            It should return a single Tensor or an OrderedDict[Tensor].
+        size (Tuple[int, int]): The size (wxh) of the input.
+
+    Returns:
+        out_channels (List[int]): A list of the output channels of the model.
+    """
+    in_training = model.training
+    model.eval()
+
+    with torch.no_grad():
+        # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
+        device = next(model.parameters()).device
+        tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
+        features = model(tmp_img)
+        if isinstance(features, torch.Tensor):
+            features = OrderedDict([("0", features)])
+        out_channels = [x.size(1) for x in features.values()]
+
+    if in_training:
+        model.train()
+
+    return out_channels
+
+
+@torch.jit.unused
+def _fake_cast_onnx(v: Tensor) -> int:
+    return v  # type: ignore[return-value]
+
+
+def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
+    """
+    ONNX spec requires the k-value to be less than or equal to the number of inputs along
+    provided dim. Certain models use the number of elements along a particular axis instead of K
+    if K exceeds the number of elements along that axis. Previously, python's min() function was
+    used to determine whether to use the provided k-value or the specified dim axis value.
+
+    However, in cases where the model is being exported in tracing mode, python min() is
+    static causing the model to be traced incorrectly and eventually fail at the topk node.
+    In order to avoid this situation, in tracing mode, torch.min() is used instead.
+
+    Args:
+        input (Tensor): The original input tensor.
+        orig_kval (int): The provided k-value.
+        axis(int): Axis along which we retrieve the input size.
+
+    Returns:
+        min_kval (int): Appropriately selected k-value.
+    """
+    if not torch.jit.is_tracing():
+        return min(orig_kval, input.size(axis))
+    axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
+    min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
+    return _fake_cast_onnx(min_kval)
+
+
+def _box_loss(
+    type: str,
+    box_coder: BoxCoder,
+    anchors_per_image: Tensor,
+    matched_gt_boxes_per_image: Tensor,
+    bbox_regression_per_image: Tensor,
+    cnf: Optional[Dict[str, float]] = None,
+) -> Tensor:
+    torch._assert(type in ["l1", "smooth_l1", "ciou", "diou", "giou"], f"Unsupported loss: {type}")
+
+    if type == "l1":
+        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+        return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
+    elif type == "smooth_l1":
+        target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
+        beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
+        return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
+    else:
+        bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
+        eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
+        if type == "ciou":
+            return complete_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+        if type == "diou":
+            return distance_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
+        # otherwise giou
+        return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)

+ 879 - 0
models/wirenet2/roi_heads.py

@@ -0,0 +1,879 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from torchvision.ops import boxes as box_ops, roi_align
+
+from . import _utils as det_utils
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+    # type: (Tensor, List[Tensor]) -> List[Tensor]
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    Args:
+        x (Tensor): the mask logits
+        labels (list[BoxList]): bounding boxes that are used as
+            reference, one for ech image
+
+    Returns:
+        results (list[BoxList]): one BoxList for each image, containing
+            the extra field mask
+    """
+    mask_prob = x.sigmoid()
+
+    # select masks corresponding to the predicted classes
+    num_masks = x.shape[0]
+    boxes_per_image = [label.shape[0] for label in labels]
+    labels = torch.cat(labels)
+    index = torch.arange(num_masks, device=labels.device)
+    mask_prob = mask_prob[index, labels][:, None]
+    mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+    return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+    # type: (Tensor, Tensor, Tensor, int) -> Tensor
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+    """
+    matched_idxs = matched_idxs.to(boxes)
+    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+    gt_masks = gt_masks[:, None].to(rois)
+    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    """
+    Args:
+        proposals (list[BoxList])
+        mask_logits (Tensor)
+        targets (list[BoxList])
+
+    Return:
+        mask_loss (Tensor): scalar tensor containing the loss
+    """
+
+    discretization_size = mask_logits.shape[-1]
+    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+    mask_targets = [
+        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+    ]
+
+    labels = torch.cat(labels, dim=0)
+    mask_targets = torch.cat(mask_targets, dim=0)
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it separately
+    if mask_targets.numel() == 0:
+        return mask_logits.sum() * 0
+
+    mask_loss = F.binary_cross_entropy_with_logits(
+        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+    )
+    return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+    width_correction = widths_i / roi_map_width
+    height_correction = heights_i / roi_map_height
+
+    roi_map = F.interpolate(
+        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+    )[:, 0]
+
+    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+    x_int = pos % w
+    y_int = (pos - x_int) // w
+
+    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+        dtype=torch.float32
+    )
+    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+        dtype=torch.float32
+    )
+
+    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+    xy_preds_i = torch.stack(
+        [
+            xy_preds_i_0.to(dtype=torch.float32),
+            xy_preds_i_1.to(dtype=torch.float32),
+            xy_preds_i_2.to(dtype=torch.float32),
+        ],
+        0,
+    )
+
+    # TODO: simplify when indexing without rank will be supported by ONNX
+    base = num_keypoints * num_keypoints + num_keypoints + 1
+    ind = torch.arange(num_keypoints)
+    ind = ind.to(dtype=torch.int64) * base
+    end_scores_i = (
+        roi_map.index_select(1, y_int.to(dtype=torch.int64))
+        .index_select(2, x_int.to(dtype=torch.int64))
+        .view(-1)
+        .index_select(0, ind.to(dtype=torch.int64))
+    )
+
+    return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+    for i in range(int(rois.size(0))):
+        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+        )
+        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+        end_scores = torch.cat(
+            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+        )
+    return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    if torchvision._is_tracing():
+        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+            maps,
+            rois,
+            widths_ceil,
+            heights_ceil,
+            widths,
+            heights,
+            offset_x,
+            offset_y,
+            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+        )
+        return xy_preds.permute(0, 2, 1), end_scores
+
+    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        x_int = pos % w
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
+def roi_line_loss(keypoints, rois, heatmap_size):
+
+    pass
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = keypoint_logits.shape
+    if H != W:
+        raise ValueError(
+            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+        kp = gt_kp_in_image[midx]
+        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+        heatmaps.append(heatmaps_per_image.view(-1))
+        valid.append(valid_per_image.view(-1))
+
+    keypoint_targets = torch.cat(heatmaps, dim=0)
+    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    valid = torch.where(valid)[0]
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    if keypoint_targets.numel() == 0 or len(valid) == 0:
+        return keypoint_logits.sum() * 0
+
+    keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+    return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    kp_probs = []
+    kp_scores = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, boxes):
+        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_probs.append(kp_prob)
+        kp_scores.append(scores)
+
+    return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half = w_half.to(dtype=torch.float32) * scale
+    h_half = h_half.to(dtype=torch.float32) * scale
+
+    boxes_exp0 = x_c - w_half
+    boxes_exp1 = y_c - h_half
+    boxes_exp2 = x_c + w_half
+    boxes_exp3 = y_c + h_half
+    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+    return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    if torchvision._is_tracing():
+        return _onnx_expand_boxes(boxes, scale)
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+    # type: (int, int) -> float
+    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+    # type: (Tensor, int) -> Tuple[Tensor, float]
+    M = mask.shape[-1]
+    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
+        scale = expand_masks_tracing_scale(M, padding)
+    else:
+        scale = float(M + 2 * padding) / M
+    padded_mask = F.pad(mask, (padding,) * 4)
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+    # type: (Tensor, Tensor, int, int) -> Tensor
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+    return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+    one = torch.ones(1, dtype=torch.int64)
+    zero = torch.zeros(1, dtype=torch.int64)
+
+    w = box[2] - box[0] + one
+    h = box[3] - box[1] + one
+    w = torch.max(torch.cat((w, one)))
+    h = torch.max(torch.cat((h, one)))
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+
+    # TODO : replace below with a dynamic padding when support is added in ONNX
+
+    # pad y
+    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+    # pad x
+    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+    return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+    res_append = torch.zeros(0, im_h, im_w)
+    for i in range(masks.size(0)):
+        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+        mask_res = mask_res.unsqueeze(0)
+        res_append = torch.cat((res_append, mask_res))
+    return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+    masks, scale = expand_masks(masks, padding=padding)
+    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+    im_h, im_w = img_shape
+
+    if torchvision._is_tracing():
+        return _onnx_paste_masks_in_image_loop(
+            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+        )[:, None]
+    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+    if len(res) > 0:
+        ret = torch.stack(res, dim=0)[:, None]
+    else:
+        ret = masks.new_empty((0, 1, im_h, im_w))
+    return ret
+
+
+class RoIHeads(nn.Module):
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+        self,
+        box_roi_pool,
+        box_head,
+        box_predictor,
+        # Faster R-CNN training
+        fg_iou_thresh,
+        bg_iou_thresh,
+        batch_size_per_image,
+        positive_fraction,
+        bbox_reg_weights,
+        # Faster R-CNN inference
+        score_thresh,
+        nms_thresh,
+        detections_per_img,
+        # Mask
+        mask_roi_pool=None,
+        mask_head=None,
+        mask_predictor=None,
+        keypoint_roi_pool=None,
+        keypoint_head=None,
+        keypoint_predictor=None,
+    ):
+        super().__init__()
+
+        self.box_similarity = box_ops.box_iou
+        # assign ground-truth boxes for each proposal
+        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+        if bbox_reg_weights is None:
+            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+        self.box_roi_pool = box_roi_pool
+        self.box_head = box_head
+        self.box_predictor = box_predictor
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+
+        self.mask_roi_pool = mask_roi_pool
+        self.mask_head = mask_head
+        self.mask_predictor = mask_predictor
+
+        self.keypoint_roi_pool = keypoint_roi_pool
+        self.keypoint_head = keypoint_head
+        self.keypoint_predictor = keypoint_predictor
+
+    def has_mask(self):
+        if self.mask_roi_pool is None:
+            return False
+        if self.mask_head is None:
+            return False
+        if self.mask_predictor is None:
+            return False
+        return True
+
+    def has_keypoint(self):
+        if self.keypoint_roi_pool is None:
+            return False
+        if self.keypoint_head is None:
+            return False
+        if self.keypoint_predictor is None:
+            return False
+        return True
+
+    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+        matched_idxs = []
+        labels = []
+        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+            if gt_boxes_in_image.numel() == 0:
+                # Background image
+                device = proposals_in_image.device
+                clamped_matched_idxs_in_image = torch.zeros(
+                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+                )
+                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+            else:
+                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+                labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+                # Label background (below the low threshold)
+                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_in_image[bg_inds] = 0
+
+                # Label ignore proposals (between low and high thresholds)
+                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            matched_idxs.append(clamped_matched_idxs_in_image)
+            labels.append(labels_in_image)
+        return matched_idxs, labels
+
+    def subsample(self, labels):
+        # type: (List[Tensor]) -> List[Tensor]
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_inds = []
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+            sampled_inds.append(img_sampled_inds)
+        return sampled_inds
+
+    def add_gt_proposals(self, proposals, gt_boxes):
+        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+        return proposals
+
+    def check_targets(self, targets):
+        # type: (Optional[List[Dict[str, Tensor]]]) -> None
+        if targets is None:
+            raise ValueError("targets should not be None")
+        if not all(["boxes" in t for t in targets]):
+            raise ValueError("Every element of targets should have a boxes key")
+        if not all(["labels" in t for t in targets]):
+            raise ValueError("Every element of targets should have a labels key")
+        if self.has_mask():
+            if not all(["masks" in t for t in targets]):
+                raise ValueError("Every element of targets should have a masks key")
+
+    def select_training_samples(
+        self,
+        proposals,  # type: List[Tensor]
+        targets,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+        self.check_targets(targets)
+        if targets is None:
+            raise ValueError("targets should not be None")
+        dtype = proposals[0].dtype
+        device = proposals[0].device
+
+        gt_boxes = [t["boxes"].to(dtype) for t in targets]
+        gt_labels = [t["labels"] for t in targets]
+
+        # append ground-truth bboxes to propos
+        proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+        # get matching gt indices for each proposal
+        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+        # sample a fixed proportion of positive-negative proposals
+        sampled_inds = self.subsample(labels)
+        matched_gt_boxes = []
+        num_images = len(proposals)
+        for img_id in range(num_images):
+            img_sampled_inds = sampled_inds[img_id]
+            proposals[img_id] = proposals[img_id][img_sampled_inds]
+            labels[img_id] = labels[img_id][img_sampled_inds]
+            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+            gt_boxes_in_image = gt_boxes[img_id]
+            if gt_boxes_in_image.numel() == 0:
+                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+        return proposals, matched_idxs, labels, regression_targets
+
+    def postprocess_detections(
+        self,
+        class_logits,  # type: Tensor
+        box_regression,  # type: Tensor
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+        device = class_logits.device
+        num_classes = class_logits.shape[-1]
+
+        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+        pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+        pred_scores = F.softmax(class_logits, -1)
+
+        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+        pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+        all_boxes = []
+        all_scores = []
+        all_labels = []
+        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            # create labels for each prediction
+            labels = torch.arange(num_classes, device=device)
+            labels = labels.view(1, -1).expand_as(scores)
+
+            # remove predictions with the background label
+            boxes = boxes[:, 1:]
+            scores = scores[:, 1:]
+            labels = labels[:, 1:]
+
+            # batch everything, by making every class prediction be a separate instance
+            boxes = boxes.reshape(-1, 4)
+            scores = scores.reshape(-1)
+            labels = labels.reshape(-1)
+
+            # remove low scoring boxes
+            inds = torch.where(scores > self.score_thresh)[0]
+            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+            # remove empty boxes
+            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            # non-maximum suppression, independently done per class
+            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+            # keep only topk scoring predictions
+            keep = keep[: self.detections_per_img]
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            all_boxes.append(boxes)
+            all_scores.append(scores)
+            all_labels.append(labels)
+
+        return all_boxes, all_scores, all_labels
+
+    def forward(
+        self,
+        features,  # type: Dict[str, Tensor]
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+        """
+        Args:
+            features (List[Tensor])
+            proposals (List[Tensor[N, 4]])
+            image_shapes (List[Tuple[H, W]])
+            targets (List[Dict])
+        """
+        if targets is not None:
+            for t in targets:
+                # TODO: https://github.com/pytorch/pytorch/issues/26731
+                floating_point_types = (torch.float, torch.double, torch.half)
+                if not t["boxes"].dtype in floating_point_types:
+                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+                if not t["labels"].dtype == torch.int64:
+                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+                if self.has_keypoint():
+                    if not t["keypoints"].dtype == torch.float32:
+                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+        if self.training:
+            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+        else:
+            labels = None
+            regression_targets = None
+            matched_idxs = None
+
+        box_features = self.box_roi_pool(features, proposals, image_shapes)
+        box_features = self.box_head(box_features)
+        class_logits, box_regression = self.box_predictor(box_features)
+
+        result: List[Dict[str, torch.Tensor]] = []
+        losses = {}
+        if self.training:
+            if labels is None:
+                raise ValueError("labels cannot be None")
+            if regression_targets is None:
+                raise ValueError("regression_targets cannot be None")
+            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        else:
+            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+            num_images = len(boxes)
+            for i in range(num_images):
+                result.append(
+                    {
+                        "boxes": boxes[i],
+                        "labels": labels[i],
+                        "scores": scores[i],
+                    }
+                )
+
+        if self.has_mask():
+            mask_proposals = [p["boxes"] for p in result]
+            if self.training:
+                if matched_idxs is None:
+                    raise ValueError("if in training, matched_idxs should not be None")
+
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                mask_proposals = []
+                pos_matched_idxs = []
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    mask_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            if self.mask_roi_pool is not None:
+                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+                mask_features = self.mask_head(mask_features)
+                mask_logits = self.mask_predictor(mask_features)
+            else:
+                raise Exception("Expected mask_roi_pool to be not None")
+
+            loss_mask = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None or mask_logits is None:
+                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+                gt_masks = [t["masks"] for t in targets]
+                gt_labels = [t["labels"] for t in targets]
+                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+                loss_mask = {"loss_mask": rcnn_loss_mask}
+            else:
+                labels = [r["labels"] for r in result]
+                masks_probs = maskrcnn_inference(mask_logits, labels)
+                for mask_prob, r in zip(masks_probs, result):
+                    r["masks"] = mask_prob
+
+            losses.update(loss_mask)
+
+        # keep none checks in if conditional so torchscript will conditionally
+        # compile each branch
+        if (
+            self.keypoint_roi_pool is not None
+            and self.keypoint_head is not None
+            and self.keypoint_predictor is not None
+        ):
+            keypoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                keypoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    keypoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+            keypoint_features = self.keypoint_head(keypoint_features)
+            keypoint_logits = self.keypoint_predictor(keypoint_features)
+
+            loss_keypoint = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+            else:
+                if keypoint_logits is None or keypoint_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
+
+                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
+
+        return result, losses

+ 50 - 0
models/wirenet2/test_linemap.py

@@ -0,0 +1,50 @@
+import numpy as np
+import cv2
+
+
+def draw_line_heatmap(image_shape, pt1, pt2, sigma=1):
+    """
+    根据给定的两个端点生成线段的热度图。
+
+    参数:
+    - image_shape: (height, width) 输出热度图的形状
+    - pt1: (x1, y1) 线段的第一个端点
+    - pt2: (x2, y2) 线段的第二个端点
+    - sigma: 高斯核的标准差,用于控制热度扩散的程度
+
+    返回:
+    - heatmap: 生成的热度图
+    """
+    # 创建空白热度图
+    heatmap = np.zeros(image_shape, dtype=np.float32)
+
+    # 绘制线段
+    cv2.line(heatmap, pt1, pt2, color=1, thickness=1)
+
+    # 应用高斯模糊以生成热度效果
+    if sigma > 0:
+        heatmap = cv2.GaussianBlur(heatmap, (0, 0), sigmaX=sigma, sigmaY=sigma)
+
+    # 归一化热度图
+    heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6)
+
+    return heatmap
+
+
+# 测试函数
+if __name__ == "__main__":
+    # 定义图像尺寸和线段端点
+    image_shape = (256, 256)  # 图像的高度和宽度
+    pt1 = (50, 50)  # 第一个端点
+    pt2 = (200, 200)  # 第二个端点
+    sigma = 2  # 控制热度扩散程度
+
+    # 生成热度图
+    heatmap = draw_line_heatmap(image_shape, pt1, pt2, sigma)
+
+    # 显示结果
+    import matplotlib.pyplot as plt
+
+    plt.imshow(heatmap, cmap='hot', interpolation='nearest')
+    plt.colorbar()
+    plt.show()

+ 122 - 4
models/wirenet2/wirenet_rcnn.py

@@ -8,7 +8,7 @@ from torch import nn
 import torch.nn.functional as F
 # from torchinfo import summary
 from torchvision.io import read_image
-from torchvision.models import resnet50, ResNet50_Weights, WeightsEnum, Weights
+from torchvision.models import resnet50, ResNet50_Weights, WeightsEnum, Weights, resnet18, ResNet18_Weights
 from torchvision.models._api import register_model
 from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param
 from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights
@@ -169,7 +169,7 @@ class WirenetRCNN(FasterRCNN):
             backbone,
             num_classes=None,
             # transform parameters
-            min_size=None,
+            min_size=512,
             max_size=1333,
             image_mean=None,
             image_std=None,
@@ -274,7 +274,8 @@ class WirenetRCNN(FasterRCNN):
 
 
 class WirenetRCNNHeads(nn.Module):
-    def __init__(self, in_channels, layers, num_keypoints):
+    def __init__(self, in_channels, layers, num_keypoints=3):
+        super().__init__()
         d = []
         next_feature = in_channels
         for out_channels in layers:
@@ -329,9 +330,11 @@ class WirenetRCNNPredictor(nn.Module):
 
     def forward(self, x):
         x = self.kps_score_lowres(x)
-        return torch.nn.functional.interpolate(
+        x=torch.nn.functional.interpolate(
             x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
         )
+        print(f'x.shape:{x.shape}')
+        return x
 
 
 _COMMON_META = {
@@ -497,3 +500,118 @@ def wirenetrcnn_resnet50_fpn(
             overwrite_eps(model, 0.0)
 
     return model
+
+def wirenetrcnn_resnet18_fpn(
+        *,
+        weights: Optional[WirenetRCNN_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        num_keypoints: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> WirenetRCNN:
+    """
+    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
+          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detected instances:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each instance
+        - scores (``Tensor[N]``): the scores or each instance
+        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=WirenetRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr
+        num_classes (int, optional): number of output classes of the model (including the background)
+        num_keypoints (int, optional): number of keypoints
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 3.
+
+    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = WirenetRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+    # if weights_backbone is None:
+
+    weights_backbone = ResNet18_Weights.IMAGENET1K_V1
+
+    if weights is not None:
+        # weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
+    else:
+        if num_classes is None:
+            num_classes = 3
+        if num_keypoints is None:
+            num_keypoints = 17
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = WirenetRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == WirenetRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
+if __name__ == '__main__':
+    model=wirenetrcnn_resnet18_fpn(num_keypoints=3)
+    img = torch.ones((3, 3, 512, 512))
+    model.eval()
+    model(img)
+    # model.train(cfg='train.yaml')