Browse Source

添加efficientnet 作为backbone

RenLiqiang 7 tháng trước cách đây
mục cha
commit
37f391a4f9

+ 3 - 2
models/line_detect/111.py

@@ -227,12 +227,13 @@ class Trainer(BaseTrainer):
 
 import torch
 
-from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn
+from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, get_line_net_efficientnetv2
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
     # model = LineNet('line_net.yaml')
-    model = linenet_resnet50_fpn().to(device)
+    # model = linenet_resnet50_fpn().to(device)
+    model=get_line_net_efficientnetv2(2, pretrained_backbone=True).to(device)
     # model=linenet_resnet18_fpn()
     trainer = Trainer()
     trainer.train_cfg(model,cfg='./train.yaml')

+ 81 - 2
models/line_detect/line_net.py

@@ -4,7 +4,9 @@ import torch
 from torch import nn
 from torchvision.ops import MultiScaleRoIAlign
 
-from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large
+from libs.vision_libs import ops
+from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large, EfficientNet_V2_S_Weights, \
+    efficientnet_v2_s, detection
 from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
 from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
 from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
@@ -18,7 +20,8 @@ from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_
 from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
 from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, ResNet18_Weights, resnet18
 from libs.vision_libs.models.detection._utils import overwrite_eps
-from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, \
+    BackboneWithFPN
 from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
 
 from .roi_heads import RoIHeads
@@ -334,6 +337,80 @@ _COMMON_META = {
 }
 
 
+def create_efficientnetv2_backbone(name='efficientnet_v2_s', pretrained=True):
+    # 加载EfficientNetV2模型
+    if name == 'efficientnet_v2_s':
+        weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 if pretrained else None
+        backbone = efficientnet_v2_s(weights=weights).features
+
+    # 定义返回的层索引和名称
+    return_layers = {"2": "0", "3": "1", "4": "2", "5": "3"}
+
+    # 获取每个层输出通道数
+    in_channels_list = []
+    for layer_idx in [2, 3, 4, 5]:
+        module = backbone[layer_idx]
+        if hasattr(module, 'out_channels'):
+            in_channels_list.append(module.out_channels)
+        elif hasattr(module[-1], 'out_channels'):
+            # 如果module本身没有out_channels,检查最后一个子模块
+            in_channels_list.append(module[-1].out_channels)
+        else:
+            raise ValueError(f"Cannot determine out_channels for layer {layer_idx}")
+
+    # 使用BackboneWithFPN包装backbone
+    backbone_with_fpn = BackboneWithFPN(
+        backbone=backbone,
+        return_layers=return_layers,
+        in_channels_list=in_channels_list,
+        out_channels=256
+    )
+
+    return backbone_with_fpn
+
+
+def get_line_net_efficientnetv2(num_classes, pretrained_backbone=True):
+    # 创建EfficientNetV2 backbone
+    backbone = create_efficientnetv2_backbone(pretrained=pretrained_backbone)
+
+    # 确认 backbone 输出特征图数量
+    # with torch.no_grad():
+    #     images = torch.rand(1,3, 600, 800)
+    #     features = backbone(images)
+    #     featmap_names = list(features.keys())
+    #     print("Feature map names:", featmap_names)  # 例如 ['0', '1', '2', '3']
+
+    # 根据实际特征层数量设置 anchors
+    # num_levels = len(featmap_names)
+    num_levels=5
+    featmap_names= ['0', '1', '2', '3', 'pool']
+
+    anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_levels))  # 自动生成不同大小
+    aspect_ratios = ((0.5, 1.0, 2.0),) * num_levels  # 所有层共享相同比例
+
+    anchor_generator = AnchorGenerator(
+        sizes=anchor_sizes,
+        aspect_ratios=aspect_ratios
+    )
+
+    # ROI Pooling
+    roi_pooler = MultiScaleRoIAlign(
+        featmap_names=featmap_names,
+        output_size=7,
+        sampling_ratio=2
+    )
+
+    # 构建模型
+    model = LineNet(
+        backbone=backbone,
+        num_classes=num_classes,
+        rpn_anchor_generator=anchor_generator,
+        box_roi_pool=roi_pooler
+    )
+
+    return model
+
+
 class LineNet_ResNet50_FPN_Weights(WeightsEnum):
     COCO_V1 = Weights(
         url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
@@ -575,6 +652,8 @@ def linenet_resnet50_fpn(
     return model
 
 
+
+
 @register_model()
 @handle_legacy_interface(
     weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),

+ 3 - 3
models/line_detect/line_predictor.py

@@ -147,7 +147,7 @@ class LineRCNNPredictor(nn.Module):
         #     print(f'out:{out.shape}')
         # outputs=merge_features(outputs,100)
         batch, channel, row, col = inputs.shape
-        print(f'outputs:{inputs.shape}')
+        # print(f'outputs:{inputs.shape}')
         # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}')
 
         if targets is not None:
@@ -256,7 +256,7 @@ class LineRCNNPredictor(nn.Module):
                 .permute(1, 0, 2)
             )
             xp = self.pooling(xp)
-            print(f'xp forward.shape:{xp.shape}')
+            # print(f'xp forward.shape:{xp.shape}')
             xs.append(xp)
             idx.append(idx[-1] + xp.shape[0])
             # print(f'idx__:{idx}')
@@ -264,7 +264,7 @@ class LineRCNNPredictor(nn.Module):
         x, y = torch.cat(xs), torch.cat(ys)
         f = torch.cat(fs)
         x = x.reshape(-1, self.n_pts1 * self.dim_loi)
-        print(f' x reshape:{x.shape}')
+        # print(f' x reshape:{x.shape}')
 
         # print("Weight dtype:", self.fc2.weight.dtype)
         x = torch.cat([x, f], 1)

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: I:/datasets/4_23jiagonggongjian
+  datadir: I:/datasets/0322_suanzaisheng
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
   resume_from: