import os from typing import Any, Callable, List, Optional, Tuple import torch from torch import nn from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator from libs.vision_libs.models.detection.rpn import RPNHead, RegionProposalNetwork from libs.vision_libs.models.detection.transform import GeneralizedRCNNTransform from libs.vision_libs.ops import misc as misc_nn_ops, MultiScaleRoIAlign from libs.vision_libs.models.detection.backbone_utils import BackboneWithFPN, resnet_fpn_backbone from libs.vision_libs.models.detection.faster_rcnn import TwoMLPHead from models.line_detect.heads.arc.arc_heads import ArcHeads from models.line_detect.heads.circle.circle_heads import CircleHeads, CirclePredictor from .heads.decoder import FPNDecoder from models.line_detect.heads.line.line_heads import LinePredictor from models.line_detect.heads.point.point_heads import PointHeads, PointPredictor from .loi_heads import RoIHeads from .trainer import Trainer from ..base.backbone_factory import get_anchor_generator, MaxVitBackbone, \ get_swin_transformer_fpn, get_efficientnetv2_fpn # from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator from ..base.base_detection_net import BaseDetectionNet import torch.nn.functional as F from ..base.high_reso_maxvit import maxvit_with_fpn from ..base.high_reso_resnet import resnet50fpn, resnet18fpn, resnet101fpn, Bottleneck __all__ = [ "LineDetect", "linedetect_resnet50_fpn", ] from ..line_net.line_detect import LineHeads def _default_anchorgen(): anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) return AnchorGenerator(anchor_sizes, aspect_ratios) class LineDetect(BaseDetectionNet): def __init__( self, backbone, num_classes=3, # transform parameters min_size=512, max_size=512, image_mean=None, image_std=None, # RPN parameters rpn_anchor_generator=None, rpn_head=None, rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, rpn_nms_thresh=0.7, rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, rpn_score_thresh=0.0, # Box parameters box_roi_pool=None, box_head=None, box_predictor=None, box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=200, box_fg_iou_thresh=0.7, box_bg_iou_thresh=0.3, box_batch_size_per_image=512, box_positive_fraction=0.25, bbox_reg_weights=None, # line parameters line_roi_pool=None, line_head=None, line_predictor=None, # point parameters point_roi_pool=None, point_head=None, point_predictor=None, circle_head=None, circle_predictor=None, circle_roi_pool=None, # arc parameters arc_roi_pool=None, arc_head=None, arc_predictor=None, num_points=4, detect_point=False, detect_line=False, detect_arc=True, detect_circle=False, **kwargs, ): out_channels = backbone.out_channels if rpn_anchor_generator is None: rpn_anchor_generator = _default_anchorgen() if rpn_head is None: rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0]) rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) rpn = RegionProposalNetwork( rpn_anchor_generator, rpn_head, rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_batch_size_per_image, rpn_positive_fraction, rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh, score_thresh=rpn_score_thresh, ) if box_roi_pool is None: box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2) if box_head is None: resolution = box_roi_pool.output_size[0] representation_size = 1024 box_head = TwoMLPHead(out_channels * resolution**2, representation_size) if box_predictor is None: representation_size = 1024 box_predictor = ObjectionPredictor(representation_size, num_classes) roi_heads = RoIHeads( # Box box_roi_pool, box_head, box_predictor, box_fg_iou_thresh, box_bg_iou_thresh, box_batch_size_per_image, box_positive_fraction, bbox_reg_weights, box_score_thresh, box_nms_thresh, box_detections_per_img, detect_point=detect_point, detect_line=detect_line, detect_arc=detect_arc, detect_circle=detect_circle, ) if image_mean is None: image_mean = [0.485, 0.456, 0.406] if image_std is None: image_std = [0.229, 0.224, 0.225] transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs) super().__init__(backbone, rpn, roi_heads, transform) if line_head is None and detect_line: layers = tuple(num_points for _ in range(8)) line_head = LineHeads(8, layers) if line_predictor is None and detect_line: # keypoint_dim_reduced = 512 # == keypoint_layers[-1] line_predictor = LinePredictor(in_channels=256) if point_head is None and detect_point: layers = tuple(num_points for _ in range(8)) point_head = PointHeads(8, layers) if point_predictor is None and detect_point: # keypoint_dim_reduced = 512 # == keypoint_layers[-1] point_predictor = PointPredictor(in_channels=256) if detect_arc and arc_head is None: layers = tuple(num_points for _ in range(8)) arc_head=ArcHeads(8,layers) if detect_arc and arc_predictor is None: layers = tuple(num_points for _ in range(8)) # arc_predictor=ArcPredictor(in_channels=256,out_channels=1) arc_predictor=FPNDecoder(Bottleneck) if detect_circle and circle_head is None: layers = tuple(num_points for _ in range(8)) circle_head = CircleHeads(8, layers) if detect_circle and circle_predictor is None: layers = tuple(num_points for _ in range(8)) # arc_predictor=ArcPredictor(in_channels=256,out_channels=1) # circle_predictor = CirclePredictor(in_channels=256,out_channels=4) circle_predictor=FPNDecoder(Bottleneck) self.roi_heads.line_roi_pool = line_roi_pool self.roi_heads.line_head = line_head self.roi_heads.line_predictor = line_predictor self.roi_heads.point_roi_pool = point_roi_pool self.roi_heads.point_head = point_head self.roi_heads.point_predictor = point_predictor self.roi_heads.arc_roi_pool = arc_roi_pool self.roi_heads.arc_head = arc_head self.roi_heads.arc_predictor = arc_predictor self.roi_heads.circle_roi_pool = circle_roi_pool self.roi_heads.circle_head = circle_head self.roi_heads.circle_predictor = circle_predictor def start_train(self, cfg): # cfg = read_yaml(cfg) self.trainer = Trainer() self.trainer.train_from_cfg(model=self, cfg=cfg) def load_weights(self, save_path, device='cuda'): if os.path.exists(save_path): checkpoint = torch.load(save_path, map_location=device) self.load_state_dict(checkpoint['model_state_dict']) # if optimizer is not None: # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # epoch = checkpoint['epoch'] # loss = checkpoint['loss'] # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}") print(f"Loaded model from {save_path}") else: print(f"No saved model found at {save_path}") return self class TwoMLPHead(nn.Module): """ Standard heads for FPN-based models Args: in_channels (int): number of input channels representation_size (int): size of the intermediate representation """ def __init__(self, in_channels, representation_size): super().__init__() self.fc6 = nn.Linear(in_channels, representation_size) self.fc7 = nn.Linear(representation_size, representation_size) def forward(self, x): x = x.flatten(start_dim=1) x = F.relu(self.fc6(x)) x = F.relu(self.fc7(x)) return x class ObjectionConvFCHead(nn.Sequential): def __init__( self, input_size: Tuple[int, int, int], conv_layers: List[int], fc_layers: List[int], norm_layer: Optional[Callable[..., nn.Module]] = None, ): """ Args: input_size (Tuple[int, int, int]): the input size in CHW format. conv_layers (list): feature dimensions of each Convolution layer fc_layers (list): feature dimensions of each FCN layer norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None """ in_channels, in_height, in_width = input_size blocks = [] previous_channels = in_channels for current_channels in conv_layers: blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer)) previous_channels = current_channels blocks.append(nn.Flatten()) previous_channels = previous_channels * in_height * in_width for current_channels in fc_layers: blocks.append(nn.Linear(previous_channels, current_channels)) blocks.append(nn.ReLU(inplace=True)) previous_channels = current_channels super().__init__(*blocks) for layer in self.modules(): if isinstance(layer, nn.Conv2d): nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu") if layer.bias is not None: nn.init.zeros_(layer.bias) class ObjectionPredictor(nn.Module): """ Standard classification + bounding box regression layers for Fast R-CNN. Args: in_channels (int): number of input channels num_classes (int): number of output classes (including background) """ def __init__(self, in_channels, num_classes): super().__init__() self.cls_score = nn.Linear(in_channels, num_classes) self.bbox_pred = nn.Linear(in_channels, num_classes * 4) def forward(self, x): if x.dim() == 4: torch._assert( list(x.shape[2:]) == [1, 1], f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}", ) x = x.flatten(start_dim=1) scores = self.cls_score(x) bbox_deltas = self.bbox_pred(x) return scores, bbox_deltas def linedetect_newresnet18fpn( *, num_classes: Optional[int] = None, num_points:Optional[int] = None, **kwargs: Any, ) -> LineDetect: # weights = LineNet_ResNet50_FPN_Weights.verify(weights) # weights_backbone = ResNet50_Weights.verify(weights_backbone) if num_classes is None: num_classes = 5 if num_points is None: num_points = 4 size=768 backbone =resnet18fpn() featmap_names=['0', '1', '2', '3','4','pool'] # print(f'featmap_names:{featmap_names}') roi_pooler = MultiScaleRoIAlign( featmap_names=featmap_names, output_size=7, sampling_ratio=2 ) num_features=len(featmap_names) anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features)) # 自动生成不同大小 # print(f'anchor_sizes:{anchor_sizes}') aspect_ratios = ((0.5, 1.0, 2.0),) * num_features # print(f'aspect_ratios:{aspect_ratios}') anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios) model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, detect_point=False, detect_line=False, detect_arc=False, detect_circle=True, **kwargs) return model def linedetect_newresnet50fpn( *, num_classes: Optional[int] = None, num_points:Optional[int] = None, **kwargs: Any, ) -> LineDetect: # weights = LineNet_ResNet50_FPN_Weights.verify(weights) # weights_backbone = ResNet50_Weights.verify(weights_backbone) if num_classes is None: num_classes = 5 if num_points is None: num_points = 4 size=768 backbone =resnet50fpn(out_channels=256) featmap_names=['0', '1', '2', '3','4','pool'] # print(f'featmap_names:{featmap_names}') roi_pooler = MultiScaleRoIAlign( featmap_names=featmap_names, output_size=7, sampling_ratio=2 ) num_features=len(featmap_names) anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features)) # 自动生成不同大小 # print(f'anchor_sizes:{anchor_sizes}') aspect_ratios = ((0.5, 1.0, 2.0),) * num_features # print(f'aspect_ratios:{aspect_ratios}') anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios) model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, detect_point=False, detect_line=False, detect_arc=False, detect_circle=True, **kwargs) return model def linedetect_newresnet101fpn( *, num_classes: Optional[int] = None, num_points:Optional[int] = None, **kwargs: Any, ) -> LineDetect: # weights = LineNet_ResNet50_FPN_Weights.verify(weights) # weights_backbone = ResNet50_Weights.verify(weights_backbone) if num_classes is None: num_classes = 5 if num_points is None: num_points = 3 size=768 backbone =resnet101fpn(out_channels=256) featmap_names=['0', '1', '2', '3','4','pool'] # print(f'featmap_names:{featmap_names}') roi_pooler = MultiScaleRoIAlign( featmap_names=featmap_names, output_size=7, sampling_ratio=2 ) num_features=len(featmap_names) anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features)) # 自动生成不同大小 # print(f'anchor_sizes:{anchor_sizes}') aspect_ratios = ((0.5, 1.0, 2.0),) * num_features # print(f'aspect_ratios:{aspect_ratios}') anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios) model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, detect_point=False, detect_line=False, detect_arc=False, detect_circle=True, **kwargs) return model def linedetect_newresnet152fpn( *, num_classes: Optional[int] = None, num_points:Optional[int] = None, **kwargs: Any, ) -> LineDetect: # weights = LineNet_ResNet50_FPN_Weights.verify(weights) # weights_backbone = ResNet50_Weights.verify(weights_backbone) if num_classes is None: num_classes = 5 if num_points is None: num_points = 3 size=768 backbone =resnet101fpn(out_channels=256) featmap_names=['0', '1', '2', '3','4','pool'] # print(f'featmap_names:{featmap_names}') roi_pooler = MultiScaleRoIAlign( featmap_names=featmap_names, output_size=7, sampling_ratio=2 ) num_features=len(featmap_names) anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features)) # 自动生成不同大小 # print(f'anchor_sizes:{anchor_sizes}') aspect_ratios = ((0.5, 1.0, 2.0),) * num_features # print(f'aspect_ratios:{aspect_ratios}') anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios) model = LineDetect(backbone, num_classes,min_size=size,max_size=size, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, detect_point=False, detect_line=False, detect_arc=False, detect_circle=True, **kwargs) return model def linedetect_efficientnet( *, num_classes: Optional[int] = None, num_points:Optional[int] = None, name: Optional[str] = 'efficientnet_v2_l', **kwargs: Any, ) -> LineDetect: # weights = LineNet_ResNet50_FPN_Weights.verify(weights) # weights_backbone = ResNet50_Weights.verify(weights_backbone) if num_classes is None: num_classes = 5 if num_points is None: num_points = 3 size=224*3 featmap_names = ['0', '1', '2', '3', '4', 'pool'] roi_pooler = MultiScaleRoIAlign( featmap_names=featmap_names, output_size=7, sampling_ratio=2 ) backbone_with_fpn=get_efficientnetv2_fpn(name=name) test_input = torch.randn(1, 3,size,size) model = LineDetect( backbone=backbone_with_fpn, min_size=size, max_size=size, num_classes=num_classes, # COCO 数据集有 91 ç±» rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input), box_roi_pool=roi_pooler, detect_line=False, detect_point=False, detect_arc=False, detect_circle=True, ) return model def linedetect_maxvitfpn( *, num_classes: Optional[int] = None, num_points:Optional[int] = None, **kwargs: Any, ) -> LineDetect: # weights = LineNet_ResNet50_FPN_Weights.verify(weights) # weights_backbone = ResNet50_Weights.verify(weights_backbone) if num_classes is None: num_classes = 5 if num_points is None: num_points = 3 size=224*3 maxvit = MaxVitBackbone(input_size=(size,size)) # print(maxvit.named_children()) # for i,layer in enumerate(maxvit.named_children()): # print(f'layer:{i}:{layer}') in_channels_list = [64, 64, 128, 256, 512] featmap_names = ['0', '1', '2', '3', '4', 'pool'] roi_pooler = MultiScaleRoIAlign( featmap_names=featmap_names, output_size=7, sampling_ratio=2 ) backbone_with_fpn = BackboneWithFPN( maxvit, return_layers={'stem': '0', 'block0': '1', 'block1': '2', 'block2': '3', 'block3': '4'}, # 确保这些键对应到实际的层 in_channels_list=in_channels_list, out_channels=256 ) test_input = torch.randn(1, 3,size,size) model = LineDetect( backbone=backbone_with_fpn, min_size=size, max_size=size, num_classes=num_classes, # COCO 数据集有 91 ç±» rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input), box_roi_pool=roi_pooler, detect_line=False, detect_point=False, detect_arc=False, detect_circle=True, ) return model def linedetect_high_maxvitfpn( *, num_classes: Optional[int] = None, num_points:Optional[int] = None, **kwargs: Any, ) -> LineDetect: # weights = LineNet_ResNet50_FPN_Weights.verify(weights) # weights_backbone = ResNet50_Weights.verify(weights_backbone) if num_classes is None: num_classes = 5 if num_points is None: num_points = 3 size=224*2 maxvitfpn =maxvit_with_fpn(size=size) # print(maxvit.named_children()) # for i,layer in enumerate(maxvit.named_children()): # print(f'layer:{i}:{layer}') in_channels_list = [64,64, 64, 128, 256, 512] featmap_names = ['0', '1', '2', '3', '4', '5','pool'] roi_pooler = MultiScaleRoIAlign( featmap_names=featmap_names, output_size=7, sampling_ratio=2 ) test_input = torch.randn(1, 3,size,size) model = LineDetect( backbone=maxvitfpn, num_classes=num_classes, min_size=size, max_size=size, rpn_anchor_generator=get_anchor_generator(maxvitfpn, test_input=test_input), box_roi_pool=roi_pooler ) return model def linedetect_swin_transformer_fpn( *, num_classes: Optional[int] = None, num_points:Optional[int] = None, type='t', **kwargs: Any, ) -> LineDetect: # weights = LineNet_ResNet50_FPN_Weights.verify(weights) # weights_backbone = ResNet50_Weights.verify(weights_backbone) if num_classes is None: num_classes = 3 if num_points is None: num_points = 3 size=512 backbone_with_fpn, roi_pooler, anchor_generator=get_swin_transformer_fpn(type=type) # test_input = torch.randn(1, 3,size,size) model = LineDetect( backbone=backbone_with_fpn, min_size=size, max_size=size, num_classes=3, # COCO 数据集有 91 ç±» rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, detect_line=False, detect_point=False, ) return model def linedetect_resnet18_fpn( *, num_classes: Optional[int] = None, num_points: Optional[int] = None, **kwargs: Any, ) -> LineDetect: if num_classes is None: num_classes = 4 if num_points is None: num_points = 3 size=1024 backbone = resnet_fpn_backbone(backbone_name='resnet18',weights=None) model = LineDetect(backbone,min_size=size,max_size=size , num_classes=num_classes, num_points=num_points, **kwargs) return model def linedetect_resnet50_fpn( *, num_classes: Optional[int] = None, num_points: Optional[int] = None, **kwargs: Any, ) -> LineDetect: if num_classes is None: num_classes = 3 if num_points is None: num_points = 3 backbone = resnet_fpn_backbone(backbone_name='resnet18', weights=None) model = LineDetect(backbone, num_classes, num_points=num_points, **kwargs) return model