|
|
@@ -18,7 +18,7 @@ from .loi_heads import RoIHeads
|
|
|
|
|
|
from .trainer import Trainer
|
|
|
from ..base.backbone_factory import get_anchor_generator, MaxVitBackbone, \
|
|
|
- get_swin_transformer_fpn
|
|
|
+ get_swin_transformer_fpn, get_efficientnetv2_fpn
|
|
|
# from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
|
|
|
from ..base.base_detection_net import BaseDetectionNet
|
|
|
import torch.nn.functional as F
|
|
|
@@ -500,6 +500,47 @@ def linedetect_newresnet152fpn(
|
|
|
|
|
|
return model
|
|
|
|
|
|
+def linedetect_efficientnet(
|
|
|
+ *,
|
|
|
+ num_classes: Optional[int] = None,
|
|
|
+ num_points:Optional[int] = None,
|
|
|
+ name: Optional[str] = 'efficientnet_v2_l',
|
|
|
+ **kwargs: Any,
|
|
|
+) -> LineDetect:
|
|
|
+ # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
|
|
|
+ # weights_backbone = ResNet50_Weights.verify(weights_backbone)
|
|
|
+ if num_classes is None:
|
|
|
+ num_classes = 5
|
|
|
+ if num_points is None:
|
|
|
+ num_points = 3
|
|
|
+
|
|
|
+
|
|
|
+ size=224*3
|
|
|
+ featmap_names = ['0', '1', '2', '3', '4', 'pool']
|
|
|
+
|
|
|
+ roi_pooler = MultiScaleRoIAlign(
|
|
|
+ featmap_names=featmap_names,
|
|
|
+ output_size=7,
|
|
|
+ sampling_ratio=2
|
|
|
+ )
|
|
|
+ backbone_with_fpn=get_efficientnetv2_fpn(name=name)
|
|
|
+
|
|
|
+ test_input = torch.randn(1, 3,size,size)
|
|
|
+
|
|
|
+ model = LineDetect(
|
|
|
+ backbone=backbone_with_fpn,
|
|
|
+ min_size=size,
|
|
|
+ max_size=size,
|
|
|
+ num_classes=num_classes, # COCO æ°æ®éæ 91 ç±»
|
|
|
+ rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input),
|
|
|
+ box_roi_pool=roi_pooler,
|
|
|
+ detect_line=False,
|
|
|
+ detect_point=False,
|
|
|
+ detect_arc=False,
|
|
|
+ detect_circle=True,
|
|
|
+ )
|
|
|
+ return model
|
|
|
+
|
|
|
def linedetect_maxvitfpn(
|
|
|
*,
|
|
|
num_classes: Optional[int] = None,
|
|
|
@@ -514,7 +555,7 @@ def linedetect_maxvitfpn(
|
|
|
num_points = 3
|
|
|
|
|
|
|
|
|
- size=224*2
|
|
|
+ size=224*3
|
|
|
|
|
|
|
|
|
maxvit = MaxVitBackbone(input_size=(size,size))
|