Kaynağa Gözat

Update block.py

Mengqi Lei 2 ay önce
ebeveyn
işleme
154ea570ce
1 değiştirilmiş dosya ile 285 ekleme ve 0 silme
  1. 285 0
      ultralytics/nn/modules/block.py

+ 285 - 0
ultralytics/nn/modules/block.py

@@ -1373,6 +1373,32 @@ class A2C2f(nn.Module):
         return self.cv2(torch.cat(y, 1))
 
 class DSBottleneck(nn.Module):
+    """
+    An improved bottleneck block using depthwise separable convolutions (DSConv).
+
+    This class implements a lightweight bottleneck module that replaces standard convolutions with depthwise
+    separable convolutions to reduce parameters and computational cost. 
+
+    Attributes:
+        c1 (int): Number of input channels.
+        c2 (int): Number of output channels.
+        shortcut (bool, optional): Whether to use a residual shortcut connection. The connection is only added if c1 == c2. Defaults to True.
+        e (float, optional): Expansion ratio for the intermediate channels. Defaults to 0.5.
+        k1 (int, optional): Kernel size for the first DSConv layer. Defaults to 3.
+        k2 (int, optional): Kernel size for the second DSConv layer. Defaults to 5.
+        d2 (int, optional): Dilation for the second DSConv layer. Defaults to 1.
+
+    Methods:
+        forward: Performs a forward pass through the DSBottleneck module.
+
+    Examples:
+        >>> import torch
+        >>> model = DSBottleneck(c1=64, c2=64, shortcut=True)
+        >>> x = torch.randn(2, 64, 32, 32)
+        >>> output = model(x)
+        >>> print(output.shape)
+        torch.Size([2, 64, 32, 32])
+    """
     def __init__(self, c1, c2, shortcut=True, e=0.5, k1=3, k2=5, d2=1):
         super().__init__()
         c_ = int(c2 * e)
@@ -1386,6 +1412,34 @@ class DSBottleneck(nn.Module):
 
 
 class DSC3k(C3):
+    """
+    An improved C3k module using DSBottleneck blocks for lightweight feature extraction.
+
+    This class extends the C3 module by replacing its standard bottleneck blocks with DSBottleneck blocks,
+    which use depthwise separable convolutions.
+
+    Attributes:
+        c1 (int): Number of input channels.
+        c2 (int): Number of output channels.
+        n (int, optional): Number of DSBottleneck blocks to stack. Defaults to 1.
+        shortcut (bool, optional): Whether to use shortcut connections within the DSBottlenecks. Defaults to True.
+        g (int, optional): Number of groups for grouped convolution (passed to parent C3). Defaults to 1.
+        e (float, optional): Expansion ratio for the C3 module's hidden channels. Defaults to 0.5.
+        k1 (int, optional): Kernel size for the first DSConv in each DSBottleneck. Defaults to 3.
+        k2 (int, optional): Kernel size for the second DSConv in each DSBottleneck. Defaults to 5.
+        d2 (int, optional): Dilation for the second DSConv in each DSBottleneck. Defaults to 1.
+
+    Methods:
+        forward: Performs a forward pass through the DSC3k module (inherited from C3).
+
+    Examples:
+        >>> import torch
+        >>> model = DSC3k(c1=128, c2=128, n=2, k1=3, k2=7)
+        >>> x = torch.randn(2, 128, 64, 64)
+        >>> output = model(x)
+        >>> print(output.shape)
+        torch.Size([2, 128, 64, 64])
+    """
     def __init__(
         self,
         c1,                
@@ -1416,6 +1470,41 @@ class DSC3k(C3):
         )
 
 class DSC3k2(C2f):
