Browse Source

add high resolution ration resnet50fpn

RenLiqiang 6 months ago
parent
commit
13a75ea6b7

+ 1 - 0
libs/vision_libs/models/detection/anchor_utils.py

@@ -85,6 +85,7 @@ class AnchorGenerator(nn.Module):
         anchors = []
         cell_anchors = self.cell_anchors
         torch._assert(cell_anchors is not None, "cell_anchors should not be None")
+        print(f'grid_sizes:{len(grid_sizes)},len(strides):{len(strides)},len(cell_anchors):{len(cell_anchors)}')
         torch._assert(
             len(grid_sizes) == len(strides) == len(cell_anchors),
             "Anchors should be Tuple[Tuple[int]] because each feature "

+ 1 - 1
models/base/backbone_factory.py

@@ -113,7 +113,7 @@ def get_anchor_generator(backbone, test_input):
     features = backbone(test_input)  # 获取 backbone 输出的所有特征图
     featmap_names = list(features.keys())
     print(f'featmap_names:{featmap_names}')
-    num_features = len(features)     # 特征图数量
+    num_features = len(features)    # 特征图数量
     print(f'num_features:{num_features}')
     # num_features=num_features-1
 

+ 266 - 0
models/base/resnet50fpn.py

@@ -0,0 +1,266 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+from torch import Tensor
+from typing import Any, Callable, List, Optional, Type, Union
+from torchvision.models.detection.backbone_utils import BackboneWithFPN
+
+# ----------------------------
+# 工具函数
+# ----------------------------
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+    """3x3 convolution with padding"""
+    return nn.Conv2d(
+        in_planes,
+        out_planes,
+        kernel_size=3,
+        stride=stride,
+        padding=dilation,
+        groups=groups,
+        bias=False,
+        dilation=dilation,
+    )
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+    """1x1 convolution"""
+
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+# ----------------------------
+# Bottleneck Block(你提供的)
+# ----------------------------
+
+class Bottleneck(nn.Module):
+    expansion: int = 4
+
+    def __init__(
+        self,
+        inplanes: int,
+        planes: int,
+        stride: int = 1,
+        downsample: Optional[nn.Module] = None,
+        groups: int = 1,
+        base_width: int = 64,
+        dilation: int = 1,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        width = int(planes * (base_width / 64.0)) * groups
+        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv1x1(inplanes, width)
+        self.bn1 = norm_layer(width)
+        self.conv2 = conv3x3(width, width, stride, groups, dilation)
+        self.bn2 = norm_layer(width)
+        self.conv3 = conv1x1(width, planes * self.expansion)
+        self.bn3 = norm_layer(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x: Tensor) -> Tensor:
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+# ----------------------------
+# ResNet 主类
+# ----------------------------
+
+def resnet50fpn(out_channels=256):
+    backbone = ResNet(Bottleneck)
+    return_layers = {
+        'encoder0': '0',
+        'encoder1': '1',
+        'encoder2': '2',
+        'encoder3': '3',
+        # 'encoder4': '5'
+    }
+
+    # in_channels_list = [self.inplanes, 64, 128, 256, 512]
+    # in_channels_list = [64, 256, 512, 1024, 2048]
+    in_channels_list = [64, 256, 512, 1024]
+
+    return BackboneWithFPN(
+        backbone,
+        return_layers=return_layers,
+        in_channels_list=in_channels_list,
+        out_channels=out_channels,
+    )
+
+
+class ResNet(nn.Module):
+    def __init__(self, block: Type[Union[Bottleneck]],):
+        super(ResNet, self).__init__()
+        self._norm_layer = nn.BatchNorm2d
+        self.inplanes = 64
+        self.dilation = 1
+        self.groups = 1
+        self.base_width = 64
+
+
+        self.encoder0 = nn.Sequential(
+            nn.Conv2d(3, self.inplanes, kernel_size=3,padding=1, bias=False),
+            self._norm_layer(self.inplanes),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
+        )
+        self.encoder1 = self._make_layer(block, 64, 3,stride=2)
+        self.encoder2 = self._make_layer(block, 128, 4, stride=2)
+        self.encoder3 = self._make_layer(block, 256, 6, stride=2)
+        # self.encoder4 = self._make_layer(block, 512, 3, stride=2)
+        # self.encoder5 = self._make_layer(block, 512, 3, stride=2)
+        # self.body = nn.ModuleDict({
+        #     'encoder0': self.encoder0,
+        #     'encoder1': self.encoder1,
+        #     'encoder2': self.encoder2,
+        #     'encoder3': self.encoder3,
+        #     'encoder4': self.encoder4
+        # })
+        # self.fpn = self.get_convnext_fpn(
+        #     backbone=self.body,
+        #     trainable_layers=5,
+        #     returned_layers=[0, 1, 2, 3, 4],
+        #     extra_blocks=None,
+        #     norm_layer=None
+        # )
+
+
+
+
+    def _make_layer(self, block: Type[Union[Bottleneck]], planes: int, blocks: int,
+                    stride: int = 1, dilate: bool = False) -> nn.Sequential:
+        norm_layer = self._norm_layer
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                norm_layer(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(
+            block(
+                self.inplanes, planes, stride, downsample, self.groups, self.base_width,
+                previous_dilation, norm_layer
+            )
+        )
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(
+                block(
+                    self.inplanes,
+                    planes,
+                    groups=self.groups,
+                    base_width=self.base_width,
+                    dilation=self.dilation,
+                    norm_layer=norm_layer,
+                )
+            )
+
+        return nn.Sequential(*layers)
+
+    def _make_decoder_layer(self, block: Type[Union[Bottleneck]], in_channels: int,
+                            out_channels: int, blocks: int = 1) -> nn.Sequential:
+        """
+        构建解码器部分的残差块
+        """
+        assert in_channels == out_channels, "in_channels must equal out_channels"
+        layers = []
+        for _ in range(blocks):
+            layers.append(
+                block(
+                    in_channels,
+                    in_channels // block.expansion,
+                    groups=self.groups,
+                    base_width=self.base_width,
+                    dilation=self.dilation,
+                    norm_layer=self._norm_layer,
+                )
+            )
+        return nn.Sequential(*layers)
+
+    def _make_upsample_layer(self, in_channels: int, out_channels: int) -> nn.Module:
+        """
+        使用转置卷积进行上采样
+        """
+        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
+
+    def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
+
+        # out = self.fpn(x)
+        # print("ssssssss")
+        x0=self.encoder0(x)
+        print(f'x0:{x0.shape}')
+        x1=self.encoder1(x0)
+        print(f'x1:{x1.shape}')
+        x2= self.encoder2(x1)
+        print(f'x2:{x2.shape}')
+        x3= self.encoder3(x2)
+        print(f'x3:{x3.shape}')
+        # x4= self.encoder4(x3)
+        # print(f'x4:{x4.shape}')
+        out={
+            'encoder0':x0,
+            'encoder1': x1,
+            'encoder2': x2,
+            'encoder3': x3,
+            # 'encoder4': x4,
+        }
+        return out
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self._forward_impl(x)
+
+
+
+# ----------------------------
+# 测试代码
+# ----------------------------
+
+if __name__ == "__main__":
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    # model = ResNet(Bottleneck, n_classes=5).to(device)
+    # print(model)
+    model=resnet50fpn().to(device)
+
+
+    input_tensor = torch.randn(1, 3, 512, 512).to(device)
+    output_tensor = model(input_tensor)
+
+    backbone = ResNet(Bottleneck).to(device)
+    features = backbone(input_tensor)
+    print("Raw backbone output:", list(features.keys()))
+    print(f"Input shape: {input_tensor.shape}")
+    print(f'feat_names:{list(output_tensor.keys())}')
+    print(f"Output shape0: {output_tensor['0'].shape}")
+    print(f"Output shape1: {output_tensor['1'].shape}")
+    print(f"Output shape2: {output_tensor['2'].shape}")
+    print(f"Output shape3: {output_tensor['3'].shape}")
+    # print(f"Output shape4: {output_tensor['5'].shape}")
+    print(f"Output shape5: {output_tensor['pool'].shape}")

+ 56 - 1
models/line_detect/line_net.py

@@ -34,6 +34,7 @@ from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
 
 from .predict import Predict1, Predict
+from ..base.resnet50fpn import resnet50fpn
 
 from ..config.config_tool import read_yaml
 
@@ -174,7 +175,7 @@ class LineNet(BaseDetectionNet):
         )
 
         if box_roi_pool is None:
-            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
+            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3","4"], output_size=7, sampling_ratio=2)
 
         if box_head is None:
             resolution = box_roi_pool.output_size[0]
@@ -527,6 +528,60 @@ class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
     DEFAULT = COCO_V1
 
 
+def linenet_newresnet50fpn(
+        *,
+        weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
+        progress: bool = True,
+        num_classes: Optional[int] = None,
+        weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1,
+        trainable_backbone_layers: Optional[int] = None,
+        **kwargs: Any,
+) -> LineNet:
+    # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
+    # weights_backbone = ResNet50_Weights.verify(weights_backbone)
+
+    if weights is not None:
+        weights_backbone = None
+        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
+    elif num_classes is None:
+        num_classes = 91
+    if weights_backbone is not None:
+        print(f'resnet50 weights is not None')
+
+    is_trained = weights is not None or weights_backbone is not None
+    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
+    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
+
+    backbone =resnet50fpn()
+    featmap_names=['0', '1', '2', '3','pool']
+    print(f'featmap_names:{featmap_names}')
+    roi_pooler = MultiScaleRoIAlign(
+        featmap_names=featmap_names,
+        output_size=7,
+        sampling_ratio=2
+    )
+    num_features=len(featmap_names)
+    anchor_sizes = tuple((int(16 * 2 ** i),) for i in range(num_features))  # 自动生成不同大小
+    print(f'anchor_sizes:{anchor_sizes}')
+    aspect_ratios = ((0.5, 1.0, 2.0),) * num_features
+    print(f'aspect_ratios:{aspect_ratios}')
+
+
+    anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
+    # anchors = anchor_generator.generate_anchors()
+    print("Number of anchor sizes:", len(anchor_generator.sizes))  # 应为 5
+    model = LineNet(backbone, num_classes=num_classes,anchor_generator=anchor_generator,
+
+                    box_roi_pool=roi_pooler,
+                    **kwargs)
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+        if weights == LineNet_ResNet50_FPN_Weights.COCO_V1:
+            overwrite_eps(model, 0.0)
+
+    return model
+
 # @register_model()
 # @handle_legacy_interface(
 #     weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),

+ 1 - 1
models/line_detect/line_predictor.py

@@ -322,7 +322,7 @@ class LineRCNNPredictor(nn.Module):
             # for t in range(n_type):
             #     match[t, jtyp[match[t]] != t] = N
 
-            match[cost > 1.5 * 1.5] = N
+            match[cost > 4 * 4] = N
             match = match.flatten()
 
             _ = torch.arange(n_type * K, device=device)

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: \\192.168.50.222/share/zyh/513train/a_dataset
+  datadir: \\192.168.50.222/share/lm/Dataset_all
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 

+ 4 - 2
models/line_detect/train_demo.py

@@ -1,6 +1,7 @@
 import torch
 
-from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn
+from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, linenet_newresnet50fpn, \
+    get_line_net_convnext_fpn
 from models.line_detect.trainer import Trainer
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -8,8 +9,9 @@ if __name__ == '__main__':
 
     # model = LineNet('line_net.yaml')
     # model=linenet_resnet50_fpn()
+    # model = linenet_resnet18_fpn()
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
-    model=linenet_resnet18_fpn()
+    model=linenet_newresnet50fpn()
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')
     model.start_train(cfg='train.yaml')