line_net.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847
  1. from typing import Any, Callable, List, Optional, Tuple, Union
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from torchvision.ops import MultiScaleRoIAlign
  6. from libs.vision_libs.ops import misc as misc_nn_ops
  7. from libs.vision_libs.transforms._presets import ObjectDetection
  8. from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
  9. from libs.vision_libs.models._meta import _COCO_CATEGORIES
  10. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  11. from libs.vision_libs.models.mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights
  12. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
  13. from libs.vision_libs.models.detection._utils import overwrite_eps
  14. from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
  15. from libs.vision_libs.models.detection.backbone_utils import _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers
  16. from libs.vision_libs.models.detection.rpn import RegionProposalNetwork, RPNHead
  17. from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
  18. ######## 弃用 ###########
  19. __all__ = [
  20. "LineNet",
  21. "LineNet_ResNet50_FPN_Weights",
  22. "LineNet_ResNet50_FPN_V2_Weights",
  23. "LineNet_MobileNet_V3_Large_FPN_Weights",
  24. "LineNet_MobileNet_V3_Large_320_FPN_Weights",
  25. "linenet_resnet50_fpn",
  26. "fasterrcnn_resnet50_fpn_v2",
  27. "linenet_mobilenet_v3_large_fpn",
  28. "linenet_mobilenet_v3_large_320_fpn",
  29. ]
  30. from .roi_heads import RoIHeads
  31. from ..base.base_detection_net import BaseDetectionNet
  32. def _default_anchorgen():
  33. anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
  34. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  35. return AnchorGenerator(anchor_sizes, aspect_ratios)
  36. class LineNet(BaseDetectionNet):
  37. """
  38. Implements Faster R-CNN.
  39. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  40. image, and should be in 0-1 range. Different images can have different sizes.
  41. The behavior of the model changes depending on if it is in training or evaluation mode.
  42. During training, the model expects both the input tensors and targets (list of dictionary),
  43. containing:
  44. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  45. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  46. - labels (Int64Tensor[N]): the class label for each ground-truth box
  47. The model returns a Dict[Tensor] during training, containing the classification and regression
  48. losses for both the RPN and the R-CNN.
  49. During inference, the model requires only the input tensors, and returns the post-processed
  50. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  51. follows:
  52. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  53. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  54. - labels (Int64Tensor[N]): the predicted labels for each image
  55. - scores (Tensor[N]): the scores or each prediction
  56. Args:
  57. backbone (nn.Module): the network used to compute the features for the model.
  58. It should contain an out_channels attribute, which indicates the number of output
  59. channels that each feature map has (and it should be the same for all feature maps).
  60. The backbone should return a single Tensor or and OrderedDict[Tensor].
  61. num_classes (int): number of output classes of the model (including the background).
  62. If box_predictor is specified, num_classes should be None.
  63. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
  64. max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
  65. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  66. They are generally the mean values of the dataset on which the backbone has been trained
  67. on
  68. image_std (Tuple[float, float, float]): std values used for input normalization.
  69. They are generally the std values of the dataset on which the backbone has been trained on
  70. rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  71. maps.
  72. rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
  73. rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
  74. rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
  75. rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
  76. rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
  77. rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
  78. rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  79. considered as positive during training of the RPN.
  80. rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  81. considered as negative during training of the RPN.
  82. rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
  83. for computing the loss
  84. rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
  85. of the RPN
  86. rpn_score_thresh (float): during inference, only return proposals with a classification score
  87. greater than rpn_score_thresh
  88. box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
  89. the locations indicated by the bounding boxes
  90. box_head (nn.Module): module that takes the cropped feature maps as input
  91. box_predictor (nn.Module): module that takes the output of box_head and returns the
  92. classification logits and box regression deltas.
  93. box_score_thresh (float): during inference, only return proposals with a classification score
  94. greater than box_score_thresh
  95. box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
  96. box_detections_per_img (int): maximum number of detections per image, for all classes.
  97. box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
  98. considered as positive during training of the classification head
  99. box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
  100. considered as negative during training of the classification head
  101. box_batch_size_per_image (int): number of proposals that are sampled during training of the
  102. classification head
  103. box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
  104. of the classification head
  105. bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
  106. bounding boxes
  107. Example::
  108. >>> import torch
  109. >>> import torchvision
  110. >>> from torchvision.models.detection import FasterRCNN
  111. >>> from torchvision.models.detection.rpn import AnchorGenerator
  112. >>> # load a pre-trained model for classification and return
  113. >>> # only the features
  114. >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
  115. >>> # FasterRCNN needs to know the number of
  116. >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
  117. >>> # so we need to add it here
  118. >>> backbone.out_channels = 1280
  119. >>>
  120. >>> # let's make the RPN generate 5 x 3 anchors per spatial
  121. >>> # location, with 5 different sizes and 3 different aspect
  122. >>> # ratios. We have a Tuple[Tuple[int]] because each feature
  123. >>> # map could potentially have different sizes and
  124. >>> # aspect ratios
  125. >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
  126. >>> aspect_ratios=((0.5, 1.0, 2.0),))
  127. >>>
  128. >>> # let's define what are the feature maps that we will
  129. >>> # use to perform the region of interest cropping, as well as
  130. >>> # the size of the crop after rescaling.
  131. >>> # if your backbone returns a Tensor, featmap_names is expected to
  132. >>> # be ['0']. More generally, the backbone should return an
  133. >>> # OrderedDict[Tensor], and in featmap_names you can choose which
  134. >>> # feature maps to use.
  135. >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
  136. >>> output_size=7,
  137. >>> sampling_ratio=2)
  138. >>>
  139. >>> # put the pieces together inside a FasterRCNN model
  140. >>> model = FasterRCNN(backbone,
  141. >>> num_classes=2,
  142. >>> rpn_anchor_generator=anchor_generator,
  143. >>> box_roi_pool=roi_pooler)
  144. >>> model.eval()
  145. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  146. >>> predictions = model(x)
  147. """
  148. def __init__(
  149. self,
  150. backbone,
  151. num_classes=None,
  152. # transform parameters
  153. min_size=512,
  154. max_size=1333,
  155. image_mean=None,
  156. image_std=None,
  157. # RPN parameters
  158. rpn_anchor_generator=None,
  159. rpn_head=None,
  160. rpn_pre_nms_top_n_train=2000,
  161. rpn_pre_nms_top_n_test=1000,
  162. rpn_post_nms_top_n_train=2000,
  163. rpn_post_nms_top_n_test=1000,
  164. rpn_nms_thresh=0.7,
  165. rpn_fg_iou_thresh=0.7,
  166. rpn_bg_iou_thresh=0.3,
  167. rpn_batch_size_per_image=256,
  168. rpn_positive_fraction=0.5,
  169. rpn_score_thresh=0.0,
  170. # Box parameters
  171. box_roi_pool=None,
  172. box_head=None,
  173. box_predictor=None,
  174. box_score_thresh=0.05,
  175. box_nms_thresh=0.5,
  176. box_detections_per_img=100,
  177. box_fg_iou_thresh=0.5,
  178. box_bg_iou_thresh=0.5,
  179. box_batch_size_per_image=512,
  180. box_positive_fraction=0.25,
  181. bbox_reg_weights=None,
  182. **kwargs,
  183. ):
  184. if not hasattr(backbone, "out_channels"):
  185. raise ValueError(
  186. "backbone should contain an attribute out_channels "
  187. "specifying the number of output channels (assumed to be the "
  188. "same for all the levels)"
  189. )
  190. if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
  191. raise TypeError(
  192. f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
  193. )
  194. if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
  195. raise TypeError(
  196. f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
  197. )
  198. if num_classes is not None:
  199. if box_predictor is not None:
  200. raise ValueError("num_classes should be None when box_predictor is specified")
  201. else:
  202. if box_predictor is None:
  203. raise ValueError("num_classes should not be None when box_predictor is not specified")
  204. out_channels = backbone.out_channels
  205. if rpn_anchor_generator is None:
  206. rpn_anchor_generator = _default_anchorgen()
  207. if rpn_head is None:
  208. rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
  209. rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
  210. rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
  211. rpn = RegionProposalNetwork(
  212. rpn_anchor_generator,
  213. rpn_head,
  214. rpn_fg_iou_thresh,
  215. rpn_bg_iou_thresh,
  216. rpn_batch_size_per_image,
  217. rpn_positive_fraction,
  218. rpn_pre_nms_top_n,
  219. rpn_post_nms_top_n,
  220. rpn_nms_thresh,
  221. score_thresh=rpn_score_thresh,
  222. )
  223. if box_roi_pool is None:
  224. box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
  225. if box_head is None:
  226. resolution = box_roi_pool.output_size[0]
  227. representation_size = 1024
  228. box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
  229. if box_predictor is None:
  230. representation_size = 1024
  231. box_predictor = BoxPredictor(representation_size, num_classes)
  232. roi_heads = RoIHeads(
  233. # Box
  234. box_roi_pool,
  235. box_head,
  236. box_predictor,
  237. box_fg_iou_thresh,
  238. box_bg_iou_thresh,
  239. box_batch_size_per_image,
  240. box_positive_fraction,
  241. bbox_reg_weights,
  242. box_score_thresh,
  243. box_nms_thresh,
  244. box_detections_per_img,
  245. )
  246. if image_mean is None:
  247. image_mean = [0.485, 0.456, 0.406]
  248. if image_std is None:
  249. image_std = [0.229, 0.224, 0.225]
  250. transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  251. super().__init__(backbone, rpn, roi_heads, transform)
  252. class TwoMLPHead(nn.Module):
  253. """
  254. Standard heads for FPN-based models
  255. Args:
  256. in_channels (int): number of input channels
  257. representation_size (int): size of the intermediate representation
  258. """
  259. def __init__(self, in_channels, representation_size):
  260. super().__init__()
  261. self.fc6 = nn.Linear(in_channels, representation_size)
  262. self.fc7 = nn.Linear(representation_size, representation_size)
  263. def forward(self, x):
  264. x = x.flatten(start_dim=1)
  265. x = F.relu(self.fc6(x))
  266. x = F.relu(self.fc7(x))
  267. return x
  268. class LineNetConvFCHead(nn.Sequential):
  269. def __init__(
  270. self,
  271. input_size: Tuple[int, int, int],
  272. conv_layers: List[int],
  273. fc_layers: List[int],
  274. norm_layer: Optional[Callable[..., nn.Module]] = None,
  275. ):
  276. """
  277. Args:
  278. input_size (Tuple[int, int, int]): the input size in CHW format.
  279. conv_layers (list): feature dimensions of each Convolution layer
  280. fc_layers (list): feature dimensions of each FCN layer
  281. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  282. """
  283. in_channels, in_height, in_width = input_size
  284. blocks = []
  285. previous_channels = in_channels
  286. for current_channels in conv_layers:
  287. blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
  288. previous_channels = current_channels
  289. blocks.append(nn.Flatten())
  290. previous_channels = previous_channels * in_height * in_width
  291. for current_channels in fc_layers:
  292. blocks.append(nn.Linear(previous_channels, current_channels))
  293. blocks.append(nn.ReLU(inplace=True))
  294. previous_channels = current_channels
  295. super().__init__(*blocks)
  296. for layer in self.modules():
  297. if isinstance(layer, nn.Conv2d):
  298. nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
  299. if layer.bias is not None:
  300. nn.init.zeros_(layer.bias)
  301. class BoxPredictor(nn.Module):
  302. """
  303. Standard classification + bounding box regression layers
  304. for Fast R-CNN.
  305. Args:
  306. in_channels (int): number of input channels
  307. num_classes (int): number of output classes (including background)
  308. """
  309. def __init__(self, in_channels, num_classes):
  310. super().__init__()
  311. self.cls_score = nn.Linear(in_channels, num_classes)
  312. self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
  313. def forward(self, x):
  314. if x.dim() == 4:
  315. torch._assert(
  316. list(x.shape[2:]) == [1, 1],
  317. f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
  318. )
  319. x = x.flatten(start_dim=1)
  320. scores = self.cls_score(x)
  321. bbox_deltas = self.bbox_pred(x)
  322. return scores, bbox_deltas
  323. _COMMON_META = {
  324. "categories": _COCO_CATEGORIES,
  325. "min_size": (1, 1),
  326. }
  327. class LineNet_ResNet50_FPN_Weights(WeightsEnum):
  328. COCO_V1 = Weights(
  329. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
  330. transforms=ObjectDetection,
  331. meta={
  332. **_COMMON_META,
  333. "num_params": 41755286,
  334. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
  335. "_metrics": {
  336. "COCO-val2017": {
  337. "box_map": 37.0,
  338. }
  339. },
  340. "_ops": 134.38,
  341. "_file_size": 159.743,
  342. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  343. },
  344. )
  345. DEFAULT = COCO_V1
  346. class LineNet_ResNet50_FPN_V2_Weights(WeightsEnum):
  347. COCO_V1 = Weights(
  348. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
  349. transforms=ObjectDetection,
  350. meta={
  351. **_COMMON_META,
  352. "num_params": 43712278,
  353. "recipe": "https://github.com/pytorch/vision/pull/5763",
  354. "_metrics": {
  355. "COCO-val2017": {
  356. "box_map": 46.7,
  357. }
  358. },
  359. "_ops": 280.371,
  360. "_file_size": 167.104,
  361. "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
  362. },
  363. )
  364. DEFAULT = COCO_V1
  365. class LineNet_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
  366. COCO_V1 = Weights(
  367. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
  368. transforms=ObjectDetection,
  369. meta={
  370. **_COMMON_META,
  371. "num_params": 19386354,
  372. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
  373. "_metrics": {
  374. "COCO-val2017": {
  375. "box_map": 32.8,
  376. }
  377. },
  378. "_ops": 4.494,
  379. "_file_size": 74.239,
  380. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  381. },
  382. )
  383. DEFAULT = COCO_V1
  384. class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
  385. COCO_V1 = Weights(
  386. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
  387. transforms=ObjectDetection,
  388. meta={
  389. **_COMMON_META,
  390. "num_params": 19386354,
  391. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
  392. "_metrics": {
  393. "COCO-val2017": {
  394. "box_map": 22.8,
  395. }
  396. },
  397. "_ops": 0.719,
  398. "_file_size": 74.239,
  399. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  400. },
  401. )
  402. DEFAULT = COCO_V1
  403. @register_model()
  404. @handle_legacy_interface(
  405. weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
  406. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  407. )
  408. def linenet_resnet50_fpn(
  409. *,
  410. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  411. progress: bool = True,
  412. num_classes: Optional[int] = None,
  413. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  414. trainable_backbone_layers: Optional[int] = None,
  415. **kwargs: Any,
  416. ) -> LineNet:
  417. """
  418. Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
  419. Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
  420. paper.
  421. .. betastatus:: detection module
  422. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  423. image, and should be in ``0-1`` range. Different images can have different sizes.
  424. The behavior of the model changes depending on if it is in training or evaluation mode.
  425. During training, the model expects both the input tensors and a targets (list of dictionary),
  426. containing:
  427. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  428. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  429. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  430. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  431. losses for both the RPN and the R-CNN.
  432. During inference, the model requires only the input tensors, and returns the post-processed
  433. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  434. follows, where ``N`` is the number of detections:
  435. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  436. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  437. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  438. - scores (``Tensor[N]``): the scores of each detection
  439. For more details on the output, you may refer to :ref:`instance_seg_output`.
  440. Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  441. Example::
  442. >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
  443. >>> # For training
  444. >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
  445. >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
  446. >>> labels = torch.randint(1, 91, (4, 11))
  447. >>> images = list(image for image in images)
  448. >>> targets = []
  449. >>> for i in range(len(images)):
  450. >>> d = {}
  451. >>> d['boxes'] = boxes[i]
  452. >>> d['labels'] = labels[i]
  453. >>> targets.append(d)
  454. >>> output = model(images, targets)
  455. >>> # For inference
  456. >>> model.eval()
  457. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  458. >>> predictions = model(x)
  459. >>>
  460. >>> # optionally, if you want to export the model to ONNX:
  461. >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
  462. Args:
  463. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
  464. pretrained weights to use. See
  465. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
  466. more details, and possible values. By default, no pre-trained
  467. weights are used.
  468. progress (bool, optional): If True, displays a progress bar of the
  469. download to stderr. Default is True.
  470. num_classes (int, optional): number of output classes of the model (including the background)
  471. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  472. pretrained weights for the backbone.
  473. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  474. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  475. trainable. If ``None`` is passed (the default) this value is set to 3.
  476. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  477. base class. Please refer to the `source code
  478. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  479. for more details about this class.
  480. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
  481. :members:
  482. """
  483. weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  484. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  485. if weights is not None:
  486. weights_backbone = None
  487. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  488. elif num_classes is None:
  489. num_classes = 91
  490. is_trained = weights is not None or weights_backbone is not None
  491. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  492. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  493. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  494. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  495. model = LineNet(backbone, num_classes=num_classes, **kwargs)
  496. if weights is not None:
  497. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  498. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  499. overwrite_eps(model, 0.0)
  500. return model
  501. @register_model()
  502. @handle_legacy_interface(
  503. weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
  504. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  505. )
  506. def fasterrcnn_resnet50_fpn_v2(
  507. *,
  508. weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
  509. progress: bool = True,
  510. num_classes: Optional[int] = None,
  511. weights_backbone: Optional[ResNet50_Weights] = None,
  512. trainable_backbone_layers: Optional[int] = None,
  513. **kwargs: Any,
  514. ) -> LineNet:
  515. """
  516. Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
  517. Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
  518. .. betastatus:: detection module
  519. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  520. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  521. details.
  522. Args:
  523. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
  524. pretrained weights to use. See
  525. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
  526. more details, and possible values. By default, no pre-trained
  527. weights are used.
  528. progress (bool, optional): If True, displays a progress bar of the
  529. download to stderr. Default is True.
  530. num_classes (int, optional): number of output classes of the model (including the background)
  531. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  532. pretrained weights for the backbone.
  533. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  534. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  535. trainable. If ``None`` is passed (the default) this value is set to 3.
  536. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  537. base class. Please refer to the `source code
  538. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  539. for more details about this class.
  540. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
  541. :members:
  542. """
  543. weights = LineNet_ResNet50_FPN_V2_Weights.verify(weights)
  544. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  545. if weights is not None:
  546. weights_backbone = None
  547. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  548. elif num_classes is None:
  549. num_classes = 91
  550. is_trained = weights is not None or weights_backbone is not None
  551. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  552. backbone = resnet50(weights=weights_backbone, progress=progress)
  553. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
  554. rpn_anchor_generator = _default_anchorgen()
  555. rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
  556. box_head = LineNetConvFCHead(
  557. (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
  558. )
  559. model = LineNet(
  560. backbone,
  561. num_classes=num_classes,
  562. rpn_anchor_generator=rpn_anchor_generator,
  563. rpn_head=rpn_head,
  564. box_head=box_head,
  565. **kwargs,
  566. )
  567. if weights is not None:
  568. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  569. return model
  570. def _linenet_mobilenet_v3_large_fpn(
  571. *,
  572. weights: Optional[Union[LineNet_MobileNet_V3_Large_FPN_Weights, LineNet_MobileNet_V3_Large_320_FPN_Weights]],
  573. progress: bool,
  574. num_classes: Optional[int],
  575. weights_backbone: Optional[MobileNet_V3_Large_Weights],
  576. trainable_backbone_layers: Optional[int],
  577. **kwargs: Any,
  578. ) -> LineNet:
  579. if weights is not None:
  580. weights_backbone = None
  581. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  582. elif num_classes is None:
  583. num_classes = 91
  584. is_trained = weights is not None or weights_backbone is not None
  585. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
  586. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  587. backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  588. backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
  589. anchor_sizes = (
  590. (
  591. 32,
  592. 64,
  593. 128,
  594. 256,
  595. 512,
  596. ),
  597. ) * 3
  598. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  599. model = LineNet(
  600. backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
  601. )
  602. if weights is not None:
  603. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  604. return model
  605. @register_model()
  606. @handle_legacy_interface(
  607. weights=("pretrained", LineNet_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
  608. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  609. )
  610. def linenet_mobilenet_v3_large_320_fpn(
  611. *,
  612. weights: Optional[LineNet_MobileNet_V3_Large_320_FPN_Weights] = None,
  613. progress: bool = True,
  614. num_classes: Optional[int] = None,
  615. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  616. trainable_backbone_layers: Optional[int] = None,
  617. **kwargs: Any,
  618. ) -> LineNet:
  619. """
  620. Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
  621. .. betastatus:: detection module
  622. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  623. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  624. details.
  625. Example::
  626. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
  627. >>> model.eval()
  628. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  629. >>> predictions = model(x)
  630. Args:
  631. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
  632. pretrained weights to use. See
  633. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
  634. more details, and possible values. By default, no pre-trained
  635. weights are used.
  636. progress (bool, optional): If True, displays a progress bar of the
  637. download to stderr. Default is True.
  638. num_classes (int, optional): number of output classes of the model (including the background)
  639. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  640. pretrained weights for the backbone.
  641. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  642. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  643. trainable. If ``None`` is passed (the default) this value is set to 3.
  644. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  645. base class. Please refer to the `source code
  646. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  647. for more details about this class.
  648. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
  649. :members:
  650. """
  651. weights = LineNet_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
  652. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  653. defaults = {
  654. "min_size": 320,
  655. "max_size": 640,
  656. "rpn_pre_nms_top_n_test": 150,
  657. "rpn_post_nms_top_n_test": 150,
  658. "rpn_score_thresh": 0.05,
  659. }
  660. kwargs = {**defaults, **kwargs}
  661. return _linenet_mobilenet_v3_large_fpn(
  662. weights=weights,
  663. progress=progress,
  664. num_classes=num_classes,
  665. weights_backbone=weights_backbone,
  666. trainable_backbone_layers=trainable_backbone_layers,
  667. **kwargs,
  668. )
  669. @register_model()
  670. @handle_legacy_interface(
  671. weights=("pretrained", LineNet_MobileNet_V3_Large_FPN_Weights.COCO_V1),
  672. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  673. )
  674. def linenet_mobilenet_v3_large_fpn(
  675. *,
  676. weights: Optional[LineNet_MobileNet_V3_Large_FPN_Weights] = None,
  677. progress: bool = True,
  678. num_classes: Optional[int] = None,
  679. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  680. trainable_backbone_layers: Optional[int] = None,
  681. **kwargs: Any,
  682. ) -> LineNet:
  683. """
  684. Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
  685. .. betastatus:: detection module
  686. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  687. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  688. details.
  689. Example::
  690. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
  691. >>> model.eval()
  692. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  693. >>> predictions = model(x)
  694. Args:
  695. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
  696. pretrained weights to use. See
  697. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
  698. more details, and possible values. By default, no pre-trained
  699. weights are used.
  700. progress (bool, optional): If True, displays a progress bar of the
  701. download to stderr. Default is True.
  702. num_classes (int, optional): number of output classes of the model (including the background)
  703. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  704. pretrained weights for the backbone.
  705. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  706. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  707. trainable. If ``None`` is passed (the default) this value is set to 3.
  708. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  709. base class. Please refer to the `source code
  710. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  711. for more details about this class.
  712. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
  713. :members:
  714. """
  715. weights = LineNet_MobileNet_V3_Large_FPN_Weights.verify(weights)
  716. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  717. defaults = {
  718. "rpn_score_thresh": 0.05,
  719. }
  720. kwargs = {**defaults, **kwargs}
  721. return _linenet_mobilenet_v3_large_fpn(
  722. weights=weights,
  723. progress=progress,
  724. num_classes=num_classes,
  725. weights_backbone=weights_backbone,
  726. trainable_backbone_layers=trainable_backbone_layers,
  727. **kwargs,
  728. )