+    """
+    An improved C3k2 module that uses lightweight depthwise separable convolution blocks.
+
+    This class redesigns C3k2 module, replacing its internal processing blocks with either DSBottleneck
+    or DSC3k modules.
+
+    Attributes:
+        c1 (int): Number of input channels.
+        c2 (int): Number of output channels.
+        n (int, optional): Number of internal processing blocks to stack. Defaults to 1.
+        dsc3k (bool, optional): If True, use DSC3k as the internal block. If False, use DSBottleneck. Defaults to False.
+        e (float, optional): Expansion ratio for the C2f module's hidden channels. Defaults to 0.5.
+        g (int, optional): Number of groups for grouped convolution (passed to parent C2f). Defaults to 1.
+        shortcut (bool, optional): Whether to use shortcut connections in the internal blocks. Defaults to True.
+        k1 (int, optional): Kernel size for the first DSConv in internal blocks. Defaults to 3.
+        k2 (int, optional): Kernel size for the second DSConv in internal blocks. Defaults to 7.
+        d2 (int, optional): Dilation for the second DSConv in internal blocks. Defaults to 1.
+
+    Methods:
+        forward: Performs a forward pass through the DSC3k2 module (inherited from C2f).
+
+    Examples:
+        >>> import torch
+        >>> # Using DSBottleneck as internal block
+        >>> model1 = DSC3k2(c1=64, c2=64, n=2, dsc3k=False)
+        >>> x = torch.randn(2, 64, 128, 128)
+        >>> output1 = model1(x)
+        >>> print(f"With DSBottleneck: {output1.shape}")
+        With DSBottleneck: torch.Size([2, 64, 128, 128])
+        >>> # Using DSC3k as internal block
+        >>> model2 = DSC3k2(c1=64, c2=64, n=1, dsc3k=True)
+        >>> output2 = model2(x)
+        >>> print(f"With DSC3k: {output2.shape}")
+        With DSC3k: torch.Size([2, 64, 128, 128])
+    """
     def __init__(
         self,
         c1,          
@@ -1458,6 +1547,31 @@ class DSC3k2(C2f):
             )
 
 class AdaHyperedgeGen(nn.Module):
+    """
+    Generates an adaptive hyperedge participation matrix from a set of vertex features.
+
+    This module implements the Adaptive Hyperedge Generation mechanism. It generates dynamic hyperedge prototypes
+    based on the global context of the input nodes and calculates a continuous participation matrix (A)
+    that defines the relationship between each vertex and each hyperedge.
+
+    Attributes:
+        node_dim (int): The feature dimension of each input node.
+        num_hyperedges (int): The number of hyperedges to generate.
+        num_heads (int, optional): The number of attention heads for multi-head similarity calculation. Defaults to 4.
+        dropout (float, optional): The dropout rate applied to the logits. Defaults to 0.1.
+        context (str, optional): The type of global context to use ('mean', 'max', or 'both'). Defaults to "both".
+
+    Methods:
+        forward: Takes a batch of vertex features and returns the participation matrix A.
+
+    Examples:
+        >>> import torch
+        >>> model = AdaHyperedgeGen(node_dim=64, num_hyperedges=16, num_heads=4)
+        >>> x = torch.randn(2, 100, 64)  # (Batch, Num_Nodes, Node_Dim)
+        >>> A = model(x)
+        >>> print(A.shape)
+        torch.Size([2, 100, 16])
+    """
     def __init__(self, node_dim, num_hyperedges, num_heads=4, dropout=0.1, context="both"):
         super().__init__()
         self.num_heads = num_heads
@@ -1510,6 +1624,33 @@ class AdaHyperedgeGen(nn.Module):
         return F.softmax(logits, dim=1)
 
 class AdaHGConv(nn.Module):
+    """
+    Performs the adaptive hypergraph convolution.
+
+    This module contains the two-stage message passing process of hypergraph convolution:
+    1. Generates an adaptive participation matrix using AdaHyperedgeGen.
+    2. Aggregates vertex features into hyperedge features (vertex-to-edge).
+    3. Disseminates hyperedge features back to update vertex features (edge-to-vertex).
+    A residual connection is added to the final output.
+
+    Attributes:
+        embed_dim (int): The feature dimension of the vertices.
+        num_hyperedges (int, optional): The number of hyperedges for the internal generator. Defaults to 16.
+        num_heads (int, optional): The number of attention heads for the internal generator. Defaults to 4.
+        dropout (float, optional): The dropout rate for the internal generator. Defaults to 0.1.
+        context (str, optional): The context type for the internal generator. Defaults to "both".
+
+    Methods:
+        forward: Performs the adaptive hypergraph convolution on a batch of vertex features.
+
+    Examples:
+        >>> import torch
+        >>> model = AdaHGConv(embed_dim=128, num_hyperedges=16, num_heads=8)
+        >>> x = torch.randn(2, 256, 128) # (Batch, Num_Nodes, Dim)
+        >>> output = model(x)
+        >>> print(output.shape)
+        torch.Size([2, 256, 128])
+    """
     def __init__(self, embed_dim, num_hyperedges=16, num_heads=4, dropout=0.1, context="both"):
         super().__init__()
         self.edge_generator = AdaHyperedgeGen(embed_dim, num_hyperedges, num_heads, dropout, context)
