|
|
@@ -16,7 +16,7 @@ from .faster_rcnn import _default_anchorgen, FasterRCNN, FastRCNNConvFCHead, RPN
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
- "MaskRCNN",
|
|
|
+ "InsDetectNet",
|
|
|
"MaskRCNN_ResNet50_FPN_Weights",
|
|
|
"MaskRCNN_ResNet50_FPN_V2_Weights",
|
|
|
"maskrcnn_resnet50_fpn",
|
|
|
@@ -24,7 +24,7 @@ __all__ = [
|
|
|
]
|
|
|
|
|
|
|
|
|
-class MaskRCNN(FasterRCNN):
|
|
|
+class InsDetectNet(FasterRCNN):
|
|
|
"""
|
|
|
Implements Mask R-CNN.
|
|
|
|
|
|
@@ -413,7 +413,7 @@ def maskrcnn_resnet50_fpn(
|
|
|
weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
|
|
|
trainable_backbone_layers: Optional[int] = None,
|
|
|
**kwargs: Any,
|
|
|
-) -> MaskRCNN:
|
|
|
+) -> InsDetectNet:
|
|
|
"""Mask R-CNN model with a ResNet-50-FPN backbone from the `Mask R-CNN
|
|
|
<https://arxiv.org/abs/1703.06870>`_ paper.
|
|
|
|
|
|
@@ -498,7 +498,7 @@ def maskrcnn_resnet50_fpn(
|
|
|
|
|
|
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
|
|
|
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
|
|
|
- model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
|
|
|
+ model = InsDetectNet(backbone, num_classes=num_classes, **kwargs)
|
|
|
|
|
|
if weights is not None:
|
|
|
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
|
|
@@ -521,7 +521,7 @@ def maskrcnn_resnet50_fpn_v2(
|
|
|
weights_backbone: Optional[ResNet50_Weights] = None,
|
|
|
trainable_backbone_layers: Optional[int] = None,
|
|
|
**kwargs: Any,
|
|
|
-) -> MaskRCNN:
|
|
|
+) -> InsDetectNet:
|
|
|
"""Improved Mask R-CNN model with a ResNet-50-FPN backbone from the `Benchmarking Detection Transfer
|
|
|
Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`_ paper.
|
|
|
|
|
|
@@ -571,7 +571,7 @@ def maskrcnn_resnet50_fpn_v2(
|
|
|
(backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
|
|
|
)
|
|
|
mask_head = MaskRCNNHeads(backbone.out_channels, [256, 256, 256, 256], 1, norm_layer=nn.BatchNorm2d)
|
|
|
- model = MaskRCNN(
|
|
|
+ model = InsDetectNet(
|
|
|
backbone,
|
|
|
num_classes=num_classes,
|
|
|
rpn_anchor_generator=rpn_anchor_generator,
|