blocks.py 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import copy
  3. import math
  4. from functools import partial
  5. from typing import Any, Optional, Tuple, Type, Union
  6. import numpy as np
  7. import torch
  8. import torch.nn.functional as F
  9. from torch import Tensor, nn
  10. from ultralytics.nn.modules import MLP, LayerNorm2d, MLPBlock
  11. from .transformer import Attention, TwoWayAttentionBlock, TwoWayTransformer
  12. from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition
  13. class DropPath(nn.Module):
  14. """
  15. Implements stochastic depth regularization for neural networks during training.
  16. Attributes:
  17. drop_prob (float): Probability of dropping a path during training.
  18. scale_by_keep (bool): Whether to scale the output by the keep probability.
  19. Methods:
  20. forward: Applies stochastic depth to input tensor during training, with optional scaling.
  21. Examples:
  22. >>> drop_path = DropPath(drop_prob=0.2, scale_by_keep=True)
  23. >>> x = torch.randn(32, 64, 224, 224)
  24. >>> output = drop_path(x)
  25. """
  26. def __init__(self, drop_prob=0.0, scale_by_keep=True):
  27. """Initialize DropPath module for stochastic depth regularization during training."""
  28. super().__init__()
  29. self.drop_prob = drop_prob
  30. self.scale_by_keep = scale_by_keep
  31. def forward(self, x):
  32. """Applies stochastic depth to input tensor during training, with optional scaling."""
  33. if self.drop_prob == 0.0 or not self.training:
  34. return x
  35. keep_prob = 1 - self.drop_prob
  36. shape = (x.shape[0],) + (1,) * (x.ndim - 1)
  37. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  38. if keep_prob > 0.0 and self.scale_by_keep:
  39. random_tensor.div_(keep_prob)
  40. return x * random_tensor
  41. class MaskDownSampler(nn.Module):
  42. """
  43. A mask downsampling and embedding module for efficient processing of input masks.
  44. This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks
  45. while expanding their channel dimensions using convolutional layers, layer normalization, and activation
  46. functions.
  47. Attributes:
  48. encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and
  49. activation functions for downsampling and embedding masks.
  50. Methods:
  51. forward: Downsamples and encodes input mask to embed_dim channels.
  52. Examples:
  53. >>> mask_downsampler = MaskDownSampler(embed_dim=256, kernel_size=4, stride=4, padding=0, total_stride=16)
  54. >>> input_mask = torch.randn(1, 1, 256, 256)
  55. >>> output = mask_downsampler(input_mask)
  56. >>> print(output.shape)
  57. torch.Size([1, 256, 16, 16])
  58. """
  59. def __init__(
  60. self,
  61. embed_dim=256,
  62. kernel_size=4,
  63. stride=4,
  64. padding=0,
  65. total_stride=16,
  66. activation=nn.GELU,
  67. ):
  68. """Initializes a mask downsampler module for progressive downsampling and channel expansion."""
  69. super().__init__()
  70. num_layers = int(math.log2(total_stride) // math.log2(stride))
  71. assert stride**num_layers == total_stride
  72. self.encoder = nn.Sequential()
  73. mask_in_chans, mask_out_chans = 1, 1
  74. for _ in range(num_layers):
  75. mask_out_chans = mask_in_chans * (stride**2)
  76. self.encoder.append(
  77. nn.Conv2d(
  78. mask_in_chans,
  79. mask_out_chans,
  80. kernel_size=kernel_size,
  81. stride=stride,
  82. padding=padding,
  83. )
  84. )
  85. self.encoder.append(LayerNorm2d(mask_out_chans))
  86. self.encoder.append(activation())
  87. mask_in_chans = mask_out_chans
  88. self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
  89. def forward(self, x):
  90. """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
  91. return self.encoder(x)
  92. class CXBlock(nn.Module):
  93. """
  94. ConvNeXt Block for efficient feature extraction in convolutional neural networks.
  95. This block implements a modified version of the ConvNeXt architecture, offering improved performance and
  96. flexibility in feature extraction.
  97. Attributes:
  98. dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
  99. norm (LayerNorm2d): Layer normalization applied to channels.
  100. pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
  101. act (nn.GELU): GELU activation function.
  102. pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
  103. gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.
  104. drop_path (nn.Module): DropPath layer for stochastic depth regularization.
  105. Methods:
  106. forward: Processes the input tensor through the ConvNeXt block.
  107. Examples:
  108. >>> import torch
  109. >>> x = torch.randn(1, 64, 56, 56)
  110. >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
  111. >>> output = block(x)
  112. >>> print(output.shape)
  113. torch.Size([1, 64, 56, 56])
  114. """
  115. def __init__(
  116. self,
  117. dim,
  118. kernel_size=7,
  119. padding=3,
  120. drop_path=0.0,
  121. layer_scale_init_value=1e-6,
  122. use_dwconv=True,
  123. ):
  124. """
  125. Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
  126. This block implements a modified version of the ConvNeXt architecture, offering improved performance and
  127. flexibility in feature extraction.
  128. Args:
  129. dim (int): Number of input channels.
  130. kernel_size (int): Size of the convolutional kernel.
  131. padding (int): Padding size for the convolution.
  132. drop_path (float): Stochastic depth rate.
  133. layer_scale_init_value (float): Initial value for Layer Scale.
  134. use_dwconv (bool): Whether to use depthwise convolution.
  135. Examples:
  136. >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
  137. >>> x = torch.randn(1, 64, 32, 32)
  138. >>> output = block(x)
  139. >>> print(output.shape)
  140. torch.Size([1, 64, 32, 32])
  141. """
  142. super().__init__()
  143. self.dwconv = nn.Conv2d(
  144. dim,
  145. dim,
  146. kernel_size=kernel_size,
  147. padding=padding,
  148. groups=dim if use_dwconv else 1,
  149. ) # depthwise conv
  150. self.norm = LayerNorm2d(dim, eps=1e-6)
  151. self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
  152. self.act = nn.GELU()
  153. self.pwconv2 = nn.Linear(4 * dim, dim)
  154. self.gamma = (
  155. nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
  156. if layer_scale_init_value > 0
  157. else None
  158. )
  159. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  160. def forward(self, x):
  161. """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection."""
  162. input = x
  163. x = self.dwconv(x)
  164. x = self.norm(x)
  165. x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
  166. x = self.pwconv1(x)
  167. x = self.act(x)
  168. x = self.pwconv2(x)
  169. if self.gamma is not None:
  170. x = self.gamma * x
  171. x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
  172. x = input + self.drop_path(x)
  173. return x
  174. class Fuser(nn.Module):
  175. """
  176. A module for fusing features through multiple layers of a neural network.
  177. This class applies a series of identical layers to an input tensor, optionally projecting the input first.
  178. Attributes:
  179. proj (nn.Module): An optional input projection layer. Identity if no projection is needed.
  180. layers (nn.ModuleList): A list of identical layers to be applied sequentially.
  181. Methods:
  182. forward: Applies the fuser to an input tensor.
  183. Examples:
  184. >>> layer = CXBlock(dim=256)
  185. >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)
  186. >>> x = torch.randn(1, 256, 32, 32)
  187. >>> output = fuser(x)
  188. >>> print(output.shape)
  189. torch.Size([1, 256, 32, 32])
  190. """
  191. def __init__(self, layer, num_layers, dim=None, input_projection=False):
  192. """
  193. Initializes the Fuser module for feature fusion through multiple layers.
  194. This module creates a sequence of identical layers and optionally applies an input projection.
  195. Args:
  196. layer (nn.Module): The layer to be replicated in the fuser.
  197. num_layers (int): The number of times to replicate the layer.
  198. dim (int | None): The dimension for input projection, if used.
  199. input_projection (bool): Whether to use input projection.
  200. Examples:
  201. >>> layer = nn.Linear(64, 64)
  202. >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)
  203. >>> input_tensor = torch.randn(1, 64)
  204. >>> output = fuser(input_tensor)
  205. """
  206. super().__init__()
  207. self.proj = nn.Identity()
  208. self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
  209. if input_projection:
  210. assert dim is not None
  211. self.proj = nn.Conv2d(dim, dim, kernel_size=1)
  212. def forward(self, x):
  213. """Applies a series of layers to the input tensor, optionally projecting it first."""
  214. x = self.proj(x)
  215. for layer in self.layers:
  216. x = layer(x)
  217. return x
  218. class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
  219. """
  220. A two-way attention block for performing self-attention and cross-attention in both directions.
  221. This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on
  222. sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and
  223. cross-attention from dense to sparse inputs.
  224. Attributes:
  225. self_attn (Attention): Self-attention layer for queries.
  226. norm1 (nn.LayerNorm): Layer normalization after the first attention block.
  227. cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
  228. norm2 (nn.LayerNorm): Layer normalization after the second attention block.
  229. mlp (MLP): MLP block for transforming query embeddings.
  230. norm3 (nn.LayerNorm): Layer normalization after the MLP block.
  231. norm4 (nn.LayerNorm): Layer normalization after the third attention block.
  232. cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
  233. skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.
  234. Methods:
  235. forward: Processes input through the attention blocks and MLP.
  236. Examples:
  237. >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8)
  238. >>> sparse_input = torch.randn(1, 100, 256)
  239. >>> dense_input = torch.randn(1, 256, 16, 16)
  240. >>> sparse_output, dense_output = block(sparse_input, dense_input)
  241. """
  242. def __init__(
  243. self,
  244. embedding_dim: int,
  245. num_heads: int,
  246. mlp_dim: int = 2048,
  247. activation: Type[nn.Module] = nn.ReLU,
  248. attention_downsample_rate: int = 2,
  249. skip_first_layer_pe: bool = False,
  250. ) -> None:
  251. """
  252. Initializes a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
  253. This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse
  254. inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention
  255. from dense to sparse inputs.
  256. Args:
  257. embedding_dim (int): The channel dimension of the embeddings.
  258. num_heads (int): The number of heads in the attention layers.
  259. mlp_dim (int): The hidden dimension of the MLP block.
  260. activation (Type[nn.Module]): The activation function of the MLP block.
  261. attention_downsample_rate (int): The downsample rate for attention computations.
  262. skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
  263. Examples:
  264. >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)
  265. >>> sparse_inputs = torch.randn(1, 100, 256)
  266. >>> dense_inputs = torch.randn(1, 256, 32, 32)
  267. >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)
  268. """
  269. super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)
  270. self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)
  271. class SAM2TwoWayTransformer(TwoWayTransformer):
  272. """
  273. A Two-Way Transformer module for simultaneous attention to image and query points.
  274. This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an
  275. input image using queries with supplied positional embeddings. It is particularly useful for tasks like
  276. object detection, image segmentation, and point cloud processing.
  277. Attributes:
  278. depth (int): Number of layers in the transformer.
  279. embedding_dim (int): Channel dimension for input embeddings.
  280. num_heads (int): Number of heads for multihead attention.
  281. mlp_dim (int): Internal channel dimension for the MLP block.
  282. layers (nn.ModuleList): List of SAM2TwoWayAttentionBlock layers comprising the transformer.
  283. final_attn_token_to_image (Attention): Final attention layer from queries to image.
  284. norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
  285. Methods:
  286. forward: Processes input image embeddings and query embeddings through the transformer.
  287. Examples:
  288. >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
  289. >>> image_embedding = torch.randn(1, 256, 64, 64)
  290. >>> query_embedding = torch.randn(1, 100, 256)
  291. >>> output = transformer(image_embedding, query_embedding)
  292. >>> print(output[0].shape, output[1].shape)
  293. torch.Size([1, 100, 256]) torch.Size([1, 256, 64, 64])
  294. """
  295. def __init__(
  296. self,
  297. depth: int,
  298. embedding_dim: int,
  299. num_heads: int,
  300. mlp_dim: int,
  301. activation: Type[nn.Module] = nn.ReLU,
  302. attention_downsample_rate: int = 2,
  303. ) -> None:
  304. """
  305. Initializes a SAM2TwoWayTransformer instance.
  306. This transformer decoder attends to an input image using queries with supplied positional embeddings.
  307. It is designed for tasks like object detection, image segmentation, and point cloud processing.
  308. Args:
  309. depth (int): Number of layers in the transformer.
  310. embedding_dim (int): Channel dimension for the input embeddings.
  311. num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
  312. mlp_dim (int): Channel dimension internal to the MLP block.
  313. activation (Type[nn.Module]): Activation function to use in the MLP block.
  314. attention_downsample_rate (int): Downsampling rate for attention computations.
  315. Examples:
  316. >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
  317. >>> transformer
  318. SAM2TwoWayTransformer(
  319. (layers): ModuleList(
  320. (0-4): 5 x SAM2TwoWayAttentionBlock(...)
  321. )
  322. (final_attn_token_to_image): Attention(...)
  323. (norm_final_attn): LayerNorm(...)
  324. )
  325. """
  326. super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)
  327. self.layers = nn.ModuleList()
  328. for i in range(depth):
  329. self.layers.append(
  330. SAM2TwoWayAttentionBlock(
  331. embedding_dim=embedding_dim,
  332. num_heads=num_heads,
  333. mlp_dim=mlp_dim,
  334. activation=activation,
  335. attention_downsample_rate=attention_downsample_rate,
  336. skip_first_layer_pe=(i == 0),
  337. )
  338. )
  339. class RoPEAttention(Attention):
  340. """
  341. Implements rotary position encoding for attention mechanisms in transformer architectures.
  342. This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance
  343. the positional awareness of the attention mechanism.
  344. Attributes:
  345. compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
  346. freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding.
  347. rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.
  348. Methods:
  349. forward: Applies rotary position encoding and computes attention between query, key, and value tensors.
  350. Examples:
  351. >>> rope_attn = RoPEAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))
  352. >>> q = torch.randn(1, 1024, 256)
  353. >>> k = torch.randn(1, 1024, 256)
  354. >>> v = torch.randn(1, 1024, 256)
  355. >>> output = rope_attn(q, k, v)
  356. >>> print(output.shape)
  357. torch.Size([1, 1024, 256])
  358. """
  359. def __init__(
  360. self,
  361. *args,
  362. rope_theta=10000.0,
  363. rope_k_repeat=False,
  364. feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
  365. **kwargs,
  366. ):
  367. """Initializes RoPEAttention with rotary position encoding for enhanced positional awareness."""
  368. super().__init__(*args, **kwargs)
  369. self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
  370. freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
  371. self.freqs_cis = freqs_cis
  372. self.rope_k_repeat = rope_k_repeat # repeat q rope to match k length, needed for cross-attention to memories
  373. def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
  374. """Applies rotary position encoding and computes attention between query, key, and value tensors."""
  375. q = self.q_proj(q)
  376. k = self.k_proj(k)
  377. v = self.v_proj(v)
  378. # Separate into heads
  379. q = self._separate_heads(q, self.num_heads)
  380. k = self._separate_heads(k, self.num_heads)
  381. v = self._separate_heads(v, self.num_heads)
  382. # Apply rotary position encoding
  383. w = h = math.sqrt(q.shape[-2])
  384. self.freqs_cis = self.freqs_cis.to(q.device)
  385. if self.freqs_cis.shape[0] != q.shape[-2]:
  386. self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
  387. if q.shape[-2] != k.shape[-2]:
  388. assert self.rope_k_repeat
  389. num_k_rope = k.size(-2) - num_k_exclude_rope
  390. q, k[:, :, :num_k_rope] = apply_rotary_enc(
  391. q,
  392. k[:, :, :num_k_rope],
  393. freqs_cis=self.freqs_cis,
  394. repeat_freqs_k=self.rope_k_repeat,
  395. )
  396. # Attention
  397. _, _, _, c_per_head = q.shape
  398. attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
  399. attn = attn / math.sqrt(c_per_head)
  400. attn = torch.softmax(attn, dim=-1)
  401. # Get output
  402. out = attn @ v
  403. out = self._recombine_heads(out)
  404. out = self.out_proj(out)
  405. return out
  406. def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
  407. """Applies pooling and optional normalization to a tensor, handling spatial dimension permutations."""
  408. if pool is None:
  409. return x
  410. # (B, H, W, C) -> (B, C, H, W)
  411. x = x.permute(0, 3, 1, 2)
  412. x = pool(x)
  413. # (B, C, H', W') -> (B, H', W', C)
  414. x = x.permute(0, 2, 3, 1)
  415. if norm:
  416. x = norm(x)
  417. return x
  418. class MultiScaleAttention(nn.Module):
  419. """
  420. Implements multiscale self-attention with optional query pooling for efficient feature extraction.
  421. This class provides a flexible implementation of multiscale attention, allowing for optional
  422. downsampling of query features through pooling. It's designed to enhance the model's ability to
  423. capture multiscale information in visual tasks.
  424. Attributes:
  425. dim (int): Input dimension of the feature map.
  426. dim_out (int): Output dimension of the attention module.
  427. num_heads (int): Number of attention heads.
  428. scale (float): Scaling factor for dot-product attention.
  429. q_pool (nn.Module | None): Optional pooling module for query features.
  430. qkv (nn.Linear): Linear projection for query, key, and value.
  431. proj (nn.Linear): Output projection.
  432. Methods:
  433. forward: Applies multiscale attention to the input tensor.
  434. Examples:
  435. >>> import torch
  436. >>> from torch import nn
  437. >>> x = torch.randn(1, 64, 64, 256)
  438. >>> msa = MultiScaleAttention(dim=256, dim_out=256, num_heads=8)
  439. >>> output = msa(x)
  440. >>> print(output.shape)
  441. torch.Size([1, 64, 64, 256])
  442. """
  443. def __init__(
  444. self,
  445. dim: int,
  446. dim_out: int,
  447. num_heads: int,
  448. q_pool: nn.Module = None,
  449. ):
  450. """Initializes multiscale attention with optional query pooling for efficient feature extraction."""
  451. super().__init__()
  452. self.dim = dim
  453. self.dim_out = dim_out
  454. self.num_heads = num_heads
  455. head_dim = dim_out // num_heads
  456. self.scale = head_dim**-0.5
  457. self.q_pool = q_pool
  458. self.qkv = nn.Linear(dim, dim_out * 3)
  459. self.proj = nn.Linear(dim_out, dim_out)
  460. def forward(self, x: torch.Tensor) -> torch.Tensor:
  461. """Applies multiscale attention with optional query pooling to extract multiscale features."""
  462. B, H, W, _ = x.shape
  463. # qkv with shape (B, H * W, 3, nHead, C)
  464. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
  465. # q, k, v with shape (B, H * W, nheads, C)
  466. q, k, v = torch.unbind(qkv, 2)
  467. # Q pooling (for downsample at stage changes)
  468. if self.q_pool:
  469. q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
  470. H, W = q.shape[1:3] # downsampled shape
  471. q = q.reshape(B, H * W, self.num_heads, -1)
  472. # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
  473. x = F.scaled_dot_product_attention(
  474. q.transpose(1, 2),
  475. k.transpose(1, 2),
  476. v.transpose(1, 2),
  477. )
  478. # Transpose back
  479. x = x.transpose(1, 2)
  480. x = x.reshape(B, H, W, -1)
  481. x = self.proj(x)
  482. return x
  483. class MultiScaleBlock(nn.Module):
  484. """
  485. A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
  486. This class implements a multiscale attention mechanism with optional window partitioning and downsampling,
  487. designed for use in vision transformer architectures.
  488. Attributes:
  489. dim (int): Input dimension of the block.
  490. dim_out (int): Output dimension of the block.
  491. norm1 (nn.Module): First normalization layer.
  492. window_size (int): Size of the window for partitioning.
  493. pool (nn.Module | None): Pooling layer for query downsampling.
  494. q_stride (Tuple[int, int] | None): Stride for query pooling.
  495. attn (MultiScaleAttention): Multi-scale attention module.
  496. drop_path (nn.Module): Drop path layer for regularization.
  497. norm2 (nn.Module): Second normalization layer.
  498. mlp (MLP): Multi-layer perceptron module.
  499. proj (nn.Linear | None): Projection layer for dimension mismatch.
  500. Methods:
  501. forward: Processes input tensor through the multiscale block.
  502. Examples:
  503. >>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7)
  504. >>> x = torch.randn(1, 56, 56, 256)
  505. >>> output = block(x)
  506. >>> print(output.shape)
  507. torch.Size([1, 28, 28, 512])
  508. """
  509. def __init__(
  510. self,
  511. dim: int,
  512. dim_out: int,
  513. num_heads: int,
  514. mlp_ratio: float = 4.0,
  515. drop_path: float = 0.0,
  516. norm_layer: Union[nn.Module, str] = "LayerNorm",
  517. q_stride: Tuple[int, int] = None,
  518. act_layer: nn.Module = nn.GELU,
  519. window_size: int = 0,
  520. ):
  521. """Initializes a multiscale attention block with window partitioning and optional query pooling."""
  522. super().__init__()
  523. if isinstance(norm_layer, str):
  524. norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
  525. self.dim = dim
  526. self.dim_out = dim_out
  527. self.norm1 = norm_layer(dim)
  528. self.window_size = window_size
  529. self.pool, self.q_stride = None, q_stride
  530. if self.q_stride:
  531. self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)
  532. self.attn = MultiScaleAttention(
  533. dim,
  534. dim_out,
  535. num_heads=num_heads,
  536. q_pool=self.pool,
  537. )
  538. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  539. self.norm2 = norm_layer(dim_out)
  540. self.mlp = MLP(
  541. dim_out,
  542. int(dim_out * mlp_ratio),
  543. dim_out,
  544. num_layers=2,
  545. act=act_layer,
  546. )
  547. if dim != dim_out:
  548. self.proj = nn.Linear(dim, dim_out)
  549. def forward(self, x: torch.Tensor) -> torch.Tensor:
  550. """Processes input through multiscale attention and MLP, with optional windowing and downsampling."""
  551. shortcut = x # B, H, W, C
  552. x = self.norm1(x)
  553. # Skip connection
  554. if self.dim != self.dim_out:
  555. shortcut = do_pool(self.proj(x), self.pool)
  556. # Window partition
  557. window_size = self.window_size
  558. if window_size > 0:
  559. H, W = x.shape[1], x.shape[2]
  560. x, pad_hw = window_partition(x, window_size)
  561. # Window Attention + Q Pooling (if stage change)
  562. x = self.attn(x)
  563. if self.q_stride:
  564. # Shapes have changed due to Q pooling
  565. window_size = self.window_size // self.q_stride[0]
  566. H, W = shortcut.shape[1:3]
  567. pad_h = (window_size - H % window_size) % window_size
  568. pad_w = (window_size - W % window_size) % window_size
  569. pad_hw = (H + pad_h, W + pad_w)
  570. # Reverse window partition
  571. if self.window_size > 0:
  572. x = window_unpartition(x, window_size, pad_hw, (H, W))
  573. x = shortcut + self.drop_path(x)
  574. # MLP
  575. x = x + self.drop_path(self.mlp(self.norm2(x)))
  576. return x
  577. class PositionEmbeddingSine(nn.Module):
  578. """
  579. A module for generating sinusoidal positional embeddings for 2D inputs like images.
  580. This class implements sinusoidal position encoding for 2D spatial positions, which can be used in
  581. transformer-based models for computer vision tasks.
  582. Attributes:
  583. num_pos_feats (int): Number of positional features (half of the embedding dimension).
  584. temperature (int): Temperature parameter for the sinusoidal functions.
  585. normalize (bool): Whether to normalize the positional embeddings.
  586. scale (float): Scaling factor for the embeddings when normalize is True.
  587. cache (Dict): Cache for storing precomputed embeddings.
  588. Methods:
  589. _encode_xy: Encodes 2D positions using sine and cosine functions.
  590. encode_boxes: Encodes box coordinates and dimensions into positional embeddings.
  591. encode_points: Encodes 2D point coordinates with sinusoidal positional embeddings.
  592. forward: Generates sinusoidal position embeddings for 2D inputs.
  593. Examples:
  594. >>> pos_emb = PositionEmbeddingSine(num_pos_feats=128)
  595. >>> x = torch.randn(1, 3, 224, 224)
  596. >>> embeddings = pos_emb(x)
  597. >>> print(embeddings.shape)
  598. torch.Size([1, 256, 224, 224])
  599. """
  600. def __init__(
  601. self,
  602. num_pos_feats,
  603. temperature: int = 10000,
  604. normalize: bool = True,
  605. scale: Optional[float] = None,
  606. ):
  607. """Initializes sinusoidal position embeddings for 2D image inputs."""
  608. super().__init__()
  609. assert num_pos_feats % 2 == 0, "Expecting even model width"
  610. self.num_pos_feats = num_pos_feats // 2
  611. self.temperature = temperature
  612. self.normalize = normalize
  613. if scale is not None and not normalize:
  614. raise ValueError("normalize should be True if scale is passed")
  615. if scale is None:
  616. scale = 2 * math.pi
  617. self.scale = scale
  618. self.cache = {}
  619. def _encode_xy(self, x, y):
  620. """Encodes 2D positions using sine/cosine functions for transformer positional embeddings."""
  621. assert len(x) == len(y) and x.ndim == y.ndim == 1
  622. x_embed = x * self.scale
  623. y_embed = y * self.scale
  624. dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
  625. dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
  626. pos_x = x_embed[:, None] / dim_t
  627. pos_y = y_embed[:, None] / dim_t
  628. pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
  629. pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
  630. return pos_x, pos_y
  631. @torch.no_grad()
  632. def encode_boxes(self, x, y, w, h):
  633. """Encodes box coordinates and dimensions into positional embeddings for detection."""
  634. pos_x, pos_y = self._encode_xy(x, y)
  635. return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
  636. encode = encode_boxes # Backwards compatibility
  637. @torch.no_grad()
  638. def encode_points(self, x, y, labels):
  639. """Encodes 2D points with sinusoidal embeddings and appends labels."""
  640. (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
  641. assert bx == by and nx == ny and bx == bl and nx == nl
  642. pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
  643. pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
  644. return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
  645. @torch.no_grad()
  646. def forward(self, x: torch.Tensor):
  647. """Generates sinusoidal position embeddings for 2D inputs like images."""
  648. cache_key = (x.shape[-2], x.shape[-1])
  649. if cache_key in self.cache:
  650. return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
  651. y_embed = (
  652. torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
  653. .view(1, -1, 1)
  654. .repeat(x.shape[0], 1, x.shape[-1])
  655. )
  656. x_embed = (
  657. torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
  658. .view(1, 1, -1)
  659. .repeat(x.shape[0], x.shape[-2], 1)
  660. )
  661. if self.normalize:
  662. eps = 1e-6
  663. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  664. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  665. dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
  666. dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
  667. pos_x = x_embed[:, :, :, None] / dim_t
  668. pos_y = y_embed[:, :, :, None] / dim_t
  669. pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
  670. pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
  671. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  672. self.cache[cache_key] = pos[0]
  673. return pos
  674. class PositionEmbeddingRandom(nn.Module):
  675. """
  676. Positional encoding using random spatial frequencies.
  677. This class generates positional embeddings for input coordinates using random spatial frequencies. It is
  678. particularly useful for transformer-based models that require position information.
  679. Attributes:
  680. positional_encoding_gaussian_matrix (torch.Tensor): A buffer containing random values for encoding.
  681. Methods:
  682. _pe_encoding: Positionally encodes points that are normalized to [0,1].
  683. forward: Generates positional encoding for a grid of the specified size.
  684. forward_with_coords: Positionally encodes points that are not normalized to [0,1].
  685. Examples:
  686. >>> pe = PositionEmbeddingRandom(num_pos_feats=64)
  687. >>> size = (32, 32)
  688. >>> encoding = pe(size)
  689. >>> print(encoding.shape)
  690. torch.Size([128, 32, 32])
  691. """
  692. def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
  693. """Initializes random spatial frequency position embedding for transformers."""
  694. super().__init__()
  695. if scale is None or scale <= 0.0:
  696. scale = 1.0
  697. self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
  698. # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
  699. torch.use_deterministic_algorithms(False)
  700. torch.backends.cudnn.deterministic = False
  701. def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
  702. """Encodes normalized [0,1] coordinates using random spatial frequencies."""
  703. # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  704. coords = 2 * coords - 1
  705. coords = coords @ self.positional_encoding_gaussian_matrix
  706. coords = 2 * np.pi * coords
  707. # Outputs d_1 x ... x d_n x C shape
  708. return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
  709. def forward(self, size: Tuple[int, int]) -> torch.Tensor:
  710. """Generates positional encoding for a grid using random spatial frequencies."""
  711. h, w = size
  712. device: Any = self.positional_encoding_gaussian_matrix.device
  713. grid = torch.ones((h, w), device=device, dtype=torch.float32)
  714. y_embed = grid.cumsum(dim=0) - 0.5
  715. x_embed = grid.cumsum(dim=1) - 0.5
  716. y_embed = y_embed / h
  717. x_embed = x_embed / w
  718. pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
  719. return pe.permute(2, 0, 1) # C x H x W
  720. def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
  721. """Positionally encodes input coordinates, normalizing them to [0,1] based on the given image size."""
  722. coords = coords_input.clone()
  723. coords[:, :, 0] = coords[:, :, 0] / image_size[1]
  724. coords[:, :, 1] = coords[:, :, 1] / image_size[0]
  725. return self._pe_encoding(coords.to(torch.float)) # B x N x C
  726. class Block(nn.Module):
  727. """
  728. Transformer block with support for window attention and residual propagation.
  729. This class implements a transformer block that can use either global or windowed self-attention,
  730. followed by a feed-forward network. It supports relative positional embeddings and is designed
  731. for use in vision transformer architectures.
  732. Attributes:
  733. norm1 (nn.Module): First normalization layer.
  734. attn (REAttention): Self-attention layer with optional relative positional encoding.
  735. norm2 (nn.Module): Second normalization layer.
  736. mlp (MLPBlock): Multi-layer perceptron block.
  737. window_size (int): Size of attention window. If 0, global attention is used.
  738. Methods:
  739. forward: Processes input through the transformer block.
  740. Examples:
  741. >>> import torch
  742. >>> block = Block(dim=256, num_heads=8, window_size=7)
  743. >>> x = torch.randn(1, 56, 56, 256)
  744. >>> output = block(x)
  745. >>> print(output.shape)
  746. torch.Size([1, 56, 56, 256])
  747. """
  748. def __init__(
  749. self,
  750. dim: int,
  751. num_heads: int,
  752. mlp_ratio: float = 4.0,
  753. qkv_bias: bool = True,
  754. norm_layer: Type[nn.Module] = nn.LayerNorm,
  755. act_layer: Type[nn.Module] = nn.GELU,
  756. use_rel_pos: bool = False,
  757. rel_pos_zero_init: bool = True,
  758. window_size: int = 0,
  759. input_size: Optional[Tuple[int, int]] = None,
  760. ) -> None:
  761. """
  762. Initializes a transformer block with optional window attention and relative positional embeddings.
  763. This constructor sets up a transformer block that can use either global or windowed self-attention,
  764. followed by a feed-forward network. It supports relative positional embeddings and is designed
  765. for use in vision transformer architectures.
  766. Args:
  767. dim (int): Number of input channels.
  768. num_heads (int): Number of attention heads in the self-attention layer.
  769. mlp_ratio (float): Ratio of mlp hidden dimension to embedding dimension.
  770. qkv_bias (bool): If True, adds a learnable bias to query, key, value projections.
  771. norm_layer (Type[nn.Module]): Type of normalization layer to use.
  772. act_layer (Type[nn.Module]): Type of activation function to use in the MLP block.
  773. use_rel_pos (bool): If True, uses relative positional embeddings in attention.
  774. rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
  775. window_size (int): Size of attention window. If 0, uses global attention.
  776. input_size (Optional[Tuple[int, int]]): Input resolution for calculating relative positional parameter size.
  777. Examples:
  778. >>> block = Block(dim=256, num_heads=8, window_size=7)
  779. >>> x = torch.randn(1, 56, 56, 256)
  780. >>> output = block(x)
  781. >>> print(output.shape)
  782. torch.Size([1, 56, 56, 256])
  783. """
  784. super().__init__()
  785. self.norm1 = norm_layer(dim)
  786. self.attn = REAttention(
  787. dim,
  788. num_heads=num_heads,
  789. qkv_bias=qkv_bias,
  790. use_rel_pos=use_rel_pos,
  791. rel_pos_zero_init=rel_pos_zero_init,
  792. input_size=input_size if window_size == 0 else (window_size, window_size),
  793. )
  794. self.norm2 = norm_layer(dim)
  795. self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
  796. self.window_size = window_size
  797. def forward(self, x: torch.Tensor) -> torch.Tensor:
  798. """Processes input through transformer block with optional windowed self-attention and residual connection."""
  799. shortcut = x
  800. x = self.norm1(x)
  801. # Window partition
  802. if self.window_size > 0:
  803. H, W = x.shape[1], x.shape[2]
  804. x, pad_hw = window_partition(x, self.window_size)
  805. x = self.attn(x)
  806. # Reverse window partition
  807. if self.window_size > 0:
  808. x = window_unpartition(x, self.window_size, pad_hw, (H, W))
  809. x = shortcut + x
  810. return x + self.mlp(self.norm2(x))
  811. class REAttention(nn.Module):
  812. """
  813. Rotary Embedding Attention module for efficient self-attention in transformer architectures.
  814. This class implements a multi-head attention mechanism with rotary positional embeddings, designed
  815. for use in vision transformer models. It supports optional query pooling and window partitioning
  816. for efficient processing of large inputs.
  817. Attributes:
  818. compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
  819. freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding.
  820. rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.
  821. q_proj (nn.Linear): Linear projection for query.
  822. k_proj (nn.Linear): Linear projection for key.
  823. v_proj (nn.Linear): Linear projection for value.
  824. out_proj (nn.Linear): Output projection.
  825. num_heads (int): Number of attention heads.
  826. internal_dim (int): Internal dimension for attention computation.
  827. Methods:
  828. forward: Applies rotary position encoding and computes attention between query, key, and value tensors.
  829. Examples:
  830. >>> rope_attn = REAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))
  831. >>> q = torch.randn(1, 1024, 256)
  832. >>> k = torch.randn(1, 1024, 256)
  833. >>> v = torch.randn(1, 1024, 256)
  834. >>> output = rope_attn(q, k, v)
  835. >>> print(output.shape)
  836. torch.Size([1, 1024, 256])
  837. """
  838. def __init__(
  839. self,
  840. dim: int,
  841. num_heads: int = 8,
  842. qkv_bias: bool = True,
  843. use_rel_pos: bool = False,
  844. rel_pos_zero_init: bool = True,
  845. input_size: Optional[Tuple[int, int]] = None,
  846. ) -> None:
  847. """
  848. Initializes a Relative Position Attention module for transformer-based architectures.
  849. This module implements multi-head attention with optional relative positional encodings, designed
  850. specifically for vision tasks in transformer models.
  851. Args:
  852. dim (int): Number of input channels.
  853. num_heads (int): Number of attention heads. Default is 8.
  854. qkv_bias (bool): If True, adds a learnable bias to query, key, value projections. Default is True.
  855. use_rel_pos (bool): If True, uses relative positional encodings. Default is False.
  856. rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. Default is True.
  857. input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
  858. Required if use_rel_pos is True. Default is None.
  859. Examples:
  860. >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))
  861. >>> x = torch.randn(1, 32, 32, 256)
  862. >>> output = attention(x)
  863. >>> print(output.shape)
  864. torch.Size([1, 32, 32, 256])
  865. """
  866. super().__init__()
  867. self.num_heads = num_heads
  868. head_dim = dim // num_heads
  869. self.scale = head_dim**-0.5
  870. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  871. self.proj = nn.Linear(dim, dim)
  872. self.use_rel_pos = use_rel_pos
  873. if self.use_rel_pos:
  874. assert input_size is not None, "Input size must be provided if using relative positional encoding."
  875. # Initialize relative positional embeddings
  876. self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
  877. self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
  878. def forward(self, x: torch.Tensor) -> torch.Tensor:
  879. """Applies multi-head attention with optional relative positional encoding to input tensor."""
  880. B, H, W, _ = x.shape
  881. # qkv with shape (3, B, nHead, H * W, C)
  882. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  883. # q, k, v with shape (B * nHead, H * W, C)
  884. q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
  885. attn = (q * self.scale) @ k.transpose(-2, -1)
  886. if self.use_rel_pos:
  887. attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
  888. attn = attn.softmax(dim=-1)
  889. x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
  890. return self.proj(x)
  891. class PatchEmbed(nn.Module):
  892. """
  893. Image to Patch Embedding module for vision transformer architectures.
  894. This module converts an input image into a sequence of patch embeddings using a convolutional layer.
  895. It is commonly used as the first layer in vision transformer architectures to transform image data
  896. into a suitable format for subsequent transformer blocks.
  897. Attributes:
  898. proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings.
  899. Methods:
  900. forward: Applies patch embedding to the input tensor.
  901. Examples:
  902. >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)
  903. >>> x = torch.randn(1, 3, 224, 224)
  904. >>> output = patch_embed(x)
  905. >>> print(output.shape)
  906. torch.Size([1, 768, 14, 14])
  907. """
  908. def __init__(
  909. self,
  910. kernel_size: Tuple[int, int] = (16, 16),
  911. stride: Tuple[int, int] = (16, 16),
  912. padding: Tuple[int, int] = (0, 0),
  913. in_chans: int = 3,
  914. embed_dim: int = 768,
  915. ) -> None:
  916. """
  917. Initializes the PatchEmbed module for converting image patches to embeddings.
  918. This module is typically used as the first layer in vision transformer architectures to transform
  919. image data into a suitable format for subsequent transformer blocks.
  920. Args:
  921. kernel_size (Tuple[int, int]): Size of the convolutional kernel for patch extraction.
  922. stride (Tuple[int, int]): Stride of the convolutional operation.
  923. padding (Tuple[int, int]): Padding applied to the input before convolution.
  924. in_chans (int): Number of input image channels.
  925. embed_dim (int): Dimensionality of the output patch embeddings.
  926. Examples:
  927. >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)
  928. >>> x = torch.randn(1, 3, 224, 224)
  929. >>> output = patch_embed(x)
  930. >>> print(output.shape)
  931. torch.Size([1, 768, 14, 14])
  932. """
  933. super().__init__()
  934. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
  935. def forward(self, x: torch.Tensor) -> torch.Tensor:
  936. """Computes patch embedding by applying convolution and transposing resulting tensor."""
  937. return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C