faster_rcnn.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846
  1. from typing import Any, Callable, List, Optional, Tuple, Union
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from torchvision.ops import MultiScaleRoIAlign
  6. from ...ops import misc as misc_nn_ops
  7. from ...transforms._presets import ObjectDetection
  8. from .._api import register_model, Weights, WeightsEnum
  9. from .._meta import _COCO_CATEGORIES
  10. from .._utils import _ovewrite_value_param, handle_legacy_interface
  11. from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
  12. from ..resnet import resnet50, ResNet50_Weights
  13. from ._utils import overwrite_eps
  14. from .anchor_utils import AnchorGenerator
  15. from .backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
  16. from .generalized_rcnn import GeneralizedRCNN
  17. from .roi_heads import RoIHeads
  18. from .rpn import RegionProposalNetwork, RPNHead
  19. from .transform import GeneralizedRCNNTransform
  20. __all__ = [
  21. "FasterRCNN",
  22. "FasterRCNN_ResNet50_FPN_Weights",
  23. "FasterRCNN_ResNet50_FPN_V2_Weights",
  24. "FasterRCNN_MobileNet_V3_Large_FPN_Weights",
  25. "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
  26. "fasterrcnn_resnet50_fpn",
  27. "fasterrcnn_resnet50_fpn_v2",
  28. "fasterrcnn_mobilenet_v3_large_fpn",
  29. "fasterrcnn_mobilenet_v3_large_320_fpn",
  30. ]
  31. def _default_anchorgen():
  32. anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
  33. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  34. return AnchorGenerator(anchor_sizes, aspect_ratios)
  35. class FasterRCNN(GeneralizedRCNN):
  36. """
  37. Implements Faster R-CNN.
  38. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  39. image, and should be in 0-1 range. Different images can have different sizes.
  40. The behavior of the model changes depending on if it is in training or evaluation mode.
  41. During training, the model expects both the input tensors and targets (list of dictionary),
  42. containing:
  43. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  44. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  45. - labels (Int64Tensor[N]): the class label for each ground-truth box
  46. The model returns a Dict[Tensor] during training, containing the classification and regression
  47. losses for both the RPN and the R-CNN.
  48. During inference, the model requires only the input tensors, and returns the post-processed
  49. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  50. follows:
  51. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  52. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  53. - labels (Int64Tensor[N]): the predicted labels for each image
  54. - scores (Tensor[N]): the scores or each prediction
  55. Args:
  56. backbone (nn.Module): the network used to compute the features for the model.
  57. It should contain an out_channels attribute, which indicates the number of output
  58. channels that each feature map has (and it should be the same for all feature maps).
  59. The backbone should return a single Tensor or and OrderedDict[Tensor].
  60. num_classes (int): number of output classes of the model (including the background).
  61. If box_predictor is specified, num_classes should be None.
  62. min_size (int): Images are rescaled before feeding them to the backbone:
  63. we attempt to preserve the aspect ratio and scale the shorter edge
  64. to ``min_size``. If the resulting longer edge exceeds ``max_size``,
  65. then downscale so that the longer edge does not exceed ``max_size``.
  66. This may result in the shorter edge beeing lower than ``min_size``.
  67. max_size (int): See ``min_size``.
  68. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  69. They are generally the mean values of the dataset on which the backbone has been trained
  70. on
  71. image_std (Tuple[float, float, float]): std values used for input normalization.
  72. They are generally the std values of the dataset on which the backbone has been trained on
  73. rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  74. maps.
  75. rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
  76. rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
  77. rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
  78. rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
  79. rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
  80. rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
  81. rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  82. considered as positive during training of the RPN.
  83. rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  84. considered as negative during training of the RPN.
  85. rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
  86. for computing the loss
  87. rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
  88. of the RPN
  89. rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
  90. box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
  91. the locations indicated by the bounding boxes
  92. box_head (nn.Module): module that takes the cropped feature maps as input
  93. box_predictor (nn.Module): module that takes the output of box_head and returns the
  94. classification logits and box regression deltas.
  95. box_score_thresh (float): during inference, only return proposals with a classification score
  96. greater than box_score_thresh
  97. box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
  98. box_detections_per_img (int): maximum number of detections per image, for all classes.
  99. box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
  100. considered as positive during training of the classification head
  101. box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
  102. considered as negative during training of the classification head
  103. box_batch_size_per_image (int): number of proposals that are sampled during training of the
  104. classification head
  105. box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
  106. of the classification head
  107. bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
  108. bounding boxes
  109. Example::
  110. >>> import torch
  111. >>> import torchvision
  112. >>> from torchvision.models.detection import FasterRCNN
  113. >>> from torchvision.models.detection.rpn import AnchorGenerator
  114. >>> # load a pre-trained model for classification and return
  115. >>> # only the features
  116. >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
  117. >>> # FasterRCNN needs to know the number of
  118. >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
  119. >>> # so we need to add it here
  120. >>> backbone.out_channels = 1280
  121. >>>
  122. >>> # let's make the RPN generate 5 x 3 anchors per spatial
  123. >>> # location, with 5 different sizes and 3 different aspect
  124. >>> # ratios. We have a Tuple[Tuple[int]] because each feature
  125. >>> # map could potentially have different sizes and
  126. >>> # aspect ratios
  127. >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
  128. >>> aspect_ratios=((0.5, 1.0, 2.0),))
  129. >>>
  130. >>> # let's define what are the feature maps that we will
  131. >>> # use to perform the region of interest cropping, as well as
  132. >>> # the size of the crop after rescaling.
  133. >>> # if your backbone returns a Tensor, featmap_names is expected to
  134. >>> # be ['0']. More generally, the backbone should return an
  135. >>> # OrderedDict[Tensor], and in featmap_names you can choose which
  136. >>> # feature maps to use.
  137. >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
  138. >>> output_size=7,
  139. >>> sampling_ratio=2)
  140. >>>
  141. >>> # put the pieces together inside a FasterRCNN model
  142. >>> model = FasterRCNN(backbone,
  143. >>> num_classes=2,
  144. >>> rpn_anchor_generator=anchor_generator,
  145. >>> box_roi_pool=roi_pooler)
  146. >>> model.eval()
  147. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  148. >>> predictions = model(x)
  149. """
  150. def __init__(
  151. self,
  152. backbone,
  153. num_classes=None,
  154. # transform parameters
  155. min_size=512, # 原800
  156. max_size=1333,
  157. image_mean=None,
  158. image_std=None,
  159. # RPN parameters
  160. rpn_anchor_generator=None,
  161. rpn_head=None,
  162. rpn_pre_nms_top_n_train=2000,
  163. rpn_pre_nms_top_n_test=1000,
  164. rpn_post_nms_top_n_train=2000,
  165. rpn_post_nms_top_n_test=1000,
  166. rpn_nms_thresh=0.7,
  167. rpn_fg_iou_thresh=0.7,
  168. rpn_bg_iou_thresh=0.3,
  169. rpn_batch_size_per_image=256,
  170. rpn_positive_fraction=0.5,
  171. rpn_score_thresh=0.0,
  172. # Box parameters
  173. box_roi_pool=None,
  174. box_head=None,
  175. box_predictor=None,
  176. box_score_thresh=0.05,
  177. box_nms_thresh=0.5,
  178. box_detections_per_img=100,
  179. box_fg_iou_thresh=0.5,
  180. box_bg_iou_thresh=0.5,
  181. box_batch_size_per_image=512,
  182. box_positive_fraction=0.25,
  183. bbox_reg_weights=None,
  184. **kwargs,
  185. ):
  186. if not hasattr(backbone, "out_channels"):
  187. raise ValueError(
  188. "backbone should contain an attribute out_channels "
  189. "specifying the number of output channels (assumed to be the "
  190. "same for all the levels)"
  191. )
  192. if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
  193. raise TypeError(
  194. f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
  195. )
  196. if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
  197. raise TypeError(
  198. f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
  199. )
  200. if num_classes is not None:
  201. if box_predictor is not None:
  202. raise ValueError("num_classes should be None when box_predictor is specified")
  203. else:
  204. if box_predictor is None:
  205. raise ValueError("num_classes should not be None when box_predictor is not specified")
  206. out_channels = backbone.out_channels
  207. if rpn_anchor_generator is None:
  208. rpn_anchor_generator = _default_anchorgen()
  209. if rpn_head is None:
  210. rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
  211. rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
  212. rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
  213. rpn = RegionProposalNetwork(
  214. rpn_anchor_generator,
  215. rpn_head,
  216. rpn_fg_iou_thresh,
  217. rpn_bg_iou_thresh,
  218. rpn_batch_size_per_image,
  219. rpn_positive_fraction,
  220. rpn_pre_nms_top_n,
  221. rpn_post_nms_top_n,
  222. rpn_nms_thresh,
  223. score_thresh=rpn_score_thresh,
  224. )
  225. if box_roi_pool is None:
  226. box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
  227. if box_head is None:
  228. resolution = box_roi_pool.output_size[0]
  229. representation_size = 1024
  230. box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
  231. if box_predictor is None:
  232. representation_size = 1024
  233. box_predictor = FastRCNNPredictor(representation_size, num_classes)
  234. roi_heads = RoIHeads(
  235. # Box
  236. box_roi_pool,
  237. box_head,
  238. box_predictor,
  239. box_fg_iou_thresh,
  240. box_bg_iou_thresh,
  241. box_batch_size_per_image,
  242. box_positive_fraction,
  243. bbox_reg_weights,
  244. box_score_thresh,
  245. box_nms_thresh,
  246. box_detections_per_img,
  247. )
  248. if image_mean is None:
  249. image_mean = [0.485, 0.456, 0.406]
  250. if image_std is None:
  251. image_std = [0.229, 0.224, 0.225]
  252. transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  253. super().__init__(backbone, rpn, roi_heads, transform)
  254. class TwoMLPHead(nn.Module):
  255. """
  256. Standard heads for FPN-based models
  257. Args:
  258. in_channels (int): number of input channels
  259. representation_size (int): size of the intermediate representation
  260. """
  261. def __init__(self, in_channels, representation_size):
  262. super().__init__()
  263. self.fc6 = nn.Linear(in_channels, representation_size)
  264. self.fc7 = nn.Linear(representation_size, representation_size)
  265. def forward(self, x):
  266. x = x.flatten(start_dim=1)
  267. x = F.relu(self.fc6(x))
  268. x = F.relu(self.fc7(x))
  269. return x
  270. class FastRCNNConvFCHead(nn.Sequential):
  271. def __init__(
  272. self,
  273. input_size: Tuple[int, int, int],
  274. conv_layers: List[int],
  275. fc_layers: List[int],
  276. norm_layer: Optional[Callable[..., nn.Module]] = None,
  277. ):
  278. """
  279. Args:
  280. input_size (Tuple[int, int, int]): the input size in CHW format.
  281. conv_layers (list): feature dimensions of each Convolution layer
  282. fc_layers (list): feature dimensions of each FCN layer
  283. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  284. """
  285. in_channels, in_height, in_width = input_size
  286. blocks = []
  287. previous_channels = in_channels
  288. for current_channels in conv_layers:
  289. blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
  290. previous_channels = current_channels
  291. blocks.append(nn.Flatten())
  292. previous_channels = previous_channels * in_height * in_width
  293. for current_channels in fc_layers:
  294. blocks.append(nn.Linear(previous_channels, current_channels))
  295. blocks.append(nn.ReLU(inplace=True))
  296. previous_channels = current_channels
  297. super().__init__(*blocks)
  298. for layer in self.modules():
  299. if isinstance(layer, nn.Conv2d):
  300. nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
  301. if layer.bias is not None:
  302. nn.init.zeros_(layer.bias)
  303. class FastRCNNPredictor(nn.Module):
  304. """
  305. Standard classification + bounding box regression layers
  306. for Fast R-CNN.
  307. Args:
  308. in_channels (int): number of input channels
  309. num_classes (int): number of output classes (including background)
  310. """
  311. def __init__(self, in_channels, num_classes):
  312. super().__init__()
  313. self.cls_score = nn.Linear(in_channels, num_classes)
  314. self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
  315. def forward(self, x):
  316. if x.dim() == 4:
  317. torch._assert(
  318. list(x.shape[2:]) == [1, 1],
  319. f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
  320. )
  321. x = x.flatten(start_dim=1)
  322. scores = self.cls_score(x)
  323. bbox_deltas = self.bbox_pred(x)
  324. return scores, bbox_deltas
  325. _COMMON_META = {
  326. "categories": _COCO_CATEGORIES,
  327. "min_size": (1, 1),
  328. }
  329. class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
  330. COCO_V1 = Weights(
  331. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
  332. transforms=ObjectDetection,
  333. meta={
  334. **_COMMON_META,
  335. "num_params": 41755286,
  336. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
  337. "_metrics": {
  338. "COCO-val2017": {
  339. "box_map": 37.0,
  340. }
  341. },
  342. "_ops": 134.38,
  343. "_file_size": 159.743,
  344. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  345. },
  346. )
  347. DEFAULT = COCO_V1
  348. class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
  349. COCO_V1 = Weights(
  350. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
  351. transforms=ObjectDetection,
  352. meta={
  353. **_COMMON_META,
  354. "num_params": 43712278,
  355. "recipe": "https://github.com/pytorch/vision/pull/5763",
  356. "_metrics": {
  357. "COCO-val2017": {
  358. "box_map": 46.7,
  359. }
  360. },
  361. "_ops": 280.371,
  362. "_file_size": 167.104,
  363. "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
  364. },
  365. )
  366. DEFAULT = COCO_V1
  367. class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
  368. COCO_V1 = Weights(
  369. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
  370. transforms=ObjectDetection,
  371. meta={
  372. **_COMMON_META,
  373. "num_params": 19386354,
  374. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
  375. "_metrics": {
  376. "COCO-val2017": {
  377. "box_map": 32.8,
  378. }
  379. },
  380. "_ops": 4.494,
  381. "_file_size": 74.239,
  382. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  383. },
  384. )
  385. DEFAULT = COCO_V1
  386. class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
  387. COCO_V1 = Weights(
  388. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
  389. transforms=ObjectDetection,
  390. meta={
  391. **_COMMON_META,
  392. "num_params": 19386354,
  393. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
  394. "_metrics": {
  395. "COCO-val2017": {
  396. "box_map": 22.8,
  397. }
  398. },
  399. "_ops": 0.719,
  400. "_file_size": 74.239,
  401. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  402. },
  403. )
  404. DEFAULT = COCO_V1
  405. @register_model()
  406. @handle_legacy_interface(
  407. weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
  408. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  409. )
  410. def fasterrcnn_resnet50_fpn(
  411. *,
  412. weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None,
  413. progress: bool = True,
  414. num_classes: Optional[int] = None,
  415. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  416. trainable_backbone_layers: Optional[int] = None,
  417. **kwargs: Any,
  418. ) -> FasterRCNN:
  419. """
  420. Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
  421. Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
  422. paper.
  423. .. betastatus:: detection module
  424. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  425. image, and should be in ``0-1`` range. Different images can have different sizes.
  426. The behavior of the model changes depending on if it is in training or evaluation mode.
  427. During training, the model expects both the input tensors and a targets (list of dictionary),
  428. containing:
  429. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  430. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  431. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  432. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  433. losses for both the RPN and the R-CNN.
  434. During inference, the model requires only the input tensors, and returns the post-processed
  435. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  436. follows, where ``N`` is the number of detections:
  437. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  438. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  439. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  440. - scores (``Tensor[N]``): the scores of each detection
  441. For more details on the output, you may refer to :ref:`instance_seg_output`.
  442. Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  443. Example::
  444. >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
  445. >>> # For training
  446. >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
  447. >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
  448. >>> labels = torch.randint(1, 91, (4, 11))
  449. >>> images = list(image for image in images)
  450. >>> targets = []
  451. >>> for i in range(len(images)):
  452. >>> d = {}
  453. >>> d['boxes'] = boxes[i]
  454. >>> d['labels'] = labels[i]
  455. >>> targets.append(d)
  456. >>> output = model(images, targets)
  457. >>> # For inference
  458. >>> model.eval()
  459. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  460. >>> predictions = model(x)
  461. >>>
  462. >>> # optionally, if you want to export the model to ONNX:
  463. >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
  464. Args:
  465. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
  466. pretrained weights to use. See
  467. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
  468. more details, and possible values. By default, no pre-trained
  469. weights are used.
  470. progress (bool, optional): If True, displays a progress bar of the
  471. download to stderr. Default is True.
  472. num_classes (int, optional): number of output classes of the model (including the background)
  473. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  474. pretrained weights for the backbone.
  475. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  476. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  477. trainable. If ``None`` is passed (the default) this value is set to 3.
  478. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  479. base class. Please refer to the `source code
  480. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  481. for more details about this class.
  482. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
  483. :members:
  484. """
  485. weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights)
  486. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  487. if weights is not None:
  488. weights_backbone = None
  489. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  490. elif num_classes is None:
  491. num_classes = 91
  492. is_trained = weights is not None or weights_backbone is not None
  493. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  494. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  495. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  496. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  497. model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
  498. if weights is not None:
  499. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  500. if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
  501. overwrite_eps(model, 0.0)
  502. return model
  503. @register_model()
  504. @handle_legacy_interface(
  505. weights=("pretrained", FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1),
  506. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  507. )
  508. def fasterrcnn_resnet50_fpn_v2(
  509. *,
  510. weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
  511. progress: bool = True,
  512. num_classes: Optional[int] = None,
  513. weights_backbone: Optional[ResNet50_Weights] = None,
  514. trainable_backbone_layers: Optional[int] = None,
  515. **kwargs: Any,
  516. ) -> FasterRCNN:
  517. """
  518. Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
  519. Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
  520. .. betastatus:: detection module
  521. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  522. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  523. details.
  524. Args:
  525. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
  526. pretrained weights to use. See
  527. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
  528. more details, and possible values. By default, no pre-trained
  529. weights are used.
  530. progress (bool, optional): If True, displays a progress bar of the
  531. download to stderr. Default is True.
  532. num_classes (int, optional): number of output classes of the model (including the background)
  533. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  534. pretrained weights for the backbone.
  535. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  536. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  537. trainable. If ``None`` is passed (the default) this value is set to 3.
  538. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  539. base class. Please refer to the `source code
  540. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  541. for more details about this class.
  542. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
  543. :members:
  544. """
  545. weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
  546. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  547. if weights is not None:
  548. weights_backbone = None
  549. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  550. elif num_classes is None:
  551. num_classes = 91
  552. is_trained = weights is not None or weights_backbone is not None
  553. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  554. backbone = resnet50(weights=weights_backbone, progress=progress)
  555. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
  556. rpn_anchor_generator = _default_anchorgen()
  557. rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
  558. box_head = FastRCNNConvFCHead(
  559. (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
  560. )
  561. model = FasterRCNN(
  562. backbone,
  563. num_classes=num_classes,
  564. rpn_anchor_generator=rpn_anchor_generator,
  565. rpn_head=rpn_head,
  566. box_head=box_head,
  567. **kwargs,
  568. )
  569. if weights is not None:
  570. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  571. return model
  572. def _fasterrcnn_mobilenet_v3_large_fpn(
  573. *,
  574. weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
  575. progress: bool,
  576. num_classes: Optional[int],
  577. weights_backbone: Optional[MobileNet_V3_Large_Weights],
  578. trainable_backbone_layers: Optional[int],
  579. **kwargs: Any,
  580. ) -> FasterRCNN:
  581. if weights is not None:
  582. weights_backbone = None
  583. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  584. elif num_classes is None:
  585. num_classes = 91
  586. is_trained = weights is not None or weights_backbone is not None
  587. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
  588. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  589. backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  590. backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
  591. anchor_sizes = (
  592. (
  593. 32,
  594. 64,
  595. 128,
  596. 256,
  597. 512,
  598. ),
  599. ) * 3
  600. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  601. model = FasterRCNN(
  602. backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
  603. )
  604. if weights is not None:
  605. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  606. return model
  607. @register_model()
  608. @handle_legacy_interface(
  609. weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
  610. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  611. )
  612. def fasterrcnn_mobilenet_v3_large_320_fpn(
  613. *,
  614. weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None,
  615. progress: bool = True,
  616. num_classes: Optional[int] = None,
  617. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  618. trainable_backbone_layers: Optional[int] = None,
  619. **kwargs: Any,
  620. ) -> FasterRCNN:
  621. """
  622. Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
  623. .. betastatus:: detection module
  624. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  625. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  626. details.
  627. Example::
  628. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
  629. >>> model.eval()
  630. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  631. >>> predictions = model(x)
  632. Args:
  633. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
  634. pretrained weights to use. See
  635. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
  636. more details, and possible values. By default, no pre-trained
  637. weights are used.
  638. progress (bool, optional): If True, displays a progress bar of the
  639. download to stderr. Default is True.
  640. num_classes (int, optional): number of output classes of the model (including the background)
  641. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  642. pretrained weights for the backbone.
  643. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  644. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  645. trainable. If ``None`` is passed (the default) this value is set to 3.
  646. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  647. base class. Please refer to the `source code
  648. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  649. for more details about this class.
  650. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
  651. :members:
  652. """
  653. weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
  654. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  655. defaults = {
  656. "min_size": 320,
  657. "max_size": 640,
  658. "rpn_pre_nms_top_n_test": 150,
  659. "rpn_post_nms_top_n_test": 150,
  660. "rpn_score_thresh": 0.05,
  661. }
  662. kwargs = {**defaults, **kwargs}
  663. return _fasterrcnn_mobilenet_v3_large_fpn(
  664. weights=weights,
  665. progress=progress,
  666. num_classes=num_classes,
  667. weights_backbone=weights_backbone,
  668. trainable_backbone_layers=trainable_backbone_layers,
  669. **kwargs,
  670. )
  671. @register_model()
  672. @handle_legacy_interface(
  673. weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
  674. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  675. )
  676. def fasterrcnn_mobilenet_v3_large_fpn(
  677. *,
  678. weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None,
  679. progress: bool = True,
  680. num_classes: Optional[int] = None,
  681. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  682. trainable_backbone_layers: Optional[int] = None,
  683. **kwargs: Any,
  684. ) -> FasterRCNN:
  685. """
  686. Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
  687. .. betastatus:: detection module
  688. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  689. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  690. details.
  691. Example::
  692. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
  693. >>> model.eval()
  694. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  695. >>> predictions = model(x)
  696. Args:
  697. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
  698. pretrained weights to use. See
  699. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
  700. more details, and possible values. By default, no pre-trained
  701. weights are used.
  702. progress (bool, optional): If True, displays a progress bar of the
  703. download to stderr. Default is True.
  704. num_classes (int, optional): number of output classes of the model (including the background)
  705. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  706. pretrained weights for the backbone.
  707. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  708. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  709. trainable. If ``None`` is passed (the default) this value is set to 3.
  710. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  711. base class. Please refer to the `source code
  712. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  713. for more details about this class.
  714. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
  715. :members:
  716. """
  717. weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights)
  718. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  719. defaults = {
  720. "rpn_score_thresh": 0.05,
  721. }
  722. kwargs = {**defaults, **kwargs}
  723. return _fasterrcnn_mobilenet_v3_large_fpn(
  724. weights=weights,
  725. progress=progress,
  726. num_classes=num_classes,
  727. weights_backbone=weights_backbone,
  728. trainable_backbone_layers=trainable_backbone_layers,
  729. **kwargs,
  730. )