|
@@ -1373,6 +1373,32 @@ class A2C2f(nn.Module):
|
|
|
return self.cv2(torch.cat(y, 1))
|
|
|
|
|
|
class DSBottleneck(nn.Module):
|
|
|
+ """
|
|
|
+ An improved bottleneck block using depthwise separable convolutions (DSConv).
|
|
|
+
|
|
|
+ This class implements a lightweight bottleneck module that replaces standard convolutions with depthwise
|
|
|
+ separable convolutions to reduce parameters and computational cost.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ c1 (int): Number of input channels.
|
|
|
+ c2 (int): Number of output channels.
|
|
|
+ shortcut (bool, optional): Whether to use a residual shortcut connection. The connection is only added if c1 == c2. Defaults to True.
|
|
|
+ e (float, optional): Expansion ratio for the intermediate channels. Defaults to 0.5.
|
|
|
+ k1 (int, optional): Kernel size for the first DSConv layer. Defaults to 3.
|
|
|
+ k2 (int, optional): Kernel size for the second DSConv layer. Defaults to 5.
|
|
|
+ d2 (int, optional): Dilation for the second DSConv layer. Defaults to 1.
|
|
|
+
|
|
|
+ Methods:
|
|
|
+ forward: Performs a forward pass through the DSBottleneck module.
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ >>> import torch
|
|
|
+ >>> model = DSBottleneck(c1=64, c2=64, shortcut=True)
|
|
|
+ >>> x = torch.randn(2, 64, 32, 32)
|
|
|
+ >>> output = model(x)
|
|
|
+ >>> print(output.shape)
|
|
|
+ torch.Size([2, 64, 32, 32])
|
|
|
+ """
|
|
|
def __init__(self, c1, c2, shortcut=True, e=0.5, k1=3, k2=5, d2=1):
|
|
|
super().__init__()
|
|
|
c_ = int(c2 * e)
|
|
@@ -1386,6 +1412,34 @@ class DSBottleneck(nn.Module):
|
|
|
|
|
|
|
|
|
class DSC3k(C3):
|
|
|
+ """
|
|
|
+ An improved C3k module using DSBottleneck blocks for lightweight feature extraction.
|
|
|
+
|
|
|
+ This class extends the C3 module by replacing its standard bottleneck blocks with DSBottleneck blocks,
|
|
|
+ which use depthwise separable convolutions.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ c1 (int): Number of input channels.
|
|
|
+ c2 (int): Number of output channels.
|
|
|
+ n (int, optional): Number of DSBottleneck blocks to stack. Defaults to 1.
|
|
|
+ shortcut (bool, optional): Whether to use shortcut connections within the DSBottlenecks. Defaults to True.
|
|
|
+ g (int, optional): Number of groups for grouped convolution (passed to parent C3). Defaults to 1.
|
|
|
+ e (float, optional): Expansion ratio for the C3 module's hidden channels. Defaults to 0.5.
|
|
|
+ k1 (int, optional): Kernel size for the first DSConv in each DSBottleneck. Defaults to 3.
|
|
|
+ k2 (int, optional): Kernel size for the second DSConv in each DSBottleneck. Defaults to 5.
|
|
|
+ d2 (int, optional): Dilation for the second DSConv in each DSBottleneck. Defaults to 1.
|
|
|
+
|
|
|
+ Methods:
|
|
|
+ forward: Performs a forward pass through the DSC3k module (inherited from C3).
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ >>> import torch
|
|
|
+ >>> model = DSC3k(c1=128, c2=128, n=2, k1=3, k2=7)
|
|
|
+ >>> x = torch.randn(2, 128, 64, 64)
|
|
|
+ >>> output = model(x)
|
|
|
+ >>> print(output.shape)
|
|
|
+ torch.Size([2, 128, 64, 64])
|
|
|
+ """
|
|
|
def __init__(
|
|
|
self,
|
|
|
c1,
|
|
@@ -1416,6 +1470,41 @@ class DSC3k(C3):
|
|
|
)
|
|
|
|
|
|
class DSC3k2(C2f):
|
|
|
+ """
|
|
|
+ An improved C3k2 module that uses lightweight depthwise separable convolution blocks.
|
|
|
+
|
|
|
+ This class redesigns C3k2 module, replacing its internal processing blocks with either DSBottleneck
|
|
|
+ or DSC3k modules.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ c1 (int): Number of input channels.
|
|
|
+ c2 (int): Number of output channels.
|
|
|
+ n (int, optional): Number of internal processing blocks to stack. Defaults to 1.
|
|
|
+ dsc3k (bool, optional): If True, use DSC3k as the internal block. If False, use DSBottleneck. Defaults to False.
|
|
|
+ e (float, optional): Expansion ratio for the C2f module's hidden channels. Defaults to 0.5.
|
|
|
+ g (int, optional): Number of groups for grouped convolution (passed to parent C2f). Defaults to 1.
|
|
|
+ shortcut (bool, optional): Whether to use shortcut connections in the internal blocks. Defaults to True.
|
|
|
+ k1 (int, optional): Kernel size for the first DSConv in internal blocks. Defaults to 3.
|
|
|
+ k2 (int, optional): Kernel size for the second DSConv in internal blocks. Defaults to 7.
|
|
|
+ d2 (int, optional): Dilation for the second DSConv in internal blocks. Defaults to 1.
|
|
|
+
|
|
|
+ Methods:
|
|
|
+ forward: Performs a forward pass through the DSC3k2 module (inherited from C2f).
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ >>> import torch
|
|
|
+ >>> # Using DSBottleneck as internal block
|
|
|
+ >>> model1 = DSC3k2(c1=64, c2=64, n=2, dsc3k=False)
|
|
|
+ >>> x = torch.randn(2, 64, 128, 128)
|
|
|
+ >>> output1 = model1(x)
|
|
|
+ >>> print(f"With DSBottleneck: {output1.shape}")
|
|
|
+ With DSBottleneck: torch.Size([2, 64, 128, 128])
|
|
|
+ >>> # Using DSC3k as internal block
|
|
|
+ >>> model2 = DSC3k2(c1=64, c2=64, n=1, dsc3k=True)
|
|
|
+ >>> output2 = model2(x)
|
|
|
+ >>> print(f"With DSC3k: {output2.shape}")
|
|
|
+ With DSC3k: torch.Size([2, 64, 128, 128])
|
|
|
+ """
|
|
|
def __init__(
|
|
|
self,
|
|
|
c1,
|
|
@@ -1458,6 +1547,31 @@ class DSC3k2(C2f):
|
|
|
)
|
|
|
|
|
|
class AdaHyperedgeGen(nn.Module):
|
|
|
+ """
|
|
|
+ Generates an adaptive hyperedge participation matrix from a set of vertex features.
|
|
|
+
|
|
|
+ This module implements the Adaptive Hyperedge Generation mechanism. It generates dynamic hyperedge prototypes
|
|
|
+ based on the global context of the input nodes and calculates a continuous participation matrix (A)
|
|
|
+ that defines the relationship between each vertex and each hyperedge.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ node_dim (int): The feature dimension of each input node.
|
|
|
+ num_hyperedges (int): The number of hyperedges to generate.
|
|
|
+ num_heads (int, optional): The number of attention heads for multi-head similarity calculation. Defaults to 4.
|
|
|
+ dropout (float, optional): The dropout rate applied to the logits. Defaults to 0.1.
|
|
|
+ context (str, optional): The type of global context to use ('mean', 'max', or 'both'). Defaults to "both".
|
|
|
+
|
|
|
+ Methods:
|
|
|
+ forward: Takes a batch of vertex features and returns the participation matrix A.
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ >>> import torch
|
|
|
+ >>> model = AdaHyperedgeGen(node_dim=64, num_hyperedges=16, num_heads=4)
|
|
|
+ >>> x = torch.randn(2, 100, 64) # (Batch, Num_Nodes, Node_Dim)
|
|
|
+ >>> A = model(x)
|
|
|
+ >>> print(A.shape)
|
|
|
+ torch.Size([2, 100, 16])
|
|
|
+ """
|
|
|
def __init__(self, node_dim, num_hyperedges, num_heads=4, dropout=0.1, context="both"):
|
|
|
super().__init__()
|
|
|
self.num_heads = num_heads
|
|
@@ -1510,6 +1624,33 @@ class AdaHyperedgeGen(nn.Module):
|
|
|
return F.softmax(logits, dim=1)
|
|
|
|
|
|
class AdaHGConv(nn.Module):
|
|
|
+ """
|
|
|
+ Performs the adaptive hypergraph convolution.
|
|
|
+
|
|
|
+ This module contains the two-stage message passing process of hypergraph convolution:
|
|
|
+ 1. Generates an adaptive participation matrix using AdaHyperedgeGen.
|
|
|
+ 2. Aggregates vertex features into hyperedge features (vertex-to-edge).
|
|
|
+ 3. Disseminates hyperedge features back to update vertex features (edge-to-vertex).
|
|
|
+ A residual connection is added to the final output.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ embed_dim (int): The feature dimension of the vertices.
|
|
|
+ num_hyperedges (int, optional): The number of hyperedges for the internal generator. Defaults to 16.
|
|
|
+ num_heads (int, optional): The number of attention heads for the internal generator. Defaults to 4.
|
|
|
+ dropout (float, optional): The dropout rate for the internal generator. Defaults to 0.1.
|
|
|
+ context (str, optional): The context type for the internal generator. Defaults to "both".
|
|
|
+
|
|
|
+ Methods:
|
|
|
+ forward: Performs the adaptive hypergraph convolution on a batch of vertex features.
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ >>> import torch
|
|
|
+ >>> model = AdaHGConv(embed_dim=128, num_hyperedges=16, num_heads=8)
|
|
|
+ >>> x = torch.randn(2, 256, 128) # (Batch, Num_Nodes, Dim)
|
|
|
+ >>> output = model(x)
|
|
|
+ >>> print(output.shape)
|
|
|
+ torch.Size([2, 256, 128])
|
|
|
+ """
|
|
|
def __init__(self, embed_dim, num_hyperedges=16, num_heads=4, dropout=0.1, context="both"):
|
|
|
super().__init__()
|
|
|
self.edge_generator = AdaHyperedgeGen(embed_dim, num_hyperedges, num_heads, dropout, context)
|
|
@@ -1534,6 +1675,31 @@ class AdaHGConv(nn.Module):
|
|
|
return X_new + X
|
|
|
|
|
|
class AdaHGComputation(nn.Module):
|
|
|
+ """
|
|
|
+ A wrapper module for applying adaptive hypergraph convolution to 4D feature maps.
|
|
|
+
|
|
|
+ This class makes the hypergraph convolution compatible with standard CNN architectures. It flattens a
|
|
|
+ 4D input tensor (B, C, H, W) into a sequence of vertices (tokens), applies the AdaHGConv layer to
|
|
|
+ model high-order correlations, and then reshapes the output back into a 4D tensor.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ embed_dim (int): The feature dimension of the vertices (equivalent to input channels C).
|
|
|
+ num_hyperedges (int, optional): The number of hyperedges for the underlying AdaHGConv. Defaults to 16.
|
|
|
+ num_heads (int, optional): The number of attention heads for the underlying AdaHGConv. Defaults to 8.
|
|
|
+ dropout (float, optional): The dropout rate for the underlying AdaHGConv. Defaults to 0.1.
|
|
|
+ context (str, optional): The context type for the underlying AdaHGConv. Defaults to "both".
|
|
|
+
|
|
|
+ Methods:
|
|
|
+ forward: Processes a 4D feature map through the adaptive hypergraph computation layer.
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ >>> import torch
|
|
|
+ >>> model = AdaHGComputation(embed_dim=64, num_hyperedges=8, num_heads=4)
|
|
|
+ >>> x = torch.randn(2, 64, 32, 32) # (B, C, H, W)
|
|
|
+ >>> output = model(x)
|
|
|
+ >>> print(output.shape)
|
|
|
+ torch.Size([2, 64, 32, 32])
|
|
|
+ """
|
|
|
def __init__(self, embed_dim, num_hyperedges=16, num_heads=8, dropout=0.1, context="both"):
|
|
|
super().__init__()
|
|
|
self.embed_dim = embed_dim
|
|
@@ -1553,6 +1719,31 @@ class AdaHGComputation(nn.Module):
|
|
|
return x_out
|
|
|
|
|
|
class C3AH(nn.Module):
|
|
|
+ """
|
|
|
+ A CSP-style block integrating Adaptive Hypergraph Computation (C3AH).
|
|
|
+
|
|
|
+ The input feature map is split into two paths.
|
|
|
+ One path is processed by the AdaHGComputation module to model high-order correlations, while the other
|
|
|
+ serves as a shortcut. The outputs are then concatenated to fuse features.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ c1 (int): Number of input channels.
|
|
|
+ c2 (int): Number of output channels.
|
|
|
+ e (float, optional): Expansion ratio for the hidden channels. Defaults to 1.0.
|
|
|
+ num_hyperedges (int, optional): The number of hyperedges for the internal AdaHGComputation. Defaults to 8.
|
|
|
+ context (str, optional): The context type for the internal AdaHGComputation. Defaults to "both".
|
|
|
+
|
|
|
+ Methods:
|
|
|
+ forward: Performs a forward pass through the C3AH module.
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ >>> import torch
|
|
|
+ >>> model = C3AH(c1=64, c2=128, num_hyperedges=8)
|
|
|
+ >>> x = torch.randn(2, 64, 32, 32)
|
|
|
+ >>> output = model(x)
|
|
|
+ >>> print(output.shape)
|
|
|
+ torch.Size([2, 128, 32, 32])
|
|
|
+ """
|
|
|
def __init__(self, c1, c2, e=1.0, num_hyperedges=8, context="both"):
|
|
|
super().__init__()
|
|
|
c_ = int(c2 * e)
|
|
@@ -1571,6 +1762,29 @@ class C3AH(nn.Module):
|
|
|
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
|
|
|
|
|
class FuseModule(nn.Module):
|
|
|
+ """
|
|
|
+ A module to fuse multi-scale features for the HyperACE block.
|
|
|
+
|
|
|
+ This module takes a list of three feature maps from different scales, aligns them to a common
|
|
|
+ spatial resolution by downsampling the first and upsampling the third, and then concatenates
|
|
|
+ and fuses them with a convolution layer.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ c_in (int): The number of channels of the input feature maps.
|
|
|
+ channel_adjust (bool): Whether to adjust the channel count of the concatenated features.
|
|
|
+
|
|
|
+ Methods:
|
|
|
+ forward: Fuses a list of three multi-scale feature maps.
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ >>> import torch
|
|
|
+ >>> model = FuseModule(c_in=64, channel_adjust=False)
|
|
|
+ >>> # Input is a list of features from different backbone stages
|
|
|
+ >>> x_list = [torch.randn(2, 64, 64, 64), torch.randn(2, 64, 32, 32), torch.randn(2, 64, 16, 16)]
|
|
|
+ >>> output = model(x_list)
|
|
|
+ >>> print(output.shape)
|
|
|
+ torch.Size([2, 64, 32, 32])
|
|
|
+ """
|
|
|
def __init__(self, c_in, channel_adjust):
|
|
|
super(FuseModule, self).__init__()
|
|
|
self.downsample = nn.AvgPool2d(kernel_size=2)
|
|
@@ -1588,6 +1802,37 @@ class FuseModule(nn.Module):
|
|
|
return out
|
|
|
|
|
|
class HyperACE(nn.Module):
|
|
|
+ """
|
|
|
+ Hypergraph-based Adaptive Correlation Enhancement (HyperACE).
|
|
|
+
|
|
|
+ This is the core module of YOLOv13, designed to model both global high-order correlations and
|
|
|
+ local low-order correlations. It first fuses multi-scale features, then processes them through parallel
|
|
|
+ branches: two C3AH branches for high-order modeling and a lightweight DSConv-based branch for
|
|
|
+ low-order feature extraction.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ c1 (int): Number of input channels for the fuse module.
|
|
|
+ c2 (int): Number of output channels for the entire block.
|
|
|
+ n (int, optional): Number of blocks in the low-order branch. Defaults to 1.
|
|
|
+ num_hyperedges (int, optional): Number of hyperedges for the C3AH branches. Defaults to 8.
|
|
|
+ dsc3k (bool, optional): If True, use DSC3k in the low-order branch; otherwise, use DSBottleneck. Defaults to True.
|
|
|
+ shortcut (bool, optional): Whether to use shortcuts in the low-order branch. Defaults to False.
|
|
|
+ e1 (float, optional): Expansion ratio for the main hidden channels. Defaults to 0.5.
|
|
|
+ e2 (float, optional): Expansion ratio within the C3AH branches. Defaults to 1.
|
|
|
+ context (str, optional): Context type for C3AH branches. Defaults to "both".
|
|
|
+ channel_adjust (bool, optional): Passed to FuseModule for channel configuration. Defaults to True.
|
|
|
+
|
|
|
+ Methods:
|
|
|
+ forward: Performs a forward pass through the HyperACE module.
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ >>> import torch
|
|
|
+ >>> model = HyperACE(c1=64, c2=256, n=1, num_hyperedges=8)
|
|
|
+ >>> x_list = [torch.randn(2, 64, 64, 64), torch.randn(2, 64, 32, 32), torch.randn(2, 64, 16, 16)]
|
|
|
+ >>> output = model(x_list)
|
|
|
+ >>> print(output.shape)
|
|
|
+ torch.Size([2, 256, 32, 32])
|
|
|
+ """
|
|
|
def __init__(self, c1, c2, n=1, num_hyperedges=8, dsc3k=True, shortcut=False, e1=0.5, e2=1, context="both", channel_adjust=True):
|
|
|
super().__init__()
|
|
|
self.c = int(c2 * e1)
|
|
@@ -1611,6 +1856,27 @@ class HyperACE(nn.Module):
|
|
|
return self.cv2(torch.cat(y, 1))
|
|
|
|
|
|
class DownsampleConv(nn.Module):
|
|
|
+ """
|
|
|
+ A simple downsampling block with optional channel adjustment.
|
|
|
+
|
|
|
+ This module uses average pooling to reduce the spatial dimensions (H, W) by a factor of 2. It can
|
|
|
+ optionally include a 1x1 convolution to adjust the number of channels, typically doubling them.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ in_channels (int): The number of input channels.
|
|
|
+ channel_adjust (bool, optional): If True, a 1x1 convolution doubles the channel dimension. Defaults to True.
|
|
|
+
|
|
|
+ Methods:
|
|
|
+ forward: Performs the downsampling and optional channel adjustment.
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ >>> import torch
|
|
|
+ >>> model = DownsampleConv(in_channels=64, channel_adjust=True)
|
|
|
+ >>> x = torch.randn(2, 64, 32, 32)
|
|
|
+ >>> output = model(x)
|
|
|
+ >>> print(output.shape)
|
|
|
+ torch.Size([2, 128, 16, 16])
|
|
|
+ """
|
|
|
def __init__(self, in_channels, channel_adjust=True):
|
|
|
super().__init__()
|
|
|
self.downsample = nn.AvgPool2d(kernel_size=2)
|
|
@@ -1623,6 +1889,25 @@ class DownsampleConv(nn.Module):
|
|
|
return self.channel_adjust(self.downsample(x))
|
|
|
|
|
|
class FullPAD_Tunnel(nn.Module):
|
|
|
+ """
|
|
|
+ A gated fusion module for the Full-Pipeline Aggregation-and-Distribution (FullPAD) paradigm.
|
|
|
+
|
|
|
+ This module implements a gated residual connection used to fuse features. It takes two inputs: the original
|
|
|
+ feature map and a correlation-enhanced feature map. It then computes `output = original + gate * enhanced`,
|
|
|
+ where `gate` is a learnable scalar parameter that adaptively balances the contribution of the enhanced features.
|
|
|
+
|
|
|
+ Methods:
|
|
|
+ forward: Performs the gated fusion of two input feature maps.
|
|
|
+
|
|
|
+ Examples:
|
|
|
+ >>> import torch
|
|
|
+ >>> model = FullPAD_Tunnel()
|
|
|
+ >>> original_feature = torch.randn(2, 64, 32, 32)
|
|
|
+ >>> enhanced_feature = torch.randn(2, 64, 32, 32)
|
|
|
+ >>> output = model([original_feature, enhanced_feature])
|
|
|
+ >>> print(output.shape)
|
|
|
+ torch.Size([2, 64, 32, 32])
|
|
|
+ """
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.gate = nn.Parameter(torch.tensor(0.0))
|