@@ -1534,6 +1675,31 @@ class AdaHGConv(nn.Module):
         return X_new + X
         
 class AdaHGComputation(nn.Module):
+    """
+    A wrapper module for applying adaptive hypergraph convolution to 4D feature maps.
+
+    This class makes the hypergraph convolution compatible with standard CNN architectures. It flattens a
+    4D input tensor (B, C, H, W) into a sequence of vertices (tokens), applies the AdaHGConv layer to
+    model high-order correlations, and then reshapes the output back into a 4D tensor.
+
+    Attributes:
+        embed_dim (int): The feature dimension of the vertices (equivalent to input channels C).
+        num_hyperedges (int, optional): The number of hyperedges for the underlying AdaHGConv. Defaults to 16.
+        num_heads (int, optional): The number of attention heads for the underlying AdaHGConv. Defaults to 8.
+        dropout (float, optional): The dropout rate for the underlying AdaHGConv. Defaults to 0.1.
+        context (str, optional): The context type for the underlying AdaHGConv. Defaults to "both".
+
+    Methods:
+        forward: Processes a 4D feature map through the adaptive hypergraph computation layer.
+
+    Examples:
+        >>> import torch
+        >>> model = AdaHGComputation(embed_dim=64, num_hyperedges=8, num_heads=4)
+        >>> x = torch.randn(2, 64, 32, 32) # (B, C, H, W)
+        >>> output = model(x)
+        >>> print(output.shape)
+        torch.Size([2, 64, 32, 32])
+    """
     def __init__(self, embed_dim, num_hyperedges=16, num_heads=8, dropout=0.1, context="both"):
         super().__init__()
         self.embed_dim = embed_dim
@@ -1553,6 +1719,31 @@ class AdaHGComputation(nn.Module):
         return x_out 
 
 class C3AH(nn.Module):
+    """
+    A CSP-style block integrating Adaptive Hypergraph Computation (C3AH).
+
+    The input feature map is split into two paths.
+    One path is processed by the AdaHGComputation module to model high-order correlations, while the other
+    serves as a shortcut. The outputs are then concatenated to fuse features.
+
+    Attributes:
+        c1 (int): Number of input channels.
+        c2 (int): Number of output channels.
+        e (float, optional): Expansion ratio for the hidden channels. Defaults to 1.0.
+        num_hyperedges (int, optional): The number of hyperedges for the internal AdaHGComputation. Defaults to 8.
+        context (str, optional): The context type for the internal AdaHGComputation. Defaults to "both".
+
+    Methods:
+        forward: Performs a forward pass through the C3AH module.
+
+    Examples:
+        >>> import torch
+        >>> model = C3AH(c1=64, c2=128, num_hyperedges=8)
+        >>> x = torch.randn(2, 64, 32, 32)
+        >>> output = model(x)
+        >>> print(output.shape)
+        torch.Size([2, 128, 32, 32])
+    """
     def __init__(self, c1, c2, e=1.0, num_hyperedges=8, context="both"):
         super().__init__()
         c_ = int(c2 * e)  
@@ -1571,6 +1762,29 @@ class C3AH(nn.Module):
         return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
 
 class FuseModule(nn.Module):
+    """
+    A module to fuse multi-scale features for the HyperACE block.
+
+    This module takes a list of three feature maps from different scales, aligns them to a common
+    spatial resolution by downsampling the first and upsampling the third, and then concatenates
+    and fuses them with a convolution layer.
+
+    Attributes:
+        c_in (int): The number of channels of the input feature maps.
+        channel_adjust (bool): Whether to adjust the channel count of the concatenated features.
+
+    Methods:
+        forward: Fuses a list of three multi-scale feature maps.
+
+    Examples:
+        >>> import torch
+        >>> model = FuseModule(c_in=64, channel_adjust=False)
+        >>> # Input is a list of features from different backbone stages
+        >>> x_list = [torch.randn(2, 64, 64, 64), torch.randn(2, 64, 32, 32), torch.randn(2, 64, 16, 16)]
+        >>> output = model(x_list)
+        >>> print(output.shape)
+        torch.Size([2, 64, 32, 32])
+    """
     def __init__(self, c_in, channel_adjust):
         super(FuseModule, self).__init__()
         self.downsample = nn.AvgPool2d(kernel_size=2)
