memory_attention.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import copy
  3. from typing import Optional
  4. import torch
  5. from torch import Tensor, nn
  6. from .blocks import RoPEAttention
  7. class MemoryAttentionLayer(nn.Module):
  8. """
  9. Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
  10. This class combines self-attention, cross-attention, and feedforward components to process input tensors and
  11. generate memory-based attention outputs.
  12. Attributes:
  13. d_model (int): Dimensionality of the model.
  14. dim_feedforward (int): Dimensionality of the feedforward network.
  15. dropout_value (float): Dropout rate for regularization.
  16. self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).
  17. cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.
  18. linear1 (nn.Linear): First linear layer of the feedforward network.
  19. linear2 (nn.Linear): Second linear layer of the feedforward network.
  20. norm1 (nn.LayerNorm): Layer normalization for self-attention output.
  21. norm2 (nn.LayerNorm): Layer normalization for cross-attention output.
  22. norm3 (nn.LayerNorm): Layer normalization for feedforward network output.
  23. dropout1 (nn.Dropout): Dropout layer after self-attention.
  24. dropout2 (nn.Dropout): Dropout layer after cross-attention.
  25. dropout3 (nn.Dropout): Dropout layer after feedforward network.
  26. activation (nn.ReLU): Activation function for the feedforward network.
  27. pos_enc_at_attn (bool): Flag to add positional encoding at attention.
  28. pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.
  29. pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.
  30. Methods:
  31. forward: Performs the full memory attention operation on input tensors.
  32. _forward_sa: Performs self-attention on input tensor.
  33. _forward_ca: Performs cross-attention between target and memory tensors.
  34. Examples:
  35. >>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)
  36. >>> tgt = torch.randn(1, 100, 256)
  37. >>> memory = torch.randn(1, 100, 64)
  38. >>> pos = torch.randn(1, 100, 256)
  39. >>> query_pos = torch.randn(1, 100, 256)
  40. >>> output = layer(tgt, memory, pos, query_pos)
  41. >>> print(output.shape)
  42. torch.Size([1, 100, 256])
  43. """
  44. def __init__(
  45. self,
  46. d_model: int = 256,
  47. dim_feedforward: int = 2048,
  48. dropout: float = 0.1,
  49. pos_enc_at_attn: bool = False,
  50. pos_enc_at_cross_attn_keys: bool = True,
  51. pos_enc_at_cross_attn_queries: bool = False,
  52. ):
  53. """Initializes a memory attention layer with self-attention, cross-attention, and feedforward components."""
  54. super().__init__()
  55. self.d_model = d_model
  56. self.dim_feedforward = dim_feedforward
  57. self.dropout_value = dropout
  58. self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
  59. self.cross_attn_image = RoPEAttention(
  60. rope_k_repeat=True,
  61. embedding_dim=256,
  62. num_heads=1,
  63. downsample_rate=1,
  64. kv_in_dim=64,
  65. )
  66. # Implementation of Feedforward model
  67. self.linear1 = nn.Linear(d_model, dim_feedforward)
  68. self.dropout = nn.Dropout(dropout)
  69. self.linear2 = nn.Linear(dim_feedforward, d_model)
  70. self.norm1 = nn.LayerNorm(d_model)
  71. self.norm2 = nn.LayerNorm(d_model)
  72. self.norm3 = nn.LayerNorm(d_model)
  73. self.dropout1 = nn.Dropout(dropout)
  74. self.dropout2 = nn.Dropout(dropout)
  75. self.dropout3 = nn.Dropout(dropout)
  76. self.activation = nn.ReLU()
  77. # Where to add pos enc
  78. self.pos_enc_at_attn = pos_enc_at_attn
  79. self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
  80. self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
  81. def _forward_sa(self, tgt, query_pos):
  82. """Performs self-attention on input tensor using positional encoding and RoPE attention mechanism."""
  83. tgt2 = self.norm1(tgt)
  84. q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
  85. tgt2 = self.self_attn(q, k, v=tgt2)
  86. tgt = tgt + self.dropout1(tgt2)
  87. return tgt
  88. def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
  89. """Performs cross-attention between target and memory tensors using RoPEAttention mechanism."""
  90. kwds = {}
  91. if num_k_exclude_rope > 0:
  92. assert isinstance(self.cross_attn_image, RoPEAttention)
  93. kwds = {"num_k_exclude_rope": num_k_exclude_rope}
  94. # Cross-Attention
  95. tgt2 = self.norm2(tgt)
  96. tgt2 = self.cross_attn_image(
  97. q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
  98. k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
  99. v=memory,
  100. **kwds,
  101. )
  102. tgt = tgt + self.dropout2(tgt2)
  103. return tgt
  104. def forward(
  105. self,
  106. tgt,
  107. memory,
  108. pos: Optional[Tensor] = None,
  109. query_pos: Optional[Tensor] = None,
  110. num_k_exclude_rope: int = 0,
  111. ) -> torch.Tensor:
  112. """Processes input tensors using self-attention, cross-attention, and MLP for memory-based attention."""
  113. tgt = self._forward_sa(tgt, query_pos)
  114. tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
  115. # MLP
  116. tgt2 = self.norm3(tgt)
  117. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
  118. tgt = tgt + self.dropout3(tgt2)
  119. return tgt
  120. class MemoryAttention(nn.Module):
  121. """
  122. Memory attention module for processing sequential data with self and cross-attention mechanisms.
  123. This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
  124. for processing sequential data, particularly useful in transformer-like architectures.
  125. Attributes:
  126. d_model (int): The dimension of the model's hidden state.
  127. layers (nn.ModuleList): A list of MemoryAttentionLayer modules.
  128. num_layers (int): The number of attention layers.
  129. norm (nn.LayerNorm): Layer normalization applied to the output.
  130. pos_enc_at_input (bool): Whether to apply positional encoding at the input.
  131. batch_first (bool): Whether the input tensors are in batch-first format.
  132. Methods:
  133. forward: Processes input tensors through the attention layers.
  134. Examples:
  135. >>> d_model = 256
  136. >>> layer = MemoryAttentionLayer(d_model)
  137. >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
  138. >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
  139. >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
  140. >>> curr_pos = torch.randn(10, 32, d_model)
  141. >>> memory_pos = torch.randn(20, 32, d_model)
  142. >>> output = attention(curr, memory, curr_pos, memory_pos)
  143. >>> print(output.shape)
  144. torch.Size([10, 32, 256])
  145. """
  146. def __init__(
  147. self,
  148. d_model: int,
  149. pos_enc_at_input: bool,
  150. layer: nn.Module,
  151. num_layers: int,
  152. batch_first: bool = True, # Do layers expect batch first input?
  153. ):
  154. """Initializes MemoryAttention module with layers and normalization for attention processing."""
  155. super().__init__()
  156. self.d_model = d_model
  157. self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
  158. self.num_layers = num_layers
  159. self.norm = nn.LayerNorm(d_model)
  160. self.pos_enc_at_input = pos_enc_at_input
  161. self.batch_first = batch_first
  162. def forward(
  163. self,
  164. curr: torch.Tensor, # self-attention inputs
  165. memory: torch.Tensor, # cross-attention inputs
  166. curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
  167. memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
  168. num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
  169. ):
  170. """Processes input tensors through multiple attention layers, applying self and cross-attention mechanisms."""
  171. if isinstance(curr, list):
  172. assert isinstance(curr_pos, list)
  173. assert len(curr) == len(curr_pos) == 1
  174. curr, curr_pos = (
  175. curr[0],
  176. curr_pos[0],
  177. )
  178. assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"
  179. output = curr
  180. if self.pos_enc_at_input and curr_pos is not None:
  181. output = output + 0.1 * curr_pos
  182. if self.batch_first:
  183. # Convert to batch first
  184. output = output.transpose(0, 1)
  185. curr_pos = curr_pos.transpose(0, 1)
  186. memory = memory.transpose(0, 1)
  187. memory_pos = memory_pos.transpose(0, 1)
  188. for layer in self.layers:
  189. kwds = {}
  190. if isinstance(layer.cross_attn_image, RoPEAttention):
  191. kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
  192. output = layer(
  193. tgt=output,
  194. memory=memory,
  195. pos=memory_pos,
  196. query_pos=curr_pos,
  197. **kwds,
  198. )
  199. normed_output = self.norm(output)
  200. if self.batch_first:
  201. # Convert back to seq first
  202. normed_output = normed_output.transpose(0, 1)
  203. curr_pos = curr_pos.transpose(0, 1)
  204. return normed_output