فهرست منبع

add high maxvit

lstrlq 5 ماه پیش
والد
کامیت
c5419d7bb2
3فایلهای تغییر یافته به همراه988 افزوده شده و 6 حذف شده
  1. 942 0
      models/base/high_reso_maxvit.py
  2. 43 4
      models/line_detect/line_detect.py
  3. 3 2
      models/line_detect/train_demo.py

+ 942 - 0
models/base/high_reso_maxvit.py

@@ -0,0 +1,942 @@
+import math
+from collections import OrderedDict
+from functools import partial
+from typing import Any, Callable, List, Optional, Sequence, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+from torchvision.models._api import register_model, Weights, WeightsEnum
+from torchvision.models._meta import _IMAGENET_CATEGORIES
+from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface
+from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation
+from torchvision.ops.stochastic_depth import StochasticDepth
+from torchvision.transforms._presets import ImageClassification, InterpolationMode
+from torchvision.utils import _log_api_usage_once
+
+__all__ = [
+    "MaxVit",
+    "MaxVit_T_Weights",
+    "maxvit_t",
+]
+
+from libs.vision_libs.models.detection.backbone_utils import BackboneWithFPN
+
+
+def _get_conv_output_shape(input_size: Tuple[int, int], kernel_size: int, stride: int, padding: int) -> Tuple[int, int]:
+    return (
+        (input_size[0] - kernel_size + 2 * padding) // stride + 1,
+        (input_size[1] - kernel_size + 2 * padding) // stride + 1,
+    )
+
+
+def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List[Tuple[int, int]]:
+    """Util function to check that the input size is correct for a MaxVit configuration."""
+    shapes = []
+    block_input_shape = _get_conv_output_shape(input_size, 3, 2, 1)
+    for _ in range(n_blocks):
+        block_input_shape = _get_conv_output_shape(block_input_shape, 3, 2, 1)
+        shapes.append(block_input_shape)
+    return shapes
+
+
+def _get_relative_position_index(height: int, width: int) -> torch.Tensor:
+    coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)]))
+    coords_flat = torch.flatten(coords, 1)
+    relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]
+    relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+    relative_coords[:, :, 0] += height - 1
+    relative_coords[:, :, 1] += width - 1
+    relative_coords[:, :, 0] *= 2 * width - 1
+    return relative_coords.sum(-1)
+
+
+class MBConv(nn.Module):
+    """MBConv: Mobile Inverted Residual Bottleneck.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        expansion_ratio (float): Expansion ratio in the bottleneck.
+        squeeze_ratio (float): Squeeze ratio in the SE Layer.
+        stride (int): Stride of the depthwise convolution.
+        activation_layer (Callable[..., nn.Module]): Activation function.
+        norm_layer (Callable[..., nn.Module]): Normalization function.
+        p_stochastic_dropout (float): Probability of stochastic depth.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        expansion_ratio: float,
+        squeeze_ratio: float,
+        stride: int,
+        activation_layer: Callable[..., nn.Module],
+        norm_layer: Callable[..., nn.Module],
+        p_stochastic_dropout: float = 0.0,
+    ) -> None:
+        super().__init__()
+
+        proj: Sequence[nn.Module]
+        self.proj: nn.Module
+
+        should_proj = stride != 1 or in_channels != out_channels
+        if should_proj:
+            proj = [nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=True)]
+            if stride == 2:
+                proj = [nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)] + proj  # type: ignore
+            self.proj = nn.Sequential(*proj)
+        else:
+            self.proj = nn.Identity()  # type: ignore
+
+        mid_channels = int(out_channels * expansion_ratio)
+        sqz_channels = int(out_channels * squeeze_ratio)
+
+        if p_stochastic_dropout:
+            self.stochastic_depth = StochasticDepth(p_stochastic_dropout, mode="row")  # type: ignore
+        else:
+            self.stochastic_depth = nn.Identity()  # type: ignore
+
+        _layers = OrderedDict()
+        _layers["pre_norm"] = norm_layer(in_channels)
+        _layers["conv_a"] = Conv2dNormActivation(
+            in_channels,
+            mid_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            activation_layer=activation_layer,
+            norm_layer=norm_layer,
+            inplace=None,
+        )
+        _layers["conv_b"] = Conv2dNormActivation(
+            mid_channels,
+            mid_channels,
+            kernel_size=3,
+            stride=stride,
+            padding=1,
+            activation_layer=activation_layer,
+            norm_layer=norm_layer,
+            groups=mid_channels,
+            inplace=None,
+        )
+        _layers["squeeze_excitation"] = SqueezeExcitation(mid_channels, sqz_channels, activation=nn.SiLU)
+        _layers["conv_c"] = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, bias=True)
+
+        self.layers = nn.Sequential(_layers)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor with expected layout of [B, C, H, W].
+        Returns:
+            Tensor: Output tensor with expected layout of [B, C, H / stride, W / stride].
+        """
+        res = self.proj(x)
+        x = self.stochastic_depth(self.layers(x))
+        return res + x
+
+
+class RelativePositionalMultiHeadAttention(nn.Module):
+    """Relative Positional Multi-Head Attention.
+
+    Args:
+        feat_dim (int): Number of input features.
+        head_dim (int): Number of features per head.
+        max_seq_len (int): Maximum sequence length.
+    """
+
+    def __init__(
+        self,
+        feat_dim: int,
+        head_dim: int,
+        max_seq_len: int,
+    ) -> None:
+        super().__init__()
+
+        if feat_dim % head_dim != 0:
+            raise ValueError(f"feat_dim: {feat_dim} must be divisible by head_dim: {head_dim}")
+
+        self.n_heads = feat_dim // head_dim
+        self.head_dim = head_dim
+        self.size = int(math.sqrt(max_seq_len))
+        self.max_seq_len = max_seq_len
+
+        self.to_qkv = nn.Linear(feat_dim, self.n_heads * self.head_dim * 3)
+        self.scale_factor = feat_dim**-0.5
+
+        self.merge = nn.Linear(self.head_dim * self.n_heads, feat_dim)
+        self.relative_position_bias_table = nn.parameter.Parameter(
+            torch.empty(((2 * self.size - 1) * (2 * self.size - 1), self.n_heads), dtype=torch.float32),
+        )
+
+        self.register_buffer("relative_position_index", _get_relative_position_index(self.size, self.size))
+        # initialize with truncated normal the bias
+        torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
+
+    def get_relative_positional_bias(self) -> torch.Tensor:
+        bias_index = self.relative_position_index.view(-1)  # type: ignore
+        relative_bias = self.relative_position_bias_table[bias_index].view(self.max_seq_len, self.max_seq_len, -1)  # type: ignore
+        relative_bias = relative_bias.permute(2, 0, 1).contiguous()
+        return relative_bias.unsqueeze(0)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor with expected layout of [B, G, P, D].
+        Returns:
+            Tensor: Output tensor with expected layout of [B, G, P, D].
+        """
+        B, G, P, D = x.shape
+        H, DH = self.n_heads, self.head_dim
+
+        qkv = self.to_qkv(x)
+        q, k, v = torch.chunk(qkv, 3, dim=-1)
+
+        q = q.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
+        k = k.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
+        v = v.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
+
+        k = k * self.scale_factor
+        dot_prod = torch.einsum("B G H I D, B G H J D -> B G H I J", q, k)
+        pos_bias = self.get_relative_positional_bias()
+
+        dot_prod = F.softmax(dot_prod + pos_bias, dim=-1)
+
+        out = torch.einsum("B G H I J, B G H J D -> B G H I D", dot_prod, v)
+        out = out.permute(0, 1, 3, 2, 4).reshape(B, G, P, D)
+
+        out = self.merge(out)
+        return out
+
+
+class SwapAxes(nn.Module):
+    """Permute the axes of a tensor."""
+
+    def __init__(self, a: int, b: int) -> None:
+        super().__init__()
+        self.a = a
+        self.b = b
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        res = torch.swapaxes(x, self.a, self.b)
+        return res
+
+
+class WindowPartition(nn.Module):
+    """
+    Partition the input tensor into non-overlapping windows.
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+
+    def forward(self, x: Tensor, p: int) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor with expected layout of [B, C, H, W].
+            p (int): Number of partitions.
+        Returns:
+            Tensor: Output tensor with expected layout of [B, H/P, W/P, P*P, C].
+        """
+        B, C, H, W = x.shape
+        P = p
+        # chunk up H and W dimensions
+        x = x.reshape(B, C, H // P, P, W // P, P)
+        x = x.permute(0, 2, 4, 3, 5, 1)
+        # colapse P * P dimension
+        x = x.reshape(B, (H // P) * (W // P), P * P, C)
+        return x
+
+
+class WindowDepartition(nn.Module):
+    """
+    Departition the input tensor of non-overlapping windows into a feature volume of layout [B, C, H, W].
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+
+    def forward(self, x: Tensor, p: int, h_partitions: int, w_partitions: int) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor with expected layout of [B, (H/P * W/P), P*P, C].
+            p (int): Number of partitions.
+            h_partitions (int): Number of vertical partitions.
+            w_partitions (int): Number of horizontal partitions.
+        Returns:
+            Tensor: Output tensor with expected layout of [B, C, H, W].
+        """
+        B, G, PP, C = x.shape
+        P = p
+        HP, WP = h_partitions, w_partitions
+        # split P * P dimension into 2 P tile dimensionsa
+        x = x.reshape(B, HP, WP, P, P, C)
+        # permute into B, C, HP, P, WP, P
+        x = x.permute(0, 5, 1, 3, 2, 4)
+        # reshape into B, C, H, W
+        x = x.reshape(B, C, HP * P, WP * P)
+        return x
+
+
+class PartitionAttentionLayer(nn.Module):
+    """
+    Layer for partitioning the input tensor into non-overlapping windows and applying attention to each window.
+
+    Args:
+        in_channels (int): Number of input channels.
+        head_dim (int): Dimension of each attention head.
+        partition_size (int): Size of the partitions.
+        partition_type (str): Type of partitioning to use. Can be either "grid" or "window".
+        grid_size (Tuple[int, int]): Size of the grid to partition the input tensor into.
+        mlp_ratio (int): Ratio of the  feature size expansion in the MLP layer.
+        activation_layer (Callable[..., nn.Module]): Activation function to use.
+        norm_layer (Callable[..., nn.Module]): Normalization function to use.
+        attention_dropout (float): Dropout probability for the attention layer.
+        mlp_dropout (float): Dropout probability for the MLP layer.
+        p_stochastic_dropout (float): Probability of dropping out a partition.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        head_dim: int,
+        # partitioning parameters
+        partition_size: int,
+        partition_type: str,
+        # grid size needs to be known at initialization time
+        # because we need to know hamy relative offsets there are in the grid
+        grid_size: Tuple[int, int],
+        mlp_ratio: int,
+        activation_layer: Callable[..., nn.Module],
+        norm_layer: Callable[..., nn.Module],
+        attention_dropout: float,
+        mlp_dropout: float,
+        p_stochastic_dropout: float,
+    ) -> None:
+        super().__init__()
+
+        self.n_heads = in_channels // head_dim
+        self.head_dim = head_dim
+        self.n_partitions = grid_size[0] // partition_size
+        self.partition_type = partition_type
+        self.grid_size = grid_size
+
+        if partition_type not in ["grid", "window"]:
+            raise ValueError("partition_type must be either 'grid' or 'window'")
+
+        if partition_type == "window":
+            self.p, self.g = partition_size, self.n_partitions
+        else:
+            self.p, self.g = self.n_partitions, partition_size
+
+        self.partition_op = WindowPartition()
+        self.departition_op = WindowDepartition()
+        self.partition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity()
+        self.departition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity()
+
+        self.attn_layer = nn.Sequential(
+            norm_layer(in_channels),
+            # it's always going to be partition_size ** 2 because
+            # of the axis swap in the case of grid partitioning
+            RelativePositionalMultiHeadAttention(in_channels, head_dim, partition_size**2),
+            nn.Dropout(attention_dropout),
+        )
+
+        # pre-normalization similar to transformer layers
+        self.mlp_layer = nn.Sequential(
+            nn.LayerNorm(in_channels),
+            nn.Linear(in_channels, in_channels * mlp_ratio),
+            activation_layer(),
+            nn.Linear(in_channels * mlp_ratio, in_channels),
+            nn.Dropout(mlp_dropout),
+        )
+
+        # layer scale factors
+        self.stochastic_dropout = StochasticDepth(p_stochastic_dropout, mode="row")
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor with expected layout of [B, C, H, W].
+        Returns:
+            Tensor: Output tensor with expected layout of [B, C, H, W].
+        """
+
+        # Undefined behavior if H or W are not divisible by p
+        # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766
+        gh, gw = self.grid_size[0] // self.p, self.grid_size[1] // self.p
+        torch._assert(
+            self.grid_size[0] % self.p == 0 and self.grid_size[1] % self.p == 0,
+            "Grid size must be divisible by partition size. Got grid size of {} and partition size of {}".format(
+                self.grid_size, self.p
+            ),
+        )
+
+        x = self.partition_op(x, self.p)
+        x = self.partition_swap(x)
+        x = x + self.stochastic_dropout(self.attn_layer(x))
+        x = x + self.stochastic_dropout(self.mlp_layer(x))
+        x = self.departition_swap(x)
+        x = self.departition_op(x, self.p, gh, gw)
+
+        return x
+
+
+class MaxVitLayer(nn.Module):
+    """
+    MaxVit layer consisting of a MBConv layer followed by a PartitionAttentionLayer with `window` and a PartitionAttentionLayer with `grid`.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        expansion_ratio (float): Expansion ratio in the bottleneck.
+        squeeze_ratio (float): Squeeze ratio in the SE Layer.
+        stride (int): Stride of the depthwise convolution.
+        activation_layer (Callable[..., nn.Module]): Activation function.
+        norm_layer (Callable[..., nn.Module]): Normalization function.
+        head_dim (int): Dimension of the attention heads.
+        mlp_ratio (int): Ratio of the MLP layer.
+        mlp_dropout (float): Dropout probability for the MLP layer.
+        attention_dropout (float): Dropout probability for the attention layer.
+        p_stochastic_dropout (float): Probability of stochastic depth.
+        partition_size (int): Size of the partitions.
+        grid_size (Tuple[int, int]): Size of the input feature grid.
+    """
+
+    def __init__(
+        self,
+        # conv parameters
+        in_channels: int,
+        out_channels: int,
+        squeeze_ratio: float,
+        expansion_ratio: float,
+        stride: int,
+        # conv + transformer parameters
+        norm_layer: Callable[..., nn.Module],
+        activation_layer: Callable[..., nn.Module],
+        # transformer parameters
+        head_dim: int,
+        mlp_ratio: int,
+        mlp_dropout: float,
+        attention_dropout: float,
+        p_stochastic_dropout: float,
+        # partitioning parameters
+        partition_size: int,
+        grid_size: Tuple[int, int],
+    ) -> None:
+        super().__init__()
+
+        layers: OrderedDict = OrderedDict()
+
+        # convolutional layer
+        layers["MBconv"] = MBConv(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            expansion_ratio=expansion_ratio,
+            squeeze_ratio=squeeze_ratio,
+            stride=stride,
+            activation_layer=activation_layer,
+            norm_layer=norm_layer,
+            p_stochastic_dropout=p_stochastic_dropout,
+        )
+        # attention layers, block -> grid
+        layers["window_attention"] = PartitionAttentionLayer(
+            in_channels=out_channels,
+            head_dim=head_dim,
+            partition_size=partition_size,
+            partition_type="window",
+            grid_size=grid_size,
+            mlp_ratio=mlp_ratio,
+            activation_layer=activation_layer,
+            norm_layer=nn.LayerNorm,
+            attention_dropout=attention_dropout,
+            mlp_dropout=mlp_dropout,
+            p_stochastic_dropout=p_stochastic_dropout,
+        )
+        layers["grid_attention"] = PartitionAttentionLayer(
+            in_channels=out_channels,
+            head_dim=head_dim,
+            partition_size=partition_size,
+            partition_type="grid",
+            grid_size=grid_size,
+            mlp_ratio=mlp_ratio,
+            activation_layer=activation_layer,
+            norm_layer=nn.LayerNorm,
+            attention_dropout=attention_dropout,
+            mlp_dropout=mlp_dropout,
+            p_stochastic_dropout=p_stochastic_dropout,
+        )
+        self.layers = nn.Sequential(layers)
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor of shape (B, C, H, W).
+        Returns:
+            Tensor: Output tensor of shape (B, C, H, W).
+        """
+        x = self.layers(x)
+        return x
+
+
+class MaxVitBlock(nn.Module):
+    """
+    A MaxVit block consisting of `n_layers` MaxVit layers.
+
+     Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        expansion_ratio (float): Expansion ratio in the bottleneck.
+        squeeze_ratio (float): Squeeze ratio in the SE Layer.
+        activation_layer (Callable[..., nn.Module]): Activation function.
+        norm_layer (Callable[..., nn.Module]): Normalization function.
+        head_dim (int): Dimension of the attention heads.
+        mlp_ratio (int): Ratio of the MLP layer.
+        mlp_dropout (float): Dropout probability for the MLP layer.
+        attention_dropout (float): Dropout probability for the attention layer.
+        p_stochastic_dropout (float): Probability of stochastic depth.
+        partition_size (int): Size of the partitions.
+        input_grid_size (Tuple[int, int]): Size of the input feature grid.
+        n_layers (int): Number of layers in the block.
+        p_stochastic (List[float]): List of probabilities for stochastic depth for each layer.
+    """
+
+    def __init__(
+        self,
+        # conv parameters
+        in_channels: int,
+        out_channels: int,
+        squeeze_ratio: float,
+        expansion_ratio: float,
+        # conv + transformer parameters
+        norm_layer: Callable[..., nn.Module],
+        activation_layer: Callable[..., nn.Module],
+        # transformer parameters
+        head_dim: int,
+        mlp_ratio: int,
+        mlp_dropout: float,
+        attention_dropout: float,
+        # partitioning parameters
+        partition_size: int,
+        input_grid_size: Tuple[int, int],
+        # number of layers
+        n_layers: int,
+        p_stochastic: List[float],
+    ) -> None:
+        super().__init__()
+        if not len(p_stochastic) == n_layers:
+            raise ValueError(f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}.")
+
+        self.layers = nn.ModuleList()
+        # account for the first stride of the first layer
+        self.grid_size = _get_conv_output_shape(input_grid_size, kernel_size=3, stride=2, padding=1)
+
+        for idx, p in enumerate(p_stochastic):
+            stride = 2 if idx == 0 else 1
+            self.layers += [
+                MaxVitLayer(
+                    in_channels=in_channels if idx == 0 else out_channels,
+                    out_channels=out_channels,
+                    squeeze_ratio=squeeze_ratio,
+                    expansion_ratio=expansion_ratio,
+                    stride=stride,
+                    norm_layer=norm_layer,
+                    activation_layer=activation_layer,
+                    head_dim=head_dim,
+                    mlp_ratio=mlp_ratio,
+                    mlp_dropout=mlp_dropout,
+                    attention_dropout=attention_dropout,
+                    partition_size=partition_size,
+                    grid_size=self.grid_size,
+                    p_stochastic_dropout=p,
+                ),
+            ]
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        Args:
+            x (Tensor): Input tensor of shape (B, C, H, W).
+        Returns:
+            Tensor: Output tensor of shape (B, C, H, W).
+        """
+        for layer in self.layers:
+            x = layer(x)
+        return x
+
+
+class MaxVit(nn.Module):
+    """
+    Implements MaxVit Transformer from the `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_ paper.
+    Args:
+        input_size (Tuple[int, int]): Size of the input image.
+        stem_channels (int): Number of channels in the stem.
+        partition_size (int): Size of the partitions.
+        block_channels (List[int]): Number of channels in each block.
+        block_layers (List[int]): Number of layers in each block.
+        stochastic_depth_prob (float): Probability of stochastic depth. Expands to a list of probabilities for each layer that scales linearly to the specified value.
+        squeeze_ratio (float): Squeeze ratio in the SE Layer. Default: 0.25.
+        expansion_ratio (float): Expansion ratio in the MBConv bottleneck. Default: 4.
+        norm_layer (Callable[..., nn.Module]): Normalization function. Default: None (setting to None will produce a `BatchNorm2d(eps=1e-3, momentum=0.99)`).
+        activation_layer (Callable[..., nn.Module]): Activation function Default: nn.GELU.
+        head_dim (int): Dimension of the attention heads.
+        mlp_ratio (int): Expansion ratio of the MLP layer. Default: 4.
+        mlp_dropout (float): Dropout probability for the MLP layer. Default: 0.0.
+        attention_dropout (float): Dropout probability for the attention layer. Default: 0.0.
+        num_classes (int): Number of classes. Default: 1000.
+    """
+
+    def __init__(
+        self,
+        # input size parameters
+        input_size: Tuple[int, int],
+        # stem and task parameters
+        stem_channels: int,
+        # partitioning parameters
+        partition_size: int,
+        # block parameters
+        block_channels: List[int],
+        block_layers: List[int],
+        # attention head dimensions
+        head_dim: int,
+        stochastic_depth_prob: float,
+        # conv + transformer parameters
+        # norm_layer is applied only to the conv layers
+        # activation_layer is applied both to conv and transformer layers
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+        activation_layer: Callable[..., nn.Module] = nn.GELU,
+        # conv parameters
+        squeeze_ratio: float = 0.25,
+        expansion_ratio: float = 4,
+        # transformer parameters
+        mlp_ratio: int = 4,
+        mlp_dropout: float = 0.0,
+        attention_dropout: float = 0.0,
+        # task parameters
+        num_classes: int = 1000,
+    ) -> None:
+        super().__init__()
+        _log_api_usage_once(self)
+
+        input_channels = 3
+
+        # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1029-L1030
+        # for the exact parameters used in batchnorm
+        if norm_layer is None:
+            norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.99)
+
+        # Make sure input size will be divisible by the partition size in all blocks
+        # Undefined behavior if H or W are not divisible by p
+        # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766
+        block_input_sizes = _make_block_input_shapes(input_size, len(block_channels))
+        for idx, block_input_size in enumerate(block_input_sizes):
+            if block_input_size[0] % partition_size != 0 or block_input_size[1] % partition_size != 0:
+                raise ValueError(
+                    f"Input size {block_input_size} of block {idx} is not divisible by partition size {partition_size}. "
+                    f"Consider changing the partition size or the input size.\n"
+                    f"Current configuration yields the following block input sizes: {block_input_sizes}."
+                )
+
+        # stem
+        self.stem0 = nn.Sequential(
+            Conv2dNormActivation(
+                input_channels,
+                stem_channels,
+                3,
+                stride=1,
+                norm_layer=norm_layer,
+                activation_layer=activation_layer,
+                bias=False,
+                inplace=None,
+            ),
+            Conv2dNormActivation(
+                stem_channels, stem_channels, 3, stride=1, norm_layer=None, activation_layer=None, bias=True
+            ),
+        )
+
+        self.stem1 = nn.Sequential(
+            Conv2dNormActivation(
+                stem_channels,
+                stem_channels,
+                3,
+                stride=2,
+                norm_layer=norm_layer,
+                activation_layer=activation_layer,
+                bias=False,
+                inplace=None,
+            ),
+            Conv2dNormActivation(
+                stem_channels, stem_channels, 3, stride=1, norm_layer=None, activation_layer=None, bias=True
+            ),
+        )
+
+
+        # account for stem stride
+        input_size = _get_conv_output_shape(input_size, kernel_size=3, stride=2, padding=1)
+        self.partition_size = partition_size
+
+        # blocks
+        self.blocks = nn.ModuleList()
+        in_channels = [stem_channels] + block_channels[:-1]
+        out_channels = block_channels
+
+        # precompute the stochastich depth probabilities from 0 to stochastic_depth_prob
+        # since we have N blocks with L layers, we will have N * L probabilities uniformly distributed
+        # over the range [0, stochastic_depth_prob]
+        p_stochastic = np.linspace(0, stochastic_depth_prob, sum(block_layers)).tolist()
+
+        p_idx = 0
+        for in_channel, out_channel, num_layers in zip(in_channels, out_channels, block_layers):
+            self.blocks.append(
+                MaxVitBlock(
+                    in_channels=in_channel,
+                    out_channels=out_channel,
+                    squeeze_ratio=squeeze_ratio,
+                    expansion_ratio=expansion_ratio,
+                    norm_layer=norm_layer,
+                    activation_layer=activation_layer,
+                    head_dim=head_dim,
+                    mlp_ratio=mlp_ratio,
+                    mlp_dropout=mlp_dropout,
+                    attention_dropout=attention_dropout,
+                    partition_size=partition_size,
+                    input_grid_size=input_size,
+                    n_layers=num_layers,
+                    p_stochastic=p_stochastic[p_idx : p_idx + num_layers],
+                ),
+            )
+            input_size = self.blocks[-1].grid_size
+            p_idx += num_layers
+
+        # see https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1137-L1158
+        # for why there is Linear -> Tanh -> Linear
+        self.classifier = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            nn.Flatten(),
+            nn.LayerNorm(block_channels[-1]),
+            nn.Linear(block_channels[-1], block_channels[-1]),
+            nn.Tanh(),
+            nn.Linear(block_channels[-1], num_classes, bias=False),
+        )
+
+        self._init_weights()
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.stem0(x)
+        x=self.stem1(x)
+        for block in self.blocks:
+            x = block(x)
+        x = self.classifier(x)
+        return x
+
+    def _init_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.normal_(m.weight, std=0.02)
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, std=0.02)
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+
+
+def _maxvit(
+    # stem parameters
+    stem_channels: int,
+    # block parameters
+    block_channels: List[int],
+    block_layers: List[int],
+    stochastic_depth_prob: float,
+    # partitioning parameters
+    partition_size: int,
+    # transformer parameters
+    head_dim: int,
+    # Weights API
+    weights: Optional[WeightsEnum] = None,
+    progress: bool = False,
+    # kwargs,
+    **kwargs: Any,
+) -> MaxVit:
+
+    if weights is not None:
+        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
+        assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
+        _ovewrite_named_param(kwargs, "input_size", weights.meta["min_size"])
+
+    input_size = kwargs.pop("input_size", (224, 224))
+
+    model = MaxVit(
+        stem_channels=stem_channels,
+        block_channels=block_channels,
+        block_layers=block_layers,
+        stochastic_depth_prob=stochastic_depth_prob,
+        head_dim=head_dim,
+        partition_size=partition_size,
+        input_size=input_size,
+        **kwargs,
+    )
+
+    if weights is not None:
+        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
+
+    return model
+
+
+class MaxVit_T_Weights(WeightsEnum):
+    IMAGENET1K_V1 = Weights(
+        # URL empty until official release
+        url="https://download.pytorch.org/models/maxvit_t-bc5ab103.pth",
+        transforms=partial(
+            ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
+        ),
+        meta={
+            "categories": _IMAGENET_CATEGORIES,
+            "num_params": 30919624,
+            "min_size": (224, 224),
+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#maxvit",
+            "_metrics": {
+                "ImageNet-1K": {
+                    "acc@1": 83.700,
+                    "acc@5": 96.722,
+                }
+            },
+            "_ops": 5.558,
+            "_file_size": 118.769,
+            "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
+        },
+    )
+    DEFAULT = IMAGENET1K_V1
+
+
+
+@handle_legacy_interface(weights=("pretrained", MaxVit_T_Weights.IMAGENET1K_V1))
+def maxvit_t(*, weights: Optional[MaxVit_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MaxVit:
+    """
+    Constructs a maxvit_t architecture from
+    `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_.
+
+    Args:
+        weights (:class:`~torchvision.models.MaxVit_T_Weights`, optional): The
+            pretrained weights to use. See
+            :class:`~torchvision.models.MaxVit_T_Weights` below for
+            more details, and possible values. By default, no pre-trained
+            weights are used.
+        progress (bool, optional): If True, displays a progress bar of the
+            download to stderr. Default is True.
+        **kwargs: parameters passed to the ``torchvision.models.maxvit.MaxVit``
+            base class. Please refer to the `source code
+            <https://github.com/pytorch/vision/blob/main/torchvision/models/maxvit.py>`_
+            for more details about this class.
+
+    .. autoclass:: torchvision.models.MaxVit_T_Weights
+        :members:
+    """
+    weights = MaxVit_T_Weights.verify(weights)
+
+    return _maxvit(
+        stem_channels=64,
+        block_channels=[64, 128, 256, 512],
+        block_layers=[2, 2, 5, 2],
+        head_dim=32,
+        stochastic_depth_prob=0.2,
+        partition_size=7,
+        weights=weights,
+        progress=progress,
+        **kwargs,
+    )
+
+class MaxVitBackbone(torch.nn.Module):
+    def __init__(self,input_size=(224*2,224*2)):
+        super(MaxVitBackbone, self).__init__()
+        # 提取MaxVit的部分层作为特征提取器
+        maxvit_model = maxvit_t(pretrained=False,input_size=input_size)
+
+        self.stem0 = maxvit_model.stem0  # Stem层
+        self.stem1 = maxvit_model.stem1  # Stem层
+        self.block0= maxvit_model.blocks[0]
+        self.block1 = maxvit_model.blocks[1]
+        self.block2 = maxvit_model.blocks[2]
+        self.block3 = maxvit_model.blocks[3]
+
+    def forward(self, x):
+        print("Input size:", x.shape)
+        x = self.stem0(x)
+        print("After stem0 size:", x.shape)
+        x=self.stem1(x)
+        print("After stem1 size:", x.shape)
+        x = self.block0(x)
+        print("After block0 size:", x.shape)
+        x = self.block1(x)
+        print("After block1 size:", x.shape)
+        x = self.block2(x)
+        print("After block2 size:", x.shape)
+        x = self.block3(x)
+        print("After block3 size:", x.shape)
+        return x
+
+def maxvit_with_fpn(size=224):
+    maxvit = MaxVitBackbone(input_size=(size, size))
+    in_channels_list = [64, 64, 64, 128, 256, 512]
+    featmap_names = ['0', '1', '2', '3', '4','5', 'pool']
+    # print(f'featmap_names:{featmap_names}')
+    # roi_pooler = MultiScaleRoIAlign(
+    #     featmap_names=featmap_names,
+    #     output_size=7,
+    #     sampling_ratio=2
+    # )
+    backbone_with_fpn = BackboneWithFPN(
+        maxvit,
+        return_layers={'stem0': '0', 'stem1': '1', 'block0': '2', 'block1': '3', 'block2': '4', 'block3': '5'},
+        # 确保这些键对应到实际的层
+        in_channels_list=in_channels_list,
+        out_channels=256
+    )
+    test_input = torch.randn(1, 3, size, size)
+
+    return backbone_with_fpn
+
+if __name__ == '__main__':
+    maxvit = MaxVitBackbone(input_size=(224 * 3, 224 * 3))
+    in_channels_list = [64,64, 64, 128, 256, 512]
+    featmap_names = ['0', '1', '2', '3', '4', 'pool']
+    # print(f'featmap_names:{featmap_names}')
+    # roi_pooler = MultiScaleRoIAlign(
+    #     featmap_names=featmap_names,
+    #     output_size=7,
+    #     sampling_ratio=2
+    # )
+    backbone_with_fpn = BackboneWithFPN(
+        maxvit,
+        return_layers={'stem0': '0','stem1': '1', 'block0': '2', 'block1': '3', 'block2': '4', 'block3': '5'},  # 确保这些键对应到实际的层
+        in_channels_list=in_channels_list,
+        out_channels=256
+    )
+    test_input = torch.randn(1, 3, 224 * 3, 224* 3 )
+
+    # model = FasterRCNN(
+    #     backbone=backbone_with_fpn,
+    #     min_size=224 * 5,
+    #     max_size=224 * 5,
+    #     num_classes=91,  # COCO 数据集有 91 类
+    #     rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input),
+    #     box_roi_pool=roi_pooler
+    # )
+
+    out = maxvit(test_input)
+    with torch.no_grad():
+        output = backbone_with_fpn(test_input)
+    #
+    print("Output feature maps:")
+
+    for k, v in output.items():
+        print(f"{k}: {v.shape}")
+
+    # model.eval()
+    # output = model(test_input)
+    # print(f'fasterrcnn output:{output}')

+ 43 - 4
models/line_detect/line_detect.py

@@ -31,6 +31,7 @@ from ..base.backbone_factory import get_convnext_fpn, get_anchor_generator, get_
 from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
 
+from ..base.high_reso_maxvit import maxvit_with_fpn
 from ..base.high_reso_resnet import resnet50fpn, resnet18fpn, resnet101fpn
 
 __all__ = [
@@ -412,7 +413,7 @@ def linedetect_maxvitfpn(
     if num_points is None:
         num_points = 3
 
-    size=224*2
+    size=224*4
 
     maxvit = MaxVitBackbone(input_size=(size,size))
     # print(maxvit.named_children())
@@ -432,21 +433,59 @@ def linedetect_maxvitfpn(
         return_layers={'stem': '0', 'block0': '1', 'block1': '2', 'block2': '3', 'block3': '4'},
         # 确保这些键对应到实际的层
         in_channels_list=in_channels_list,
-        out_channels=256
+        out_channels=64
     )
-    test_input = torch.randn(1, 3, 224 * 2, 224 * 2)
+    test_input = torch.randn(1, 3,size,size)
 
     model = LineDetect(
         backbone=backbone_with_fpn,
         min_size=size,
         max_size=size,
-        num_classes=91,  # COCO 数据集有 91 类
+        num_classes=3,  # COCO 数据集有 91 类
         rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input),
         box_roi_pool=roi_pooler
     )
     return model
 
+def linedetect_high_maxvitfpn(
+        *,
+        num_classes: Optional[int] = None,
+        num_points:Optional[int] = None,
+        **kwargs: Any,
+) -> LineDetect:
+    # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
+    # weights_backbone = ResNet50_Weights.verify(weights_backbone)
+    if num_classes is None:
+        num_classes = 3
+    if num_points is None:
+        num_points = 3
+
+    size=224*2
+
+    maxvitfpn =maxvit_with_fpn(size=size)
+    # print(maxvit.named_children())
 
+    # for i,layer in enumerate(maxvit.named_children()):
+    #     print(f'layer:{i}:{layer}')
+
+    in_channels_list = [64,64, 64, 128, 256, 512]
+    featmap_names = ['0', '1', '2', '3', '4', '5','pool']
+    roi_pooler = MultiScaleRoIAlign(
+        featmap_names=featmap_names,
+        output_size=7,
+        sampling_ratio=2
+    )
+    test_input = torch.randn(1, 3,size,size)
+
+    model = LineDetect(
+        backbone=maxvitfpn,
+        min_size=size,
+        max_size=size,
+        num_classes=3,  # COCO 数据集有 91 类
+        rpn_anchor_generator=get_anchor_generator(maxvitfpn, test_input=test_input),
+        box_roi_pool=roi_pooler
+    )
+    return model
 
 def linedetect_resnet18_fpn(
         *,

+ 3 - 2
models/line_detect/train_demo.py

@@ -1,7 +1,7 @@
 import torch
 
 from models.line_detect.line_detect import linedetect_newresnet18fpn, linedetect_resnet50_fpn, linedetect_resnet18_fpn, \
-    linedetect_newresnet50fpn, linedetect_maxvitfpn
+    linedetect_newresnet50fpn, linedetect_maxvitfpn, linedetect_high_maxvitfpn
 
 from models.line_net.trainer import Trainer
 
@@ -19,5 +19,6 @@ if __name__ == '__main__':
     # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet50fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
-    model=linedetect_maxvitfpn()
+    # model=linedetect_maxvitfpn()
+    model=linedetect_high_maxvitfpn()
     model.start_train(cfg='train.yaml')