transformer.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import math
  3. from typing import Tuple, Type
  4. import torch
  5. from torch import Tensor, nn
  6. from ultralytics.nn.modules import MLPBlock
  7. class TwoWayTransformer(nn.Module):
  8. """
  9. A Two-Way Transformer module for simultaneous attention to image and query points.
  10. This class implements a specialized transformer decoder that attends to an input image using queries with
  11. supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
  12. cloud processing.
  13. Attributes:
  14. depth (int): Number of layers in the transformer.
  15. embedding_dim (int): Channel dimension for input embeddings.
  16. num_heads (int): Number of heads for multihead attention.
  17. mlp_dim (int): Internal channel dimension for the MLP block.
  18. layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
  19. final_attn_token_to_image (Attention): Final attention layer from queries to image.
  20. norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
  21. Methods:
  22. forward: Processes image and point embeddings through the transformer.
  23. Examples:
  24. >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
  25. >>> image_embedding = torch.randn(1, 256, 32, 32)
  26. >>> image_pe = torch.randn(1, 256, 32, 32)
  27. >>> point_embedding = torch.randn(1, 100, 256)
  28. >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
  29. >>> print(output_queries.shape, output_image.shape)
  30. """
  31. def __init__(
  32. self,
  33. depth: int,
  34. embedding_dim: int,
  35. num_heads: int,
  36. mlp_dim: int,
  37. activation: Type[nn.Module] = nn.ReLU,
  38. attention_downsample_rate: int = 2,
  39. ) -> None:
  40. """
  41. Initialize a Two-Way Transformer for simultaneous attention to image and query points.
  42. Args:
  43. depth (int): Number of layers in the transformer.
  44. embedding_dim (int): Channel dimension for input embeddings.
  45. num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
  46. mlp_dim (int): Internal channel dimension for the MLP block.
  47. activation (Type[nn.Module]): Activation function to use in the MLP block.
  48. attention_downsample_rate (int): Downsampling rate for attention mechanism.
  49. Attributes:
  50. depth (int): Number of layers in the transformer.
  51. embedding_dim (int): Channel dimension for input embeddings.
  52. num_heads (int): Number of heads for multihead attention.
  53. mlp_dim (int): Internal channel dimension for the MLP block.
  54. layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
  55. final_attn_token_to_image (Attention): Final attention layer from queries to image.
  56. norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
  57. Examples:
  58. >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
  59. >>> image_embedding = torch.randn(1, 256, 32, 32)
  60. >>> image_pe = torch.randn(1, 256, 32, 32)
  61. >>> point_embedding = torch.randn(1, 100, 256)
  62. >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
  63. >>> print(output_queries.shape, output_image.shape)
  64. """
  65. super().__init__()
  66. self.depth = depth
  67. self.embedding_dim = embedding_dim
  68. self.num_heads = num_heads
  69. self.mlp_dim = mlp_dim
  70. self.layers = nn.ModuleList()
  71. for i in range(depth):
  72. self.layers.append(
  73. TwoWayAttentionBlock(
  74. embedding_dim=embedding_dim,
  75. num_heads=num_heads,
  76. mlp_dim=mlp_dim,
  77. activation=activation,
  78. attention_downsample_rate=attention_downsample_rate,
  79. skip_first_layer_pe=(i == 0),
  80. )
  81. )
  82. self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
  83. self.norm_final_attn = nn.LayerNorm(embedding_dim)
  84. def forward(
  85. self,
  86. image_embedding: Tensor,
  87. image_pe: Tensor,
  88. point_embedding: Tensor,
  89. ) -> Tuple[Tensor, Tensor]:
  90. """
  91. Processes image and point embeddings through the Two-Way Transformer.
  92. Args:
  93. image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
  94. image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
  95. point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
  96. Returns:
  97. (Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
  98. Examples:
  99. >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
  100. >>> image_embedding = torch.randn(1, 256, 32, 32)
  101. >>> image_pe = torch.randn(1, 256, 32, 32)
  102. >>> point_embedding = torch.randn(1, 100, 256)
  103. >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
  104. >>> print(output_queries.shape, output_image.shape)
  105. """
  106. # BxCxHxW -> BxHWxC == B x N_image_tokens x C
  107. image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
  108. image_pe = image_pe.flatten(2).permute(0, 2, 1)
  109. # Prepare queries
  110. queries = point_embedding
  111. keys = image_embedding
  112. # Apply transformer blocks and final layernorm
  113. for layer in self.layers:
  114. queries, keys = layer(
  115. queries=queries,
  116. keys=keys,
  117. query_pe=point_embedding,
  118. key_pe=image_pe,
  119. )
  120. # Apply the final attention layer from the points to the image
  121. q = queries + point_embedding
  122. k = keys + image_pe
  123. attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
  124. queries = queries + attn_out
  125. queries = self.norm_final_attn(queries)
  126. return queries, keys
  127. class TwoWayAttentionBlock(nn.Module):
  128. """
  129. A two-way attention block for simultaneous attention to image and query points.
  130. This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
  131. cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
  132. inputs to sparse inputs.
  133. Attributes:
  134. self_attn (Attention): Self-attention layer for queries.
  135. norm1 (nn.LayerNorm): Layer normalization after self-attention.
  136. cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
  137. norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
  138. mlp (MLPBlock): MLP block for transforming query embeddings.
  139. norm3 (nn.LayerNorm): Layer normalization after MLP block.
  140. norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
  141. cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
  142. skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
  143. Methods:
  144. forward: Applies self-attention and cross-attention to queries and keys.
  145. Examples:
  146. >>> embedding_dim, num_heads = 256, 8
  147. >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
  148. >>> queries = torch.randn(1, 100, embedding_dim)
  149. >>> keys = torch.randn(1, 1000, embedding_dim)
  150. >>> query_pe = torch.randn(1, 100, embedding_dim)
  151. >>> key_pe = torch.randn(1, 1000, embedding_dim)
  152. >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
  153. """
  154. def __init__(
  155. self,
  156. embedding_dim: int,
  157. num_heads: int,
  158. mlp_dim: int = 2048,
  159. activation: Type[nn.Module] = nn.ReLU,
  160. attention_downsample_rate: int = 2,
  161. skip_first_layer_pe: bool = False,
  162. ) -> None:
  163. """
  164. Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
  165. This block implements a specialized transformer layer with four main components: self-attention on sparse
  166. inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
  167. of dense inputs to sparse inputs.
  168. Args:
  169. embedding_dim (int): Channel dimension of the embeddings.
  170. num_heads (int): Number of attention heads in the attention layers.
  171. mlp_dim (int): Hidden dimension of the MLP block.
  172. activation (Type[nn.Module]): Activation function for the MLP block.
  173. attention_downsample_rate (int): Downsampling rate for the attention mechanism.
  174. skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
  175. Examples:
  176. >>> embedding_dim, num_heads = 256, 8
  177. >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
  178. >>> queries = torch.randn(1, 100, embedding_dim)
  179. >>> keys = torch.randn(1, 1000, embedding_dim)
  180. >>> query_pe = torch.randn(1, 100, embedding_dim)
  181. >>> key_pe = torch.randn(1, 1000, embedding_dim)
  182. >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
  183. """
  184. super().__init__()
  185. self.self_attn = Attention(embedding_dim, num_heads)
  186. self.norm1 = nn.LayerNorm(embedding_dim)
  187. self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
  188. self.norm2 = nn.LayerNorm(embedding_dim)
  189. self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
  190. self.norm3 = nn.LayerNorm(embedding_dim)
  191. self.norm4 = nn.LayerNorm(embedding_dim)
  192. self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
  193. self.skip_first_layer_pe = skip_first_layer_pe
  194. def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
  195. """Applies two-way attention to process query and key embeddings in a transformer block."""
  196. # Self attention block
  197. if self.skip_first_layer_pe:
  198. queries = self.self_attn(q=queries, k=queries, v=queries)
  199. else:
  200. q = queries + query_pe
  201. attn_out = self.self_attn(q=q, k=q, v=queries)
  202. queries = queries + attn_out
  203. queries = self.norm1(queries)
  204. # Cross attention block, tokens attending to image embedding
  205. q = queries + query_pe
  206. k = keys + key_pe
  207. attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
  208. queries = queries + attn_out
  209. queries = self.norm2(queries)
  210. # MLP block
  211. mlp_out = self.mlp(queries)
  212. queries = queries + mlp_out
  213. queries = self.norm3(queries)
  214. # Cross attention block, image embedding attending to tokens
  215. q = queries + query_pe
  216. k = keys + key_pe
  217. attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
  218. keys = keys + attn_out
  219. keys = self.norm4(keys)
  220. return queries, keys
  221. class Attention(nn.Module):
  222. """
  223. An attention layer with downscaling capability for embedding size after projection.
  224. This class implements a multi-head attention mechanism with the option to downsample the internal
  225. dimension of queries, keys, and values.
  226. Attributes:
  227. embedding_dim (int): Dimensionality of input embeddings.
  228. kv_in_dim (int): Dimensionality of key and value inputs.
  229. internal_dim (int): Internal dimension after downsampling.
  230. num_heads (int): Number of attention heads.
  231. q_proj (nn.Linear): Linear projection for queries.
  232. k_proj (nn.Linear): Linear projection for keys.
  233. v_proj (nn.Linear): Linear projection for values.
  234. out_proj (nn.Linear): Linear projection for output.
  235. Methods:
  236. _separate_heads: Separates input tensor into attention heads.
  237. _recombine_heads: Recombines separated attention heads.
  238. forward: Computes attention output for given query, key, and value tensors.
  239. Examples:
  240. >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
  241. >>> q = torch.randn(1, 100, 256)
  242. >>> k = v = torch.randn(1, 50, 256)
  243. >>> output = attn(q, k, v)
  244. >>> print(output.shape)
  245. torch.Size([1, 100, 256])
  246. """
  247. def __init__(
  248. self,
  249. embedding_dim: int,
  250. num_heads: int,
  251. downsample_rate: int = 1,
  252. kv_in_dim: int = None,
  253. ) -> None:
  254. """
  255. Initializes the Attention module with specified dimensions and settings.
  256. This class implements a multi-head attention mechanism with optional downsampling of the internal
  257. dimension for queries, keys, and values.
  258. Args:
  259. embedding_dim (int): Dimensionality of input embeddings.
  260. num_heads (int): Number of attention heads.
  261. downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
  262. kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
  263. Raises:
  264. AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
  265. Examples:
  266. >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
  267. >>> q = torch.randn(1, 100, 256)
  268. >>> k = v = torch.randn(1, 50, 256)
  269. >>> output = attn(q, k, v)
  270. >>> print(output.shape)
  271. torch.Size([1, 100, 256])
  272. """
  273. super().__init__()
  274. self.embedding_dim = embedding_dim
  275. self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
  276. self.internal_dim = embedding_dim // downsample_rate
  277. self.num_heads = num_heads
  278. assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
  279. self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
  280. self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
  281. self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
  282. self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
  283. @staticmethod
  284. def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
  285. """Separates the input tensor into the specified number of attention heads."""
  286. b, n, c = x.shape
  287. x = x.reshape(b, n, num_heads, c // num_heads)
  288. return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
  289. @staticmethod
  290. def _recombine_heads(x: Tensor) -> Tensor:
  291. """Recombines separated attention heads into a single tensor."""
  292. b, n_heads, n_tokens, c_per_head = x.shape
  293. x = x.transpose(1, 2)
  294. return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
  295. def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
  296. """Applies multi-head attention to query, key, and value tensors with optional downsampling."""
  297. # Input projections
  298. q = self.q_proj(q)
  299. k = self.k_proj(k)
  300. v = self.v_proj(v)
  301. # Separate into heads
  302. q = self._separate_heads(q, self.num_heads)
  303. k = self._separate_heads(k, self.num_heads)
  304. v = self._separate_heads(v, self.num_heads)
  305. # Attention
  306. _, _, _, c_per_head = q.shape
  307. attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
  308. attn = attn / math.sqrt(c_per_head)
  309. attn = torch.softmax(attn, dim=-1)
  310. # Get output
  311. out = attn @ v
  312. out = self._recombine_heads(out)
  313. return self.out_proj(out)