@@ -1588,6 +1802,37 @@ class FuseModule(nn.Module):
         return out
 
 class HyperACE(nn.Module):
+    """
+    Hypergraph-based Adaptive Correlation Enhancement (HyperACE).
+
+    This is the core module of YOLOv13, designed to model both global high-order correlations and
+    local low-order correlations. It first fuses multi-scale features, then processes them through parallel
+    branches: two C3AH branches for high-order modeling and a lightweight DSConv-based branch for
+    low-order feature extraction.
+
+    Attributes:
+        c1 (int): Number of input channels for the fuse module.
+        c2 (int): Number of output channels for the entire block.
+        n (int, optional): Number of blocks in the low-order branch. Defaults to 1.
+        num_hyperedges (int, optional): Number of hyperedges for the C3AH branches. Defaults to 8.
+        dsc3k (bool, optional): If True, use DSC3k in the low-order branch; otherwise, use DSBottleneck. Defaults to True.
+        shortcut (bool, optional): Whether to use shortcuts in the low-order branch. Defaults to False.
+        e1 (float, optional): Expansion ratio for the main hidden channels. Defaults to 0.5.
+        e2 (float, optional): Expansion ratio within the C3AH branches. Defaults to 1.
+        context (str, optional): Context type for C3AH branches. Defaults to "both".
+        channel_adjust (bool, optional): Passed to FuseModule for channel configuration. Defaults to True.
+
+    Methods:
+        forward: Performs a forward pass through the HyperACE module.
+
+    Examples:
+        >>> import torch
+        >>> model = HyperACE(c1=64, c2=256, n=1, num_hyperedges=8)
+        >>> x_list = [torch.randn(2, 64, 64, 64), torch.randn(2, 64, 32, 32), torch.randn(2, 64, 16, 16)]
+        >>> output = model(x_list)
+        >>> print(output.shape)
+        torch.Size([2, 256, 32, 32])
+    """
     def __init__(self, c1, c2, n=1, num_hyperedges=8, dsc3k=True, shortcut=False, e1=0.5, e2=1, context="both", channel_adjust=True):
         super().__init__()
         self.c = int(c2 * e1) 
@@ -1611,6 +1856,27 @@ class HyperACE(nn.Module):
         return self.cv2(torch.cat(y, 1))
 
 class DownsampleConv(nn.Module):
+    """
+    A simple downsampling block with optional channel adjustment.
+
+    This module uses average pooling to reduce the spatial dimensions (H, W) by a factor of 2. It can
+    optionally include a 1x1 convolution to adjust the number of channels, typically doubling them.
+
+    Attributes:
+        in_channels (int): The number of input channels.
+        channel_adjust (bool, optional): If True, a 1x1 convolution doubles the channel dimension. Defaults to True.
+
+    Methods:
+        forward: Performs the downsampling and optional channel adjustment.
+
+    Examples:
+        >>> import torch
+        >>> model = DownsampleConv(in_channels=64, channel_adjust=True)
+        >>> x = torch.randn(2, 64, 32, 32)
+        >>> output = model(x)
+        >>> print(output.shape)
+        torch.Size([2, 128, 16, 16])
+    """
     def __init__(self, in_channels, channel_adjust=True):
         super().__init__()
         self.downsample = nn.AvgPool2d(kernel_size=2)
@@ -1623,6 +1889,25 @@ class DownsampleConv(nn.Module):
         return self.channel_adjust(self.downsample(x))
 
 class FullPAD_Tunnel(nn.Module):
+    """
+    A gated fusion module for the Full-Pipeline Aggregation-and-Distribution (FullPAD) paradigm.
+
+    This module implements a gated residual connection used to fuse features. It takes two inputs: the original
+    feature map and a correlation-enhanced feature map. It then computes `output = original + gate * enhanced`,
+    where `gate` is a learnable scalar parameter that adaptively balances the contribution of the enhanced features.
+
+    Methods:
+        forward: Performs the gated fusion of two input feature maps.
+
+    Examples:
+        >>> import torch
+        >>> model = FullPAD_Tunnel()
+        >>> original_feature = torch.randn(2, 64, 32, 32)
+        >>> enhanced_feature = torch.randn(2, 64, 32, 32)
+        >>> output = model([original_feature, enhanced_feature])
+        >>> print(output.shape)
+        torch.Size([2, 64, 32, 32])
+    """
     def __init__(self):
         super().__init__()
         self.gate = nn.Parameter(torch.tensor(0.0))