|
@@ -15,7 +15,7 @@ from .line_predictor import LineRCNNPredictor
|
|
from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
|
|
from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
|
|
from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
|
|
from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
|
|
from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
|
|
from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
|
|
-from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
|
|
|
|
|
|
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, ResNet18_Weights, resnet18
|
|
from libs.vision_libs.models.detection._utils import overwrite_eps
|
|
from libs.vision_libs.models.detection._utils import overwrite_eps
|
|
from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
|
|
from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
|
|
from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
|
|
from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
|
|
@@ -396,6 +396,43 @@ class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
|
|
weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
|
|
weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
|
|
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
|
|
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
|
|
)
|
|
)
|
|
|
|
+
|
|
|
|
+def linenet_resnet18_fpn(
|
|
|
|
+ *,
|
|
|
|
+ weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
|
|
|
|
+ progress: bool = True,
|
|
|
|
+ num_classes: Optional[int] = None,
|
|
|
|
+ weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
|
|
|
|
+ trainable_backbone_layers: Optional[int] = None,
|
|
|
|
+ **kwargs: Any,
|
|
|
|
+) -> LineNet:
|
|
|
|
+
|
|
|
|
+ # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
|
|
|
|
+ # weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
|
|
|
+
|
|
|
|
+ if weights is not None:
|
|
|
|
+ weights_backbone = None
|
|
|
|
+ num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
|
|
|
|
+ elif num_classes is None:
|
|
|
|
+ num_classes = 91
|
|
|
|
+ if weights_backbone is not None:
|
|
|
|
+ print(f'resnet50 weights is not None')
|
|
|
|
+
|
|
|
|
+ is_trained = weights is not None or weights_backbone is not None
|
|
|
|
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
|
|
|
|
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
|
|
|
+
|
|
|
|
+ backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
|
|
|
|
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
|
|
|
|
+ model = LineNet(backbone, num_classes=num_classes, **kwargs)
|
|
|
|
+
|
|
|
|
+ if weights is not None:
|
|
|
|
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
|
|
|
+ if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
|
|
|
|
+ overwrite_eps(model, 0.0)
|
|
|
|
+
|
|
|
|
+ return model
|
|
|
|
+
|
|
def linenet_resnet50_fpn(
|
|
def linenet_resnet50_fpn(
|
|
*,
|
|
*,
|
|
weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
|
|
weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
|