Explorar o código

重构line_rcnn

RenLiqiang hai 3 meses
pai
achega
95edfdc8ae

+ 6 - 0
models/base/base_detection_net.py

@@ -99,12 +99,18 @@ class BaseDetectionNet(nn.Module):
                     )
 
         features = self.backbone(images.tensors)
+
         if isinstance(features, torch.Tensor):
             features = OrderedDict([("0", features)])
         proposals, proposal_losses = self.rpn(images, features, targets)
+
         detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
         detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)  # type: ignore[operator]
 
+        # ->multi task head
+        # ->learner,->vectorize
+
+
         losses = {}
         losses.update(detector_losses)
         losses.update(proposal_losses)

+ 61 - 57
models/line_net/LineNet.py → models/line_detect/LineNet.py

@@ -7,32 +7,36 @@ from torchvision.ops import MultiScaleRoIAlign
 
 from  libs.vision_libs.ops import misc as misc_nn_ops
 from libs.vision_libs.transforms._presets import ObjectDetection
-from .._api import register_model, Weights, WeightsEnum
-from .._meta import _COCO_CATEGORIES
-from .._utils import _ovewrite_value_param, handle_legacy_interface
-from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
-from ..resnet import resnet50, ResNet50_Weights
-from ._utils import overwrite_eps
-from .anchor_utils import AnchorGenerator
-from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
-from .generalized_rcnn import GeneralizedRCNN
-from .roi_heads import RoIHeads
-from .rpn import RegionProposalNetwork, RPNHead
-from .transform import GeneralizedRCNNTransform
+from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
+from libs.vision_libs.models._meta import _COCO_CATEGORIES
+from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
+from libs.vision_libs.models.mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
+from libs.vision_libs.models.detection._utils import overwrite_eps
+from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
+from libs.vision_libs.models.detection.backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
+
+from libs.vision_libs.models.detection.rpn import RegionProposalNetwork, RPNHead
+from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
 
+######## 弃用  ###########
 
 __all__ = [
-    "FasterRCNN",
-    "FasterRCNN_ResNet50_FPN_Weights",
-    "FasterRCNN_ResNet50_FPN_V2_Weights",
-    "FasterRCNN_MobileNet_V3_Large_FPN_Weights",
-    "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
-    "fasterrcnn_resnet50_fpn",
+    "LineNet",
+    "LineNet_ResNet50_FPN_Weights",
+    "LineNet_ResNet50_FPN_V2_Weights",
+    "LineNet_MobileNet_V3_Large_FPN_Weights",
+    "LineNet_MobileNet_V3_Large_320_FPN_Weights",
+    "linenet_resnet50_fpn",
     "fasterrcnn_resnet50_fpn_v2",
-    "fasterrcnn_mobilenet_v3_large_fpn",
-    "fasterrcnn_mobilenet_v3_large_320_fpn",
+    "linenet_mobilenet_v3_large_fpn",
+    "linenet_mobilenet_v3_large_320_fpn",
 ]
 
+from .roi_heads import RoIHeads
+
+from ..base.base_detection_net import BaseDetectionNet
+
 
 def _default_anchorgen():
     anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
@@ -40,7 +44,7 @@ def _default_anchorgen():
     return AnchorGenerator(anchor_sizes, aspect_ratios)
 
 
-class FasterRCNN(GeneralizedRCNN):
+class LineNet(BaseDetectionNet):
     """
     Implements Faster R-CNN.
 
@@ -254,7 +258,7 @@ class FasterRCNN(GeneralizedRCNN):
 
         if box_predictor is None:
             representation_size = 1024
-            box_predictor = FastRCNNPredictor(representation_size, num_classes)
+            box_predictor = BoxPredictor(representation_size, num_classes)
 
         roi_heads = RoIHeads(
             # Box
@@ -304,7 +308,7 @@ class TwoMLPHead(nn.Module):
         return x
 
 
-class FastRCNNConvFCHead(nn.Sequential):
+class LineNetConvFCHead(nn.Sequential):
     def __init__(
         self,
         input_size: Tuple[int, int, int],
@@ -341,7 +345,7 @@ class FastRCNNConvFCHead(nn.Sequential):
                     nn.init.zeros_(layer.bias)
 
 
-class FastRCNNPredictor(nn.Module):
+class BoxPredictor(nn.Module):
     """
     Standard classification + bounding box regression layers
     for Fast R-CNN.
@@ -375,7 +379,7 @@ _COMMON_META = {
 }
 
 
