Forráskód Böngészése

add linenet_resnet18_fpn

RenLiqiang 3 hónapja
szülő
commit
b9183cd92c
2 módosított fájl, 41 hozzáadás és 3 törlés
  1. 38 1
      models/line_detect/line_net.py
  2. 3 2
      models/line_detect/test_train.py

+ 38 - 1
models/line_detect/line_net.py

@@ -15,7 +15,7 @@ from .line_predictor import LineRCNNPredictor
 from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
 from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
 from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
-from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, ResNet18_Weights, resnet18
 from libs.vision_libs.models.detection._utils import overwrite_eps
 from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
 from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
@@ -396,6 +396,43 @@ class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
     weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
     weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
 )
+
+def linenet_resnet18_fpn(
+        *,
+        weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> LineNet:
+
+    # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
+    # weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+    if weights_backbone is not None:
+        print(f'resnet50 weights is not None')
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
+    model = LineNet(backbone, num_classes=num_classes, **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
 def linenet_resnet50_fpn(
         *,
         weights: Optional[LineNet_ResNet50_FPN_Weights] = None,

+ 3 - 2
models/line_detect/test_train.py

@@ -1,13 +1,14 @@
 import torch
 
-from models.line_detect.line_net import linenet_resnet50_fpn, LineNet
+from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn
 from models.line_detect.trainer import Trainer
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
 
     # model = LineNet('line_net.yaml')
-    model=linenet_resnet50_fpn()
+    # model=linenet_resnet50_fpn()
+    model=linenet_resnet18_fpn()
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')
     model.train_by_cfg(cfg='train.yaml')