瀏覽代碼

add efficientnet

admin 4 月之前
父節點
當前提交
ec46c89c5d
共有 2 個文件被更改,包括 46 次插入4 次删除
  1. 43 2
      models/line_detect/line_detect.py
  2. 3 2
      models/line_detect/train_demo.py

+ 43 - 2
models/line_detect/line_detect.py

@@ -18,7 +18,7 @@ from .loi_heads import RoIHeads
 
 from .trainer import Trainer
 from ..base.backbone_factory import get_anchor_generator, MaxVitBackbone, \
-    get_swin_transformer_fpn
+    get_swin_transformer_fpn, get_efficientnetv2_fpn
 # from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
 from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
@@ -500,6 +500,47 @@ def linedetect_newresnet152fpn(
 
     return model
 
+def linedetect_efficientnet(
+        *,
+        num_classes: Optional[int] = None,
+        num_points:Optional[int] = None,
+        name: Optional[str] = 'efficientnet_v2_l',
+        **kwargs: Any,
+) -> LineDetect:
+    # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
+    # weights_backbone = ResNet50_Weights.verify(weights_backbone)
+    if num_classes is None:
+        num_classes = 5
+    if num_points is None:
+        num_points = 3
+
+
+    size=224*3
+    featmap_names = ['0', '1', '2', '3', '4', 'pool']
+
+    roi_pooler = MultiScaleRoIAlign(
+        featmap_names=featmap_names,
+        output_size=7,
+        sampling_ratio=2
+    )
+    backbone_with_fpn=get_efficientnetv2_fpn(name=name)
+
+    test_input = torch.randn(1, 3,size,size)
+
+    model = LineDetect(
+        backbone=backbone_with_fpn,
+        min_size=size,
+        max_size=size,
+        num_classes=num_classes,  # COCO 数据集有 91 类
+        rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input),
+        box_roi_pool=roi_pooler,
+        detect_line=False,
+        detect_point=False,
+        detect_arc=False,
+        detect_circle=True,
+    )
+    return model
+
 def linedetect_maxvitfpn(
         *,
         num_classes: Optional[int] = None,
@@ -514,7 +555,7 @@ def linedetect_maxvitfpn(
         num_points = 3
 
 
-    size=224*2
+    size=224*3
 
 
     maxvit = MaxVitBackbone(input_size=(size,size))

+ 3 - 2
models/line_detect/train_demo.py

@@ -2,7 +2,7 @@ import torch
 
 from models.line_detect.line_detect import linedetect_newresnet18fpn, linedetect_resnet50_fpn, linedetect_resnet18_fpn, \
     linedetect_newresnet50fpn, linedetect_maxvitfpn, linedetect_high_maxvitfpn, linedetect_swin_transformer_fpn, \
-    linedetect_newresnet101fpn, linedetect_newresnet152fpn
+    linedetect_newresnet101fpn, linedetect_newresnet152fpn, linedetect_efficientnet
 
 from models.line_net.trainer import Trainer
 
@@ -18,11 +18,12 @@ if __name__ == '__main__':
 
     # model=linedetect_resnet18_fpn()
     # model=linedetect_newresnet18fpn(num_points=4)
-    model=linedetect_newresnet50fpn(num_points=4)
+    # model=linedetect_newresnet50fpn(num_points=4)
     # model = linedetect_newresnet101fpn(num_points=3)
     # model = linedetect_newresnet152fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
+    model=linedetect_efficientnet(name='efficientnet_v2_l')
     # model=linedetect_high_maxvitfpn()
 
     # model=linedetect_swin_transformer_fpn(type='t')