|
@@ -1,5 +1,7 @@
|
|
|
+from libs.vision_libs.models import mobilenet_v3_large
|
|
|
from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
|
|
|
-from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
|
|
|
+from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
|
|
|
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, resnet18
|
|
|
from libs.vision_libs.models.detection._utils import overwrite_eps
|
|
|
from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
|
|
|
from libs.vision_libs.ops import misc as misc_nn_ops
|
|
@@ -12,4 +14,21 @@ def get_resnet50_fpn():
|
|
|
norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
|
|
backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
|
|
|
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
|
|
|
+ return backbone
|
|
|
+
|
|
|
+def get_resnet18_fpn():
|
|
|
+ is_trained = False
|
|
|
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
|
|
|
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
|
|
+ backbone = resnet18(weights=None, progress=True, norm_layer=norm_layer)
|
|
|
+ backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
|
|
|
+ return backbone
|
|
|
+
|
|
|
+def get_mobilenet_v3_large_fpn():
|
|
|
+ is_trained =False
|
|
|
+ trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 6, 3)
|
|
|
+ norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
|
|
|
+
|
|
|
+ backbone = mobilenet_v3_large(weights=None, progress=True, norm_layer=norm_layer)
|
|
|
+ backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
|
|
|
return backbone
|