Browse Source

重构相关head

RenLiqiang 4 months ago
parent
commit
fbcba61f02

+ 0 - 0
models/line_detect/heads/arc_heads.py → models/line_detect/heads/arc/arc_heads.py


+ 0 - 0
models/line_detect/heads/circle_heads.py → models/line_detect/heads/circle/circle_heads.py


+ 0 - 0
models/line_detect/heads/line_heads.py → models/line_detect/heads/line/line_heads.py


+ 0 - 0
models/line_detect/heads/point_heads.py → models/line_detect/heads/point/point_heads.py


+ 8 - 22
models/line_detect/line_detect.py

@@ -1,37 +1,23 @@
 import os
-from typing import Any, Callable, List, Optional, Tuple, Union
+from typing import Any, Callable, List, Optional, Tuple
 import torch
 from torch import nn
 
-
-from libs.vision_libs import ops
-from libs.vision_libs.models import MobileNet_V3_Large_Weights, mobilenet_v3_large, EfficientNet_V2_S_Weights, \
-    efficientnet_v2_s, detection, EfficientNet_V2_L_Weights, efficientnet_v2_l, EfficientNet_V2_M_Weights, \
-    efficientnet_v2_m
 from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
 from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork
-from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
 from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform
 from libs.vision_libs.ops import misc as misc_nn_ops, MultiScaleRoIAlign
-from libs.vision_libs.transforms._presets import ObjectDetection
-from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
-from libs.vision_libs.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES, _COCO_CATEGORIES
-from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
-from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, ResNet18_Weights, resnet18
-from libs.vision_libs.models.detection._utils import overwrite_eps
-from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, \
-    BackboneWithFPN, resnet_fpn_backbone
-from libs.vision_libs.models.detection.faster_rcnn import FasterRCNN, TwoMLPHead, FastRCNNPredictor
-from .heads.arc_heads import ArcHeads, ArcPredictor
-from .heads.circle_heads import CircleHeads, CirclePredictor
+from libs.vision_libs.models.detection.backbone_utils import BackboneWithFPN, resnet_fpn_backbone
+from libs.vision_libs.models.detection.faster_rcnn import TwoMLPHead
+from models.line_detect.heads.arc.arc_heads import ArcHeads
+from models.line_detect.heads.circle.circle_heads import CircleHeads, CirclePredictor
 from .heads.decoder import FPNDecoder
-from .heads.line_heads import LinePredictor
-from .heads.point_heads import PointHeads, PointPredictor
+from models.line_detect.heads.line.line_heads import LinePredictor
+from models.line_detect.heads.point.point_heads import PointHeads, PointPredictor
 from .loi_heads import RoIHeads
 
 from .trainer import Trainer
-from ..base import backbone_factory
-from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator, get_maxvit_fpn, MaxVitBackbone, \
+from ..base.backbone_factory import get_anchor_generator, MaxVitBackbone, \
     get_swin_transformer_fpn
 # from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator
 from ..base.base_detection_net import BaseDetectionNet