Explorar o código

添加backbone_factory

RenLiqiang hai 3 meses
pai
achega
a60eb9a2ac
Modificáronse 2 ficheiros con 21 adicións e 80 borrados
  1. 15 0
      models/base/backbone_factory.py
  2. 6 80
      models/line_detect/line_net.py

+ 15 - 0
models/base/backbone_factory.py

@@ -0,0 +1,15 @@
+from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
+from libs.vision_libs.models.detection._utils import overwrite_eps
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from libs.vision_libs.ops import misc as misc_nn_ops
+from torch import nn
+
+
+def get_resnet50_fpn():
+    is_trained = False
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+    backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    return backbone

+ 6 - 80
models/line_detect/line_net.py

@@ -22,6 +22,7 @@ from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead
 
 from .roi_heads import RoIHeads
 from .trainer import Trainer
+from ..base import backbone_factory
 from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
 
@@ -57,87 +58,12 @@ class LineNet(BaseDetectionNet):
         num_classes = cfg['num_classes']
 
         if backbone == 'resnet50_fpn':
-            is_trained = False
-            trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
-            norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
-            backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
-            backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+            backbone=backbone_factory.get_resnet50_fpn()
             print(f'out_chanenels:{backbone.out_channels}')
-            self.construct(backbone=backbone,num_classes=num_classes,**kwargs)
-            # out_channels = backbone.out_channels
-            #
-            # min_size = 512,
-            # max_size = 1333,
-            # rpn_pre_nms_top_n_train = 2000,
-            # rpn_pre_nms_top_n_test = 1000,
-            # rpn_post_nms_top_n_train = 2000,
-            # rpn_post_nms_top_n_test = 1000,
-            # rpn_nms_thresh = 0.7,
-            # rpn_fg_iou_thresh = 0.7,
-            # rpn_bg_iou_thresh = 0.3,
-            # rpn_batch_size_per_image = 256,
-            # rpn_positive_fraction = 0.5,
-            # rpn_score_thresh = 0.0,
-            # box_score_thresh = 0.05,
-            # box_nms_thresh = 0.5,
-            # box_detections_per_img = 100,
-            # box_fg_iou_thresh = 0.5,
-            # box_bg_iou_thresh = 0.5,
-            # box_batch_size_per_image = 512,
-            # box_positive_fraction = 0.25,
-            # bbox_reg_weights = None,
-            #
-            # line_head = LineRCNNHeads(out_channels, 5)
-            # line_predictor = LineRCNNPredictor(cfg)
-            # rpn_anchor_generator = _default_anchorgen()
-            # rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
-            # rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
-            # rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
-            #
-            # rpn = RegionProposalNetwork(
-            #     rpn_anchor_generator,
-            #     rpn_head,
-            #     rpn_fg_iou_thresh,
-            #     rpn_bg_iou_thresh,
-            #     rpn_batch_size_per_image,
-            #     rpn_positive_fraction,
-            #     rpn_pre_nms_top_n,
-            #     rpn_post_nms_top_n,
-            #     rpn_nms_thresh,
-            #     score_thresh=rpn_score_thresh,
-            # )
-            #
-            # box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
-            #
-            # resolution = box_roi_pool.output_size[0]
-            # representation_size = 1024
-            # box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
-            # representation_size = 1024
-            # box_predictor = BoxPredictor(representation_size, num_classes)
-            #
-            # roi_heads = RoIHeads(
-            #     # Box
-            #     box_roi_pool,
-            #     box_head,
-            #     box_predictor,
-            #     line_head,
-            #     line_predictor,
-            #     box_fg_iou_thresh,
-            #     box_bg_iou_thresh,
-            #     box_batch_size_per_image,
-            #     box_positive_fraction,
-            #     bbox_reg_weights,
-            #     box_score_thresh,
-            #     box_nms_thresh,
-            #     box_detections_per_img,
-            # )
-            # image_mean = [0.485, 0.456, 0.406]
-            # image_std = [0.229, 0.224, 0.225]
-            # transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
-            # super().__init__(backbone, rpn, roi_heads, transform)
-            # self.roi_heads = roi_heads
-
-    def construct(
+            self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
+            
+
+    def __construct__(
             self,
             backbone,
             num_classes=None,