|
|
@@ -4,7 +4,9 @@ import torch
|
|
|
from torch import nn
|
|
|
from torchvision.ops import MultiScaleRoIAlign
|
|
|
|
|
|
-from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large
|
|
|
+from libs.vision_libs import ops
|
|
|
+from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large, EfficientNet_V2_S_Weights, \
|
|
|
+ efficientnet_v2_s, detection
|
|
|
from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
|
|
|
from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
|
|
|
from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
|
|
|
@@ -18,7 +20,8 @@ from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_
|
|
|
from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
|
|
|
from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, ResNet18_Weights, resnet18
|
|
|
from libs.vision_libs.models.detection._utils import overwrite_eps
|
|
|
-from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
|
|
|
+from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, \
|
|
|
+ BackboneWithFPN
|
|
|
from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
|
|
|
|
|
|
from .roi_heads import RoIHeads
|
|
|
@@ -334,6 +337,80 @@ _COMMON_META = {
|
|
|
}
|
|
|
|
|
|
|
|
|
+def create_efficientnetv2_backbone(name='efficientnet_v2_s', pretrained=True):
|
|
|
+ # 加载EfficientNetV2模型
|
|
|
+ if name == 'efficientnet_v2_s':
|
|
|
+ weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1 if pretrained else None
|
|
|
+ backbone = efficientnet_v2_s(weights=weights).features
|
|
|
+
|
|
|
+ # 定义返回的层索引和名称
|
|
|
+ return_layers = {"2": "0", "3": "1", "4": "2", "5": "3"}
|
|
|
+
|
|
|
+ # 获取每个层输出通道数
|
|
|
+ in_channels_list = []
|
|
|
+ for layer_idx in [2, 3, 4, 5]:
|
|
|
+ module = backbone[layer_idx]
|
|
|
+ if hasattr(module, 'out_channels'):
|
|
|
+ in_channels_list.append(module.out_channels)
|
|
|
+ elif hasattr(module[-1], 'out_channels'):
|
|
|
+ # 如果module本身没有out_channels,检查最后一个子模块
|
|
|
+ in_channels_list.append(module[-1].out_channels)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Cannot determine out_channels for layer {layer_idx}")
|
|
|
+
|
|
|
+ # 使用BackboneWithFPN包装backbone
|
|
|
+ backbone_with_fpn = BackboneWithFPN(
|
|
|
+ backbone=backbone,
|
|
|
+ return_layers=return_layers,
|
|
|
+ in_channels_list=in_channels_list,
|
|
|
+ out_channels=256
|
|
|
+ )
|
|
|
+
|
|
|
+ return backbone_with_fpn
|
|
|
+
|
|
|
+
|
|
|
+def get_line_net_efficientnetv2(num_classes, pretrained_backbone=True):
|
|
|
+ # 创建EfficientNetV2 backbone
|
|
|
+ backbone = create_efficientnetv2_backbone(pretrained=pretrained_backbone)
|
|
|
+
|
|
|
+ # 确认 backbone 输出特征图数量
|
|
|
+ # with torch.no_grad():
|
|
|
+ # images = torch.rand(1,3, 600, 800)
|
|
|
+ # features = backbone(images)
|
|
|
+ # featmap_names = list(features.keys())
|
|
|
+ # print("Feature map names:", featmap_names) # 例如 ['0', '1', '2', '3']
|
|
|
+
|
|
|
+ # 根据实际特征层数量设置 anchors
|
|
|
+ # num_levels = len(featmap_names)
|
|
|
+ num_levels=5
|
|
|
+ featmap_names= ['0', '1', '2', '3', 'pool']
|
|
|
+
|
|
|
+ anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_levels)) # 自动生成不同大小
|
|
|
+ aspect_ratios = ((0.5, 1.0, 2.0),) * num_levels # 所有层共享相同比例
|
|
|
+
|
|
|
+ anchor_generator = AnchorGenerator(
|
|
|
+ sizes=anchor_sizes,
|
|
|
+ aspect_ratios=aspect_ratios
|
|
|
+ )
|
|
|
+
|
|
|
+ # ROI Pooling
|
|
|
+ roi_pooler = MultiScaleRoIAlign(
|
|
|
+ featmap_names=featmap_names,
|
|
|
+ output_size=7,
|
|
|
+ sampling_ratio=2
|
|
|
+ )
|
|
|
+
|
|
|
+ # 构建模型
|
|
|
+ model = LineNet(
|
|
|
+ backbone=backbone,
|
|
|
+ num_classes=num_classes,
|
|
|
+ rpn_anchor_generator=anchor_generator,
|
|
|
+ box_roi_pool=roi_pooler
|
|
|
+ )
|
|
|
+
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
class LineNet_ResNet50_FPN_Weights(WeightsEnum):
|
|
|
COCO_V1 = Weights(
|
|
|
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
|
|
|
@@ -575,6 +652,8 @@ def linenet_resnet50_fpn(
|
|
|
return model
|
|
|
|
|
|
|
|
|
+
|
|
|
+
|
|
|
@register_model()
|
|
|
@handle_legacy_interface(
|
|
|
weights=("pretrained", LineNet_ResNet50_FPN_V2_Weights.COCO_V1),
|