line_net.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838
  1. from typing import Any, Callable, List, Optional, Tuple, Union
  2. import torch
  3. from torch import nn
  4. from torchvision.ops import MultiScaleRoIAlign
  5. from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large
  6. from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
  7. from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
  8. from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
  9. from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
  10. from libs.vision_libs.ops import misc as misc_nn_ops
  11. from libs.vision_libs.transforms._presets import ObjectDetection
  12. from .line_head import LineRCNNHeads
  13. from .line_predictor import LineRCNNPredictor
  14. from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
  15. from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
  16. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  17. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
  18. from libs.vision_libs.models.detection._utils import overwrite_eps
  19. from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  20. from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
  21. from .roi_heads import RoIHeads
  22. from .trainer import Trainer
  23. from ..base.base_detection_net import BaseDetectionNet
  24. import torch.nn.functional as F
  25. from ..config.config_tool import read_yaml
  26. FEATURE_DIM = 8
  27. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  28. __all__ = [
  29. "LineNet",
  30. "LineNet_ResNet50_FPN_Weights",
  31. "LineNet_ResNet50_FPN_V2_Weights",
  32. "LineNet_MobileNet_V3_Large_FPN_Weights",
  33. "LineNet_MobileNet_V3_Large_320_FPN_Weights",
  34. "linenet_resnet50_fpn",
  35. "linenet_resnet50_fpn_v2",
  36. "linenet_mobilenet_v3_large_fpn",
  37. "linenet_mobilenet_v3_large_320_fpn",
  38. ]
  39. def _default_anchorgen():
  40. anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
  41. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  42. return AnchorGenerator(anchor_sizes, aspect_ratios)
  43. class LineNet(BaseDetectionNet):
  44. def __init__(self, cfg, **kwargs):
  45. cfg = read_yaml(cfg)
  46. backbone = cfg['model']['backbone']
  47. num_classes = cfg['model']['num_classes']
  48. if backbone == 'resnet50_fpn':
  49. is_trained = False
  50. trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
  51. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  52. backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
  53. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  54. out_channels = backbone.out_channels
  55. min_size = 512,
  56. max_size = 1333,
  57. rpn_pre_nms_top_n_train = 2000,
  58. rpn_pre_nms_top_n_test = 1000,
  59. rpn_post_nms_top_n_train = 2000,
  60. rpn_post_nms_top_n_test = 1000,
  61. rpn_nms_thresh = 0.7,
  62. rpn_fg_iou_thresh = 0.7,
  63. rpn_bg_iou_thresh = 0.3,
  64. rpn_batch_size_per_image = 256,
  65. rpn_positive_fraction = 0.5,
  66. rpn_score_thresh = 0.0,
  67. box_score_thresh = 0.05,
  68. box_nms_thresh = 0.5,
  69. box_detections_per_img = 100,
  70. box_fg_iou_thresh = 0.5,
  71. box_bg_iou_thresh = 0.5,
  72. box_batch_size_per_image = 512,
  73. box_positive_fraction = 0.25,
  74. bbox_reg_weights = None,
  75. line_head = LineRCNNHeads(out_channels, 5)
  76. line_predictor = LineRCNNPredictor(cfg)
  77. rpn_anchor_generator = _default_anchorgen()
  78. rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
  79. rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
  80. rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
  81. rpn = RegionProposalNetwork(
  82. rpn_anchor_generator,
  83. rpn_head,
  84. rpn_fg_iou_thresh,
  85. rpn_bg_iou_thresh,
  86. rpn_batch_size_per_image,
  87. rpn_positive_fraction,
  88. rpn_pre_nms_top_n,
  89. rpn_post_nms_top_n,
  90. rpn_nms_thresh,
  91. score_thresh=rpn_score_thresh,
  92. )
  93. box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
  94. resolution = box_roi_pool.output_size[0]
  95. representation_size = 1024
  96. box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
  97. representation_size = 1024
  98. box_predictor = BoxPredictor(representation_size, num_classes)
  99. roi_heads = RoIHeads(
  100. # Box
  101. box_roi_pool,
  102. box_head,
  103. box_predictor,
  104. line_head,
  105. line_predictor,
  106. box_fg_iou_thresh,
  107. box_bg_iou_thresh,
  108. box_batch_size_per_image,
  109. box_positive_fraction,
  110. bbox_reg_weights,
  111. box_score_thresh,
  112. box_nms_thresh,
  113. box_detections_per_img,
  114. )
  115. image_mean = [0.485, 0.456, 0.406]
  116. image_std = [0.229, 0.224, 0.225]
  117. transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  118. super().__init__(backbone, rpn, roi_heads, transform)
  119. self.roi_heads = roi_heads
  120. # def __init__(
  121. # self,
  122. # backbone,
  123. # num_classes=None,
  124. # # transform parameters
  125. # min_size=512,
  126. # max_size=1333,
  127. # image_mean=None,
  128. # image_std=None,
  129. # # RPN parameters
  130. # rpn_anchor_generator=None,
  131. # rpn_head=None,
  132. # rpn_pre_nms_top_n_train=2000,
  133. # rpn_pre_nms_top_n_test=1000,
  134. # rpn_post_nms_top_n_train=2000,
  135. # rpn_post_nms_top_n_test=1000,
  136. # rpn_nms_thresh=0.7,
  137. # rpn_fg_iou_thresh=0.7,
  138. # rpn_bg_iou_thresh=0.3,
  139. # rpn_batch_size_per_image=256,
  140. # rpn_positive_fraction=0.5,
  141. # rpn_score_thresh=0.0,
  142. # # Box parameters
  143. # box_roi_pool=None,
  144. # box_head=None,
  145. # box_predictor=None,
  146. # box_score_thresh=0.05,
  147. # box_nms_thresh=0.5,
  148. # box_detections_per_img=100,
  149. # box_fg_iou_thresh=0.5,
  150. # box_bg_iou_thresh=0.5,
  151. # box_batch_size_per_image=512,
  152. # box_positive_fraction=0.25,
  153. # bbox_reg_weights=None,
  154. # # line parameters
  155. # line_head=None,
  156. # line_predictor=None,
  157. # **kwargs,
  158. # ):
  159. #
  160. # if not hasattr(backbone, "out_channels"):
  161. # raise ValueError(
  162. # "backbone should contain an attribute out_channels "
  163. # "specifying the number of output channels (assumed to be the "
  164. # "same for all the levels)"
  165. # )
  166. #
  167. # if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
  168. # raise TypeError(
  169. # f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
  170. # )
  171. # if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
  172. # raise TypeError(
  173. # f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
  174. # )
  175. #
  176. # if num_classes is not None:
  177. # if box_predictor is not None:
  178. # raise ValueError("num_classes should be None when box_predictor is specified")
  179. # else:
  180. # if box_predictor is None:
  181. # raise ValueError("num_classes should not be None when box_predictor is not specified")
  182. #
  183. # out_channels = backbone.out_channels
  184. #
  185. # if line_head is None:
  186. # num_class = 5
  187. # line_head = LineRCNNHeads(out_channels, num_class)
  188. #
  189. # if line_predictor is None:
  190. # line_predictor = LineRCNNPredictor()
  191. #
  192. # if rpn_anchor_generator is None:
  193. # rpn_anchor_generator = _default_anchorgen()
  194. # if rpn_head is None:
  195. # rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
  196. #
  197. # rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
  198. # rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
  199. #
  200. # rpn = RegionProposalNetwork(
  201. # rpn_anchor_generator,
  202. # rpn_head,
  203. # rpn_fg_iou_thresh,
  204. # rpn_bg_iou_thresh,
  205. # rpn_batch_size_per_image,
  206. # rpn_positive_fraction,
  207. # rpn_pre_nms_top_n,
  208. # rpn_post_nms_top_n,
  209. # rpn_nms_thresh,
  210. # score_thresh=rpn_score_thresh,
  211. # )
  212. #
  213. # if box_roi_pool is None:
  214. # box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
  215. #
  216. # if box_head is None:
  217. # resolution = box_roi_pool.output_size[0]
  218. # representation_size = 1024
  219. # box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
  220. #
  221. # if box_predictor is None:
  222. # representation_size = 1024
  223. # box_predictor = BoxPredictor(representation_size, num_classes)
  224. #
  225. # roi_heads = RoIHeads(
  226. # # Box
  227. # box_roi_pool,
  228. # box_head,
  229. # box_predictor,
  230. # line_head,
  231. # line_predictor,
  232. # box_fg_iou_thresh,
  233. # box_bg_iou_thresh,
  234. # box_batch_size_per_image,
  235. # box_positive_fraction,
  236. # bbox_reg_weights,
  237. # box_score_thresh,
  238. # box_nms_thresh,
  239. # box_detections_per_img,
  240. # )
  241. #
  242. # if image_mean is None:
  243. # image_mean = [0.485, 0.456, 0.406]
  244. # if image_std is None:
  245. # image_std = [0.229, 0.224, 0.225]
  246. # transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  247. #
  248. # super().__init__(backbone, rpn, roi_heads, transform)
  249. #
  250. # self.roi_heads = roi_heads
  251. # self.roi_heads.line_head = line_head
  252. # self.roi_heads.line_predictor = line_predictor
  253. def train(self, cfg):
  254. # cfg = read_yaml(cfg)
  255. self.trainer = Trainer()
  256. self.trainer.train_cfg(model=self, cfg=cfg)
  257. class TwoMLPHead(nn.Module):
  258. """
  259. Standard heads for FPN-based models
  260. Args:
  261. in_channels (int): number of input channels
  262. representation_size (int): size of the intermediate representation
  263. """
  264. def __init__(self, in_channels, representation_size):
  265. super().__init__()
  266. self.fc6 = nn.Linear(in_channels, representation_size)
  267. self.fc7 = nn.Linear(representation_size, representation_size)
  268. def forward(self, x):
  269. x = x.flatten(start_dim=1)
  270. x = F.relu(self.fc6(x))
  271. x = F.relu(self.fc7(x))
  272. return x
  273. class LineNetConvFCHead(nn.Sequential):
  274. def __init__(
  275. self,
  276. input_size: Tuple[int, int, int],
  277. conv_layers: List[int],
  278. fc_layers: List[int],
  279. norm_layer: Optional[Callable[..., nn.Module]] = None,
  280. ):
  281. """
  282. Args:
  283. input_size (Tuple[int, int, int]): the input size in CHW format.
  284. conv_layers (list): feature dimensions of each Convolution layer
  285. fc_layers (list): feature dimensions of each FCN layer
  286. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  287. """
  288. in_channels, in_height, in_width = input_size
  289. blocks = []
  290. previous_channels = in_channels
  291. for current_channels in conv_layers:
  292. blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
  293. previous_channels = current_channels
  294. blocks.append(nn.Flatten())
  295. previous_channels = previous_channels * in_height * in_width
  296. for current_channels in fc_layers:
  297. blocks.append(nn.Linear(previous_channels, current_channels))
  298. blocks.append(nn.ReLU(inplace=True))
  299. previous_channels = current_channels
  300. super().__init__(*blocks)
  301. for layer in self.modules():
  302. if isinstance(layer, nn.Conv2d):
  303. nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
  304. if layer.bias is not None:
  305. nn.init.zeros_(layer.bias)
  306. class BoxPredictor(nn.Module):
  307. """
  308. Standard classification + bounding box regression layers
  309. for Fast R-CNN.
  310. Args:
  311. in_channels (int): number of input channels
  312. num_classes (int): number of output classes (including background)
  313. """
  314. def __init__(self, in_channels, num_classes):
  315. super().__init__()
  316. self.cls_score = nn.Linear(in_channels, num_classes)
  317. self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
  318. def forward(self, x):
  319. if x.dim() == 4:
  320. torch._assert(
  321. list(x.shape[2:]) == [1, 1],
  322. f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
  323. )
  324. x = x.flatten(start_dim=1)
  325. scores = self.cls_score(x)
  326. bbox_deltas = self.bbox_pred(x)
  327. return scores, bbox_deltas
  328. _COMMON_META = {
  329. "categories": _COCO_CATEGORIES,
  330. "min_size": (1, 1),
  331. }
  332. class LineNet_ResNet50_FPN_Weights(WeightsEnum):
  333. COCO_V1 = Weights(
  334. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
  335. transforms=ObjectDetection,
  336. meta={
  337. **_COMMON_META,
  338. "num_params": 41755286,
  339. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
  340. "_metrics": {
  341. "COCO-val2017": {
  342. "box_map": 37.0,
  343. }
  344. },
  345. "_ops": 134.38,
  346. "_file_size": 159.743,
  347. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  348. },
  349. )
  350. DEFAULT = COCO_V1
  351. class LineNet_ResNet50_FPN_V2_Weights(WeightsEnum):
  352. COCO_V1 = Weights(
  353. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
  354. transforms=ObjectDetection,
  355. meta={
  356. **_COMMON_META,
  357. "num_params": 43712278,
  358. "recipe": "https://github.com/pytorch/vision/pull/5763",
  359. "_metrics": {
  360. "COCO-val2017": {
  361. "box_map": 46.7,
  362. }
  363. },
  364. "_ops": 280.371,
  365. "_file_size": 167.104,
  366. "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
  367. },
  368. )
  369. DEFAULT = COCO_V1
  370. class LineNet_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
  371. COCO_V1 = Weights(
  372. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
  373. transforms=ObjectDetection,
  374. meta={
  375. **_COMMON_META,
  376. "num_params": 19386354,
  377. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
  378. "_metrics": {
  379. "COCO-val2017": {
  380. "box_map": 32.8,
  381. }
  382. },
  383. "_ops": 4.494,
  384. "_file_size": 74.239,
  385. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  386. },
  387. )
  388. DEFAULT = COCO_V1
  389. class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
  390. COCO_V1 = Weights(
  391. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
  392. transforms=ObjectDetection,
  393. meta={
  394. **_COMMON_META,
  395. "num_params": 19386354,
  396. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
  397. "_metrics": {
  398. "COCO-val2017": {
  399. "box_map": 22.8,
  400. }
  401. },
  402. "_ops": 0.719,
  403. "_file_size": 74.239,
  404. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  405. },
  406. )
  407. DEFAULT = COCO_V1
  408. @register_model()
  409. @handle_legacy_interface(
  410. weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
  411. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  412. )
  413. def linenet_resnet50_fpn(
  414. *,
  415. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  416. progress: bool = True,
  417. num_classes: Optional[int] = None,
  418. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  419. trainable_backbone_layers: Optional[int] = None,
  420. **kwargs: Any,
  421. ) -> LineNet:
  422. """
  423. Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
  424. Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
  425. paper.
  426. .. betastatus:: detection module
  427. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  428. image, and should be in ``0-1`` range. Different images can have different sizes.
  429. The behavior of the model changes depending on if it is in training or evaluation mode.
  430. During training, the model expects both the input tensors and a targets (list of dictionary),
  431. containing:
  432. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  433. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  434. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  435. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  436. losses for both the RPN and the R-CNN.
  437. During inference, the model requires only the input tensors, and returns the post-processed
  438. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  439. follows, where ``N`` is the number of detections:
  440. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  441. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  442. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  443. - scores (``Tensor[N]``): the scores of each detection
  444. For more details on the output, you may refer to :ref:`instance_seg_output`.
  445. Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  446. Example::
  447. >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
  448. >>> # For training
  449. >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
  450. >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
  451. >>> labels = torch.randint(1, 91, (4, 11))
  452. >>> images = list(image for image in images)
  453. >>> targets = []
  454. >>> for i in range(len(images)):
  455. >>> d = {}
  456. >>> d['boxes'] = boxes[i]
  457. >>> d['labels'] = labels[i]
  458. >>> targets.append(d)
  459. >>> output = model(images, targets)
  460. >>> # For inference
  461. >>> model.eval()
  462. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  463. >>> predictions = model(x)
  464. >>>
  465. >>> # optionally, if you want to export the model to ONNX:
  466. >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
  467. Args:
  468. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
  469. pretrained weights to use. See
  470. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
  471. more details, and possible values. By default, no pre-trained
  472. weights are used.
  473. progress (bool, optional): If True, displays a progress bar of the
  474. download to stderr. Default is True.
  475. num_classes (int, optional): number of output classes of the model (including the background)
  476. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  477. pretrained weights for the backbone.
  478. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  479. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  480. trainable. If ``None`` is passed (the default) this value is set to 3.
  481. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  482. base class. Please refer to the `source code
  483. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  484. for more details about this class.
  485. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
  486. :members:
  487. """
  488. weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  489. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  490. if weights is not None:
  491. weights_backbone = None
  492. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  493. elif num_classes is None:
  494. num_classes = 91
  495. is_trained = weights is not None or weights_backbone is not None
  496. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  497. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  498. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  499. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  500. model = LineNet(backbone, num_classes=num_classes, **kwargs)
  501. if weights is not None:
  502. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  503. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  504. overwrite_eps(model, 0.0)
  505. return model
  506. @register_model()
  507. @handle_legacy_interface(
  508. weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
  509. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  510. )
  511. def linenet_resnet50_fpn_v2(
  512. *,
  513. weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
  514. progress: bool = True,
  515. num_classes: Optional[int] = None,
  516. weights_backbone: Optional[ResNet50_Weights] = None,
  517. trainable_backbone_layers: Optional[int] = None,
  518. **kwargs: Any,
  519. ) -> LineNet:
  520. """
  521. Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
  522. Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
  523. .. betastatus:: detection module
  524. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  525. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  526. details.
  527. Args:
  528. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
  529. pretrained weights to use. See
  530. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
  531. more details, and possible values. By default, no pre-trained
  532. weights are used.
  533. progress (bool, optional): If True, displays a progress bar of the
  534. download to stderr. Default is True.
  535. num_classes (int, optional): number of output classes of the model (including the background)
  536. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  537. pretrained weights for the backbone.
  538. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  539. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  540. trainable. If ``None`` is passed (the default) this value is set to 3.
  541. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  542. base class. Please refer to the `source code
  543. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  544. for more details about this class.
  545. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
  546. :members:
  547. """
  548. weights = LineNet_ResNet50_FPN_V2_Weights.verify(weights)
  549. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  550. if weights is not None:
  551. weights_backbone = None
  552. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  553. elif num_classes is None:
  554. num_classes = 91
  555. is_trained = weights is not None or weights_backbone is not None
  556. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  557. backbone = resnet50(weights=weights_backbone, progress=progress)
  558. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
  559. rpn_anchor_generator = _default_anchorgen()
  560. rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
  561. box_head = LineNetConvFCHead(
  562. (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
  563. )
  564. model = LineNet(
  565. backbone,
  566. num_classes=num_classes,
  567. rpn_anchor_generator=rpn_anchor_generator,
  568. rpn_head=rpn_head,
  569. box_head=box_head,
  570. **kwargs,
  571. )
  572. if weights is not None:
  573. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  574. return model
  575. def _linenet_mobilenet_v3_large_fpn(
  576. *,
  577. weights: Optional[Union[LineNet_MobileNet_V3_Large_FPN_Weights, LineNet_MobileNet_V3_Large_320_FPN_Weights]],
  578. progress: bool,
  579. num_classes: Optional[int],
  580. weights_backbone: Optional[MobileNet_V3_Large_Weights],
  581. trainable_backbone_layers: Optional[int],
  582. **kwargs: Any,
  583. ) -> LineNet:
  584. if weights is not None:
  585. weights_backbone = None
  586. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  587. elif num_classes is None:
  588. num_classes = 91
  589. is_trained = weights is not None or weights_backbone is not None
  590. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
  591. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  592. backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  593. backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
  594. anchor_sizes = (
  595. (
  596. 32,
  597. 64,
  598. 128,
  599. 256,
  600. 512,
  601. ),
  602. ) * 3
  603. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  604. model = LineNet(
  605. backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
  606. )
  607. if weights is not None:
  608. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  609. return model
  610. @register_model()
  611. @handle_legacy_interface(
  612. weights=("pretrained", LineNet_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
  613. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  614. )
  615. def linenet_mobilenet_v3_large_320_fpn(
  616. *,
  617. weights: Optional[LineNet_MobileNet_V3_Large_320_FPN_Weights] = None,
  618. progress: bool = True,
  619. num_classes: Optional[int] = None,
  620. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  621. trainable_backbone_layers: Optional[int] = None,
  622. **kwargs: Any,
  623. ) -> LineNet:
  624. """
  625. Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
  626. .. betastatus:: detection module
  627. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  628. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  629. details.
  630. Example::
  631. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
  632. >>> model.eval()
  633. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  634. >>> predictions = model(x)
  635. Args:
  636. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
  637. pretrained weights to use. See
  638. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
  639. more details, and possible values. By default, no pre-trained
  640. weights are used.
  641. progress (bool, optional): If True, displays a progress bar of the
  642. download to stderr. Default is True.
  643. num_classes (int, optional): number of output classes of the model (including the background)
  644. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  645. pretrained weights for the backbone.
  646. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  647. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  648. trainable. If ``None`` is passed (the default) this value is set to 3.
  649. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  650. base class. Please refer to the `source code
  651. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  652. for more details about this class.
  653. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
  654. :members:
  655. """
  656. weights = LineNet_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
  657. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  658. defaults = {
  659. "min_size": 320,
  660. "max_size": 640,
  661. "rpn_pre_nms_top_n_test": 150,
  662. "rpn_post_nms_top_n_test": 150,
  663. "rpn_score_thresh": 0.05,
  664. }
  665. kwargs = {**defaults, **kwargs}
  666. return _linenet_mobilenet_v3_large_fpn(
  667. weights=weights,
  668. progress=progress,
  669. num_classes=num_classes,
  670. weights_backbone=weights_backbone,
  671. trainable_backbone_layers=trainable_backbone_layers,
  672. **kwargs,
  673. )
  674. @register_model()
  675. @handle_legacy_interface(
  676. weights=("pretrained", LineNet_MobileNet_V3_Large_FPN_Weights.COCO_V1),
  677. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  678. )
  679. def linenet_mobilenet_v3_large_fpn(
  680. *,
  681. weights: Optional[LineNet_MobileNet_V3_Large_FPN_Weights] = None,
  682. progress: bool = True,
  683. num_classes: Optional[int] = None,
  684. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  685. trainable_backbone_layers: Optional[int] = None,
  686. **kwargs: Any,
  687. ) -> LineNet:
  688. """
  689. Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
  690. .. betastatus:: detection module
  691. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  692. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  693. details.
  694. Example::
  695. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
  696. >>> model.eval()
  697. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  698. >>> predictions = model(x)
  699. Args:
  700. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
  701. pretrained weights to use. See
  702. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
  703. more details, and possible values. By default, no pre-trained
  704. weights are used.
  705. progress (bool, optional): If True, displays a progress bar of the
  706. download to stderr. Default is True.
  707. num_classes (int, optional): number of output classes of the model (including the background)
  708. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  709. pretrained weights for the backbone.
  710. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  711. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  712. trainable. If ``None`` is passed (the default) this value is set to 3.
  713. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  714. base class. Please refer to the `source code
  715. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  716. for more details about this class.
  717. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
  718. :members:
  719. """
  720. weights = LineNet_MobileNet_V3_Large_FPN_Weights.verify(weights)
  721. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  722. defaults = {
  723. "rpn_score_thresh": 0.05,
  724. }
  725. kwargs = {**defaults, **kwargs}
  726. return _linenet_mobilenet_v3_large_fpn(
  727. weights=weights,
  728. progress=progress,
  729. num_classes=num_classes,
  730. weights_backbone=weights_backbone,
  731. trainable_backbone_layers=trainable_backbone_layers,
  732. **kwargs,
  733. )