line_net.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145
  1. import os
  2. from typing import Any, Callable, List, Optional, Tuple, Union
  3. import torch
  4. from torch import nn
  5. from torchvision.ops import MultiScaleRoIAlign
  6. from libs.vision_libs import ops
  7. from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large, EfficientNet_V2_S_Weights, \
  8. efficientnet_v2_s, detection, EfficientNet_V2_L_Weights, efficientnet_v2_l, EfficientNet_V2_M_Weights, \
  9. efficientnet_v2_m
  10. from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
  11. from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
  12. from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
  13. from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
  14. from libs.vision_libs.ops import misc as misc_nn_ops
  15. from libs.vision_libs.transforms._presets import ObjectDetection
  16. from .line_head import LineRCNNHeads
  17. from .line_predictor import LineRCNNPredictor
  18. from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
  19. from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
  20. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  21. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, ResNet18_Weights, resnet18, resnet101
  22. from libs.vision_libs.models.detection._utils import overwrite_eps
  23. from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, \
  24. BackboneWithFPN
  25. from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
  26. from .roi_heads import RoIHeads
  27. from .trainer import Trainer
  28. from ..base import backbone_factory
  29. from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
  30. # from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
  31. from ..base.base_detection_net import BaseDetectionNet
  32. import torch.nn.functional as F
  33. from .predict import Predict1, Predict
  34. from ..base.high_reso_resnet import resnet50fpn, resnet18fpn
  35. from ..config.config_tool import read_yaml
  36. FEATURE_DIM = 8
  37. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  38. __all__ = [
  39. "LineNet",
  40. "LineNet_ResNet50_FPN_Weights",
  41. "LineNet_ResNet50_FPN_V2_Weights",
  42. "LineNet_MobileNet_V3_Large_FPN_Weights",
  43. "LineNet_MobileNet_V3_Large_320_FPN_Weights",
  44. "linenet_resnet50_fpn",
  45. "linenet_resnet50_fpn_v2",
  46. "linenet_mobilenet_v3_large_fpn",
  47. "linenet_mobilenet_v3_large_320_fpn",
  48. ]
  49. def _default_anchorgen():
  50. anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
  51. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  52. return AnchorGenerator(anchor_sizes, aspect_ratios)
  53. class LineNet(BaseDetectionNet):
  54. # def __init__(self, cfg, **kwargs):
  55. # cfg = read_yaml(cfg)
  56. # self.cfg=cfg
  57. # backbone = cfg['backbone']
  58. # print(f'LineNet Backbone:{backbone}')
  59. # num_classes = cfg['num_classes']
  60. #
  61. # if backbone == 'resnet50_fpn':
  62. # backbone=backbone_factory.get_resnet50_fpn()
  63. # print(f'out_chanenels:{backbone.out_channels}')
  64. # elif backbone== 'mobilenet_v3_large_fpn':
  65. # backbone=backbone_factory.get_mobilenet_v3_large_fpn()
  66. # elif backbone=='resnet18_fpn':
  67. # backbone=backbone_factory.get_resnet18_fpn()
  68. #
  69. # self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
  70. def __init__(
  71. self,
  72. backbone,
  73. num_classes=None,
  74. # transform parameters
  75. min_size=512,
  76. max_size=1333,
  77. image_mean=None,
  78. image_std=None,
  79. # RPN parameters
  80. rpn_anchor_generator=None,
  81. rpn_head=None,
  82. rpn_pre_nms_top_n_train=2000,
  83. rpn_pre_nms_top_n_test=1000,
  84. rpn_post_nms_top_n_train=2000,
  85. rpn_post_nms_top_n_test=1000,
  86. rpn_nms_thresh=0.7,
  87. rpn_fg_iou_thresh=0.7,
  88. rpn_bg_iou_thresh=0.3,
  89. rpn_batch_size_per_image=256,
  90. rpn_positive_fraction=0.5,
  91. rpn_score_thresh=0.0,
  92. # Box parameters
  93. box_roi_pool=None,
  94. box_head=None,
  95. box_predictor=None,
  96. box_score_thresh=0.05,
  97. box_nms_thresh=0.5,
  98. box_detections_per_img=100,
  99. box_fg_iou_thresh=0.5,
  100. box_bg_iou_thresh=0.5,
  101. box_batch_size_per_image=512,
  102. box_positive_fraction=0.25,
  103. bbox_reg_weights=None,
  104. # line parameters
  105. line_head=None,
  106. line_predictor=None,
  107. **kwargs,
  108. ):
  109. if not hasattr(backbone, "out_channels"):
  110. raise ValueError(
  111. "backbone should contain an attribute out_channels "
  112. "specifying the number of output channels (assumed to be the "
  113. "same for all the levels)"
  114. )
  115. if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
  116. raise TypeError(
  117. f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
  118. )
  119. if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
  120. raise TypeError(
  121. f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
  122. )
  123. # 修改第一个卷积层,将 in_channels 从 3 改为 4
  124. backbone.body.conv1 = nn.Conv2d(
  125. in_channels=4,
  126. out_channels=64,
  127. kernel_size=7,
  128. stride=2,
  129. padding=3,
  130. bias=False
  131. )
  132. if num_classes is not None:
  133. if box_predictor is not None:
  134. raise ValueError("num_classes should be None when box_predictor is specified")
  135. else:
  136. if box_predictor is None:
  137. raise ValueError("num_classes should not be None when box_predictor is not specified")
  138. out_channels = backbone.out_channels
  139. # cfg = read_yaml(cfg)
  140. # self.cfg=cfg
  141. if line_head is None:
  142. num_class = 5
  143. line_head = LineRCNNHeads(out_channels, num_class)
  144. if line_predictor is None:
  145. line_predictor = LineRCNNPredictor()
  146. if rpn_anchor_generator is None:
  147. rpn_anchor_generator = _default_anchorgen()
  148. if rpn_head is None:
  149. rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
  150. rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
  151. rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
  152. rpn = RegionProposalNetwork(
  153. rpn_anchor_generator,
  154. rpn_head,
  155. rpn_fg_iou_thresh,
  156. rpn_bg_iou_thresh,
  157. rpn_batch_size_per_image,
  158. rpn_positive_fraction,
  159. rpn_pre_nms_top_n,
  160. rpn_post_nms_top_n,
  161. rpn_nms_thresh,
  162. score_thresh=rpn_score_thresh,
  163. )
  164. if box_roi_pool is None:
  165. box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3","4"], output_size=7, sampling_ratio=2)
  166. if box_head is None:
  167. resolution = box_roi_pool.output_size[0]
  168. representation_size = 1024
  169. box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
  170. if box_predictor is None:
  171. representation_size = 1024
  172. box_predictor = BoxPredictor(representation_size, num_classes)
  173. roi_heads = RoIHeads(
  174. # Box
  175. box_roi_pool,
  176. box_head,
  177. box_predictor,
  178. line_head,
  179. line_predictor,
  180. box_fg_iou_thresh,
  181. box_bg_iou_thresh,
  182. box_batch_size_per_image,
  183. box_positive_fraction,
  184. bbox_reg_weights,
  185. box_score_thresh,
  186. box_nms_thresh,
  187. box_detections_per_img,
  188. )
  189. if image_mean is None:
  190. # image_mean = [0.485, 0.456, 0.406]
  191. image_mean = [0.485, 0.456, 0.406, 0.2549] # 假设你新加的通道均值为0.5
  192. if image_std is None:
  193. # image_std = [0.229, 0.224, 0.225]
  194. image_std = [0.229, 0.224, 0.225, 0.4093] # 标准差也补一个值
  195. transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  196. super().__init__(backbone, rpn, roi_heads, transform)
  197. self.roi_heads = roi_heads
  198. # self.roi_heads.line_head = line_head
  199. # self.roi_heads.line_predictor = line_predictor
  200. def start_train(self, cfg):
  201. # cfg = read_yaml(cfg)
  202. self.trainer = Trainer()
  203. self.trainer.train_from_cfg(model=self, cfg=cfg)
  204. def load_best_model(self,save_path, device='cuda'):
  205. if os.path.exists(save_path):
  206. checkpoint = torch.load(save_path, map_location=device)
  207. self.load_state_dict(checkpoint['model_state_dict'])
  208. # if optimizer is not None:
  209. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  210. # epoch = checkpoint['epoch']
  211. # loss = checkpoint['loss']
  212. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  213. print(f"Loaded model from {save_path}")
  214. else:
  215. print(f"No saved model found at {save_path}")
  216. return self
  217. # 加载权重和推理一起
  218. def predict(self,img_path, type=0, threshold=0.5, save_path=None, show=False):
  219. # self.predict = Predict(pt_path, model, img_path, type, threshold, save_path, show)
  220. self.eval()
  221. self.to(device)
  222. self.predict = Predict(self, img_path, type, threshold, save_path, show)
  223. self.predict.run()
  224. # 不加载权重
  225. def predict1(self, model, img_path, type=0, threshold=0.5, save_path=None, show=False):
  226. self.predict = Predict1(model, img_path, type, threshold, save_path, show)
  227. self.predict.run()
  228. class TwoMLPHead(nn.Module):
  229. """
  230. Standard heads for FPN-based models
  231. Args:
  232. in_channels (int): number of input channels
  233. representation_size (int): size of the intermediate representation
  234. """
  235. def __init__(self, in_channels, representation_size):
  236. super().__init__()
  237. self.fc6 = nn.Linear(in_channels, representation_size)
  238. self.fc7 = nn.Linear(representation_size, representation_size)
  239. def forward(self, x):
  240. x = x.flatten(start_dim=1)
  241. x = F.relu(self.fc6(x))
  242. x = F.relu(self.fc7(x))
  243. return x
  244. class LineNetConvFCHead(nn.Sequential):
  245. def __init__(
  246. self,
  247. input_size: Tuple[int, int, int],
  248. conv_layers: List[int],
  249. fc_layers: List[int],
  250. norm_layer: Optional[Callable[..., nn.Module]] = None,
  251. ):
  252. """
  253. Args:
  254. input_size (Tuple[int, int, int]): the input size in CHW format.
  255. conv_layers (list): feature dimensions of each Convolution layer
  256. fc_layers (list): feature dimensions of each FCN layer
  257. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  258. """
  259. in_channels, in_height, in_width = input_size
  260. blocks = []
  261. previous_channels = in_channels
  262. for current_channels in conv_layers:
  263. blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
  264. previous_channels = current_channels
  265. blocks.append(nn.Flatten())
  266. previous_channels = previous_channels * in_height * in_width
  267. for current_channels in fc_layers:
  268. blocks.append(nn.Linear(previous_channels, current_channels))
  269. blocks.append(nn.ReLU(inplace=True))
  270. previous_channels = current_channels
  271. super().__init__(*blocks)
  272. for layer in self.modules():
  273. if isinstance(layer, nn.Conv2d):
  274. nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
  275. if layer.bias is not None:
  276. nn.init.zeros_(layer.bias)
  277. class BoxPredictor(nn.Module):
  278. """
  279. Standard classification + bounding box regression layers
  280. for Fast R-CNN.
  281. Args:
  282. in_channels (int): number of input channels
  283. num_classes (int): number of output classes (including background)
  284. """
  285. def __init__(self, in_channels, num_classes):
  286. super().__init__()
  287. self.cls_score = nn.Linear(in_channels, num_classes)
  288. self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
  289. def forward(self, x):
  290. if x.dim() == 4:
  291. torch._assert(
  292. list(x.shape[2:]) == [1, 1],
  293. f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
  294. )
  295. x = x.flatten(start_dim=1)
  296. scores = self.cls_score(x)
  297. bbox_deltas = self.bbox_pred(x)
  298. return scores, bbox_deltas
  299. _COMMON_META = {
  300. "categories": _COCO_CATEGORIES,
  301. "min_size": (1, 1),
  302. }
  303. def create_efficientnetv2_backbone(name='efficientnet_v2_m', pretrained=True):
  304. # 加载EfficientNetV2模型
  305. if name == 'efficientnet_v2_s':
  306. weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 if pretrained else None
  307. backbone = efficientnet_v2_s(weights=weights).features
  308. if name == 'efficientnet_v2_m':
  309. weights = EfficientNet_V2_M_Weights.IMAGENET1K_V1 if pretrained else None
  310. backbone = efficientnet_v2_m(weights=weights).features
  311. if name == 'efficientnet_v2_l':
  312. weights = EfficientNet_V2_L_Weights.IMAGENET1K_V1 if pretrained else None
  313. backbone = efficientnet_v2_l(weights=weights).features
  314. # 定义返回的层索引和名称
  315. return_layers = {"2": "0", "3": "1", "4": "2", "5": "3"}
  316. # 获取每个层输出通道数
  317. in_channels_list = []
  318. for layer_idx in [2, 3, 4, 5]:
  319. module = backbone[layer_idx]
  320. if hasattr(module, 'out_channels'):
  321. in_channels_list.append(module.out_channels)
  322. elif hasattr(module[-1], 'out_channels'):
  323. # 如果module本身没有out_channels,检查最后一个子模块
  324. in_channels_list.append(module[-1].out_channels)
  325. else:
  326. raise ValueError(f"Cannot determine out_channels for layer {layer_idx}")
  327. # 使用BackboneWithFPN包装backbone
  328. backbone_with_fpn = BackboneWithFPN(
  329. backbone=backbone,
  330. return_layers=return_layers,
  331. in_channels_list=in_channels_list,
  332. out_channels=256
  333. )
  334. return backbone_with_fpn
  335. def get_line_net_efficientnetv2(num_classes, pretrained_backbone=True):
  336. # 创建EfficientNetV2 backbone
  337. backbone = create_efficientnetv2_backbone(pretrained=pretrained_backbone)
  338. # 确认 backbone 输出特征图数量
  339. with torch.no_grad():
  340. images = torch.rand(1,3, 600, 800)
  341. features = backbone(images)
  342. featmap_names = list(features.keys())
  343. print("Feature map names:", featmap_names) # 例如 ['0', '1', '2', '3']
  344. # 根据实际特征层数量设置 anchors
  345. # num_levels = len(featmap_names)
  346. num_levels=5
  347. featmap_names= ['0', '1', '2', '3', 'pool']
  348. anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_levels)) # 自动生成不同大小
  349. print(f'anchor_sizes:{anchor_sizes}')
  350. aspect_ratios = ((0.5, 1.0, 2.0),) * num_levels # 所有层共享相同比例
  351. print(f'aspect_ratios:{aspect_ratios}')
  352. anchor_generator = AnchorGenerator(
  353. sizes=anchor_sizes,
  354. aspect_ratios=aspect_ratios
  355. )
  356. # ROI Pooling
  357. roi_pooler = MultiScaleRoIAlign(
  358. featmap_names=featmap_names,
  359. output_size=7,
  360. sampling_ratio=2
  361. )
  362. # 构建模型
  363. model = LineNet(
  364. backbone=backbone,
  365. num_classes=num_classes,
  366. rpn_anchor_generator=anchor_generator,
  367. box_roi_pool=roi_pooler
  368. )
  369. return model
  370. def get_line_net_convnext_fpn(num_classes=91):
  371. backbone=get_convnext_fpn()
  372. featmap_names = ['0', '1', '2', '3', 'pool']
  373. roi_pooler = MultiScaleRoIAlign(
  374. featmap_names=featmap_names,
  375. output_size=7,
  376. sampling_ratio=2
  377. )
  378. test_input = torch.rand(1, 3, 224, 224)
  379. anchor_generator = get_anchor_generator(backbone, test_input)
  380. model = LineNet(
  381. backbone=backbone,
  382. num_classes=num_classes, # COCO 数据集有 91 类
  383. rpn_anchor_generator=anchor_generator,
  384. box_roi_pool=roi_pooler
  385. )
  386. return model
  387. class LineNet_ResNet50_FPN_Weights(WeightsEnum):
  388. COCO_V1 = Weights(
  389. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
  390. transforms=ObjectDetection,
  391. meta={
  392. **_COMMON_META,
  393. "num_params": 41755286,
  394. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
  395. "_metrics": {
  396. "COCO-val2017": {
  397. "box_map": 37.0,
  398. }
  399. },
  400. "_ops": 134.38,
  401. "_file_size": 159.743,
  402. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  403. },
  404. )
  405. DEFAULT = COCO_V1
  406. class LineNet_ResNet50_FPN_V2_Weights(WeightsEnum):
  407. COCO_V1 = Weights(
  408. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
  409. transforms=ObjectDetection,
  410. meta={
  411. **_COMMON_META,
  412. "num_params": 43712278,
  413. "recipe": "https://github.com/pytorch/vision/pull/5763",
  414. "_metrics": {
  415. "COCO-val2017": {
  416. "box_map": 46.7,
  417. }
  418. },
  419. "_ops": 280.371,
  420. "_file_size": 167.104,
  421. "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
  422. },
  423. )
  424. DEFAULT = COCO_V1
  425. class LineNet_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
  426. COCO_V1 = Weights(
  427. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
  428. transforms=ObjectDetection,
  429. meta={
  430. **_COMMON_META,
  431. "num_params": 19386354,
  432. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
  433. "_metrics": {
  434. "COCO-val2017": {
  435. "box_map": 32.8,
  436. }
  437. },
  438. "_ops": 4.494,
  439. "_file_size": 74.239,
  440. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  441. },
  442. )
  443. DEFAULT = COCO_V1
  444. class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
  445. COCO_V1 = Weights(
  446. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
  447. transforms=ObjectDetection,
  448. meta={
  449. **_COMMON_META,
  450. "num_params": 19386354,
  451. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
  452. "_metrics": {
  453. "COCO-val2017": {
  454. "box_map": 22.8,
  455. }
  456. },
  457. "_ops": 0.719,
  458. "_file_size": 74.239,
  459. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  460. },
  461. )
  462. DEFAULT = COCO_V1
  463. def linenet_newresnet18fpn(
  464. *,
  465. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  466. progress: bool = True,
  467. num_classes: Optional[int] = None,
  468. weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
  469. trainable_backbone_layers: Optional[int] = None,
  470. **kwargs: Any,
  471. ) -> LineNet:
  472. # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  473. # weights_backbone = ResNet50_Weights.verify(weights_backbone)
  474. if weights is not None:
  475. weights_backbone = None
  476. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  477. elif num_classes is None:
  478. num_classes = 91
  479. if weights_backbone is not None:
  480. print(f'resnet50 weights is not None')
  481. is_trained = weights is not None or weights_backbone is not None
  482. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  483. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  484. backbone =resnet18fpn()
  485. featmap_names=['0', '1', '2', '3','pool']
  486. # print(f'featmap_names:{featmap_names}')
  487. roi_pooler = MultiScaleRoIAlign(
  488. featmap_names=featmap_names,
  489. output_size=7,
  490. sampling_ratio=2
  491. )
  492. num_features=len(featmap_names)
  493. anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features)) # 自动生成不同大小
  494. # print(f'anchor_sizes:{anchor_sizes}')
  495. aspect_ratios = ((0.5, 1.0, 2.0),) * num_features
  496. # print(f'aspect_ratios:{aspect_ratios}')
  497. anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
  498. # anchors = anchor_generator.generate_anchors()
  499. # print("Number of anchor sizes:", len(anchor_generator.sizes)) # 应为 5
  500. model = LineNet(backbone, num_classes=num_classes,anchor_generator=anchor_generator,
  501. box_roi_pool=roi_pooler,
  502. **kwargs)
  503. if weights is not None:
  504. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  505. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  506. overwrite_eps(model, 0.0)
  507. return model
  508. def linenet_newresnet50fpn(
  509. *,
  510. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  511. progress: bool = True,
  512. num_classes: Optional[int] = None,
  513. weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
  514. trainable_backbone_layers: Optional[int] = None,
  515. **kwargs: Any,
  516. ) -> LineNet:
  517. # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  518. # weights_backbone = ResNet50_Weights.verify(weights_backbone)
  519. if weights is not None:
  520. weights_backbone = None
  521. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  522. elif num_classes is None:
  523. num_classes = 91
  524. if weights_backbone is not None:
  525. print(f'resnet50 weights is not None')
  526. is_trained = weights is not None or weights_backbone is not None
  527. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  528. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  529. backbone =resnet50fpn()
  530. featmap_names=['0', '1', '2', '3','pool']
  531. # print(f'featmap_names:{featmap_names}')
  532. roi_pooler = MultiScaleRoIAlign(
  533. featmap_names=featmap_names,
  534. output_size=7,
  535. sampling_ratio=2
  536. )
  537. num_features=len(featmap_names)
  538. anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features)) # 自动生成不同大小
  539. # print(f'anchor_sizes:{anchor_sizes}')
  540. aspect_ratios = ((0.5, 1.0, 2.0),) * num_features
  541. # print(f'aspect_ratios:{aspect_ratios}')
  542. anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
  543. # anchors = anchor_generator.generate_anchors()
  544. # print("Number of anchor sizes:", len(anchor_generator.sizes)) # 应为 5
  545. model = LineNet(backbone, num_classes=num_classes,anchor_generator=anchor_generator,
  546. box_roi_pool=roi_pooler,
  547. **kwargs)
  548. if weights is not None:
  549. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  550. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  551. overwrite_eps(model, 0.0)
  552. return model
  553. # @register_model()
  554. # @handle_legacy_interface(
  555. # weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
  556. # weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  557. # )
  558. def linenet_resnet18_fpn(
  559. *,
  560. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  561. progress: bool = True,
  562. num_classes: Optional[int] = None,
  563. weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
  564. trainable_backbone_layers: Optional[int] = None,
  565. **kwargs: Any,
  566. ) -> LineNet:
  567. # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  568. # weights_backbone = ResNet50_Weights.verify(weights_backbone)
  569. if weights is not None:
  570. weights_backbone = None
  571. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  572. elif num_classes is None:
  573. num_classes = 91
  574. if weights_backbone is not None:
  575. print(f'resnet50 weights is not None')
  576. is_trained = weights is not None or weights_backbone is not None
  577. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  578. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  579. backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  580. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  581. model = LineNet(backbone, num_classes=num_classes, **kwargs)
  582. if weights is not None:
  583. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  584. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  585. overwrite_eps(model, 0.0)
  586. return model
  587. def linenet_resnet50_fpn(
  588. *,
  589. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  590. progress: bool = True,
  591. num_classes: Optional[int] = None,
  592. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  593. trainable_backbone_layers: Optional[int] = None,
  594. **kwargs: Any,
  595. ) -> LineNet:
  596. """
  597. Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
  598. Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
  599. paper.
  600. .. betastatus:: detection module
  601. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  602. image, and should be in ``0-1`` range. Different images can have different sizes.
  603. The behavior of the model changes depending on if it is in training or evaluation mode.
  604. During training, the model expects both the input tensors and a targets (list of dictionary),
  605. containing:
  606. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  607. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  608. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  609. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  610. losses for both the RPN and the R-CNN.
  611. During inference, the model requires only the input tensors, and returns the post-processed
  612. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  613. follows, where ``N`` is the number of detections:
  614. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  615. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  616. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  617. - scores (``Tensor[N]``): the scores of each detection
  618. For more details on the output, you may refer to :ref:`instance_seg_output`.
  619. Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  620. Example::
  621. >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
  622. >>> # For training
  623. >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
  624. >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
  625. >>> labels = torch.randint(1, 91, (4, 11))
  626. >>> images = list(image for image in images)
  627. >>> targets = []
  628. >>> for i in range(len(images)):
  629. >>> d = {}
  630. >>> d['boxes'] = boxes[i]
  631. >>> d['labels'] = labels[i]
  632. >>> targets.append(d)
  633. >>> output = model(images, targets)
  634. >>> # For inference
  635. >>> model.eval()
  636. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  637. >>> predictions = model(x)
  638. >>>
  639. >>> # optionally, if you want to export the model to ONNX:
  640. >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
  641. Args:
  642. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
  643. pretrained weights to use. See
  644. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
  645. more details, and possible values. By default, no pre-trained
  646. weights are used.
  647. progress (bool, optional): If True, displays a progress bar of the
  648. download to stderr. Default is True.
  649. num_classes (int, optional): number of output classes of the model (including the background)
  650. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  651. pretrained weights for the backbone.
  652. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  653. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  654. trainable. If ``None`` is passed (the default) this value is set to 3.
  655. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  656. base class. Please refer to the `source code
  657. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  658. for more details about this class.
  659. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
  660. :members:
  661. """
  662. weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  663. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  664. if weights is not None:
  665. weights_backbone = None
  666. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  667. elif num_classes is None:
  668. num_classes = 91
  669. if weights_backbone is not None:
  670. print(f'resnet50 weights is not None')
  671. is_trained = weights is not None or weights_backbone is not None
  672. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  673. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  674. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  675. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  676. model = LineNet(backbone, num_classes=num_classes, **kwargs)
  677. if weights is not None:
  678. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  679. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  680. overwrite_eps(model, 0.0)
  681. return model
  682. # @register_model()
  683. # @handle_legacy_interface(
  684. # weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
  685. # weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  686. # )
  687. def linenet_resnet50_fpn_v2(
  688. *,
  689. weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
  690. progress: bool = True,
  691. num_classes: Optional[int] = None,
  692. weights_backbone: Optional[ResNet50_Weights] = None,
  693. trainable_backbone_layers: Optional[int] = None,
  694. **kwargs: Any,
  695. ) -> LineNet:
  696. """
  697. Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
  698. Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
  699. .. betastatus:: detection module
  700. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  701. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  702. details.
  703. Args:
  704. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
  705. pretrained weights to use. See
  706. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
  707. more details, and possible values. By default, no pre-trained
  708. weights are used.
  709. progress (bool, optional): If True, displays a progress bar of the
  710. download to stderr. Default is True.
  711. num_classes (int, optional): number of output classes of the model (including the background)
  712. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  713. pretrained weights for the backbone.
  714. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  715. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  716. trainable. If ``None`` is passed (the default) this value is set to 3.
  717. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  718. base class. Please refer to the `source code
  719. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  720. for more details about this class.
  721. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
  722. :members:
  723. """
  724. weights = LineNet_ResNet50_FPN_V2_Weights.verify(weights)
  725. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  726. if weights is not None:
  727. weights_backbone = None
  728. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  729. elif num_classes is None:
  730. num_classes = 91
  731. is_trained = weights is not None or weights_backbone is not None
  732. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  733. backbone = resnet50(weights=weights_backbone, progress=progress)
  734. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
  735. rpn_anchor_generator = _default_anchorgen()
  736. rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
  737. box_head = LineNetConvFCHead(
  738. (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
  739. )
  740. model = LineNet(
  741. backbone,
  742. num_classes=num_classes,
  743. rpn_anchor_generator=rpn_anchor_generator,
  744. rpn_head=rpn_head,
  745. box_head=box_head,
  746. **kwargs,
  747. )
  748. if weights is not None:
  749. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  750. return model
  751. def linenet_resnet101_fpn_v2(
  752. *,
  753. weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
  754. progress: bool = True,
  755. num_classes: Optional[int] = None,
  756. weights_backbone: Optional[ResNet50_Weights] = None,
  757. trainable_backbone_layers: Optional[int] = None,
  758. **kwargs: Any,
  759. ) -> LineNet:
  760. """
  761. Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
  762. Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
  763. .. betastatus:: detection module
  764. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  765. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  766. details.
  767. Args:
  768. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
  769. pretrained weights to use. See
  770. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
  771. more details, and possible values. By default, no pre-trained
  772. weights are used.
  773. progress (bool, optional): If True, displays a progress bar of the
  774. download to stderr. Default is True.
  775. num_classes (int, optional): number of output classes of the model (including the background)
  776. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  777. pretrained weights for the backbone.
  778. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  779. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  780. trainable. If ``None`` is passed (the default) this value is set to 3.
  781. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  782. base class. Please refer to the `source code
  783. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  784. for more details about this class.
  785. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
  786. :members:
  787. """
  788. weights = LineNet_ResNet50_FPN_V2_Weights.verify(weights)
  789. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  790. if weights is not None:
  791. weights_backbone = None
  792. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  793. elif num_classes is None:
  794. num_classes = 91
  795. is_trained = weights is not None or weights_backbone is not None
  796. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  797. backbone = resnet101(weights=weights_backbone, progress=progress)
  798. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
  799. rpn_anchor_generator = _default_anchorgen()
  800. rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
  801. box_head = LineNetConvFCHead(
  802. (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
  803. )
  804. model = LineNet(
  805. backbone,
  806. num_classes=num_classes,
  807. rpn_anchor_generator=rpn_anchor_generator,
  808. rpn_head=rpn_head,
  809. box_head=box_head,
  810. **kwargs,
  811. )
  812. if weights is not None:
  813. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  814. return model
  815. def _linenet_mobilenet_v3_large_fpn(
  816. *,
  817. weights: Optional[Union[LineNet_MobileNet_V3_Large_FPN_Weights, LineNet_MobileNet_V3_Large_320_FPN_Weights]],
  818. progress: bool,
  819. num_classes: Optional[int],
  820. weights_backbone: Optional[MobileNet_V3_Large_Weights],
  821. trainable_backbone_layers: Optional[int],
  822. **kwargs: Any,
  823. ) -> LineNet:
  824. if weights is not None:
  825. weights_backbone = None
  826. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  827. elif num_classes is None:
  828. num_classes = 91
  829. is_trained = weights is not None or weights_backbone is not None
  830. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
  831. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  832. backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  833. backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
  834. anchor_sizes = (
  835. (
  836. 32,
  837. 64,
  838. 128,
  839. 256,
  840. 512,
  841. ),
  842. ) * 3
  843. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  844. model = LineNet(
  845. backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
  846. )
  847. if weights is not None:
  848. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  849. return model
  850. # @register_model()
  851. # @handle_legacy_interface(
  852. # weights=("pretrained", LineNet_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
  853. # weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  854. # )
  855. def linenet_mobilenet_v3_large_320_fpn(
  856. *,
  857. weights: Optional[LineNet_MobileNet_V3_Large_320_FPN_Weights] = None,
  858. progress: bool = True,
  859. num_classes: Optional[int] = None,
  860. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  861. trainable_backbone_layers: Optional[int] = None,
  862. **kwargs: Any,
  863. ) -> LineNet:
  864. """
  865. Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
  866. .. betastatus:: detection module
  867. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  868. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  869. details.
  870. Example::
  871. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
  872. >>> model.eval()
  873. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  874. >>> predictions = model(x)
  875. Args:
  876. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
  877. pretrained weights to use. See
  878. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
  879. more details, and possible values. By default, no pre-trained
  880. weights are used.
  881. progress (bool, optional): If True, displays a progress bar of the
  882. download to stderr. Default is True.
  883. num_classes (int, optional): number of output classes of the model (including the background)
  884. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  885. pretrained weights for the backbone.
  886. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  887. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  888. trainable. If ``None`` is passed (the default) this value is set to 3.
  889. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  890. base class. Please refer to the `source code
  891. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  892. for more details about this class.
  893. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
  894. :members:
  895. """
  896. weights = LineNet_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
  897. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  898. defaults = {
  899. "min_size": 320,
  900. "max_size": 640,
  901. "rpn_pre_nms_top_n_test": 150,
  902. "rpn_post_nms_top_n_test": 150,
  903. "rpn_score_thresh": 0.05,
  904. }
  905. kwargs = {**defaults, **kwargs}
  906. return _linenet_mobilenet_v3_large_fpn(
  907. weights=weights,
  908. progress=progress,
  909. num_classes=num_classes,
  910. weights_backbone=weights_backbone,
  911. trainable_backbone_layers=trainable_backbone_layers,
  912. **kwargs,
  913. )
  914. # @register_model()
  915. # @handle_legacy_interface(
  916. # weights=("pretrained", LineNet_MobileNet_V3_Large_FPN_Weights.COCO_V1),
  917. # weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  918. # )
  919. def linenet_mobilenet_v3_large_fpn(
  920. *,
  921. weights: Optional[LineNet_MobileNet_V3_Large_FPN_Weights] = None,
  922. progress: bool = True,
  923. num_classes: Optional[int] = None,
  924. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  925. trainable_backbone_layers: Optional[int] = None,
  926. **kwargs: Any,
  927. ) -> LineNet:
  928. """
  929. Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
  930. .. betastatus:: detection module
  931. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  932. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  933. details.
  934. Example::
  935. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
  936. >>> model.eval()
  937. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  938. >>> predictions = model(x)
  939. Args:
  940. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
  941. pretrained weights to use. See
  942. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
  943. more details, and possible values. By default, no pre-trained
  944. weights are used.
  945. progress (bool, optional): If True, displays a progress bar of the
  946. download to stderr. Default is True.
  947. num_classes (int, optional): number of output classes of the model (including the background)
  948. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  949. pretrained weights for the backbone.
  950. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  951. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  952. trainable. If ``None`` is passed (the default) this value is set to 3.
  953. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  954. base class. Please refer to the `source code
  955. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  956. for more details about this class.
  957. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
  958. :members:
  959. """
  960. weights = LineNet_MobileNet_V3_Large_FPN_Weights.verify(weights)
  961. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  962. defaults = {
  963. "rpn_score_thresh": 0.05,
  964. }
  965. kwargs = {**defaults, **kwargs}
  966. return _linenet_mobilenet_v3_large_fpn(
  967. weights=weights,
  968. progress=progress,
  969. num_classes=num_classes,
  970. weights_backbone=weights_backbone,
  971. trainable_backbone_layers=trainable_backbone_layers,
  972. **kwargs,
  973. )