Explorar o código

添加选择efficientnet s m l 权重功能

RenLiqiang hai 7 meses
pai
achega
9b0399e15c
Modificáronse 1 ficheiros con 9 adicións e 2 borrados
  1. 9 2
      models/line_detect/line_net.py

+ 9 - 2
models/line_detect/line_net.py

@@ -6,7 +6,8 @@ from torchvision.ops import MultiScaleRoIAlign
 
 from libs.vision_libs import ops
 from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large, EfficientNet_V2_S_Weights, \
-    efficientnet_v2_s, detection
+    efficientnet_v2_s, detection, EfficientNet_V2_L_Weights, efficientnet_v2_l, EfficientNet_V2_M_Weights, \
+    efficientnet_v2_m
 from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
 from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
 from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
@@ -337,11 +338,17 @@ _COMMON_META = {
 }
 
 
-def create_efficientnetv2_backbone(name='efficientnet_v2_s', pretrained=True):
+def create_efficientnetv2_backbone(name='efficientnet_v2_l', pretrained=True):
     # 加载EfficientNetV2模型
     if name == 'efficientnet_v2_s':
         weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 if pretrained else None
         backbone = efficientnet_v2_s(weights=weights).features
+    if name == 'efficientnet_v2_m':
+        weights = EfficientNet_V2_M_Weights.IMAGENET1K_V1 if pretrained else None
+        backbone = efficientnet_v2_m(weights=weights).features
+    if name == 'efficientnet_v2_l':
+        weights = EfficientNet_V2_L_Weights.IMAGENET1K_V1 if pretrained else None
+        backbone = efficientnet_v2_l(weights=weights).features
 
     # 定义返回的层索引和名称
     return_layers = {"2": "0", "3": "1", "4": "2", "5": "3"}