瀏覽代碼

添加resnet101

RenLiqiang 6 月之前
父節點
當前提交
f9a885b322
共有 2 個文件被更改,包括 94 次插入24 次删除
  1. 6 11
      models/line_detect/train_demo.py
  2. 88 13
      models/line_net/line_net.py

+ 6 - 11
models/line_detect/train_demo.py

@@ -1,21 +1,16 @@
 import torch
 import torch
 
 
-from models.line_detect.line_detect import linedetect_newresnet18fpn, linedetect_resnet50_fpn, linedetect_resnet18_fpn
-
-
-from models.line_net.trainer import Trainer
+from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, linenet_resnet101_fpn_v2
+from models.line_detect.trainer import Trainer
 
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
 if __name__ == '__main__':
 
 
     # model = LineNet('line_net.yaml')
     # model = LineNet('line_net.yaml')
     # model=linenet_resnet50_fpn()
     # model=linenet_resnet50_fpn()
-    # model = linedetect_resnet50_fpn()
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
-    # model=linenet_newresnet50fpn()
-    # model = lineDetect_resnet18_fpn()
-
-    # model=linedetect_resnet18_fpn()
-    model=linedetect_newresnet18fpn(num_points=3)
-
+    # model=linenet_resnet18_fpn()
+    model=linenet_resnet101_fpn_v2()
+    # trainer = Trainer()
+    # trainer.train_cfg(model,cfg='./train.yaml')
     model.start_train(cfg='train.yaml')
     model.start_train(cfg='train.yaml')

+ 88 - 13
models/line_net/line_net.py

@@ -19,7 +19,7 @@ from .line_predictor import LineRCNNPredictor
 from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
 from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
 from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
 from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
 from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
 from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
-from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, ResNet18_Weights, resnet18
+from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, ResNet18_Weights, resnet18, resnet101
 from libs.vision_libs.models.detection._utils import overwrite_eps
 from libs.vision_libs.models.detection._utils import overwrite_eps
 from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, \
 from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, \
     BackboneWithFPN
     BackboneWithFPN
@@ -135,14 +135,14 @@ class LineNet(BaseDetectionNet):
             )
             )
 
 
         # 修改第一个卷积层,将 in_channels 从 3 改为 4
         # 修改第一个卷积层,将 in_channels 从 3 改为 4
-        backbone.body.conv1 = nn.Conv2d(
-            in_channels=4,
-            out_channels=64,
-            kernel_size=7,
-            stride=2,
-            padding=3,
-            bias=False
-        )
+        # backbone.body.conv1 = nn.Conv2d(
+        #     in_channels=4,
+        #     out_channels=64,
+        #     kernel_size=7,
+        #     stride=2,
+        #     padding=3,
+        #     bias=False
+        # )
         if num_classes is not None:
         if num_classes is not None:
             if box_predictor is not None:
             if box_predictor is not None:
                 raise ValueError("num_classes should be None when box_predictor is specified")
                 raise ValueError("num_classes should be None when box_predictor is specified")
@@ -213,12 +213,12 @@ class LineNet(BaseDetectionNet):
         )
         )
 
 
         if image_mean is None:
         if image_mean is None:
-            # image_mean = [0.485, 0.456, 0.406]
-            image_mean = [0.485, 0.456, 0.406, 0.5]  # 假设你新加的通道均值为0.5
+            image_mean = [0.485, 0.456, 0.406]
+            # image_mean = [0.485, 0.456, 0.406, 0.5]  # 假设你新加的通道均值为0.5
 
 
         if image_std is None:
         if image_std is None:
-            # image_std = [0.229, 0.224, 0.225]
-            image_std = [0.229, 0.224, 0.225, 0.2]  # 标准差也补一个值
+            image_std = [0.229, 0.224, 0.225]
+            # image_std = [0.229, 0.224, 0.225, 0.2]  # 标准差也补一个值
         transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
         transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
 
 
         super().__init__(backbone, rpn, roi_heads, transform)
         super().__init__(backbone, rpn, roi_heads, transform)
@@ -885,6 +885,81 @@ def linenet_resnet50_fpn_v2(
     return model
     return model
 
 
 
 
+def linenet_resnet101_fpn_v2(
+        *,
+        weights: Optional[LineNet_ResNet50_FPN_V2_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        weights_backbone: Optional[ResNet50_Weights] = None,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> LineNet:
+    """
+    Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone from `Benchmarking Detection
+    Transfer Learning with Vision Transformers <https://arxiv.org/abs/2111.11429>`__ paper.
+
+    .. betastatus:: detection module
+
+    It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
+    :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more
+    details.
+
+    Args:
+        weights (:class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        num_classes (int, optional): number of output classes of the model (including the background)
+        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
+            pretrained weights for the backbone.
+        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from
+            final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are
+            trainable. If ``None`` is passed (the default) this value is set to 3.
+        **kwargs: parameters passed to the ``torchvision.models.detection.faster_rcnn.FasterRCNN``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights
+        :members:
+    """
+    weights = LineNet_ResNet50_FPN_V2_Weights.verify(weights)
+    weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+
+    backbone = resnet101(weights=weights_backbone, progress=progress)
+    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
+    rpn_anchor_generator = _default_anchorgen()
+    rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
+    box_head = LineNetConvFCHead(
+        (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
+    )
+    model = LineNet(
+        backbone,
+        num_classes=num_classes,
+        rpn_anchor_generator=rpn_anchor_generator,
+        rpn_head=rpn_head,
+        box_head=box_head,
+        **kwargs,
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
 def _linenet_mobilenet_v3_large_fpn(
 def _linenet_mobilenet_v3_large_fpn(
         *,
         *,
         weights: Optional[Union[LineNet_MobileNet_V3_Large_FPN_Weights, LineNet_MobileNet_V3_Large_320_FPN_Weights]],
         weights: Optional[Union[LineNet_MobileNet_V3_Large_FPN_Weights, LineNet_MobileNet_V3_Large_320_FPN_Weights]],