backbone_factory.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334
  1. from libs.vision_libs.models import mobilenet_v3_large
  2. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  3. from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
  4. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, resnet18
  5. from libs.vision_libs.models.detection._utils import overwrite_eps
  6. from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  7. from libs.vision_libs.ops import misc as misc_nn_ops
  8. from torch import nn
  9. def get_resnet50_fpn():
  10. is_trained = False
  11. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
  12. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  13. backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
  14. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  15. return backbone
  16. def get_resnet18_fpn():
  17. is_trained = False
  18. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
  19. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  20. backbone = resnet18(weights=None, progress=True, norm_layer=norm_layer)
  21. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  22. return backbone
  23. def get_mobilenet_v3_large_fpn():
  24. is_trained =False
  25. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 6, 3)
  26. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  27. backbone = mobilenet_v3_large(weights=None, progress=True, norm_layer=norm_layer)
  28. backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
  29. return backbone