瀏覽代碼

添加resnet18_fpn backbone

RenLiqiang 3 月之前
父節點
當前提交
962f39456c
共有 4 個文件被更改,包括 32 次插入5 次删除
  1. 20 1
      models/base/backbone_factory.py
  2. 7 1
      models/line_detect/line_net.py
  3. 3 1
      models/line_detect/line_net.yaml
  4. 2 2
      models/line_detect/trainer.py

+ 20 - 1
models/base/backbone_factory.py

@@ -1,5 +1,7 @@
+from libs.vision_libs.models import mobilenet_v3_large
 from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
-from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
+from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, resnet18
 from libs.vision_libs.models.detection._utils import overwrite_eps
 from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
 from libs.vision_libs.ops import misc as misc_nn_ops
@@ -12,4 +14,21 @@ def get_resnet50_fpn():
     norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
     backbone = resnet50(weights=None, progress=True, norm_layer=norm_layer)
     backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    return backbone
+
+def get_resnet18_fpn():
+    is_trained = False
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+    backbone = resnet18(weights=None, progress=True, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    return backbone
+
+def get_mobilenet_v3_large_fpn():
+    is_trained =False
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, None, 6, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = mobilenet_v3_large(weights=None, progress=True, norm_layer=norm_layer)
+    backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
     return backbone

+ 7 - 1
models/line_detect/line_net.py

@@ -55,12 +55,18 @@ class LineNet(BaseDetectionNet):
         cfg = read_yaml(cfg)
         self.cfg=cfg
         backbone = cfg['backbone']
+        print(f'LineNet Backbone:{backbone}')
         num_classes = cfg['num_classes']
 
         if backbone == 'resnet50_fpn':
             backbone=backbone_factory.get_resnet50_fpn()
             print(f'out_chanenels:{backbone.out_channels}')
-            self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
+        elif backbone== 'mobilenet_v3_large_fpn':
+            backbone=backbone_factory.get_mobilenet_v3_large_fpn()
+        elif backbone=='resnet18_fpn':
+            backbone=backbone_factory.get_resnet18_fpn()
+
+        self.__construct__(backbone=backbone, num_classes=num_classes, **kwargs)
 
 
     def __construct__(

+ 3 - 1
models/line_detect/line_net.yaml

@@ -16,7 +16,9 @@ lneg: 1
 boxes: 1.0
 
 # backbone parameters
-backbone: resnet50_fpn
+#backbone: resnet50_fpn
+backbone: resnet18_fpn
+#backbone: mobilenet_v3_large_fpn
 #  backbone: unet
 depth: 4
 num_stacks: 1

+ 2 - 2
models/line_detect/trainer.py

@@ -95,7 +95,7 @@ class Trainer(BaseTrainer):
         dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
         # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
+        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
         train_collate_fn = utils.collate_fn_wirepoint
         data_loader_train = torch.utils.data.DataLoader(
             dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
@@ -104,7 +104,7 @@ class Trainer(BaseTrainer):
         dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
         val_sampler = torch.utils.data.RandomSampler(dataset_val)
         # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
+        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
         val_collate_fn = utils.collate_fn_wirepoint
         data_loader_val = torch.utils.data.DataLoader(
             dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn