backbone_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import warnings
  2. from typing import Callable, Dict, List, Optional, Union
  3. from torch import nn, Tensor
  4. from torchvision.ops import misc as misc_nn_ops
  5. from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
  6. from .. import mobilenet, resnet
  7. from .._api import _get_enum_from_fn, WeightsEnum
  8. from .._utils import handle_legacy_interface, IntermediateLayerGetter
  9. class BackboneWithFPN(nn.Module):
  10. """
  11. Adds a FPN on top of a model.
  12. Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
  13. extract a submodel that returns the feature maps specified in return_layers.
  14. The same limitations of IntermediateLayerGetter apply here.
  15. Args:
  16. backbone (nn.Module)
  17. return_layers (Dict[name, new_name]): a dict containing the names
  18. of the modules for which the activations will be returned as
  19. the key of the dict, and the value of the dict is the name
  20. of the returned activation (which the user can specify).
  21. in_channels_list (List[int]): number of channels for each feature map
  22. that is returned, in the order they are present in the OrderedDict
  23. out_channels (int): number of channels in the FPN.
  24. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  25. Attributes:
  26. out_channels (int): the number of channels in the FPN
  27. """
  28. def __init__(
  29. self,
  30. backbone: nn.Module,
  31. return_layers: Dict[str, str],
  32. in_channels_list: List[int],
  33. out_channels: int,
  34. extra_blocks: Optional[ExtraFPNBlock] = None,
  35. norm_layer: Optional[Callable[..., nn.Module]] = None,
  36. ) -> None:
  37. super().__init__()
  38. if extra_blocks is None:
  39. extra_blocks = LastLevelMaxPool()
  40. self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
  41. self.fpn = FeaturePyramidNetwork(
  42. in_channels_list=in_channels_list,
  43. out_channels=out_channels,
  44. extra_blocks=extra_blocks,
  45. norm_layer=norm_layer,
  46. )
  47. self.out_channels = out_channels
  48. def forward(self, x: Tensor) -> Dict[str, Tensor]:
  49. x = self.body(x)
  50. x = self.fpn(x)
  51. return x
  52. @handle_legacy_interface(
  53. weights=(
  54. "pretrained",
  55. lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
  56. ),
  57. )
  58. def resnet_fpn_backbone(
  59. *,
  60. backbone_name: str,
  61. weights: Optional[WeightsEnum],
  62. norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
  63. trainable_layers: int = 3,
  64. returned_layers: Optional[List[int]] = None,
  65. extra_blocks: Optional[ExtraFPNBlock] = None,
  66. ) -> BackboneWithFPN:
  67. """
  68. Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
  69. Examples::
  70. >>> import torch
  71. >>> from torchvision.models import ResNet50_Weights
  72. >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
  73. >>> backbone = resnet_fpn_backbone(backbone_name='resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3)
  74. >>> # get some dummy image
  75. >>> x = torch.rand(1,3,64,64)
  76. >>> # compute the output
  77. >>> output = backbone(x)
  78. >>> print([(k, v.shape) for k, v in output.items()])
  79. >>> # returns
  80. >>> [('0', torch.Size([1, 256, 16, 16])),
  81. >>> ('1', torch.Size([1, 256, 8, 8])),
  82. >>> ('2', torch.Size([1, 256, 4, 4])),
  83. >>> ('3', torch.Size([1, 256, 2, 2])),
  84. >>> ('pool', torch.Size([1, 256, 1, 1]))]
  85. Args:
  86. backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50',
  87. 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
  88. weights (WeightsEnum, optional): The pretrained weights for the model
  89. norm_layer (callable): it is recommended to use the default value. For details visit:
  90. (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
  91. trainable_layers (int): number of trainable (not frozen) layers starting from final block.
  92. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
  93. returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
  94. By default, all layers are returned.
  95. extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
  96. be performed. It is expected to take the fpn features, the original
  97. features and the names of the original features as input, and returns
  98. a new list of feature maps and their corresponding names. By
  99. default, a ``LastLevelMaxPool`` is used.
  100. """
  101. backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
  102. return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
  103. def _resnet_fpn_extractor(
  104. backbone: resnet.ResNet,
  105. trainable_layers: int,
  106. returned_layers: Optional[List[int]] = None,
  107. extra_blocks: Optional[ExtraFPNBlock] = None,
  108. norm_layer: Optional[Callable[..., nn.Module]] = None,
  109. ) -> BackboneWithFPN:
  110. # select layers that won't be frozen
  111. if trainable_layers < 0 or trainable_layers > 5:
  112. raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
  113. layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
  114. if trainable_layers == 5:
  115. layers_to_train.append("bn1")
  116. for name, parameter in backbone.named_parameters():
  117. if all([not name.startswith(layer) for layer in layers_to_train]):
  118. parameter.requires_grad_(False)
  119. if extra_blocks is None:
  120. extra_blocks = LastLevelMaxPool()
  121. if returned_layers is None:
  122. returned_layers = [1, 2, 3, 4]
  123. if min(returned_layers) <= 0 or max(returned_layers) >= 5:
  124. raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
  125. return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
  126. in_channels_stage2 = backbone.inplanes // 8
  127. in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
  128. out_channels = 256
  129. return BackboneWithFPN(
  130. backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
  131. )
  132. def _validate_trainable_layers(
  133. is_trained: bool,
  134. trainable_backbone_layers: Optional[int],
  135. max_value: int,
  136. default_value: int,
  137. ) -> int:
  138. # don't freeze any layers if pretrained model or backbone is not used
  139. if not is_trained:
  140. if trainable_backbone_layers is not None:
  141. warnings.warn(
  142. "Changing trainable_backbone_layers has no effect if "
  143. "neither pretrained nor pretrained_backbone have been set to True, "
  144. f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
  145. )
  146. trainable_backbone_layers = max_value
  147. # by default freeze first blocks
  148. if trainable_backbone_layers is None:
  149. trainable_backbone_layers = default_value
  150. if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
  151. raise ValueError(
  152. f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
  153. )
  154. return trainable_backbone_layers
  155. @handle_legacy_interface(
  156. weights=(
  157. "pretrained",
  158. lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
  159. ),
  160. )
  161. def mobilenet_backbone(
  162. *,
  163. backbone_name: str,
  164. weights: Optional[WeightsEnum],
  165. fpn: bool,
  166. norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d,
  167. trainable_layers: int = 2,
  168. returned_layers: Optional[List[int]] = None,
  169. extra_blocks: Optional[ExtraFPNBlock] = None,
  170. ) -> nn.Module:
  171. backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
  172. return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)
  173. def _mobilenet_extractor(
  174. backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
  175. fpn: bool,
  176. trainable_layers: int,
  177. returned_layers: Optional[List[int]] = None,
  178. extra_blocks: Optional[ExtraFPNBlock] = None,
  179. norm_layer: Optional[Callable[..., nn.Module]] = None,
  180. ) -> nn.Module:
  181. backbone = backbone.features
  182. # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
  183. # The first and last blocks are always included because they are the C0 (conv1) and Cn.
  184. stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
  185. num_stages = len(stage_indices)
  186. # find the index of the layer from which we won't freeze
  187. if trainable_layers < 0 or trainable_layers > num_stages:
  188. raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
  189. freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
  190. for b in backbone[:freeze_before]:
  191. for parameter in b.parameters():
  192. parameter.requires_grad_(False)
  193. out_channels = 256
  194. if fpn:
  195. if extra_blocks is None:
  196. extra_blocks = LastLevelMaxPool()
  197. if returned_layers is None:
  198. returned_layers = [num_stages - 2, num_stages - 1]
  199. if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
  200. raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
  201. return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
  202. in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
  203. return BackboneWithFPN(
  204. backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
  205. )
  206. else:
  207. m = nn.Sequential(
  208. backbone,
  209. # depthwise linear combination of channels to reduce their size
  210. nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
  211. )
  212. m.out_channels = out_channels # type: ignore[assignment]
  213. return m