squeezenet.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. from functools import partial
  2. from typing import Any, Optional
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.init as init
  6. from ..transforms._presets import ImageClassification
  7. from ..utils import _log_api_usage_once
  8. from ._api import register_model, Weights, WeightsEnum
  9. from ._meta import _IMAGENET_CATEGORIES
  10. from ._utils import _ovewrite_named_param, handle_legacy_interface
  11. __all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"]
  12. class Fire(nn.Module):
  13. def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None:
  14. super().__init__()
  15. self.inplanes = inplanes
  16. self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
  17. self.squeeze_activation = nn.ReLU(inplace=True)
  18. self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
  19. self.expand1x1_activation = nn.ReLU(inplace=True)
  20. self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
  21. self.expand3x3_activation = nn.ReLU(inplace=True)
  22. def forward(self, x: torch.Tensor) -> torch.Tensor:
  23. x = self.squeeze_activation(self.squeeze(x))
  24. return torch.cat(
  25. [self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1
  26. )
  27. class SqueezeNet(nn.Module):
  28. def __init__(self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5) -> None:
  29. super().__init__()
  30. _log_api_usage_once(self)
  31. self.num_classes = num_classes
  32. if version == "1_0":
  33. self.features = nn.Sequential(
  34. nn.Conv2d(3, 96, kernel_size=7, stride=2),
  35. nn.ReLU(inplace=True),
  36. nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
  37. Fire(96, 16, 64, 64),
  38. Fire(128, 16, 64, 64),
  39. Fire(128, 32, 128, 128),
  40. nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
  41. Fire(256, 32, 128, 128),
  42. Fire(256, 48, 192, 192),
  43. Fire(384, 48, 192, 192),
  44. Fire(384, 64, 256, 256),
  45. nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
  46. Fire(512, 64, 256, 256),
  47. )
  48. elif version == "1_1":
  49. self.features = nn.Sequential(
  50. nn.Conv2d(3, 64, kernel_size=3, stride=2),
  51. nn.ReLU(inplace=True),
  52. nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
  53. Fire(64, 16, 64, 64),
  54. Fire(128, 16, 64, 64),
  55. nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
  56. Fire(128, 32, 128, 128),
  57. Fire(256, 32, 128, 128),
  58. nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
  59. Fire(256, 48, 192, 192),
  60. Fire(384, 48, 192, 192),
  61. Fire(384, 64, 256, 256),
  62. Fire(512, 64, 256, 256),
  63. )
  64. else:
  65. # FIXME: Is this needed? SqueezeNet should only be called from the
  66. # FIXME: squeezenet1_x() functions
  67. # FIXME: This checking is not done for the other models
  68. raise ValueError(f"Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected")
  69. # Final convolution is initialized differently from the rest
  70. final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
  71. self.classifier = nn.Sequential(
  72. nn.Dropout(p=dropout), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
  73. )
  74. for m in self.modules():
  75. if isinstance(m, nn.Conv2d):
  76. if m is final_conv:
  77. init.normal_(m.weight, mean=0.0, std=0.01)
  78. else:
  79. init.kaiming_uniform_(m.weight)
  80. if m.bias is not None:
  81. init.constant_(m.bias, 0)
  82. def forward(self, x: torch.Tensor) -> torch.Tensor:
  83. x = self.features(x)
  84. x = self.classifier(x)
  85. return torch.flatten(x, 1)
  86. def _squeezenet(
  87. version: str,
  88. weights: Optional[WeightsEnum],
  89. progress: bool,
  90. **kwargs: Any,
  91. ) -> SqueezeNet:
  92. if weights is not None:
  93. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  94. model = SqueezeNet(version, **kwargs)
  95. if weights is not None:
  96. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  97. return model
  98. _COMMON_META = {
  99. "categories": _IMAGENET_CATEGORIES,
  100. "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
  101. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  102. }
  103. class SqueezeNet1_0_Weights(WeightsEnum):
  104. IMAGENET1K_V1 = Weights(
  105. url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
  106. transforms=partial(ImageClassification, crop_size=224),
  107. meta={
  108. **_COMMON_META,
  109. "min_size": (21, 21),
  110. "num_params": 1248424,
  111. "_metrics": {
  112. "ImageNet-1K": {
  113. "acc@1": 58.092,
  114. "acc@5": 80.420,
  115. }
  116. },
  117. "_ops": 0.819,
  118. "_file_size": 4.778,
  119. },
  120. )
  121. DEFAULT = IMAGENET1K_V1
  122. class SqueezeNet1_1_Weights(WeightsEnum):
  123. IMAGENET1K_V1 = Weights(
  124. url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
  125. transforms=partial(ImageClassification, crop_size=224),
  126. meta={
  127. **_COMMON_META,
  128. "min_size": (17, 17),
  129. "num_params": 1235496,
  130. "_metrics": {
  131. "ImageNet-1K": {
  132. "acc@1": 58.178,
  133. "acc@5": 80.624,
  134. }
  135. },
  136. "_ops": 0.349,
  137. "_file_size": 4.729,
  138. },
  139. )
  140. DEFAULT = IMAGENET1K_V1
  141. @register_model()
  142. @handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1))
  143. def squeezenet1_0(
  144. *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any
  145. ) -> SqueezeNet:
  146. """SqueezeNet model architecture from the `SqueezeNet: AlexNet-level
  147. accuracy with 50x fewer parameters and <0.5MB model size
  148. <https://arxiv.org/abs/1602.07360>`_ paper.
  149. Args:
  150. weights (:class:`~torchvision.models.SqueezeNet1_0_Weights`, optional): The
  151. pretrained weights to use. See
  152. :class:`~torchvision.models.SqueezeNet1_0_Weights` below for
  153. more details, and possible values. By default, no pre-trained
  154. weights are used.
  155. progress (bool, optional): If True, displays a progress bar of the
  156. download to stderr. Default is True.
  157. **kwargs: parameters passed to the ``torchvision.models.squeezenet.SqueezeNet``
  158. base class. Please refer to the `source code
  159. <https://github.com/pytorch/vision/blob/main/torchvision/models/squeezenet.py>`_
  160. for more details about this class.
  161. .. autoclass:: torchvision.models.SqueezeNet1_0_Weights
  162. :members:
  163. """
  164. weights = SqueezeNet1_0_Weights.verify(weights)
  165. return _squeezenet("1_0", weights, progress, **kwargs)
  166. @register_model()
  167. @handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1))
  168. def squeezenet1_1(
  169. *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any
  170. ) -> SqueezeNet:
  171. """SqueezeNet 1.1 model from the `official SqueezeNet repo
  172. <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
  173. SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
  174. than SqueezeNet 1.0, without sacrificing accuracy.
  175. Args:
  176. weights (:class:`~torchvision.models.SqueezeNet1_1_Weights`, optional): The
  177. pretrained weights to use. See
  178. :class:`~torchvision.models.SqueezeNet1_1_Weights` below for
  179. more details, and possible values. By default, no pre-trained
  180. weights are used.
  181. progress (bool, optional): If True, displays a progress bar of the
  182. download to stderr. Default is True.
  183. **kwargs: parameters passed to the ``torchvision.models.squeezenet.SqueezeNet``
  184. base class. Please refer to the `source code
  185. <https://github.com/pytorch/vision/blob/main/torchvision/models/squeezenet.py>`_
  186. for more details about this class.
  187. .. autoclass:: torchvision.models.SqueezeNet1_1_Weights
  188. :members:
  189. """
  190. weights = SqueezeNet1_1_Weights.verify(weights)
  191. return _squeezenet("1_1", weights, progress, **kwargs)