line_net.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058
  1. import os
  2. from typing import Any, Callable, List, Optional, Tuple, Union
  3. import torch
  4. from torch import nn
  5. from torchvision.ops import MultiScaleRoIAlign
  6. from libs.vision_libs import ops
  7. from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large, EfficientNet_V2_S_Weights, \
  8. efficientnet_v2_s, detection, EfficientNet_V2_L_Weights, efficientnet_v2_l, EfficientNet_V2_M_Weights, \
  9. efficientnet_v2_m
  10. from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
  11. from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
  12. from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
  13. from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
  14. from libs.vision_libs.ops import misc as misc_nn_ops
  15. from libs.vision_libs.transforms._presets import ObjectDetection
  16. from .line_head import LineRCNNHeads
  17. from .line_predictor import LineRCNNPredictor
  18. from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
  19. from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
  20. from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
  21. from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, ResNet18_Weights, resnet18
  22. from libs.vision_libs.models.detection._utils import overwrite_eps
  23. from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, \
  24. BackboneWithFPN
  25. from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
  26. from .roi_heads import RoIHeads
  27. from .trainer import Trainer
  28. from ..base import backbone_factory
  29. from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
  30. # from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
  31. from ..base.base_detection_net import BaseDetectionNet
  32. import torch.nn.functional as F
  33. from .predict import Predict1, Predict
  34. from ..base.high_reso_resnet import resnet50fpn, resnet18fpn
  35. from ..config.config_tool import read_yaml
  36. FEATURE_DIM = 8
  37. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  38. __all__ = [
  39. "LineNet",
  40. "LineNet_ResNet50_FPN_Weights",
  41. "LineNet_ResNet50_FPN_V2_Weights",
  42. "LineNet_MobileNet_V3_Large_FPN_Weights",
  43. "LineNet_MobileNet_V3_Large_320_FPN_Weights",
  44. "linenet_resnet50_fpn",
  45. "linenet_resnet50_fpn_v2",
  46. "linenet_mobilenet_v3_large_fpn",
  47. "linenet_mobilenet_v3_large_320_fpn",
  48. ]
  49. def _default_anchorgen():
  50. anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
  51. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  52. return AnchorGenerator(anchor_sizes, aspect_ratios)
  53. class LineNet(BaseDetectionNet):
  54. # def __init__(self, cfg, **kwargs):
  55. # cfg = read_yaml(cfg)
  56. # self.cfg=cfg
  57. # backbone = cfg['backbone']
  58. # print(f'LineNet Backbone:{backbone}')
  59. # num_classes = cfg['num_classes']
  60. #
  61. # if backbone == 'resnet50_fpn':
  62. # backbone=backbone_factory.get_resnet50_fpn()
  63. # print(f'out_chanenels:{backbone.out_channels}')
  64. # elif backbone== 'mobilenet_v3_large_fpn':
  65. # backbone=backbone_factory.get_mobilenet_v3_large_fpn()
  66. # elif backbone=='resnet18_fpn':
  67. # backbone=backbone_factory.get_resnet18_fpn()
  68. #
  69. # self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
  70. def __init__(
  71. self,
  72. backbone,
  73. num_classes=None,
  74. # transform parameters
  75. min_size=512,
  76. max_size=1333,
  77. image_mean=None,
  78. image_std=None,
  79. # RPN parameters
  80. rpn_anchor_generator=None,
  81. rpn_head=None,
  82. rpn_pre_nms_top_n_train=2000,
  83. rpn_pre_nms_top_n_test=1000,
  84. rpn_post_nms_top_n_train=2000,
  85. rpn_post_nms_top_n_test=1000,
  86. rpn_nms_thresh=0.7,
  87. rpn_fg_iou_thresh=0.7,
  88. rpn_bg_iou_thresh=0.3,
  89. rpn_batch_size_per_image=256,
  90. rpn_positive_fraction=0.5,
  91. rpn_score_thresh=0.0,
  92. # Box parameters
  93. box_roi_pool=None,
  94. box_head=None,
  95. box_predictor=None,
  96. box_score_thresh=0.05,
  97. box_nms_thresh=0.5,
  98. box_detections_per_img=100,
  99. box_fg_iou_thresh=0.5,
  100. box_bg_iou_thresh=0.5,
  101. box_batch_size_per_image=512,
  102. box_positive_fraction=0.25,
  103. bbox_reg_weights=None,
  104. # line parameters
  105. line_head=None,
  106. line_predictor=None,
  107. **kwargs,
  108. ):
  109. if not hasattr(backbone, "out_channels"):
  110. raise ValueError(
  111. "backbone should contain an attribute out_channels "
  112. "specifying the number of output channels (assumed to be the "
  113. "same for all the levels)"
  114. )
  115. if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
  116. raise TypeError(
  117. f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
  118. )
  119. if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
  120. raise TypeError(
  121. f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
  122. )
  123. if num_classes is not None:
  124. if box_predictor is not None:
  125. raise ValueError("num_classes should be None when box_predictor is specified")
  126. else:
  127. if box_predictor is None:
  128. raise ValueError("num_classes should not be None when box_predictor is not specified")
  129. out_channels = backbone.out_channels
  130. # cfg = read_yaml(cfg)
  131. # self.cfg=cfg
  132. if line_head is None:
  133. num_class = 5
  134. line_head = LineRCNNHeads(out_channels, num_class)
  135. if line_predictor is None:
  136. line_predictor = LineRCNNPredictor()
  137. if rpn_anchor_generator is None:
  138. rpn_anchor_generator = _default_anchorgen()
  139. if rpn_head is None:
  140. rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
  141. rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
  142. rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
  143. rpn = RegionProposalNetwork(
  144. rpn_anchor_generator,
  145. rpn_head,
  146. rpn_fg_iou_thresh,
  147. rpn_bg_iou_thresh,
  148. rpn_batch_size_per_image,
  149. rpn_positive_fraction,
  150. rpn_pre_nms_top_n,
  151. rpn_post_nms_top_n,
  152. rpn_nms_thresh,
  153. score_thresh=rpn_score_thresh,
  154. )
  155. if box_roi_pool is None:
  156. box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3","4"], output_size=7, sampling_ratio=2)
  157. if box_head is None:
  158. resolution = box_roi_pool.output_size[0]
  159. representation_size = 1024
  160. box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
  161. if box_predictor is None:
  162. representation_size = 1024
  163. box_predictor = BoxPredictor(representation_size, num_classes)
  164. roi_heads = RoIHeads(
  165. # Box
  166. box_roi_pool,
  167. box_head,
  168. box_predictor,
  169. line_head,
  170. line_predictor,
  171. box_fg_iou_thresh,
  172. box_bg_iou_thresh,
  173. box_batch_size_per_image,
  174. box_positive_fraction,
  175. bbox_reg_weights,
  176. box_score_thresh,
  177. box_nms_thresh,
  178. box_detections_per_img,
  179. )
  180. if image_mean is None:
  181. image_mean = [0.485, 0.456, 0.406]
  182. if image_std is None:
  183. image_std = [0.229, 0.224, 0.225]
  184. transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
  185. super().__init__(backbone, rpn, roi_heads, transform)
  186. self.roi_heads = roi_heads
  187. # self.roi_heads.line_head = line_head
  188. # self.roi_heads.line_predictor = line_predictor
  189. def start_train(self, cfg):
  190. # cfg = read_yaml(cfg)
  191. self.trainer = Trainer()
  192. self.trainer.train_from_cfg(model=self, cfg=cfg)
  193. def load_best_model(self,save_path, device='cuda'):
  194. if os.path.exists(save_path):
  195. checkpoint = torch.load(save_path, map_location=device)
  196. self.load_state_dict(checkpoint['model_state_dict'])
  197. # if optimizer is not None:
  198. # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  199. # epoch = checkpoint['epoch']
  200. # loss = checkpoint['loss']
  201. # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  202. print(f"Loaded model from {save_path}")
  203. else:
  204. print(f"No saved model found at {save_path}")
  205. return self
  206. # 加载权重和推理一起
  207. def predict(self,img_path, type=0, threshold=0.5, save_path=None, show=False):
  208. # self.predict = Predict(pt_path, model, img_path, type, threshold, save_path, show)
  209. self.eval()
  210. self.to(device)
  211. self.predict = Predict(self, img_path, type, threshold, save_path, show)
  212. self.predict.run()
  213. # 不加载权重
  214. def predict1(self, model, img_path, type=0, threshold=0.5, save_path=None, show=False):
  215. self.predict = Predict1(model, img_path, type, threshold, save_path, show)
  216. self.predict.run()
  217. class TwoMLPHead(nn.Module):
  218. """
  219. Standard heads for FPN-based models
  220. Args:
  221. in_channels (int): number of input channels
  222. representation_size (int): size of the intermediate representation
  223. """
  224. def __init__(self, in_channels, representation_size):
  225. super().__init__()
  226. self.fc6 = nn.Linear(in_channels, representation_size)
  227. self.fc7 = nn.Linear(representation_size, representation_size)
  228. def forward(self, x):
  229. x = x.flatten(start_dim=1)
  230. x = F.relu(self.fc6(x))
  231. x = F.relu(self.fc7(x))
  232. return x
  233. class LineNetConvFCHead(nn.Sequential):
  234. def __init__(
  235. self,
  236. input_size: Tuple[int, int, int],
  237. conv_layers: List[int],
  238. fc_layers: List[int],
  239. norm_layer: Optional[Callable[..., nn.Module]] = None,
  240. ):
  241. """
  242. Args:
  243. input_size (Tuple[int, int, int]): the input size in CHW format.
  244. conv_layers (list): feature dimensions of each Convolution layer
  245. fc_layers (list): feature dimensions of each FCN layer
  246. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  247. """
  248. in_channels, in_height, in_width = input_size
  249. blocks = []
  250. previous_channels = in_channels
  251. for current_channels in conv_layers:
  252. blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
  253. previous_channels = current_channels
  254. blocks.append(nn.Flatten())
  255. previous_channels = previous_channels * in_height * in_width
  256. for current_channels in fc_layers:
  257. blocks.append(nn.Linear(previous_channels, current_channels))
  258. blocks.append(nn.ReLU(inplace=True))
  259. previous_channels = current_channels
  260. super().__init__(*blocks)
  261. for layer in self.modules():
  262. if isinstance(layer, nn.Conv2d):
  263. nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
  264. if layer.bias is not None:
  265. nn.init.zeros_(layer.bias)
  266. class BoxPredictor(nn.Module):
  267. """
  268. Standard classification + bounding box regression layers
  269. for Fast R-CNN.
  270. Args:
  271. in_channels (int): number of input channels
  272. num_classes (int): number of output classes (including background)
  273. """
  274. def __init__(self, in_channels, num_classes):
  275. super().__init__()
  276. self.cls_score = nn.Linear(in_channels, num_classes)
  277. self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
  278. def forward(self, x):
  279. if x.dim() == 4:
  280. torch._assert(
  281. list(x.shape[2:]) == [1, 1],
  282. f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
  283. )
  284. x = x.flatten(start_dim=1)
  285. scores = self.cls_score(x)
  286. bbox_deltas = self.bbox_pred(x)
  287. return scores, bbox_deltas
  288. _COMMON_META = {
  289. "categories": _COCO_CATEGORIES,
  290. "min_size": (1, 1),
  291. }
  292. def create_efficientnetv2_backbone(name='efficientnet_v2_m', pretrained=True):
  293. # 加载EfficientNetV2模型
  294. if name == 'efficientnet_v2_s':
  295. weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 if pretrained else None
  296. backbone = efficientnet_v2_s(weights=weights).features
  297. if name == 'efficientnet_v2_m':
  298. weights = EfficientNet_V2_M_Weights.IMAGENET1K_V1 if pretrained else None
  299. backbone = efficientnet_v2_m(weights=weights).features
  300. if name == 'efficientnet_v2_l':
  301. weights = EfficientNet_V2_L_Weights.IMAGENET1K_V1 if pretrained else None
  302. backbone = efficientnet_v2_l(weights=weights).features
  303. # 定义返回的层索引和名称
  304. return_layers = {"2": "0", "3": "1", "4": "2", "5": "3"}
  305. # 获取每个层输出通道数
  306. in_channels_list = []
  307. for layer_idx in [2, 3, 4, 5]:
  308. module = backbone[layer_idx]
  309. if hasattr(module, 'out_channels'):
  310. in_channels_list.append(module.out_channels)
  311. elif hasattr(module[-1], 'out_channels'):
  312. # 如果module本身没有out_channels,检查最后一个子模块
  313. in_channels_list.append(module[-1].out_channels)
  314. else:
  315. raise ValueError(f"Cannot determine out_channels for layer {layer_idx}")
  316. # 使用BackboneWithFPN包装backbone
  317. backbone_with_fpn = BackboneWithFPN(
  318. backbone=backbone,
  319. return_layers=return_layers,
  320. in_channels_list=in_channels_list,
  321. out_channels=256
  322. )
  323. return backbone_with_fpn
  324. def get_line_net_efficientnetv2(num_classes, pretrained_backbone=True):
  325. # 创建EfficientNetV2 backbone
  326. backbone = create_efficientnetv2_backbone(pretrained=pretrained_backbone)
  327. # 确认 backbone 输出特征图数量
  328. with torch.no_grad():
  329. images = torch.rand(1,3, 600, 800)
  330. features = backbone(images)
  331. featmap_names = list(features.keys())
  332. print("Feature map names:", featmap_names) # 例如 ['0', '1', '2', '3']
  333. # 根据实际特征层数量设置 anchors
  334. # num_levels = len(featmap_names)
  335. num_levels=5
  336. featmap_names= ['0', '1', '2', '3', 'pool']
  337. anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_levels)) # 自动生成不同大小
  338. print(f'anchor_sizes:{anchor_sizes}')
  339. aspect_ratios = ((0.5, 1.0, 2.0),) * num_levels # 所有层共享相同比例
  340. print(f'aspect_ratios:{aspect_ratios}')
  341. anchor_generator = AnchorGenerator(
  342. sizes=anchor_sizes,
  343. aspect_ratios=aspect_ratios
  344. )
  345. # ROI Pooling
  346. roi_pooler = MultiScaleRoIAlign(
  347. featmap_names=featmap_names,
  348. output_size=7,
  349. sampling_ratio=2
  350. )
  351. # 构建模型
  352. model = LineNet(
  353. backbone=backbone,
  354. num_classes=num_classes,
  355. rpn_anchor_generator=anchor_generator,
  356. box_roi_pool=roi_pooler
  357. )
  358. return model
  359. def get_line_net_convnext_fpn(num_classes=91):
  360. backbone=get_convnext_fpn()
  361. featmap_names = ['0', '1', '2', '3', 'pool']
  362. roi_pooler = MultiScaleRoIAlign(
  363. featmap_names=featmap_names,
  364. output_size=7,
  365. sampling_ratio=2
  366. )
  367. test_input = torch.rand(1, 3, 224, 224)
  368. anchor_generator = get_anchor_generator(backbone, test_input)
  369. model = LineNet(
  370. backbone=backbone,
  371. num_classes=num_classes, # COCO 数据集有 91 类
  372. rpn_anchor_generator=anchor_generator,
  373. box_roi_pool=roi_pooler
  374. )
  375. return model
  376. class LineNet_ResNet50_FPN_Weights(WeightsEnum):
  377. COCO_V1 = Weights(
  378. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
  379. transforms=ObjectDetection,
  380. meta={
  381. **_COMMON_META,
  382. "num_params": 41755286,
  383. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
  384. "_metrics": {
  385. "COCO-val2017": {
  386. "box_map": 37.0,
  387. }
  388. },
  389. "_ops": 134.38,
  390. "_file_size": 159.743,
  391. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  392. },
  393. )
  394. DEFAULT = COCO_V1
  395. class LineNet_ResNet50_FPN_V2_Weights(WeightsEnum):
  396. COCO_V1 = Weights(
  397. url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_v2_coco-dd69338a.pth",
  398. transforms=ObjectDetection,
  399. meta={
  400. **_COMMON_META,
  401. "num_params": 43712278,
  402. "recipe": "https://github.com/pytorch/vision/pull/5763",
  403. "_metrics": {
  404. "COCO-val2017": {
  405. "box_map": 46.7,
  406. }
  407. },
  408. "_ops": 280.371,
  409. "_file_size": 167.104,
  410. "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
  411. },
  412. )
  413. DEFAULT = COCO_V1
  414. class LineNet_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
  415. COCO_V1 = Weights(
  416. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
  417. transforms=ObjectDetection,
  418. meta={
  419. **_COMMON_META,
  420. "num_params": 19386354,
  421. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
  422. "_metrics": {
  423. "COCO-val2017": {
  424. "box_map": 32.8,
  425. }
  426. },
  427. "_ops": 4.494,
  428. "_file_size": 74.239,
  429. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  430. },
  431. )
  432. DEFAULT = COCO_V1
  433. class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
  434. COCO_V1 = Weights(
  435. url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
  436. transforms=ObjectDetection,
  437. meta={
  438. **_COMMON_META,
  439. "num_params": 19386354,
  440. "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
  441. "_metrics": {
  442. "COCO-val2017": {
  443. "box_map": 22.8,
  444. }
  445. },
  446. "_ops": 0.719,
  447. "_file_size": 74.239,
  448. "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
  449. },
  450. )
  451. DEFAULT = COCO_V1
  452. def linenet_newresnet18fpn(
  453. *,
  454. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  455. progress: bool = True,
  456. num_classes: Optional[int] = None,
  457. weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
  458. trainable_backbone_layers: Optional[int] = None,
  459. **kwargs: Any,
  460. ) -> LineNet:
  461. # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  462. # weights_backbone = ResNet50_Weights.verify(weights_backbone)
  463. if weights is not None:
  464. weights_backbone = None
  465. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  466. elif num_classes is None:
  467. num_classes = 91
  468. if weights_backbone is not None:
  469. print(f'resnet50 weights is not None')
  470. is_trained = weights is not None or weights_backbone is not None
  471. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  472. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  473. backbone =resnet18fpn()
  474. featmap_names=['0', '1', '2', '3','pool']
  475. # print(f'featmap_names:{featmap_names}')
  476. roi_pooler = MultiScaleRoIAlign(
  477. featmap_names=featmap_names,
  478. output_size=7,
  479. sampling_ratio=2
  480. )
  481. num_features=len(featmap_names)
  482. anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features)) # 自动生成不同大小
  483. # print(f'anchor_sizes:{anchor_sizes}')
  484. aspect_ratios = ((0.5, 1.0, 2.0),) * num_features
  485. # print(f'aspect_ratios:{aspect_ratios}')
  486. anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
  487. # anchors = anchor_generator.generate_anchors()
  488. # print("Number of anchor sizes:", len(anchor_generator.sizes)) # 应为 5
  489. model = LineNet(backbone, num_classes=num_classes,anchor_generator=anchor_generator,
  490. box_roi_pool=roi_pooler,
  491. **kwargs)
  492. if weights is not None:
  493. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  494. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  495. overwrite_eps(model, 0.0)
  496. return model
  497. def linenet_newresnet50fpn(
  498. *,
  499. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  500. progress: bool = True,
  501. num_classes: Optional[int] = None,
  502. weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
  503. trainable_backbone_layers: Optional[int] = None,
  504. **kwargs: Any,
  505. ) -> LineNet:
  506. # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  507. # weights_backbone = ResNet50_Weights.verify(weights_backbone)
  508. if weights is not None:
  509. weights_backbone = None
  510. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  511. elif num_classes is None:
  512. num_classes = 91
  513. if weights_backbone is not None:
  514. print(f'resnet50 weights is not None')
  515. is_trained = weights is not None or weights_backbone is not None
  516. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  517. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  518. backbone =resnet50fpn()
  519. featmap_names=['0', '1', '2', '3','pool']
  520. # print(f'featmap_names:{featmap_names}')
  521. roi_pooler = MultiScaleRoIAlign(
  522. featmap_names=featmap_names,
  523. output_size=7,
  524. sampling_ratio=2
  525. )
  526. num_features=len(featmap_names)
  527. anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features)) # 自动生成不同大小
  528. # print(f'anchor_sizes:{anchor_sizes}')
  529. aspect_ratios = ((0.5, 1.0, 2.0),) * num_features
  530. # print(f'aspect_ratios:{aspect_ratios}')
  531. anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
  532. # anchors = anchor_generator.generate_anchors()
  533. # print("Number of anchor sizes:", len(anchor_generator.sizes)) # 应为 5
  534. model = LineNet(backbone, num_classes=num_classes,anchor_generator=anchor_generator,
  535. box_roi_pool=roi_pooler,
  536. **kwargs)
  537. if weights is not None:
  538. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  539. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  540. overwrite_eps(model, 0.0)
  541. return model
  542. # @register_model()
  543. # @handle_legacy_interface(
  544. # weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
  545. # weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  546. # )
  547. def linenet_resnet18_fpn(
  548. *,
  549. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  550. progress: bool = True,
  551. num_classes: Optional[int] = None,
  552. weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
  553. trainable_backbone_layers: Optional[int] = None,
  554. **kwargs: Any,
  555. ) -> LineNet:
  556. # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  557. # weights_backbone = ResNet50_Weights.verify(weights_backbone)
  558. if weights is not None:
  559. weights_backbone = None
  560. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  561. elif num_classes is None:
  562. num_classes = 91
  563. if weights_backbone is not None:
  564. print(f'resnet50 weights is not None')
  565. is_trained = weights is not None or weights_backbone is not None
  566. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  567. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  568. backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  569. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  570. model = LineNet(backbone, num_classes=num_classes, **kwargs)
  571. if weights is not None:
  572. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  573. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  574. overwrite_eps(model, 0.0)
  575. return model
  576. def linenet_resnet50_fpn(
  577. *,
  578. weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
  579. progress: bool = True,
  580. num_classes: Optional[int] = None,
  581. weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
  582. trainable_backbone_layers: Optional[int] = None,
  583. **kwargs: Any,
  584. ) -> LineNet:
  585. """
  586. Faster R-CNN model with a ResNet-50-FPN backbone from the `Faster R-CNN: Towards Real-Time Object
  587. Detection with Region Proposal Networks <https://arxiv.org/abs/1506.01497>`__
  588. paper.
  589. .. betastatus:: detection module
  590. The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
  591. image, and should be in ``0-1`` range. Different images can have different sizes.
  592. The behavior of the model changes depending on if it is in training or evaluation mode.
  593. During training, the model expects both the input tensors and a targets (list of dictionary),
  594. containing:
  595. - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
  596. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  597. - labels (``Int64Tensor[N]``): the class label for each ground-truth box
  598. The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
  599. losses for both the RPN and the R-CNN.
  600. During inference, the model requires only the input tensors, and returns the post-processed
  601. predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
  602. follows, where ``N`` is the number of detections:
  603. - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
  604. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
  605. - labels (``Int64Tensor[N]``): the predicted labels for each detection
  606. - scores (``Tensor[N]``): the scores of each detection
  607. For more details on the output, you may refer to :ref:`instance_seg_output`.
  608. Faster R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.
  609. Example::
  610. >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
  611. >>> # For training
  612. >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
  613. >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
  614. >>> labels = torch.randint(1, 91, (4, 11))
  615. >>> images = list(image for image in images)
  616. >>> targets = []
  617. >>> for i in range(len(images)):
  618. >>> d = {}
  619. >>> d['boxes'] = boxes[i]
  620. >>> d['labels'] = labels[i]
  621. >>> targets.append(d)
  622. >>> output = model(images, targets)
  623. >>> # For inference
  624. >>> model.eval()
  625. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  626. >>> predictions = model(x)
  627. >>>
  628. >>> # optionally, if you want to export the model to ONNX:
  629. >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11)
  630. Args:
  631. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights`, optional): The
  632. pretrained weights to use. See
  633. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights` below for
  634. more details, and possible values. By default, no pre-trained
  635. weights are used.
  636. progress (bool, optional): If True, displays a progress bar of the
  637. download to stderr. Default is True.
  638. num_classes (int, optional): number of output classes of the model (including the background)
  639. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  640. pretrained weights for the backbone.
  641. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  642. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  643. trainable. If ``None`` is passed (the default) this value is set to 3.
  644. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  645. base class. Please refer to the `source code
  646. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  647. for more details about this class.
  648. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights
  649. :members:
  650. """
  651. weights = LineNet_ResNet50_FPN_Weights.verify(weights)
  652. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  653. if weights is not None:
  654. weights_backbone = None
  655. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  656. elif num_classes is None:
  657. num_classes = 91
  658. if weights_backbone is not None:
  659. print(f'resnet50 weights is not None')
  660. is_trained = weights is not None or weights_backbone is not None
  661. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  662. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  663. backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  664. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
  665. model = LineNet(backbone, num_classes=num_classes, **kwargs)
  666. if weights is not None:
  667. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  668. if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
  669. overwrite_eps(model, 0.0)
  670. return model
  671. # @register_model()
  672. # @handle_legacy_interface(
  673. # weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
  674. # weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
  675. # )
  676. def linenet_resnet50_fpn_v2(
  677. *,
  678. weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
  679. progress: bool = True,
  680. num_classes: Optional[int] = None,
  681. weights_backbone: Optional[ResNet50_Weights] = None,
  682. trainable_backbone_layers: Optional[int] = None,
  683. **kwargs: Any,
  684. ) -> LineNet:
  685. """
  686. Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
  687. Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
  688. .. betastatus:: detection module
  689. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  690. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  691. details.
  692. Args:
  693. weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
  694. pretrained weights to use. See
  695. :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
  696. more details, and possible values. By default, no pre-trained
  697. weights are used.
  698. progress (bool, optional): If True, displays a progress bar of the
  699. download to stderr. Default is True.
  700. num_classes (int, optional): number of output classes of the model (including the background)
  701. weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
  702. pretrained weights for the backbone.
  703. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  704. final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
  705. trainable. If ``None`` is passed (the default) this value is set to 3.
  706. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  707. base class. Please refer to the `source code
  708. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  709. for more details about this class.
  710. .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
  711. :members:
  712. """
  713. weights = LineNet_ResNet50_FPN_V2_Weights.verify(weights)
  714. weights_backbone = ResNet50_Weights.verify(weights_backbone)
  715. if weights is not None:
  716. weights_backbone = None
  717. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  718. elif num_classes is None:
  719. num_classes = 91
  720. is_trained = weights is not None or weights_backbone is not None
  721. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
  722. backbone = resnet50(weights=weights_backbone, progress=progress)
  723. backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
  724. rpn_anchor_generator = _default_anchorgen()
  725. rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
  726. box_head = LineNetConvFCHead(
  727. (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
  728. )
  729. model = LineNet(
  730. backbone,
  731. num_classes=num_classes,
  732. rpn_anchor_generator=rpn_anchor_generator,
  733. rpn_head=rpn_head,
  734. box_head=box_head,
  735. **kwargs,
  736. )
  737. if weights is not None:
  738. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  739. return model
  740. def _linenet_mobilenet_v3_large_fpn(
  741. *,
  742. weights: Optional[Union[LineNet_MobileNet_V3_Large_FPN_Weights, LineNet_MobileNet_V3_Large_320_FPN_Weights]],
  743. progress: bool,
  744. num_classes: Optional[int],
  745. weights_backbone: Optional[MobileNet_V3_Large_Weights],
  746. trainable_backbone_layers: Optional[int],
  747. **kwargs: Any,
  748. ) -> LineNet:
  749. if weights is not None:
  750. weights_backbone = None
  751. num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
  752. elif num_classes is None:
  753. num_classes = 91
  754. is_trained = weights is not None or weights_backbone is not None
  755. trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3)
  756. norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
  757. backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
  758. backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
  759. anchor_sizes = (
  760. (
  761. 32,
  762. 64,
  763. 128,
  764. 256,
  765. 512,
  766. ),
  767. ) * 3
  768. aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  769. model = LineNet(
  770. backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
  771. )
  772. if weights is not None:
  773. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  774. return model
  775. # @register_model()
  776. # @handle_legacy_interface(
  777. # weights=("pretrained", LineNet_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
  778. # weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  779. # )
  780. def linenet_mobilenet_v3_large_320_fpn(
  781. *,
  782. weights: Optional[LineNet_MobileNet_V3_Large_320_FPN_Weights] = None,
  783. progress: bool = True,
  784. num_classes: Optional[int] = None,
  785. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  786. trainable_backbone_layers: Optional[int] = None,
  787. **kwargs: Any,
  788. ) -> LineNet:
  789. """
  790. Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
  791. .. betastatus:: detection module
  792. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  793. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  794. details.
  795. Example::
  796. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT)
  797. >>> model.eval()
  798. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  799. >>> predictions = model(x)
  800. Args:
  801. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights`, optional): The
  802. pretrained weights to use. See
  803. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights` below for
  804. more details, and possible values. By default, no pre-trained
  805. weights are used.
  806. progress (bool, optional): If True, displays a progress bar of the
  807. download to stderr. Default is True.
  808. num_classes (int, optional): number of output classes of the model (including the background)
  809. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  810. pretrained weights for the backbone.
  811. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  812. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  813. trainable. If ``None`` is passed (the default) this value is set to 3.
  814. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  815. base class. Please refer to the `source code
  816. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  817. for more details about this class.
  818. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
  819. :members:
  820. """
  821. weights = LineNet_MobileNet_V3_Large_320_FPN_Weights.verify(weights)
  822. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  823. defaults = {
  824. "min_size": 320,
  825. "max_size": 640,
  826. "rpn_pre_nms_top_n_test": 150,
  827. "rpn_post_nms_top_n_test": 150,
  828. "rpn_score_thresh": 0.05,
  829. }
  830. kwargs = {**defaults, **kwargs}
  831. return _linenet_mobilenet_v3_large_fpn(
  832. weights=weights,
  833. progress=progress,
  834. num_classes=num_classes,
  835. weights_backbone=weights_backbone,
  836. trainable_backbone_layers=trainable_backbone_layers,
  837. **kwargs,
  838. )
  839. # @register_model()
  840. # @handle_legacy_interface(
  841. # weights=("pretrained", LineNet_MobileNet_V3_Large_FPN_Weights.COCO_V1),
  842. # weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
  843. # )
  844. def linenet_mobilenet_v3_large_fpn(
  845. *,
  846. weights: Optional[LineNet_MobileNet_V3_Large_FPN_Weights] = None,
  847. progress: bool = True,
  848. num_classes: Optional[int] = None,
  849. weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
  850. trainable_backbone_layers: Optional[int] = None,
  851. **kwargs: Any,
  852. ) -> LineNet:
  853. """
  854. Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
  855. .. betastatus:: detection module
  856. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
  857. :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
  858. details.
  859. Example::
  860. >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
  861. >>> model.eval()
  862. >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
  863. >>> predictions = model(x)
  864. Args:
  865. weights (:class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights`, optional): The
  866. pretrained weights to use. See
  867. :class:`~torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights` below for
  868. more details, and possible values. By default, no pre-trained
  869. weights are used.
  870. progress (bool, optional): If True, displays a progress bar of the
  871. download to stderr. Default is True.
  872. num_classes (int, optional): number of output classes of the model (including the background)
  873. weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  874. pretrained weights for the backbone.
  875. trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
  876. final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are
  877. trainable. If ``None`` is passed (the default) this value is set to 3.
  878. **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
  879. base class. Please refer to the `source code
  880. <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
  881. for more details about this class.
  882. .. autoclass:: torchvision.models.detection.FasterRCNN_MobileNet_V3_Large_FPN_Weights
  883. :members:
  884. """
  885. weights = LineNet_MobileNet_V3_Large_FPN_Weights.verify(weights)
  886. weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
  887. defaults = {
  888. "rpn_score_thresh": 0.05,
  889. }
  890. kwargs = {**defaults, **kwargs}
  891. return _linenet_mobilenet_v3_large_fpn(
  892. weights=weights,
  893. progress=progress,
  894. num_classes=num_classes,
  895. weights_backbone=weights_backbone,
  896. trainable_backbone_layers=trainable_backbone_layers,
  897. **kwargs,
  898. )