line_net.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840
  1. import os
  2. from typing import Any, Callable, List, Optional, Tuple, Union
  3. import torch
  4. from torch import nn
  5. from torchvision.ops import MultiScaleRoIAlign
  6. from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large
  7. from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
  8. from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
  9. from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
  10. from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
  11. from libs.vision_libs.ops import misc as misc_nn_ops
  12. from libs.vision_libs.transforms._presets import ObjectDetection
  13. from .line_head import LineRCNNHeads
  14. from .line_predictor import LineRCNNPredictor
  15. from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
  16. from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
  17. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  18. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, ResNet18_Weights, resnet18
  19. from libs.vision_libs.models.detection._utils import overwrite_eps
  20. from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
  21. from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
  22. from .roi_heads import RoIHeads
  23. from .trainer import Trainer
  24. from ..base import backbone_factory
  25. from ..base.base_detection_net import BaseDetectionNet
  26. import torch.nn.functional as F
  27. from .predict import Predict1, Predict
  28. from ..config.config_tool import read_yaml
  29. FEATURE_DIM = 8
  30. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  31. __all__ = [
  32. "LineNet",
  33. "LineNet_ResNet50_FPN_Weights",
  34. "LineNet_ResNet50_FPN_V2_Weights",
  35. "LineNet_MobileNet_V3_Large_FPN_Weights",
  36. "LineNet_MobileNet_V3_Large_320_FPN_Weights",
  37. "linenet_resnet50_fpn",
  38. "linenet_resnet50_fpn_v2",
  39. "linenet_mobilenet_v3_large_fpn",
  40. "linenet_mobilenet_v3_large_320_fpn",
  41. ]
  42. def _default_anchorgen():
  43. anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
  44. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  45. return AnchorGenerator(anchor_sizes, aspect_ratios)
  46. class LineNet(BaseDetectionNet):
  47. # def __init__(self, cfg, **kwargs):
  48. # cfg = read_yaml(cfg)
  49. # self.cfg=cfg
  50. # backbone = cfg['backbone']
  51. # print(f'LineNet Backbone:{backbone}')
  52. # num_classes = cfg['num_classes']
  53. #
  54. # if backbone == 'resnet50_fpn':
  55. # backbone=backbone_factory.get_resnet50_fpn()
  56. # print(f'out_chanenels:{backbone.out_channels}')
  57. # elif backbone== 'mobilenet_v3_large_fpn':
  58. # backbone=backbone_factory.get_mobilenet_v3_large_fpn()
  59. # elif backbone=='resnet18_fpn':
  60. # backbone=backbone_factory.get_resnet18_fpn()
  61. #
  62. # self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
  63. def __init__(
  64. self,
  65. backbone,
  66. num_classes=None,
  67. # transform parameters
  68. min_size=512,
  69. max_size=1333,
  70. image_mean=None,
  71. image_std=None,
  72. # RPN parameters
  73. rpn_anchor_generator=None,
  74. rpn_head=None,
  75. rpn_pre_nms_top_n_train=2000,
  76. rpn_pre_nms_top_n_test=1000,
  77. rpn_post_nms_top_n_train=2000,
  78. rpn_post_nms_top_n_test=1000,
  79. rpn_nms_thresh=0.7,
  80. rpn_fg_iou_thresh=0.7,
  81. rpn_bg_iou_thresh=0.3,
  82. rpn_batch_size_per_image=256,
  83. rpn_positive_fraction=0.5,
  84. rpn_score_thresh=0.0,
  85. # Box parameters
  86. box_roi_pool=None,
  87. box_head=None,
  88. box_predictor=None,
  89. box_score_thresh=0.05,
  90. box_nms_thresh=0.5,
  91. box_detections_per_img=100,
  92. box_fg_iou_thresh=0.5,
  93. box_bg_iou_thresh=0.5,
  94. box_batch_size_per_image=512,
  95. box_positive_fraction=0.25,
  96. bbox_reg_weights=None,
  97. # line parameters
  98. line_head=None,
  99. line_predictor=None,
  100. **kwargs,
  101. ):
  102. if not hasattr(backbone, "out_channels"):
  103. raise ValueError(
  104. "backbone should contain an attribute out_channels "
  105. "specifying the number of output channels (assumed to be the "
  106. "same for all the levels)"
  107. )
  108. if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
  109. raise TypeError(
  110. f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
  111. )
  112. if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
  113. raise TypeError(
  114. f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
  115. )
  116. if num_classes is not None:
  117. if box_predictor is not None:
  118. raise ValueError("num_classes should be None when box_predictor is specified")
  119. else:
  120. if box_predictor is None:
  121. raise ValueError("num_classes should not be None when box_predictor is not specified")
  122. out_channels = backbone.out_channels
  123. # cfg = read_yaml(cfg)
  124. # self.cfg=cfg
  125. if line_head is None:
  126. num_class = 5
  127. line_head = LineRCNNHeads(out_channels, num_class)
  128. if line_predictor is None:
  129. line_predictor = LineRCNNPredictor()
  130. if rpn_anchor_generator is None:
  131. rpn_anchor_generator = _default_anchorgen()
  132. if rpn_head is None:
  133. rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
  134. rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
  135. rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
  136. rpn = RegionProposalNetwork(
  137. rpn_anchor_generator,
  138. rpn_head,
  139. rpn_fg_iou_thresh,
  140. rpn_bg_iou_thresh,
  141. rpn_batch_size_per_image,
  142. rpn_positive_fraction,
  143. rpn_pre_nms_top_n,
  144. rpn_post_nms_top_n,
  145. rpn_nms_thresh,
  146. score_thresh=rpn_score_thresh,
  147. )
  148. if box_roi_pool is None:
  149. box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
  150. if box_head is None:
  151. resolution = box_roi_pool.output_size[0]
  152. representation_size = 1024
  153. box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
  154. if box_predictor is None:
  155. representation_size = 1024
  156. box_predictor = BoxPredictor(representation_size, num_classes)
  157. roi_heads = RoIHeads(
  158. # Box
  159. box_roi_pool,
  160. box_head,
  161. box_predictor,
  162. line_head,
  163. line_predictor,
  164. box_fg_iou_thresh,
  165. box_bg_iou_thresh,
  166. box_batch_size_per_image,
  167. box_positive_fraction,
  168. bbox_reg_weights,
  169. box_score_thresh,
  170. box_nms_thresh,
  171. box_detections_per_img,
  172. )
  173. if image_mean is None:
  174. image_mean = [0.485, 0.456, 0.406]
  175. if image_std is None:
  176. image_std = [0.229, 0.224, 0.225]
  177. transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  178. super().__init__(backbone, rpn, roi_heads, transform)
  179. self.roi_heads = roi_heads
  180. # self.roi_heads.line_head = line_head
  181. # self.roi_heads.line_predictor = line_predictor
  182. def train_by_cfg(self, cfg):
  183. # cfg = read_yaml(cfg)
  184. self.trainer = Trainer()
  185. self.trainer.train_cfg(model=self, cfg=cfg)
  186. def load_best_model(self,model, save_path, device='cuda'):
  187. if os.path.exists(save_path):
  188. checkpoint = torch.load(save_path, map_location=device)
  189. model.load_state_dict(checkpoint['model_state_dict'])
  190. # if optimizer is not None:
  191. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  192. # epoch = checkpoint['epoch']
  193. # loss = checkpoint['loss']
  194. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  195. print(f"Loaded model from {save_path}")
  196. else:
  197. print(f"No saved model found at {save_path}")
  198. return model
  199. # 加载权重和推理一起
  200. def predict(self, pt_path, model, img_path, type=0, threshold=0.5, save_path=None, show=False):
  201. self.predict = Predict(pt_path, model, img_path, type, threshold, save_path, show)
  202. self.predict.run()
  203. # 不加载权重
  204. def predict1(self, model, img_path, type=0, threshold=0.5, save_path=None, show=False):
  205. self.predict = Predict1(model, img_path, type, threshold, save_path, show)
  206. self.predict.run()
  207. class TwoMLPHead(nn.Module):
  208. """
  209. Standard heads for FPN-based models
  210. Args:
  211. in_channels (int): number of input channels
  212. representation_size (int): size of the intermediate representation
  213. """
  214. def __init__(self, in_channels, representation_size):
  215. super().__init__()
  216. self.fc6 = nn.Linear(in_channels, representation_size)
  217. self.fc7 = nn.Linear(representation_size, representation_size)
  218. def forward(self, x):
  219. x = x.flatten(start_dim=1)
  220. x = F.relu(self.fc6(x))
  221. x = F.relu(self.fc7(x))
  222. return x
  223. class LineNetConvFCHead(nn.Sequential):
  224. def __init__(
  225. self,
  226. input_size: Tuple[int, int, int],
  227. conv_layers: List[int],
  228. fc_layers: List[int],
  229. norm_layer: Optional[Callable[..., nn.Module]] = None,
  230. ):
  231. """
  232. Args:
  233. input_size (Tuple[int, int, int]): the input size in CHW format.
  234. conv_layers (list): feature dimensions of each Convolution layer
  235. fc_layers (list): feature dimensions of each FCN layer
  236. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  237. """
  238. in_channels, in_height, in_width = input_size
  239. blocks = []
  240. previous_channels = in_channels
  241. for current_channels in conv_layers:
  242. blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
  243. previous_channels = current_channels
  244. blocks.append(nn.Flatten())
  245. previous_channels = previous_channels * in_height * in_width
  246. for current_channels in fc_layers:
  247. blocks.append(nn.Linear(previous_channels, current_channels))
  248. blocks.append(nn.ReLU(inplace=True))
  249. previous_channels = current_channels
  250. super().__init__(*blocks)
  251. for layer in self.modules():
  252. if isinstance(layer, nn.Conv2d):
  253. nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
  254. if layer.bias is not None:
  255. nn.init.zeros_(layer.bias)
  256. class BoxPredictor(nn.Module):
  257. """
  258. Standard classification + bounding box regression layers
  259. for Fast R-CNN.
  260. Args:
  261. in_channels (int): number of input channels
  262. num_classes (int): number of output classes (including background)
  263. """
  264. def __init__(self, in_channels, num_classes):
  265. super().__init__()
  266. self.cls_score = nn.Linear(in_channels, num_classes)
  267. self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
  268. def forward(self, x):
  269. if x.dim() == 4:
  270. torch._assert(
  271. list(x.shape[2:]) == [1, 1],
  272. f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
  273. )
  274. x = x.flatten(start_dim=1)
  275. scores = self.cls_score(x)
  276. bbox_deltas = self.bbox_pred(x)
  277. return scores, bbox_deltas
  278. _COMMON_META = {
  279. "categories": _COCO_CATEGORIES,
  280. "min_size": (1, 1),
  281. }
  282. class LineNet_ResNet50_FPN_Weights(WeightsEnum):
  283. COCO_V1 = Weights(
  284. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
  285. transforms=ObjectDetection,
  286. meta={
  287. **_COMMON_META,
  288. "num_params": 41755286,
  289. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
  290. "_metrics": {
  291. "COCO-val2017": {
  292. "box_map": 37.0,
  293. }
  294. },
  295. "_ops": 134.38,
  296. "_file_size": 159.743,
  297. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  298. },
  299. )
  300. DEFAULT = COCO_V1
  301. class LineNet_ResNet50_FPN_V2_Weights(WeightsEnum):
  302. COCO_V1 = Weights(
  303. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
  304. transforms=ObjectDetection,
  305. meta={
  306. **_COMMON_META,
  307. "num_params": 43712278,
  308. "recipe": "https://github.com/pytorch/vision/pull/5763",
  309. "_metrics": {
  310. "COCO-val2017": {
  311. "box_map": 46.7,
  312. }
  313. },
  314. "_ops": 280.371,
  315. "_file_size": 167.104,
  316. "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
  317. },
  318. )
  319. DEFAULT = COCO_V1
  320. class LineNet_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
  321. COCO_V1 = Weights(
  322. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
  323. transforms=ObjectDetection,
  324. meta={
  325. **_COMMON_META,
  326. "num_params": 19386354,
  327. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
  328. "_metrics": {
  329. "COCO-val2017": {
  330. "box_map": 32.8,
  331. }
  332. },
  333. "_ops": 4.494,
  334. "_file_size": 74.239,
  335. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  336. },
  337. )
  338. DEFAULT = COCO_V1
  339. class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
  340. COCO_V1 = Weights(
  341. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
  342. transforms=ObjectDetection,
  343. meta={
  344. **_COMMON_META,
  345. "num_params": 19386354,
  346. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
  347. "_metrics": {
  348. "COCO-val2017": {
  349. "box_map": 22.8,
  350. }
  351. },
  352. "_ops": 0.719,
  353. "_file_size": 74.239,
  354. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  355. },
  356. )
  357. DEFAULT = COCO_V1
  358. @register_model()
  359. @handle_legacy_interface(
  360. weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
  361. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  362. )
  363. def linenet_resnet18_fpn(
  364. *,
  365. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  366. progress: bool = True,
  367. num_classes: Optional[int] = None,
  368. weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
  369. trainable_backbone_layers: Optional[int] = None,
  370. **kwargs: Any,
  371. ) -> LineNet:
  372. # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  373. # weights_backbone = ResNet50_Weights.verify(weights_backbone)
  374. if weights is not None:
  375. weights_backbone = None
  376. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  377. elif num_classes is None:
  378. num_classes = 91
  379. if weights_backbone is not None:
  380. print(f'resnet50 weights is not None')
  381. is_trained = weights is not None or weights_backbone is not None
  382. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  383. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  384. backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  385. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  386. model = LineNet(backbone, num_classes=num_classes, **kwargs)
  387. if weights is not None:
  388. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  389. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  390. overwrite_eps(model, 0.0)
  391. return model
  392. def linenet_resnet50_fpn(
  393. *,
  394. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  395. progress: bool = True,
  396. num_classes: Optional[int] = None,
  397. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  398. trainable_backbone_layers: Optional[int] = None,
  399. **kwargs: Any,
  400. ) -> LineNet:
  401. """
  402. Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
  403. Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
  404. paper.
  405. .. betastatus:: detection module
  406. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  407. image, and should be in ``0-1`` range. Different images can have different sizes.
  408. The behavior of the model changes depending on if it is in training or evaluation mode.
  409. During training, the model expects both the input tensors and a targets (list of dictionary),
  410. containing:
  411. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  412. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  413. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  414. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  415. losses for both the RPN and the R-CNN.
  416. During inference, the model requires only the input tensors, and returns the post-processed
  417. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  418. follows, where ``N`` is the number of detections:
  419. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  420. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  421. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  422. - scores (``Tensor[N]``): the scores of each detection
  423. For more details on the output, you may refer to :ref:`instance_seg_output`.
  424. Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  425. Example::
  426. >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
  427. >>> # For training
  428. >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
  429. >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
  430. >>> labels = torch.randint(1, 91, (4, 11))
  431. >>> images = list(image for image in images)
  432. >>> targets = []
  433. >>> for i in range(len(images)):
  434. >>> d = {}
  435. >>> d['boxes'] = boxes[i]
  436. >>> d['labels'] = labels[i]
  437. >>> targets.append(d)
  438. >>> output = model(images, targets)
  439. >>> # For inference
  440. >>> model.eval()
  441. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  442. >>> predictions = model(x)
  443. >>>
  444. >>> # optionally, if you want to export the model to ONNX:
  445. >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
  446. Args:
  447. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
  448. pretrained weights to use. See
  449. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
  450. more details, and possible values. By default, no pre-trained
  451. weights are used.
  452. progress (bool, optional): If True, displays a progress bar of the
  453. download to stderr. Default is True.
  454. num_classes (int, optional): number of output classes of the model (including the background)
  455. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  456. pretrained weights for the backbone.
  457. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  458. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  459. trainable. If ``None`` is passed (the default) this value is set to 3.
  460. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  461. base class. Please refer to the `source code
  462. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  463. for more details about this class.
  464. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
  465. :members:
  466. """
  467. weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  468. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  469. if weights is not None:
  470. weights_backbone = None
  471. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  472. elif num_classes is None:
  473. num_classes = 91
  474. if weights_backbone is not None:
  475. print(f'resnet50 weights is not None')
  476. is_trained = weights is not None or weights_backbone is not None
  477. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  478. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  479. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  480. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  481. model = LineNet(backbone, num_classes=num_classes, **kwargs)
  482. if weights is not None:
  483. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  484. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  485. overwrite_eps(model, 0.0)
  486. return model
  487. @register_model()
  488. @handle_legacy_interface(
  489. weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
  490. weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  491. )
  492. def linenet_resnet50_fpn_v2(
  493. *,
  494. weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
  495. progress: bool = True,
  496. num_classes: Optional[int] = None,
  497. weights_backbone: Optional[ResNet50_Weights] = None,
  498. trainable_backbone_layers: Optional[int] = None,
  499. **kwargs: Any,
  500. ) -> LineNet:
  501. """
  502. Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
  503. Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
  504. .. betastatus:: detection module
  505. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  506. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  507. details.
  508. Args:
  509. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
  510. pretrained weights to use. See
  511. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
  512. more details, and possible values. By default, no pre-trained
  513. weights are used.
  514. progress (bool, optional): If True, displays a progress bar of the
  515. download to stderr. Default is True.
  516. num_classes (int, optional): number of output classes of the model (including the background)
  517. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  518. pretrained weights for the backbone.
  519. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  520. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  521. trainable. If ``None`` is passed (the default) this value is set to 3.
  522. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  523. base class. Please refer to the `source code
  524. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  525. for more details about this class.
  526. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
  527. :members:
  528. """
  529. weights = LineNet_ResNet50_FPN_V2_Weights.verify(weights)
  530. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  531. if weights is not None:
  532. weights_backbone = None
  533. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  534. elif num_classes is None:
  535. num_classes = 91
  536. is_trained = weights is not None or weights_backbone is not None
  537. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  538. backbone = resnet50(weights=weights_backbone, progress=progress)
  539. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
  540. rpn_anchor_generator = _default_anchorgen()
  541. rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
  542. box_head = LineNetConvFCHead(
  543. (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
  544. )
  545. model = LineNet(
  546. backbone,
  547. num_classes=num_classes,
  548. rpn_anchor_generator=rpn_anchor_generator,
  549. rpn_head=rpn_head,
  550. box_head=box_head,
  551. **kwargs,
  552. )
  553. if weights is not None:
  554. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  555. return model
  556. def _linenet_mobilenet_v3_large_fpn(
  557. *,
  558. weights: Optional[Union[LineNet_MobileNet_V3_Large_FPN_Weights, LineNet_MobileNet_V3_Large_320_FPN_Weights]],
  559. progress: bool,
  560. num_classes: Optional[int],
  561. weights_backbone: Optional[MobileNet_V3_Large_Weights],
  562. trainable_backbone_layers: Optional[int],
  563. **kwargs: Any,
  564. ) -> LineNet:
  565. if weights is not None:
  566. weights_backbone = None
  567. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  568. elif num_classes is None:
  569. num_classes = 91
  570. is_trained = weights is not None or weights_backbone is not None
  571. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
  572. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  573. backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  574. backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
  575. anchor_sizes = (
  576. (
  577. 32,
  578. 64,
  579. 128,
  580. 256,
  581. 512,
  582. ),
  583. ) * 3
  584. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  585. model = LineNet(
  586. backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
  587. )
  588. if weights is not None:
  589. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  590. return model
  591. @register_model()
  592. @handle_legacy_interface(
  593. weights=("pretrained", LineNet_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
  594. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  595. )
  596. def linenet_mobilenet_v3_large_320_fpn(
  597. *,
  598. weights: Optional[LineNet_MobileNet_V3_Large_320_FPN_Weights] = None,
  599. progress: bool = True,
  600. num_classes: Optional[int] = None,
  601. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  602. trainable_backbone_layers: Optional[int] = None,
  603. **kwargs: Any,
  604. ) -> LineNet:
  605. """
  606. Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
  607. .. betastatus:: detection module
  608. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  609. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  610. details.
  611. Example::
  612. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
  613. >>> model.eval()
  614. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  615. >>> predictions = model(x)
  616. Args:
  617. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
  618. pretrained weights to use. See
  619. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
  620. more details, and possible values. By default, no pre-trained
  621. weights are used.
  622. progress (bool, optional): If True, displays a progress bar of the
  623. download to stderr. Default is True.
  624. num_classes (int, optional): number of output classes of the model (including the background)
  625. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  626. pretrained weights for the backbone.
  627. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  628. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  629. trainable. If ``None`` is passed (the default) this value is set to 3.
  630. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  631. base class. Please refer to the `source code
  632. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  633. for more details about this class.
  634. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
  635. :members:
  636. """
  637. weights = LineNet_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
  638. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  639. defaults = {
  640. "min_size": 320,
  641. "max_size": 640,
  642. "rpn_pre_nms_top_n_test": 150,
  643. "rpn_post_nms_top_n_test": 150,
  644. "rpn_score_thresh": 0.05,
  645. }
  646. kwargs = {**defaults, **kwargs}
  647. return _linenet_mobilenet_v3_large_fpn(
  648. weights=weights,
  649. progress=progress,
  650. num_classes=num_classes,
  651. weights_backbone=weights_backbone,
  652. trainable_backbone_layers=trainable_backbone_layers,
  653. **kwargs,
  654. )
  655. @register_model()
  656. @handle_legacy_interface(
  657. weights=("pretrained", LineNet_MobileNet_V3_Large_FPN_Weights.COCO_V1),
  658. weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  659. )
  660. def linenet_mobilenet_v3_large_fpn(
  661. *,
  662. weights: Optional[LineNet_MobileNet_V3_Large_FPN_Weights] = None,
  663. progress: bool = True,
  664. num_classes: Optional[int] = None,
  665. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  666. trainable_backbone_layers: Optional[int] = None,
  667. **kwargs: Any,
  668. ) -> LineNet:
  669. """
  670. Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
  671. .. betastatus:: detection module
  672. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  673. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  674. details.
  675. Example::
  676. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
  677. >>> model.eval()
  678. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  679. >>> predictions = model(x)
  680. Args:
  681. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
  682. pretrained weights to use. See
  683. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
  684. more details, and possible values. By default, no pre-trained
  685. weights are used.
  686. progress (bool, optional): If True, displays a progress bar of the
  687. download to stderr. Default is True.
  688. num_classes (int, optional): number of output classes of the model (including the background)
  689. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  690. pretrained weights for the backbone.
  691. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  692. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  693. trainable. If ``None`` is passed (the default) this value is set to 3.
  694. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  695. base class. Please refer to the `source code
  696. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  697. for more details about this class.
  698. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
  699. :members:
  700. """
  701. weights = LineNet_MobileNet_V3_Large_FPN_Weights.verify(weights)
  702. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  703. defaults = {
  704. "rpn_score_thresh": 0.05,
  705. }
  706. kwargs = {**defaults, **kwargs}
  707. return _linenet_mobilenet_v3_large_fpn(
  708. weights=weights,
  709. progress=progress,
  710. num_classes=num_classes,
  711. weights_backbone=weights_backbone,
  712. trainable_backbone_layers=trainable_backbone_layers,
  713. **kwargs,
  714. )