backbone_factory.py 794 B

123456789101112131415
  1. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  2. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
  3. from libs.vision_libs.models.detection._utils import overwrite_eps
  4. from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  5. from libs.vision_libs.ops import misc as misc_nn_ops
  6. from torch import nn
  7. def get_resnet50_fpn():
  8. is_trained = False
  9. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
  10. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  11. backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
  12. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  13. return backbone