line_net.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912
  1. from typing import Any, Callable, List, Optional, Tuple, Union
  2. import torch
  3. from torch import nn
  4. from torchvision.ops import MultiScaleRoIAlign
  5. from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large
  6. from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
  7. from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
  8. from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
  9. from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
  10. from libs.vision_libs.ops import misc as misc_nn_ops
  11. from libs.vision_libs.transforms._presets import ObjectDetection
  12. from .line_head import LineRCNNHeads
  13. from .line_predictor import LineRCNNPredictor
  14. from .roi_heads import RoIHeads
  15. from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
  16. from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
  17. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  18. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
  19. from libs.vision_libs.models.detection._utils import overwrite_eps
  20. from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  21. from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
  22. from models.config.config_tool import read_yaml
  23. import numpy as np
  24. import torch.nn.functional as F
  25. FEATURE_DIM = 8
  26. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  27. __all__ = [
  28. "LineNet",
  29. "LineNet_ResNet50_FPN_Weights",
  30. "LineNet_ResNet50_FPN_V2_Weights",
  31. "LineNet_MobileNet_V3_Large_FPN_Weights",
  32. "LineNet_MobileNet_V3_Large_320_FPN_Weights",
  33. "linenet_resnet50_fpn",
  34. "linenet_resnet50_fpn_v2",
  35. "linenet_mobilenet_v3_large_fpn",
  36. "linenet_mobilenet_v3_large_320_fpn",
  37. ]
  38. # __all__ = [
  39. # "LineNet",
  40. # "LineRCNN_ResNet50_FPN_Weights",
  41. # "linercnn_resnet50_fpn",
  42. # ]
  43. def non_maximum_suppression(a):
  44. ap = F.max_pool2d(a, 3, stride=1, padding=1)
  45. mask = (a == ap).float().clamp(min=0.0)
  46. return a * mask
  47. # class Bottleneck1D(nn.Module):
  48. # def __init__(self, inplanes, outplanes):
  49. # super(Bottleneck1D, self).__init__()
  50. #
  51. # planes = outplanes // 2
  52. # self.op = nn.Sequential(
  53. # nn.BatchNorm1d(inplanes),
  54. # nn.ReLU(inplace=True),
  55. # nn.Conv1d(inplanes, planes, kernel_size=1),
  56. # nn.BatchNorm1d(planes),
  57. # nn.ReLU(inplace=True),
  58. # nn.Conv1d(planes, planes, kernel_size=3, padding=1),
  59. # nn.BatchNorm1d(planes),
  60. # nn.ReLU(inplace=True),
  61. # nn.Conv1d(planes, outplanes, kernel_size=1),
  62. # )
  63. #
  64. # def forward(self, x):
  65. # return x + self.op(x)
  66. from .roi_heads import RoIHeads
  67. from ..base.base_detection_net import BaseDetectionNet
  68. def _default_anchorgen():
  69. anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
  70. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  71. return AnchorGenerator(anchor_sizes, aspect_ratios)
  72. class LineNet(BaseDetectionNet):
  73. """
  74. Implements Faster R-CNN.
  75. The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
  76. image, and should be in 0-1 range. Different images can have different sizes.
  77. The behavior of the model changes depending on if it is in training or evaluation mode.
  78. During training, the model expects both the input tensors and targets (list of dictionary),
  79. containing:
  80. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  81. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  82. - labels (Int64Tensor[N]): the class label for each ground-truth box
  83. The model returns a Dict[Tensor] during training, containing the classification and regression
  84. losses for both the RPN and the R-CNN.
  85. During inference, the model requires only the input tensors, and returns the post-processed
  86. predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
  87. follows:
  88. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  89. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  90. - labels (Int64Tensor[N]): the predicted labels for each image
  91. - scores (Tensor[N]): the scores or each prediction
  92. Args:
  93. backbone (nn.Module): the network used to compute the features for the model.
  94. It should contain an out_channels attribute, which indicates the number of output
  95. channels that each feature map has (and it should be the same for all feature maps).
  96. The backbone should return a single Tensor or and OrderedDict[Tensor].
  97. num_classes (int): number of output classes of the model (including the background).
  98. If box_predictor is specified, num_classes should be None.
  99. min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
  100. max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
  101. image_mean (Tuple[float, float, float]): mean values used for input normalization.
  102. They are generally the mean values of the dataset on which the backbone has been trained
  103. on
  104. image_std (Tuple[float, float, float]): std values used for input normalization.
  105. They are generally the std values of the dataset on which the backbone has been trained on
  106. rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  107. maps.
  108. rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
  109. rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
  110. rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
  111. rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
  112. rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
  113. rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
  114. rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  115. considered as positive during training of the RPN.
  116. rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  117. considered as negative during training of the RPN.
  118. rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
  119. for computing the loss
  120. rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
  121. of the RPN
  122. rpn_score_thresh (float): during inference, only return proposals with a classification score
  123. greater than rpn_score_thresh
  124. box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
  125. the locations indicated by the bounding boxes
  126. box_head (nn.Module): module that takes the cropped feature maps as input
  127. box_predictor (nn.Module): module that takes the output of box_head and returns the
  128. classification logits and box regression deltas.
  129. box_score_thresh (float): during inference, only return proposals with a classification score
  130. greater than box_score_thresh
  131. box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
  132. box_detections_per_img (int): maximum number of detections per image, for all classes.
  133. box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
  134. considered as positive during training of the classification head
  135. box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
  136. considered as negative during training of the classification head
  137. box_batch_size_per_image (int): number of proposals that are sampled during training of the
  138. classification head
  139. box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
  140. of the classification head
  141. bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
  142. bounding boxes
  143. Example::
  144. >>> import torch
  145. >>> import torchvision
  146. >>> from torchvision.models.detection import FasterRCNN
  147. >>> from torchvision.models.detection.rpn import AnchorGenerator
  148. >>> # load a pre-trained model for classification and return
  149. >>> # only the features
  150. >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
  151. >>> # FasterRCNN needs to know the number of
  152. >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
  153. >>> # so we need to add it here
  154. >>> backbone.out_channels = 1280
  155. >>>
  156. >>> # let's make the RPN generate 5 x 3 anchors per spatial
  157. >>> # location, with 5 different sizes and 3 different aspect
  158. >>> # ratios. We have a Tuple[Tuple[int]] because each feature
  159. >>> # map could potentially have different sizes and
  160. >>> # aspect ratios
  161. >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
  162. >>> aspect_ratios=((0.5, 1.0, 2.0),))
  163. >>>
  164. >>> # let's define what are the feature maps that we will
  165. >>> # use to perform the region of interest cropping, as well as
  166. >>> # the size of the crop after rescaling.
  167. >>> # if your backbone returns a Tensor, featmap_names is expected to
  168. >>> # be ['0']. More generally, the backbone should return an
  169. >>> # OrderedDict[Tensor], and in featmap_names you can choose which
  170. >>> # feature maps to use.
  171. >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
  172. >>> output_size=7,
  173. >>> sampling_ratio=2)
  174. >>>
  175. >>> # put the pieces together inside a FasterRCNN model
  176. >>> model = FasterRCNN(backbone,
  177. >>> num_classes=2,
  178. >>> rpn_anchor_generator=anchor_generator,
  179. >>> box_roi_pool=roi_pooler)
  180. >>> model.eval()
  181. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  182. >>> predictions = model(x)
  183. """
  184. def __init__(
  185. self,
  186. backbone,
  187. num_classes=None,
  188. # transform parameters
  189. min_size=512,
  190. max_size=1333,
  191. image_mean=None,
  192. image_std=None,
  193. # RPN parameters
  194. rpn_anchor_generator=None,
  195. rpn_head=None,
  196. rpn_pre_nms_top_n_train=2000,
  197. rpn_pre_nms_top_n_test=1000,
  198. rpn_post_nms_top_n_train=2000,
  199. rpn_post_nms_top_n_test=1000,
  200. rpn_nms_thresh=0.7,
  201. rpn_fg_iou_thresh=0.7,
  202. rpn_bg_iou_thresh=0.3,
  203. rpn_batch_size_per_image=256,
  204. rpn_positive_fraction=0.5,
  205. rpn_score_thresh=0.0,
  206. # Box parameters
  207. box_roi_pool=None,
  208. box_head=None,
  209. box_predictor=None,
  210. box_score_thresh=0.05,
  211. box_nms_thresh=0.5,
  212. box_detections_per_img=100,
  213. box_fg_iou_thresh=0.5,
  214. box_bg_iou_thresh=0.5,
  215. box_batch_size_per_image=512,
  216. box_positive_fraction=0.25,
  217. bbox_reg_weights=None,
  218. # line parameters
  219. line_head=None,
  220. line_predictor=None,
  221. **kwargs,
  222. ):
  223. if not hasattr(backbone, "out_channels"):
  224. raise ValueError(
  225. "backbone should contain an attribute out_channels "
  226. "specifying the number of output channels (assumed to be the "
  227. "same for all the levels)"
  228. )
  229. if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
  230. raise TypeError(
  231. f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
  232. )
  233. if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
  234. raise TypeError(
  235. f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
  236. )
  237. if num_classes is not None:
  238. if box_predictor is not None:
  239. raise ValueError("num_classes should be None when box_predictor is specified")
  240. else:
  241. if box_predictor is None:
  242. raise ValueError("num_classes should not be None when box_predictor is not specified")
  243. out_channels = backbone.out_channels
  244. if line_head is None:
  245. num_class = 5
  246. line_head = LineRCNNHeads(out_channels, num_class)
  247. if line_predictor is None:
  248. line_predictor = LineRCNNPredictor()
  249. if rpn_anchor_generator is None:
  250. rpn_anchor_generator = _default_anchorgen()
  251. if rpn_head is None:
  252. rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
  253. rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
  254. rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
  255. rpn = RegionProposalNetwork(
  256. rpn_anchor_generator,
  257. rpn_head,
  258. rpn_fg_iou_thresh,
  259. rpn_bg_iou_thresh,
  260. rpn_batch_size_per_image,
  261. rpn_positive_fraction,
  262. rpn_pre_nms_top_n,
  263. rpn_post_nms_top_n,
  264. rpn_nms_thresh,
  265. score_thresh=rpn_score_thresh,
  266. )
  267. if box_roi_pool is None:
  268. box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
  269. if box_head is None:
  270. resolution = box_roi_pool.output_size[0]
  271. representation_size = 1024
  272. box_head = TwoMLPHead(out_channels * resolution**2, representation_size)
  273. if box_predictor is None:
  274. representation_size = 1024
  275. box_predictor = BoxPredictor(representation_size, num_classes)
  276. roi_heads = RoIHeads(
  277. # Box
  278. box_roi_pool,
  279. box_head,
  280. box_predictor,
  281. line_head,
  282. line_predictor,
  283. box_fg_iou_thresh,
  284. box_bg_iou_thresh,
  285. box_batch_size_per_image,
  286. box_positive_fraction,
  287. bbox_reg_weights,
  288. box_score_thresh,
  289. box_nms_thresh,
  290. box_detections_per_img,
  291. )
  292. if image_mean is None:
  293. image_mean = [0.485, 0.456, 0.406]
  294. if image_std is None:
  295. image_std = [0.229, 0.224, 0.225]
  296. transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  297. super().__init__(backbone, rpn, roi_heads, transform)
  298. self.roi_heads = roi_heads
  299. # self.roi_heads.line_head = line_head
  300. # self.roi_heads.line_predictor = line_predictor
  301. class TwoMLPHead(nn.Module):
  302. """
  303. Standard heads for FPN-based models
  304. Args:
  305. in_channels (int): number of input channels
  306. representation_size (int): size of the intermediate representation
  307. """
  308. def __init__(self, in_channels, representation_size):
  309. super().__init__()
  310. self.fc6 = nn.Linear(in_channels, representation_size)
  311. self.fc7 = nn.Linear(representation_size, representation_size)
  312. def forward(self, x):
  313. x = x.flatten(start_dim=1)
  314. x = F.relu(self.fc6(x))
  315. x = F.relu(self.fc7(x))
  316. return x
  317. class LineNetConvFCHead(nn.Sequential):
  318. def __init__(
  319. self,
  320. input_size: Tuple[int, int, int],
  321. conv_layers: List[int],
  322. fc_layers: List[int],
  323. norm_layer: Optional[Callable[..., nn.Module]] = None,
  324. ):
  325. """
  326. Args:
  327. input_size (Tuple[int, int, int]): the input size in CHW format.
  328. conv_layers (list): feature dimensions of each Convolution layer
  329. fc_layers (list): feature dimensions of each FCN layer
  330. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  331. """
  332. in_channels, in_height, in_width = input_size
  333. blocks = []
  334. previous_channels = in_channels
  335. for current_channels in conv_layers:
  336. blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
  337. previous_channels = current_channels
  338. blocks.append(nn.Flatten())
  339. previous_channels = previous_channels * in_height * in_width
  340. for current_channels in fc_layers:
  341. blocks.append(nn.Linear(previous_channels, current_channels))
  342. blocks.append(nn.ReLU(inplace=True))
  343. previous_channels = current_channels
  344. super().__init__(*blocks)
  345. for layer in self.modules():
  346. if isinstance(layer, nn.Conv2d):
  347. nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
  348. if layer.bias is not None:
  349. nn.init.zeros_(layer.bias)
  350. class BoxPredictor(nn.Module):
  351. """
  352. Standard classification + bounding box regression layers
  353. for Fast R-CNN.
  354. Args:
  355. in_channels (int): number of input channels
  356. num_classes (int): number of output classes (including background)
  357. """
  358. def __init__(self, in_channels, num_classes):
  359. super().__init__()
  360. self.cls_score = nn.Linear(in_channels, num_classes)
  361. self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
  362. def forward(self, x):
  363. if x.dim() == 4:
  364. torch._assert(
  365. list(x.shape[2:]) == [1, 1],
  366. f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
  367. )
  368. x = x.flatten(start_dim=1)
  369. scores = self.cls_score(x)
  370. bbox_deltas = self.bbox_pred(x)
  371. return scores, bbox_deltas
  372. _COMMON_META = {
  373. "categories": _COCO_CATEGORIES,
  374. "min_size": (1, 1),
  375. }
  376. class LineNet_ResNet50_FPN_Weights(WeightsEnum):
  377. COCO_V1 = Weights(
  378. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
  379. transforms=ObjectDetection,
  380. meta={
  381. **_COMMON_META,
  382. "num_params": 41755286,
  383. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
  384. "_metrics": {
  385. "COCO-val2017": {
  386. "box_map": 37.0,
  387. }
  388. },
  389. "_ops": 134.38,
  390. "_file_size": 159.743,
  391. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  392. },
  393. )
  394. DEFAULT = COCO_V1
  395. class LineNet_ResNet50_FPN_V2_Weights(WeightsEnum):
  396. COCO_V1 = Weights(
  397. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
  398. transforms=ObjectDetection,
  399. meta={
  400. **_COMMON_META,
  401. "num_params": 43712278,
  402. "recipe": "https://github.com/pytorch/vision/pull/5763",
  403. "_metrics": {
  404. "COCO-val2017": {
  405. "box_map": 46.7,
  406. }
  407. },
  408. "_ops": 280.371,
  409. "_file_size": 167.104,
  410. "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
  411. },
  412. )
  413. DEFAULT = COCO_V1
  414. class LineNet_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
  415. COCO_V1 = Weights(
  416. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
  417. transforms=ObjectDetection,
  418. meta={
  419. **_COMMON_META,
  420. "num_params": 19386354,
  421. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
  422. "_metrics": {
  423. "COCO-val2017": {
  424. "box_map": 32.8,
  425. }
  426. },
  427. "_ops": 4.494,
  428. "_file_size": 74.239,
  429. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  430. },
  431. )
  432. DEFAULT = COCO_V1
  433. class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
  434. COCO_V1 = Weights(
  435. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
  436. transforms=ObjectDetection,
  437. meta={
  438. **_COMMON_META,
  439. "num_params": 19386354,
  440. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
  441. "_metrics": {
  442. "COCO-val2017": {
  443. "box_map": 22.8,
  444. }
  445. },
  446. "_ops": 0.719,
  447. "_file_size": 74.239,
  448. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  449. },
  450. )
  451. DEFAULT = COCO_V1
  452. @register_model()
  453. @handle_legacy_interface(
  454. weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
  455. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  456. )
  457. def linenet_resnet50_fpn(
  458. *,
  459. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  460. progress: bool = True,
  461. num_classes: Optional[int] = None,
  462. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  463. trainable_backbone_layers: Optional[int] = None,
  464. **kwargs: Any,
  465. ) -> LineNet:
  466. """
  467. Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
  468. Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
  469. paper.
  470. .. betastatus:: detection module
  471. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  472. image, and should be in ``0-1`` range. Different images can have different sizes.
  473. The behavior of the model changes depending on if it is in training or evaluation mode.
  474. During training, the model expects both the input tensors and a targets (list of dictionary),
  475. containing:
  476. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  477. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  478. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  479. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  480. losses for both the RPN and the R-CNN.
  481. During inference, the model requires only the input tensors, and returns the post-processed
  482. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  483. follows, where ``N`` is the number of detections:
  484. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  485. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  486. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  487. - scores (``Tensor[N]``): the scores of each detection
  488. For more details on the output, you may refer to :ref:`instance_seg_output`.
  489. Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  490. Example::
  491. >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
  492. >>> # For training
  493. >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
  494. >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
  495. >>> labels = torch.randint(1, 91, (4, 11))
  496. >>> images = list(image for image in images)
  497. >>> targets = []
  498. >>> for i in range(len(images)):
  499. >>> d = {}
  500. >>> d['boxes'] = boxes[i]
  501. >>> d['labels'] = labels[i]
  502. >>> targets.append(d)
  503. >>> output = model(images, targets)
  504. >>> # For inference
  505. >>> model.eval()
  506. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  507. >>> predictions = model(x)
  508. >>>
  509. >>> # optionally, if you want to export the model to ONNX:
  510. >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
  511. Args:
  512. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
  513. pretrained weights to use. See
  514. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
  515. more details, and possible values. By default, no pre-trained
  516. weights are used.
  517. progress (bool, optional): If True, displays a progress bar of the
  518. download to stderr. Default is True.
  519. num_classes (int, optional): number of output classes of the model (including the background)
  520. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  521. pretrained weights for the backbone.
  522. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  523. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  524. trainable. If ``None`` is passed (the default) this value is set to 3.
  525. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  526. base class. Please refer to the `source code
  527. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  528. for more details about this class.
  529. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
  530. :members:
  531. """
  532. weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  533. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  534. if weights is not None:
  535. weights_backbone = None
  536. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  537. elif num_classes is None:
  538. num_classes = 91
  539. is_trained = weights is not None or weights_backbone is not None
  540. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  541. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  542. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  543. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  544. model = LineNet(backbone, num_classes=num_classes, **kwargs)
  545. if weights is not None:
  546. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  547. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  548. overwrite_eps(model, 0.0)
  549. return model
  550. @register_model()
  551. @handle_legacy_interface(
  552. weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
  553. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  554. )
  555. def linenet_resnet50_fpn_v2(
  556. *,
  557. weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
  558. progress: bool = True,
  559. num_classes: Optional[int] = None,
  560. weights_backbone: Optional[ResNet50_Weights] = None,
  561. trainable_backbone_layers: Optional[int] = None,
  562. **kwargs: Any,
  563. ) -> LineNet:
  564. """
  565. Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
  566. Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
  567. .. betastatus:: detection module
  568. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  569. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  570. details.
  571. Args:
  572. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
  573. pretrained weights to use. See
  574. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
  575. more details, and possible values. By default, no pre-trained
  576. weights are used.
  577. progress (bool, optional): If True, displays a progress bar of the
  578. download to stderr. Default is True.
  579. num_classes (int, optional): number of output classes of the model (including the background)
  580. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  581. pretrained weights for the backbone.
  582. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  583. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  584. trainable. If ``None`` is passed (the default) this value is set to 3.
  585. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  586. base class. Please refer to the `source code
  587. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  588. for more details about this class.
  589. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
  590. :members:
  591. """
  592. weights = LineNet_ResNet50_FPN_V2_Weights.verify(weights)
  593. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  594. if weights is not None:
  595. weights_backbone = None
  596. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  597. elif num_classes is None:
  598. num_classes = 91
  599. is_trained = weights is not None or weights_backbone is not None
  600. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  601. backbone = resnet50(weights=weights_backbone, progress=progress)
  602. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
  603. rpn_anchor_generator = _default_anchorgen()
  604. rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
  605. box_head = LineNetConvFCHead(
  606. (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
  607. )
  608. model = LineNet(
  609. backbone,
  610. num_classes=num_classes,
  611. rpn_anchor_generator=rpn_anchor_generator,
  612. rpn_head=rpn_head,
  613. box_head=box_head,
  614. **kwargs,
  615. )
  616. if weights is not None:
  617. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  618. return model
  619. def _linenet_mobilenet_v3_large_fpn(
  620. *,
  621. weights: Optional[Union[LineNet_MobileNet_V3_Large_FPN_Weights, LineNet_MobileNet_V3_Large_320_FPN_Weights]],
  622. progress: bool,
  623. num_classes: Optional[int],
  624. weights_backbone: Optional[MobileNet_V3_Large_Weights],
  625. trainable_backbone_layers: Optional[int],
  626. **kwargs: Any,
  627. ) -> LineNet:
  628. if weights is not None:
  629. weights_backbone = None
  630. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  631. elif num_classes is None:
  632. num_classes = 91
  633. is_trained = weights is not None or weights_backbone is not None
  634. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
  635. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  636. backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  637. backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
  638. anchor_sizes = (
  639. (
  640. 32,
  641. 64,
  642. 128,
  643. 256,
  644. 512,
  645. ),
  646. ) * 3
  647. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  648. model = LineNet(
  649. backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
  650. )
  651. if weights is not None:
  652. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  653. return model
  654. @register_model()
  655. @handle_legacy_interface(
  656. weights=("pretrained", LineNet_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
  657. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  658. )
  659. def linenet_mobilenet_v3_large_320_fpn(
  660. *,
  661. weights: Optional[LineNet_MobileNet_V3_Large_320_FPN_Weights] = None,
  662. progress: bool = True,
  663. num_classes: Optional[int] = None,
  664. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  665. trainable_backbone_layers: Optional[int] = None,
  666. **kwargs: Any,
  667. ) -> LineNet:
  668. """
  669. Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
  670. .. betastatus:: detection module
  671. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  672. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  673. details.
  674. Example::
  675. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
  676. >>> model.eval()
  677. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  678. >>> predictions = model(x)
  679. Args:
  680. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
  681. pretrained weights to use. See
  682. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
  683. more details, and possible values. By default, no pre-trained
  684. weights are used.
  685. progress (bool, optional): If True, displays a progress bar of the
  686. download to stderr. Default is True.
  687. num_classes (int, optional): number of output classes of the model (including the background)
  688. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  689. pretrained weights for the backbone.
  690. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  691. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  692. trainable. If ``None`` is passed (the default) this value is set to 3.
  693. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  694. base class. Please refer to the `source code
  695. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  696. for more details about this class.
  697. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
  698. :members:
  699. """
  700. weights = LineNet_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
  701. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  702. defaults = {
  703. "min_size": 320,
  704. "max_size": 640,
  705. "rpn_pre_nms_top_n_test": 150,
  706. "rpn_post_nms_top_n_test": 150,
  707. "rpn_score_thresh": 0.05,
  708. }
  709. kwargs = {**defaults, **kwargs}
  710. return _linenet_mobilenet_v3_large_fpn(
  711. weights=weights,
  712. progress=progress,
  713. num_classes=num_classes,
  714. weights_backbone=weights_backbone,
  715. trainable_backbone_layers=trainable_backbone_layers,
  716. **kwargs,
  717. )
  718. @register_model()
  719. @handle_legacy_interface(
  720. weights=("pretrained", LineNet_MobileNet_V3_Large_FPN_Weights.COCO_V1),
  721. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  722. )
  723. def linenet_mobilenet_v3_large_fpn(
  724. *,
  725. weights: Optional[LineNet_MobileNet_V3_Large_FPN_Weights] = None,
  726. progress: bool = True,
  727. num_classes: Optional[int] = None,
  728. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  729. trainable_backbone_layers: Optional[int] = None,
  730. **kwargs: Any,
  731. ) -> LineNet:
  732. """
  733. Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
  734. .. betastatus:: detection module
  735. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  736. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  737. details.
  738. Example::
  739. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
  740. >>> model.eval()
  741. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  742. >>> predictions = model(x)
  743. Args:
  744. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
  745. pretrained weights to use. See
  746. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
  747. more details, and possible values. By default, no pre-trained
  748. weights are used.
  749. progress (bool, optional): If True, displays a progress bar of the
  750. download to stderr. Default is True.
  751. num_classes (int, optional): number of output classes of the model (including the background)
  752. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  753. pretrained weights for the backbone.
  754. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  755. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  756. trainable. If ``None`` is passed (the default) this value is set to 3.
  757. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  758. base class. Please refer to the `source code
  759. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  760. for more details about this class.
  761. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
  762. :members:
  763. """
  764. weights = LineNet_MobileNet_V3_Large_FPN_Weights.verify(weights)
  765. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  766. defaults = {
  767. "rpn_score_thresh": 0.05,
  768. }
  769. kwargs = {**defaults, **kwargs}
  770. return _linenet_mobilenet_v3_large_fpn(
  771. weights=weights,
  772. progress=progress,
  773. num_classes=num_classes,
  774. weights_backbone=weights_backbone,
  775. trainable_backbone_layers=trainable_backbone_layers,
  776. **kwargs,
  777. )