-class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
+class LineNet_ResNet50_FPN_Weights(WeightsEnum):
     COCO_V1 = Weights(
         url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
         transforms=ObjectDetection,
@@ -396,7 +400,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
     DEFAULT = COCO_V1
 
 
-class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
+class LineNet_ResNet50_FPN_V2_Weights(WeightsEnum):
     COCO_V1 = Weights(
         url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
         transforms=ObjectDetection,
@@ -417,7 +421,7 @@ class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
     DEFAULT = COCO_V1
 
 
-class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
+class LineNet_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
     COCO_V1 = Weights(
         url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
         transforms=ObjectDetection,
@@ -438,7 +442,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
     DEFAULT = COCO_V1
 
 
-class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
+class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
     COCO_V1 = Weights(
         url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
         transforms=ObjectDetection,
@@ -461,18 +465,18 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
 
 @register_model()
 @handle_legacy_interface(
-    weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
+    weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
     weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
 )
-def fasterrcnn_resnet50_fpn(
+def linenet_resnet50_fpn(
     *,
-    weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
+    weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
     progress: bool = True,
     num_classes: Optional[int] = None,
     weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
     trainable_backbone_layers: Optional[int] = None,
     **kwargs: Any,
-) -> FasterRCNN:
+) -> LineNet:
     """
     Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
     Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
@@ -553,7 +557,7 @@ def fasterrcnn_resnet50_fpn(
     .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
         :members:
     """
-    weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     weights_backbone = ResNet50_Weights.verify(weights_backbone)
 
     if weights is not None:
@@ -568,11 +572,11 @@ def fasterrcnn_resnet50_fpn(
 
     backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
     backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
-    model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
+    model = LineNet(backbone, num_classes=num_classes, **kwargs)
 
     if weights is not None:
         model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
-        if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
+        if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
             overwrite_eps(model, 0.0)
 
     return model
@@ -580,18 +584,18 @@ def fasterrcnn_resnet50_fpn(
 
 @register_model()
 @handle_legacy_interface(
-    weights=("pretrained", FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
+    weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
     weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
 )
 def fasterrcnn_resnet50_fpn_v2(
     *,
-    weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
+    weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
     progress: bool = True,
     num_classes: Optional[int] = None,
     weights_backbone: Optional[ResNet50_Weights] = None,
     trainable_backbone_layers: Optional[int] = None,
     **kwargs: Any,
-) -> FasterRCNN:
+) -> LineNet:
     """
     Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
     Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
@@ -624,7 +628,7 @@ def fasterrcnn_resnet50_fpn_v2(
     .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
         :members:
     """
-    weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
+    weights = LineNet_ResNet50_FPN_V2_Weights.verify(weights)
     weights_backbone = ResNet50_Weights.verify(weights_backbone)
 
     if weights is not None:
@@ -640,10 +644,10 @@ def fasterrcnn_resnet50_fpn_v2(
     backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
     rpn_anchor_generator = _default_anchorgen()
     rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
-    box_head = FastRCNNConvFCHead(
+    box_head = LineNetConvFCHead(
         (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
     )
-    model = FasterRCNN(
+    model = LineNet(
         backbone,
         num_classes=num_classes,
         rpn_anchor_generator=rpn_anchor_generator,
@@ -658,15 +662,15 @@ def fasterrcnn_resnet50_fpn_v2(
     return model
 
 
-def _fasterrcnn_mobilenet_v3_large_fpn(
+def _linenet_mobilenet_v3_large_fpn(
     *,
-    weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
+    weights: Optional[Union[LineNet_MobileNet_V3_Large_FPN_Weights, LineNet_MobileNet_V3_Large_320_FPN_Weights]],
     progress: bool,
     num_classes: Optional[int],
     weights_backbone: Optional[MobileNet_V3_Large_Weights],
     trainable_backbone_layers: Optional[int],
     **kwargs: Any,
-) -> FasterRCNN:
+) -> LineNet:
     if weights is not None:
         weights_backbone = None
         num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
@@ -689,7 +693,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
         ),
     ) * 3
     aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
-    model = FasterRCNN(
+    model = LineNet(
         backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
     )
 
@@ -701,18 +705,18 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
 
 @register_model()
 @handle_legacy_interface(
-    weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
+    weights=("pretrained", LineNet_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
     weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
 )
-def fasterrcnn_mobilenet_v3_large_320_fpn(
+def linenet_mobilenet_v3_large_320_fpn(
     *,
-    weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
+    weights: Optional[LineNet_MobileNet_V3_Large_320_FPN_Weights] = None,
     progress: bool = True,
     num_classes: Optional[int] = None,
     weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
     trainable_backbone_layers: Optional[int] = None,
     **kwargs: Any,
-) -> FasterRCNN:
+) -> LineNet:
     """
     Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
 
@@ -751,7 +755,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
     .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
         :members:
     """
-    weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
+    weights = LineNet_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
     weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
 
     defaults = {
@@ -763,7 +767,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
     }
 
     kwargs = {**defaults, **kwargs}
-    return _fasterrcnn_mobilenet_v3_large_fpn(
+    return _linenet_mobilenet_v3_large_fpn(
         weights=weights,
         progress=progress,
         num_classes=num_classes,
@@ -775,18 +779,18 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
 
 @register_model()
 @handle_legacy_interface(
-    weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
+    weights=("pretrained", LineNet_MobileNet_V3_Large_FPN_Weights.COCO_V1),
     weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
 )
-def fasterrcnn_mobilenet_v3_large_fpn(
+def linenet_mobilenet_v3_large_fpn(
     *,
-    weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
+    weights: Optional[LineNet_MobileNet_V3_Large_FPN_Weights] = None,
     progress: bool = True,
     num_classes: Optional[int] = None,
     weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
     trainable_backbone_layers: Optional[int] = None,
     **kwargs: Any,
-) -> FasterRCNN:
+) -> LineNet:
     """
     Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
 
@@ -825,7 +829,7 @@ def fasterrcnn_mobilenet_v3_large_fpn(
     .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
         :members:
     """
-    weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
+    weights = LineNet_MobileNet_V3_Large_FPN_Weights.verify(weights)
     weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
 
     defaults = {
@@ -833,7 +837,7 @@ def fasterrcnn_mobilenet_v3_large_fpn(
     }
 
     kwargs = {**defaults, **kwargs}
-    return _fasterrcnn_mobilenet_v3_large_fpn(
+    return _linenet_mobilenet_v3_large_fpn(
         weights=weights,
         progress=progress,
         num_classes=num_classes,

+ 0 - 0
models/line_net/__init__.py → models/line_detect/__init__.py


+ 0 - 0
models/line_net/fasterrcnn_resnet50.py → models/line_detect/fasterrcnn_resnet50.py


+ 490 - 0
models/line_detect/line_rcnn.py

@@ -0,0 +1,490 @@
+from typing import Any, Optional
+
+import torch
+from torch import nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from libs.vision_libs.ops import misc as misc_nn_ops
+from libs.vision_libs.transforms._presets import ObjectDetection
+from .roi_heads import RoIHeads
+from libs.vision_libs.models._api  import register_model, Weights, WeightsEnum
+from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
+from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
+from libs.vision_libs.models.detection._utils import overwrite_eps
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN
+
+__all__ = [
+    "LineRCNN",
+    "LineRCNN_ResNet50_FPN_Weights",
+    "linercnn_resnet50_fpn",
+]
+
+
+class LineRCNN(FasterRCNN):
+    """
+    Implements Keypoint R-CNN.
+
+    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+    image, and should be in 0-1 range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the class label for each ground-truth box
+        - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
+          format [x, y, visibility], where visibility=0 means that the keypoint is not visible.
+
+    The model returns a Dict[Tensor] during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+    follows:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (Int64Tensor[N]): the predicted labels for each image
+        - scores (Tensor[N]): the scores or each prediction
+        - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
+
+    Args:
+        backbone (nn.Module): the network used to compute the features for the model.
+            It should contain an out_channels attribute, which indicates the number of output
+            channels that each feature map has (and it should be the same for all feature maps).
+            The backbone should return a single Tensor or and OrderedDict[Tensor].
+        num_classes (int): number of output classes of the model (including the background).
+            If box_predictor is specified, num_classes should be None.
+        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+        image_mean (Tuple[float, float, float]): mean values used for input normalization.
+            They are generally the mean values of the dataset on which the backbone has been trained
+            on
+        image_std (Tuple[float, float, float]): std values used for input normalization.
+            They are generally the std values of the dataset on which the backbone has been trained on
+        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+            maps.
+        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+            considered as positive during training of the RPN.
+        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+            considered as negative during training of the RPN.
+        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+            for computing the loss
+        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+            of the RPN
+        rpn_score_thresh (float): during inference, only return proposals with a classification score
+            greater than rpn_score_thresh
+        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+            the locations indicated by the bounding boxes
+        box_head (nn.Module): module that takes the cropped feature maps as input
+        box_predictor (nn.Module): module that takes the output of box_head and returns the
+            classification logits and box regression deltas.
+        box_score_thresh (float): during inference, only return proposals with a classification score
+            greater than box_score_thresh
+        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+        box_detections_per_img (int): maximum number of detections per image, for all classes.
+        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+            considered as positive during training of the classification head
+        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+            considered as negative during training of the classification head
+        box_batch_size_per_image (int): number of proposals that are sampled during training of the
+            classification head
+        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+            of the classification head
+        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+            bounding boxes
+        keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+             the locations indicated by the bounding boxes, which will be used for the keypoint head.
+        keypoint_head (nn.Module): module that takes the cropped feature maps as input
+        keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
+            heatmap logits
+
+    Example::
+
+        >>> import torch
+        >>> import torchvision
+        >>> from torchvision.models.detection import KeypointRCNN
+        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
+        >>>
+        >>> # load a pre-trained model for classification and return
+        >>> # only the features
+        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
+        >>> # KeypointRCNN needs to know the number of
+        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
+        >>> # so we need to add it here
+        >>> backbone.out_channels = 1280
+        >>>
+        >>> # let's make the RPN generate 5 x 3 anchors per spatial
+        >>> # location, with 5 different sizes and 3 different aspect
+        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
+        >>> # map could potentially have different sizes and
+        >>> # aspect ratios
+        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
+        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
+        >>>
+        >>> # let's define what are the feature maps that we will
+        >>> # use to perform the region of interest cropping, as well as
+        >>> # the size of the crop after rescaling.
+        >>> # if your backbone returns a Tensor, featmap_names is expected to
+        >>> # be ['0']. More generally, the backbone should return an
+        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
+        >>> # feature maps to use.
+        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                 output_size=7,
+        >>>                                                 sampling_ratio=2)
+        >>>
+        >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
+        >>>                                                          output_size=14,
+        >>>                                                          sampling_ratio=2)
+        >>> # put the pieces together inside a KeypointRCNN model
+        >>> model = KeypointRCNN(backbone,
+        >>>                      num_classes=2,
+        >>>                      rpn_anchor_generator=anchor_generator,
+        >>>                      box_roi_pool=roi_pooler,
+        >>>                      keypoint_roi_pool=keypoint_roi_pooler)
+        >>> model.eval()
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+    """
+
+    def __init__(
+            self,
+            backbone,
+            num_classes=None,
+            # transform parameters
+            min_size=None,
+            max_size=1333,
+            image_mean=None,
+            image_std=None,
+            # RPN parameters
+            rpn_anchor_generator=None,
+            rpn_head=None,
+            rpn_pre_nms_top_n_train=2000,
+            rpn_pre_nms_top_n_test=1000,
+            rpn_post_nms_top_n_train=2000,
+            rpn_post_nms_top_n_test=1000,
+            rpn_nms_thresh=0.7,
+            rpn_fg_iou_thresh=0.7,
+            rpn_bg_iou_thresh=0.3,
+            rpn_batch_size_per_image=256,
+            rpn_positive_fraction=0.5,
+            rpn_score_thresh=0.0,
+            # Box parameters
+            box_roi_pool=None,
+            box_head=None,
+            box_predictor=None,
+            box_score_thresh=0.05,
+            box_nms_thresh=0.5,
+            box_detections_per_img=100,
+            box_fg_iou_thresh=0.5,
+            box_bg_iou_thresh=0.5,
+            box_batch_size_per_image=512,
+            box_positive_fraction=0.25,
+            bbox_reg_weights=None,
+            # line parameters
+            line_head=None,
+            line_predictor=None,
+
+            **kwargs,
+    ):
+
+        # if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
+        #     raise TypeError(
+        #         "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
+        #     )
+        # if min_size is None:
+        #     min_size = (640, 672, 704, 736, 768, 800)
+        #
+        # if num_keypoints is not None:
+        #     if keypoint_predictor is not None:
+        #         raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
+        # else:
+        #     num_keypoints = 17
+
+        out_channels = backbone.out_channels
+
+        if line_head is None:
+            keypoint_layers = tuple(512 for _ in range(8))
+            line_head = LineRCNNHeads(out_channels, keypoint_layers)
+
+        if line_predictor is None:
+            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
+            line_predictor = LineRCNNPredictor(keypoint_dim_reduced)
+
+        super().__init__(
+            backbone,
+            num_classes,
+            # transform parameters
+            min_size,
+            max_size,
+            image_mean,
+            image_std,
+            # RPN-specific parameters
+            rpn_anchor_generator,
+            rpn_head,
+            rpn_pre_nms_top_n_train,
+            rpn_pre_nms_top_n_test,
+            rpn_post_nms_top_n_train,
+            rpn_post_nms_top_n_test,
+            rpn_nms_thresh,
+            rpn_fg_iou_thresh,
+            rpn_bg_iou_thresh,
+            rpn_batch_size_per_image,
+            rpn_positive_fraction,
+            rpn_score_thresh,
+            # Box parameters
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            **kwargs,
+        )
+
+        roi_heads = RoIHeads(
+            # Box
+            box_roi_pool,
+            box_head,
+            box_predictor,
+            line_head,
+            line_predictor,
+            box_fg_iou_thresh,
+            box_bg_iou_thresh,
+            box_batch_size_per_image,
+            box_positive_fraction,
+            bbox_reg_weights,
+            box_score_thresh,
+            box_nms_thresh,
+            box_detections_per_img,
+
+        )
+        super().roi_heads = roi_heads
+        # self.roi_heads = roi_heads
+
+        # self.roi_heads.line_head = line_head
+        # self.roi_heads.line_predictor = line_predictor
+
+
+class LineRCNNHeads(nn.Sequential):
+    pass
+    # def __init__(self, in_channels, layers):
+    #     d = []
+    #     next_feature = in_channels
+    #     for out_channels in layers:
+    #         d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
+    #         d.append(nn.ReLU(inplace=True))
+    #         next_feature = out_channels
+    #     super().__init__(*d)
+    #     for m in self.children():
+    #         if isinstance(m, nn.Conv2d):
+    #             nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+    #             nn.init.constant_(m.bias, 0)
+
+
+class LineRCNNPredictor(nn.Module):
+    pass
+    # def __init__(self, in_channels, num_keypoints):
+    #     super().__init__()
+    #     input_features = in_channels
+    #     deconv_kernel = 4
+    #     self.kps_score_lowres = nn.ConvTranspose2d(
+    #         input_features,
+    #         num_keypoints,
+    #         deconv_kernel,
+    #         stride=2,
+    #         padding=deconv_kernel // 2 - 1,
+    #     )
+    #     nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
+    #     nn.init.constant_(self.kps_score_lowres.bias, 0)
+    #     self.up_scale = 2
+    #     self.out_channels = num_keypoints
+    #
+    # def forward(self, x):
+    #     x = self.kps_score_lowres(x)
+    #     return torch.nn.functional.interpolate(
+    #         x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
+    #     )
+
+
+_COMMON_META = {
+    "categories": _COCO_PERSON_CATEGORIES,
+    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
+    "min_size": (1, 1),
+}
+
+
+class LineRCNN_ResNet50_FPN_Weights(WeightsEnum):
+    COCO_LEGACY = Weights(
+        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 59137258,
+            "recipe": "https://github.com/pytorch/vision/issues/1606",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 50.6,
+                    "kp_map": 61.1,
+                }
+            },
+            "_ops": 133.924,
+            "_file_size": 226.054,
+            "_docs": """
+                These weights were produced by following a similar training recipe as on the paper but use a checkpoint
+                from an early epoch.
+            """,
+        },
+    )
+    COCO_V1 = Weights(
+        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
+        transforms=ObjectDetection,
+        meta={
+            **_COMMON_META,
+            "num_params": 59137258,
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
+            "_metrics": {
+                "COCO-val2017": {
+                    "box_map": 54.6,
+                    "kp_map": 65.0,
+                }
+            },
+            "_ops": 137.42,
+            "_file_size": 226.054,
+            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
+        },
+    )
+    DEFAULT = COCO_V1
+
+
+@register_model()
+@handle_legacy_interface(
+    weights=(
+            "pretrained",
+            lambda kwargs: LineRCNN_ResNet50_FPN_Weights.COCO_LEGACY
+            if kwargs["pretrained"] == "legacy"
+            else LineRCNN_ResNet50_FPN_Weights.COCO_V1,
+    ),
+    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
+)
+def linercnn_resnet50_fpn(
+        *,
+        weights: Optional[LineRCNN_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        num_keypoints: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> LineRCNN:
+    """
+    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
+
+    .. betastatus:: detection module
+
+    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
+
+    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
+    image, and should be in ``0-1`` range. Different images can have different sizes.
+
+    The behavior of the model changes depending on if it is in training or evaluation mode.
+
+    During training, the model expects both the input tensors and targets (list of dictionary),
+    containing:
+
+        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
+        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
+          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.
+
+    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
+    losses for both the RPN and the R-CNN, and the keypoint loss.
+
+    During inference, the model requires only the input tensors, and returns the post-processed
+    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
+    follows, where ``N`` is the number of detected instances:
+
+        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+        - labels (``Int64Tensor[N]``): the predicted labels for each instance
+        - scores (``Tensor[N]``): the scores or each instance
+        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
+
+    For more details on the output, you may refer to :ref:`instance_seg_output`.
+
+    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
+
+    Example::
+
+        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
+        >>> model.eval()
+        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
+        >>> predictions = model(x)
+        >>>
+        >>> # optionally, if you want to export the model to ONNX:
+        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
+
+    Args:
+        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
+            below for more details, and possible values. By default, no
+            pre-trained weights are used.
+        progress (bool): If True, displays a progress bar of the download to stderr
+        num_classes (int, optional): number of output classes of the model (including the background)
+        num_keypoints (int, optional): number of keypoints
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
+            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
+            passed (the default) this value is set to 3.
+
+    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
+        :members:
+    """
+    weights = LineRCNN_ResNet50_FPN_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
+    else:
+        if num_classes is None:
+            num_classes = 2
+        if num_keypoints is None:
+            num_keypoints = 17
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = LineRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == LineRCNN_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model

+ 903 - 0
models/line_detect/roi_heads.py

@@ -0,0 +1,903 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn, Tensor
+from torchvision.ops import boxes as box_ops, roi_align
+
+from libs.vision_libs.models.detection  import _utils as det_utils
+
+###计算多头损失
+def line_loss():
+    pass
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+    """
+    Computes the loss for Faster R-CNN.
+
+    Args:
+        class_logits (Tensor)
+        box_regression (Tensor)
+        labels (list[BoxList])
+        regression_targets (Tensor)
+
+    Returns:
+        classification_loss (Tensor)
+        box_loss (Tensor)
+    """
+
+    labels = torch.cat(labels, dim=0)
+    regression_targets = torch.cat(regression_targets, dim=0)
+
+    classification_loss = F.cross_entropy(class_logits, labels)
+
+    # get indices that correspond to the regression targets for
+    # the corresponding ground truth labels, to be used with
+    # advanced indexing
+    sampled_pos_inds_subset = torch.where(labels > 0)[0]
+    labels_pos = labels[sampled_pos_inds_subset]
+    N, num_classes = class_logits.shape
+    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
+
+    box_loss = F.smooth_l1_loss(
+        box_regression[sampled_pos_inds_subset, labels_pos],
+        regression_targets[sampled_pos_inds_subset],
+        beta=1 / 9,
+        reduction="sum",
+    )
+    box_loss = box_loss / labels.numel()
+
+    return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+    # type: (Tensor, List[Tensor]) -> List[Tensor]
+    """
+    From the results of the CNN, post process the masks
+    by taking the mask corresponding to the class with max
+    probability (which are of fixed size and directly output
+    by the CNN) and return the masks in the mask field of the BoxList.
+
+    Args:
+        x (Tensor): the mask logits
+        labels (list[BoxList]): bounding boxes that are used as
+            reference, one for ech image
+
+    Returns:
+        results (list[BoxList]): one BoxList for each image, containing
+            the extra field mask
+    """
+    mask_prob = x.sigmoid()
+
+    # select masks corresponding to the predicted classes
+    num_masks = x.shape[0]
+    boxes_per_image = [label.shape[0] for label in labels]
+    labels = torch.cat(labels)
+    index = torch.arange(num_masks, device=labels.device)
+    mask_prob = mask_prob[index, labels][:, None]
+    mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+    return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+    # type: (Tensor, Tensor, Tensor, int) -> Tensor
+    """
+    Given segmentation masks and the bounding boxes corresponding
+    to the location of the masks in the image, this function
+    crops and resizes the masks in the position defined by the
+    boxes. This prepares the masks for them to be fed to the
+    loss computation as the targets.
+    """
+    matched_idxs = matched_idxs.to(boxes)
+    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+    gt_masks = gt_masks[:, None].to(rois)
+    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    """
+    Args:
+        proposals (list[BoxList])
+        mask_logits (Tensor)
+        targets (list[BoxList])
+
+    Return:
+        mask_loss (Tensor): scalar tensor containing the loss
+    """
+
+    discretization_size = mask_logits.shape[-1]
+    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+    mask_targets = [
+        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+    ]
+
+    labels = torch.cat(labels, dim=0)
+    mask_targets = torch.cat(mask_targets, dim=0)
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it separately
+    if mask_targets.numel() == 0:
+        return mask_logits.sum() * 0
+
+    mask_loss = F.binary_cross_entropy_with_logits(
+        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+    )
+    return mask_loss
+
+
+def keypoints_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    lin_ind = y * heatmap_size + x
+    heatmaps = lin_ind * valid
+
+    return heatmaps, valid
+
+
+def _onnx_heatmaps_to_keypoints(
+    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
+):
+    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
+
+    width_correction = widths_i / roi_map_width
+    height_correction = heights_i / roi_map_height
+
+    roi_map = F.interpolate(
+        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
+    )[:, 0]
+
+    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
+    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+    x_int = pos % w
+    y_int = (pos - x_int) // w
+
+    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
+        dtype=torch.float32
+    )
+    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
+        dtype=torch.float32
+    )
+
+    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
+    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
+    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
+    xy_preds_i = torch.stack(
+        [
+            xy_preds_i_0.to(dtype=torch.float32),
+            xy_preds_i_1.to(dtype=torch.float32),
+            xy_preds_i_2.to(dtype=torch.float32),
+        ],
+        0,
+    )
+
+    # TODO: simplify when indexing without rank will be supported by ONNX
+    base = num_keypoints * num_keypoints + num_keypoints + 1
+    ind = torch.arange(num_keypoints)
+    ind = ind.to(dtype=torch.int64) * base
+    end_scores_i = (
+        roi_map.index_select(1, y_int.to(dtype=torch.int64))
+        .index_select(2, x_int.to(dtype=torch.int64))
+        .view(-1)
+        .index_select(0, ind.to(dtype=torch.int64))
+    )
+
+    return xy_preds_i, end_scores_i
+
+
+@torch.jit._script_if_tracing
+def _onnx_heatmaps_to_keypoints_loop(
+    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
+):
+    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
+
+    for i in range(int(rois.size(0))):
+        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
+            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
+        )
+        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
+        end_scores = torch.cat(
+            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
+        )
+    return xy_preds, end_scores
+
+
+def heatmaps_to_keypoints(maps, rois):
+    """Extract predicted keypoint locations from heatmaps. Output has shape
+    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
+    for each keypoint.
+    """
+    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
+    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
+    # consistency with keypoints_to_heatmap_labels by using the conversion from
+    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
+    # continuous coordinate.
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+
+    widths = rois[:, 2] - rois[:, 0]
+    heights = rois[:, 3] - rois[:, 1]
+    widths = widths.clamp(min=1)
+    heights = heights.clamp(min=1)
+    widths_ceil = widths.ceil()
+    heights_ceil = heights.ceil()
+
+    num_keypoints = maps.shape[1]
+
+    if torchvision._is_tracing():
+        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
+            maps,
+            rois,
+            widths_ceil,
+            heights_ceil,
+            widths,
+            heights,
+            offset_x,
+            offset_y,
+            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
+        )
+        return xy_preds.permute(0, 2, 1), end_scores
+
+    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
+    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
+    for i in range(len(rois)):
+        roi_map_width = int(widths_ceil[i].item())
+        roi_map_height = int(heights_ceil[i].item())
+        width_correction = widths[i] / roi_map_width
+        height_correction = heights[i] / roi_map_height
+        roi_map = F.interpolate(
+            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
+        )[:, 0]
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = roi_map.shape[2]
+        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        x_int = pos % w
+        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
+        # assert (roi_map_probs[k, y_int, x_int] ==
+        #         roi_map_probs[k, :, :].max())
+        x = (x_int.float() + 0.5) * width_correction
+        y = (y_int.float() + 0.5) * height_correction
+        xy_preds[i, 0, :] = x + offset_x[i]
+        xy_preds[i, 1, :] = y + offset_y[i]
+        xy_preds[i, 2, :] = 1
+        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
+
+    return xy_preds.permute(0, 2, 1), end_scores
+
+
+def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
+    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+    N, K, H, W = keypoint_logits.shape
+    if H != W:
+        raise ValueError(
+            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
+        )
+    discretization_size = H
+    heatmaps = []
+    valid = []
+    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
+        kp = gt_kp_in_image[midx]
+        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+        heatmaps.append(heatmaps_per_image.view(-1))
+        valid.append(valid_per_image.view(-1))
+
+    keypoint_targets = torch.cat(heatmaps, dim=0)
+    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    valid = torch.where(valid)[0]
+
+    # torch.mean (in binary_cross_entropy_with_logits) doesn't
+    # accept empty tensors, so handle it sepaartely
+    if keypoint_targets.numel() == 0 or len(valid) == 0:
+        return keypoint_logits.sum() * 0
+
+    keypoint_logits = keypoint_logits.view(N * K, H * W)
+
+    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
+    return keypoint_loss
+
+
+def keypointrcnn_inference(x, boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    kp_probs = []
+    kp_scores = []
+
+    boxes_per_image = [box.size(0) for box in boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, boxes):
+        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
+        kp_probs.append(kp_prob)
+        kp_scores.append(scores)
+
+    return kp_probs, kp_scores
+
+
+def _onnx_expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half = w_half.to(dtype=torch.float32) * scale
+    h_half = h_half.to(dtype=torch.float32) * scale
+
+    boxes_exp0 = x_c - w_half
+    boxes_exp1 = y_c - h_half
+    boxes_exp2 = x_c + w_half
+    boxes_exp3 = y_c + h_half
+    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+    return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+    # type: (Tensor, float) -> Tensor
+    if torchvision._is_tracing():
+        return _onnx_expand_boxes(boxes, scale)
+    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+    w_half *= scale
+    h_half *= scale
+
+    boxes_exp = torch.zeros_like(boxes)
+    boxes_exp[:, 0] = x_c - w_half
+    boxes_exp[:, 2] = x_c + w_half
+    boxes_exp[:, 1] = y_c - h_half
+    boxes_exp[:, 3] = y_c + h_half
+    return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+    # type: (int, int) -> float
+    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+    # type: (Tensor, int) -> Tuple[Tensor, float]
+    M = mask.shape[-1]
+    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
+        scale = expand_masks_tracing_scale(M, padding)
+    else:
+        scale = float(M + 2 * padding) / M
+    padded_mask = F.pad(mask, (padding,) * 4)
+    return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+    # type: (Tensor, Tensor, int, int) -> Tensor
+    TO_REMOVE = 1
+    w = int(box[2] - box[0] + TO_REMOVE)
+    h = int(box[3] - box[1] + TO_REMOVE)
+    w = max(w, 1)
+    h = max(h, 1)
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, -1, -1))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+    x_0 = max(box[0], 0)
+    x_1 = min(box[2] + 1, im_w)
+    y_0 = max(box[1], 0)
+    y_1 = min(box[3] + 1, im_h)
+
+    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+    return im_mask
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+    one = torch.ones(1, dtype=torch.int64)
+    zero = torch.zeros(1, dtype=torch.int64)
+
+    w = box[2] - box[0] + one
+    h = box[3] - box[1] + one
+    w = torch.max(torch.cat((w, one)))
+    h = torch.max(torch.cat((h, one)))
+
+    # Set shape to [batchxCxHxW]
+    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+    # Resize mask
+    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+    mask = mask[0][0]
+
+    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
+
+    # TODO : replace below with a dynamic padding when support is added in ONNX
+
+    # pad y
+    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+    # pad x
+    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+    return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
+    res_append = torch.zeros(0, im_h, im_w)
+    for i in range(masks.size(0)):
+        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+        mask_res = mask_res.unsqueeze(0)
+        res_append = torch.cat((res_append, mask_res))
+    return res_append
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+    masks, scale = expand_masks(masks, padding=padding)
+    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+    im_h, im_w = img_shape
+
+    if torchvision._is_tracing():
+        return _onnx_paste_masks_in_image_loop(
+            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+        )[:, None]
+    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+    if len(res) > 0:
+        ret = torch.stack(res, dim=0)[:, None]
+    else:
+        ret = masks.new_empty((0, 1, im_h, im_w))
+    return ret
+
+
+class RoIHeads(nn.Module):
+    __annotations__ = {
+        "box_coder": det_utils.BoxCoder,
+        "proposal_matcher": det_utils.Matcher,
+        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
+    }
+
+    def __init__(
+        self,
+        box_roi_pool,
+        box_head,
+        box_predictor,
+        line_head,
+        line_predictor,
+        # Faster R-CNN training
+        fg_iou_thresh,
+        bg_iou_thresh,
+        batch_size_per_image,
+        positive_fraction,
+        bbox_reg_weights,
+        # Faster R-CNN inference
+        score_thresh,
+        nms_thresh,
+        detections_per_img,
+        # Mask
+        mask_roi_pool=None,
+        mask_head=None,
+        mask_predictor=None,
+        keypoint_roi_pool=None,
+        keypoint_head=None,
+        keypoint_predictor=None,
+    ):
+        super().__init__()
+
+        self.box_similarity = box_ops.box_iou
+        # assign ground-truth boxes for each proposal
+        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
+
+        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
+
+        if bbox_reg_weights is None:
+            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
+        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+        self.box_roi_pool = box_roi_pool
+        self.box_head = box_head
+        self.box_predictor = box_predictor
+
+        self.line_head=line_head
+        self.line_predictor=line_predictor
+
+        self.score_thresh = score_thresh
+        self.nms_thresh = nms_thresh
+        self.detections_per_img = detections_per_img
+
+        self.mask_roi_pool = mask_roi_pool
+        self.mask_head = mask_head
+        self.mask_predictor = mask_predictor
+
+        self.keypoint_roi_pool = keypoint_roi_pool
+        self.keypoint_head = keypoint_head
+        self.keypoint_predictor = keypoint_predictor
+
+    def has_line(self):
+        pass
+    def has_mask(self):
+        if self.mask_roi_pool is None:
+            return False
+        if self.mask_head is None:
+            return False
+        if self.mask_predictor is None:
+            return False
+        return True
+
+    def has_keypoint(self):
+        if self.keypoint_roi_pool is None:
+            return False
+        if self.keypoint_head is None:
+            return False
+        if self.keypoint_predictor is None:
+            return False
+        return True
+
+    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+        matched_idxs = []
+        labels = []
+        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+
+            if gt_boxes_in_image.numel() == 0:
+                # Background image
+                device = proposals_in_image.device
+                clamped_matched_idxs_in_image = torch.zeros(
+                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+                )
+                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
+            else:
+                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+
+                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+                labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+                # Label background (below the low threshold)
+                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
+                labels_in_image[bg_inds] = 0
+
+                # Label ignore proposals (between low and high thresholds)
+                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
+                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
+
+            matched_idxs.append(clamped_matched_idxs_in_image)
+            labels.append(labels_in_image)
+        return matched_idxs, labels
+
+    def subsample(self, labels):
+        # type: (List[Tensor]) -> List[Tensor]
+        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+        sampled_inds = []
+        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+            sampled_inds.append(img_sampled_inds)
+        return sampled_inds
+
+    def add_gt_proposals(self, proposals, gt_boxes):
+        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
+
+        return proposals
+
+    def check_targets(self, targets):
+        # type: (Optional[List[Dict[str, Tensor]]]) -> None
+        if targets is None:
+            raise ValueError("targets should not be None")
+        if not all(["boxes" in t for t in targets]):
+            raise ValueError("Every element of targets should have a boxes key")
+        if not all(["labels" in t for t in targets]):
+            raise ValueError("Every element of targets should have a labels key")
+        if self.has_mask():
+            if not all(["masks" in t for t in targets]):
+                raise ValueError("Every element of targets should have a masks key")
+
+    def select_training_samples(
+        self,
+        proposals,  # type: List[Tensor]
+        targets,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+        self.check_targets(targets)
+        if targets is None:
+            raise ValueError("targets should not be None")
+        dtype = proposals[0].dtype
+        device = proposals[0].device
+
+        gt_boxes = [t["boxes"].to(dtype) for t in targets]
+        gt_labels = [t["labels"] for t in targets]
+
+        # append ground-truth bboxes to propos
+        proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+        # get matching gt indices for each proposal
+        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+        # sample a fixed proportion of positive-negative proposals
+        sampled_inds = self.subsample(labels)
+        matched_gt_boxes = []
+        num_images = len(proposals)
+        for img_id in range(num_images):
+            img_sampled_inds = sampled_inds[img_id]
+            proposals[img_id] = proposals[img_id][img_sampled_inds]
+            labels[img_id] = labels[img_id][img_sampled_inds]
+            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+            gt_boxes_in_image = gt_boxes[img_id]
+            if gt_boxes_in_image.numel() == 0:
+                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+        return proposals, matched_idxs, labels, regression_targets
+
+    def postprocess_detections(
+        self,
+        class_logits,  # type: Tensor
+        box_regression,  # type: Tensor
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+    ):
+        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+        device = class_logits.device
+        num_classes = class_logits.shape[-1]
+
+        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+        pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+        pred_scores = F.softmax(class_logits, -1)
+
+        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+        pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+        all_boxes = []
+        all_scores = []
+        all_labels = []
+        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+            # create labels for each prediction
+            labels = torch.arange(num_classes, device=device)
+            labels = labels.view(1, -1).expand_as(scores)
+
+            # remove predictions with the background label
+            boxes = boxes[:, 1:]
+            scores = scores[:, 1:]
+            labels = labels[:, 1:]
+
+            # batch everything, by making every class prediction be a separate instance
+            boxes = boxes.reshape(-1, 4)
+            scores = scores.reshape(-1)
+            labels = labels.reshape(-1)
+
+            # remove low scoring boxes
+            inds = torch.where(scores > self.score_thresh)[0]
+            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+            # remove empty boxes
+            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            # non-maximum suppression, independently done per class
+            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+            # keep only topk scoring predictions
+            keep = keep[: self.detections_per_img]
+            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+            all_boxes.append(boxes)
+            all_scores.append(scores)
+            all_labels.append(labels)
+
+        return all_boxes, all_scores, all_labels
+
+
+    def forward(
+        self,
+        features,  # type: Dict[str, Tensor]
+        proposals,  # type: List[Tensor]
+        image_shapes,  # type: List[Tuple[int, int]]
+        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
+    ):
+        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+        """
+        Args:
+            features (List[Tensor])
+            proposals (List[Tensor[N, 4]])
+            image_shapes (List[Tuple[H, W]])
+            targets (List[Dict])
+        """
+        if targets is not None:
+            for t in targets:
+                # TODO: https://github.com/pytorch/pytorch/issues/26731
+                floating_point_types = (torch.float, torch.double, torch.half)
+                if not t["boxes"].dtype in floating_point_types:
+                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
+                if not t["labels"].dtype == torch.int64:
+                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
+                if self.has_keypoint():
+                    if not t["keypoints"].dtype == torch.float32:
+                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
+
+        if self.training:
+            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+        else:
+            labels = None
+            regression_targets = None
+            matched_idxs = None
+
+        box_features = self.box_roi_pool(features, proposals, image_shapes)
+        box_features = self.box_head(box_features)
+
+
+
+
+        class_logits, box_regression = self.box_predictor(box_features)
+
+        result: List[Dict[str, torch.Tensor]] = []
+        losses = {}
+        if self.training:
+            if labels is None:
+                raise ValueError("labels cannot be None")
+            if regression_targets is None:
+                raise ValueError("regression_targets cannot be None")
+            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
+
+
+            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
+        else:
+            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+            num_images = len(boxes)
+            for i in range(num_images):
+                result.append(
+                    {
+                        "boxes": boxes[i],
+                        "labels": labels[i],
+                        "scores": scores[i],
+                    }
+                )
+        if self.has_line():
+
+            line_features = self.line_head(features)
+            _ = self.line_predictor(line_features)
+            ### line_loss(multitasklearner)
+
+
+            ### infer
+
+
+            pass
+        if self.has_mask():
+            mask_proposals = [p["boxes"] for p in result]
+            if self.training:
+                if matched_idxs is None:
+                    raise ValueError("if in training, matched_idxs should not be None")
+
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                mask_proposals = []
+                pos_matched_idxs = []
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    mask_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            if self.mask_roi_pool is not None:
+                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+                mask_features = self.mask_head(mask_features)
+                mask_logits = self.mask_predictor(mask_features)
+            else:
+                raise Exception("Expected mask_roi_pool to be not None")
+
+            loss_mask = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None or mask_logits is None:
+                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+                gt_masks = [t["masks"] for t in targets]
+                gt_labels = [t["labels"] for t in targets]
+                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+                loss_mask = {"loss_mask": rcnn_loss_mask}
+            else:
+                labels = [r["labels"] for r in result]
+                masks_probs = maskrcnn_inference(mask_logits, labels)
+                for mask_prob, r in zip(masks_probs, result):
+                    r["masks"] = mask_prob
+
+            losses.update(loss_mask)
+
+        # keep none checks in if conditional so torchscript will conditionally
+        # compile each branch
+        if (
+            self.keypoint_roi_pool is not None
+            and self.keypoint_head is not None
+            and self.keypoint_predictor is not None
+        ):
+            keypoint_proposals = [p["boxes"] for p in result]
+            if self.training:
+                # during training, only focus on positive boxes
+                num_images = len(proposals)
+                keypoint_proposals = []
+                pos_matched_idxs = []
+                if matched_idxs is None:
+                    raise ValueError("if in trainning, matched_idxs should not be None")
+
+                for img_id in range(num_images):
+                    pos = torch.where(labels[img_id] > 0)[0]
+                    keypoint_proposals.append(proposals[img_id][pos])
+                    pos_matched_idxs.append(matched_idxs[img_id][pos])
+            else:
+                pos_matched_idxs = None
+
+            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
+            keypoint_features = self.keypoint_head(keypoint_features)
+            keypoint_logits = self.keypoint_predictor(keypoint_features)
+
+            loss_keypoint = {}
+            if self.training:
+                if targets is None or pos_matched_idxs is None:
+                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
+
+                gt_keypoints = [t["keypoints"] for t in targets]
+                rcnn_loss_keypoint = keypointrcnn_loss(
+                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
+                )
+                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
+            else:
+                if keypoint_logits is None or keypoint_proposals is None:
+                    raise ValueError(
+                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
+                    )
+
+                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
+                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
+                    r["keypoints"] = keypoint_prob
+                    r["keypoints_scores"] = kps
+            losses.update(loss_keypoint)
+
+        return result, losses

+ 0 - 0
models/obj/__init__.py → models/obj_detect/__init__.py