from libs.vision_libs.models import mobilenet_v3_large from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, resnet18 from libs.vision_libs.models.detection._utils import overwrite_eps from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from libs.vision_libs.ops import misc as misc_nn_ops from torch import nn def get_resnet50_fpn(): is_trained = False trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) return backbone def get_resnet18_fpn(): is_trained = False trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d backbone = resnet18(weights=None, progress=True, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) return backbone def get_mobilenet_v3_large_fpn(): is_trained =False trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 6, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d backbone = mobilenet_v3_large(weights=None, progress=True, norm_layer=norm_layer) backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) return backbone