tiny_encoder.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. # --------------------------------------------------------
  3. # TinyViT Model Architecture
  4. # Copyright (c) 2022 Microsoft
  5. # Adapted from LeViT and Swin Transformer
  6. # LeViT: (https://github.com/facebookresearch/levit)
  7. # Swin: (https://github.com/microsoft/swin-transformer)
  8. # Build the TinyViT Model
  9. # --------------------------------------------------------
  10. import itertools
  11. from typing import Tuple
  12. import torch
  13. import torch.nn as nn
  14. import torch.nn.functional as F
  15. import torch.utils.checkpoint as checkpoint
  16. from ultralytics.nn.modules import LayerNorm2d
  17. from ultralytics.utils.instance import to_2tuple
  18. class Conv2d_BN(torch.nn.Sequential):
  19. """
  20. A sequential container that performs 2D convolution followed by batch normalization.
  21. Attributes:
  22. c (torch.nn.Conv2d): 2D convolution layer.
  23. 1 (torch.nn.BatchNorm2d): Batch normalization layer.
  24. Methods:
  25. __init__: Initializes the Conv2d_BN with specified parameters.
  26. Args:
  27. a (int): Number of input channels.
  28. b (int): Number of output channels.
  29. ks (int): Kernel size for the convolution. Defaults to 1.
  30. stride (int): Stride for the convolution. Defaults to 1.
  31. pad (int): Padding for the convolution. Defaults to 0.
  32. dilation (int): Dilation factor for the convolution. Defaults to 1.
  33. groups (int): Number of groups for the convolution. Defaults to 1.
  34. bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1.
  35. Examples:
  36. >>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)
  37. >>> input_tensor = torch.randn(1, 3, 224, 224)
  38. >>> output = conv_bn(input_tensor)
  39. >>> print(output.shape)
  40. """
  41. def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
  42. """Initializes a sequential container with 2D convolution followed by batch normalization."""
  43. super().__init__()
  44. self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
  45. bn = torch.nn.BatchNorm2d(b)
  46. torch.nn.init.constant_(bn.weight, bn_weight_init)
  47. torch.nn.init.constant_(bn.bias, 0)
  48. self.add_module("bn", bn)
  49. class PatchEmbed(nn.Module):
  50. """
  51. Embeds images into patches and projects them into a specified embedding dimension.
  52. Attributes:
  53. patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.
  54. num_patches (int): Total number of patches.
  55. in_chans (int): Number of input channels.
  56. embed_dim (int): Dimension of the embedding.
  57. seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.
  58. Methods:
  59. forward: Processes the input tensor through the patch embedding sequence.
  60. Examples:
  61. >>> import torch
  62. >>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)
  63. >>> x = torch.randn(1, 3, 224, 224)
  64. >>> output = patch_embed(x)
  65. >>> print(output.shape)
  66. """
  67. def __init__(self, in_chans, embed_dim, resolution, activation):
  68. """Initializes patch embedding with convolutional layers for image-to-patch conversion and projection."""
  69. super().__init__()
  70. img_size: Tuple[int, int] = to_2tuple(resolution)
  71. self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
  72. self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
  73. self.in_chans = in_chans
  74. self.embed_dim = embed_dim
  75. n = embed_dim
  76. self.seq = nn.Sequential(
  77. Conv2d_BN(in_chans, n // 2, 3, 2, 1),
  78. activation(),
  79. Conv2d_BN(n // 2, n, 3, 2, 1),
  80. )
  81. def forward(self, x):
  82. """Processes input tensor through patch embedding sequence, converting images to patch embeddings."""
  83. return self.seq(x)
  84. class MBConv(nn.Module):
  85. """
  86. Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
  87. Attributes:
  88. in_chans (int): Number of input channels.
  89. hidden_chans (int): Number of hidden channels.
  90. out_chans (int): Number of output channels.
  91. conv1 (Conv2d_BN): First convolutional layer.
  92. act1 (nn.Module): First activation function.
  93. conv2 (Conv2d_BN): Depthwise convolutional layer.
  94. act2 (nn.Module): Second activation function.
  95. conv3 (Conv2d_BN): Final convolutional layer.
  96. act3 (nn.Module): Third activation function.
  97. drop_path (nn.Module): Drop path layer (Identity for inference).
  98. Methods:
  99. forward: Performs the forward pass through the MBConv layer.
  100. Examples:
  101. >>> in_chans, out_chans = 32, 64
  102. >>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)
  103. >>> x = torch.randn(1, in_chans, 56, 56)
  104. >>> output = mbconv(x)
  105. >>> print(output.shape)
  106. torch.Size([1, 64, 56, 56])
  107. """
  108. def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
  109. """Initializes the MBConv layer with specified input/output channels, expansion ratio, and activation."""
  110. super().__init__()
  111. self.in_chans = in_chans
  112. self.hidden_chans = int(in_chans * expand_ratio)
  113. self.out_chans = out_chans
  114. self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
  115. self.act1 = activation()
  116. self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)
  117. self.act2 = activation()
  118. self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
  119. self.act3 = activation()
  120. # NOTE: `DropPath` is needed only for training.
  121. # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  122. self.drop_path = nn.Identity()
  123. def forward(self, x):
  124. """Implements the forward pass of MBConv, applying convolutions and skip connection."""
  125. shortcut = x
  126. x = self.conv1(x)
  127. x = self.act1(x)
  128. x = self.conv2(x)
  129. x = self.act2(x)
  130. x = self.conv3(x)
  131. x = self.drop_path(x)
  132. x += shortcut
  133. return self.act3(x)
  134. class PatchMerging(nn.Module):
  135. """
  136. Merges neighboring patches in the feature map and projects to a new dimension.
  137. This class implements a patch merging operation that combines spatial information and adjusts the feature
  138. dimension. It uses a series of convolutional layers with batch normalization to achieve this.
  139. Attributes:
  140. input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.
  141. dim (int): The input dimension of the feature map.
  142. out_dim (int): The output dimension after merging and projection.
  143. act (nn.Module): The activation function used between convolutions.
  144. conv1 (Conv2d_BN): The first convolutional layer for dimension projection.
  145. conv2 (Conv2d_BN): The second convolutional layer for spatial merging.
  146. conv3 (Conv2d_BN): The third convolutional layer for final projection.
  147. Methods:
  148. forward: Applies the patch merging operation to the input tensor.
  149. Examples:
  150. >>> input_resolution = (56, 56)
  151. >>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)
  152. >>> x = torch.randn(4, 64, 56, 56)
  153. >>> output = patch_merging(x)
  154. >>> print(output.shape)
  155. """
  156. def __init__(self, input_resolution, dim, out_dim, activation):
  157. """Initializes the PatchMerging module for merging and projecting neighboring patches in feature maps."""
  158. super().__init__()
  159. self.input_resolution = input_resolution
  160. self.dim = dim
  161. self.out_dim = out_dim
  162. self.act = activation()
  163. self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
  164. stride_c = 1 if out_dim in {320, 448, 576} else 2
  165. self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
  166. self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
  167. def forward(self, x):
  168. """Applies patch merging and dimension projection to the input feature map."""
  169. if x.ndim == 3:
  170. H, W = self.input_resolution
  171. B = len(x)
  172. # (B, C, H, W)
  173. x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
  174. x = self.conv1(x)
  175. x = self.act(x)
  176. x = self.conv2(x)
  177. x = self.act(x)
  178. x = self.conv3(x)
  179. return x.flatten(2).transpose(1, 2)
  180. class ConvLayer(nn.Module):
  181. """
  182. Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
  183. This layer optionally applies downsample operations to the output and supports gradient checkpointing.
  184. Attributes:
  185. dim (int): Dimensionality of the input and output.
  186. input_resolution (Tuple[int, int]): Resolution of the input image.
  187. depth (int): Number of MBConv layers in the block.
  188. use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
  189. blocks (nn.ModuleList): List of MBConv layers.
  190. downsample (Optional[Callable]): Function for downsampling the output.
  191. Methods:
  192. forward: Processes the input through the convolutional layers.
  193. Examples:
  194. >>> input_tensor = torch.randn(1, 64, 56, 56)
  195. >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
  196. >>> output = conv_layer(input_tensor)
  197. >>> print(output.shape)
  198. """
  199. def __init__(
  200. self,
  201. dim,
  202. input_resolution,
  203. depth,
  204. activation,
  205. drop_path=0.0,
  206. downsample=None,
  207. use_checkpoint=False,
  208. out_dim=None,
  209. conv_expand_ratio=4.0,
  210. ):
  211. """
  212. Initializes the ConvLayer with the given dimensions and settings.
  213. This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
  214. optionally applies downsampling to the output.
  215. Args:
  216. dim (int): The dimensionality of the input and output.
  217. input_resolution (Tuple[int, int]): The resolution of the input image.
  218. depth (int): The number of MBConv layers in the block.
  219. activation (Callable): Activation function applied after each convolution.
  220. drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv.
  221. downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
  222. use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
  223. out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
  224. conv_expand_ratio (float): Expansion ratio for the MBConv layers.
  225. Examples:
  226. >>> input_tensor = torch.randn(1, 64, 56, 56)
  227. >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
  228. >>> output = conv_layer(input_tensor)
  229. >>> print(output.shape)
  230. """
  231. super().__init__()
  232. self.dim = dim
  233. self.input_resolution = input_resolution
  234. self.depth = depth
  235. self.use_checkpoint = use_checkpoint
  236. # Build blocks
  237. self.blocks = nn.ModuleList(
  238. [
  239. MBConv(
  240. dim,
  241. dim,
  242. conv_expand_ratio,
  243. activation,
  244. drop_path[i] if isinstance(drop_path, list) else drop_path,
  245. )
  246. for i in range(depth)
  247. ]
  248. )
  249. # Patch merging layer
  250. self.downsample = (
  251. None
  252. if downsample is None
  253. else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
  254. )
  255. def forward(self, x):
  256. """Processes input through convolutional layers, applying MBConv blocks and optional downsampling."""
  257. for blk in self.blocks:
  258. x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
  259. return x if self.downsample is None else self.downsample(x)
  260. class Mlp(nn.Module):
  261. """
  262. Multi-layer Perceptron (MLP) module for transformer architectures.
  263. This module applies layer normalization, two fully-connected layers with an activation function in between,
  264. and dropout. It is commonly used in transformer-based architectures.
  265. Attributes:
  266. norm (nn.LayerNorm): Layer normalization applied to the input.
  267. fc1 (nn.Linear): First fully-connected layer.
  268. fc2 (nn.Linear): Second fully-connected layer.
  269. act (nn.Module): Activation function applied after the first fully-connected layer.
  270. drop (nn.Dropout): Dropout layer applied after the activation function.
  271. Methods:
  272. forward: Applies the MLP operations on the input tensor.
  273. Examples:
  274. >>> import torch
  275. >>> from torch import nn
  276. >>> mlp = Mlp(in_features=256, hidden_features=512, out_features=256, act_layer=nn.GELU, drop=0.1)
  277. >>> x = torch.randn(32, 100, 256)
  278. >>> output = mlp(x)
  279. >>> print(output.shape)
  280. torch.Size([32, 100, 256])
  281. """
  282. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
  283. """Initializes a multi-layer perceptron with configurable input, hidden, and output dimensions."""
  284. super().__init__()
  285. out_features = out_features or in_features
  286. hidden_features = hidden_features or in_features
  287. self.norm = nn.LayerNorm(in_features)
  288. self.fc1 = nn.Linear(in_features, hidden_features)
  289. self.fc2 = nn.Linear(hidden_features, out_features)
  290. self.act = act_layer()
  291. self.drop = nn.Dropout(drop)
  292. def forward(self, x):
  293. """Applies MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
  294. x = self.norm(x)
  295. x = self.fc1(x)
  296. x = self.act(x)
  297. x = self.drop(x)
  298. x = self.fc2(x)
  299. return self.drop(x)
  300. class Attention(torch.nn.Module):
  301. """
  302. Multi-head attention module with spatial awareness and trainable attention biases.
  303. This module implements a multi-head attention mechanism with support for spatial awareness, applying
  304. attention biases based on spatial resolution. It includes trainable attention biases for each unique
  305. offset between spatial positions in the resolution grid.
  306. Attributes:
  307. num_heads (int): Number of attention heads.
  308. scale (float): Scaling factor for attention scores.
  309. key_dim (int): Dimensionality of the keys and queries.
  310. nh_kd (int): Product of num_heads and key_dim.
  311. d (int): Dimensionality of the value vectors.
  312. dh (int): Product of d and num_heads.
  313. attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.
  314. norm (nn.LayerNorm): Layer normalization applied to input.
  315. qkv (nn.Linear): Linear layer for computing query, key, and value projections.
  316. proj (nn.Linear): Linear layer for final projection.
  317. attention_biases (nn.Parameter): Learnable attention biases.
  318. attention_bias_idxs (Tensor): Indices for attention biases.
  319. ab (Tensor): Cached attention biases for inference, deleted during training.
  320. Methods:
  321. train: Sets the module in training mode and handles the 'ab' attribute.
  322. forward: Performs the forward pass of the attention mechanism.
  323. Examples:
  324. >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
  325. >>> x = torch.randn(1, 196, 256)
  326. >>> output = attn(x)
  327. >>> print(output.shape)
  328. torch.Size([1, 196, 256])
  329. """
  330. def __init__(
  331. self,
  332. dim,
  333. key_dim,
  334. num_heads=8,
  335. attn_ratio=4,
  336. resolution=(14, 14),
  337. ):
  338. """
  339. Initializes the Attention module for multi-head attention with spatial awareness.
  340. This module implements a multi-head attention mechanism with support for spatial awareness, applying
  341. attention biases based on spatial resolution. It includes trainable attention biases for each unique
  342. offset between spatial positions in the resolution grid.
  343. Args:
  344. dim (int): The dimensionality of the input and output.
  345. key_dim (int): The dimensionality of the keys and queries.
  346. num_heads (int): Number of attention heads. Default is 8.
  347. attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
  348. resolution (Tuple[int, int]): Spatial resolution of the input feature map. Default is (14, 14).
  349. Raises:
  350. AssertionError: If 'resolution' is not a tuple of length 2.
  351. Examples:
  352. >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
  353. >>> x = torch.randn(1, 196, 256)
  354. >>> output = attn(x)
  355. >>> print(output.shape)
  356. torch.Size([1, 196, 256])
  357. """
  358. super().__init__()
  359. assert isinstance(resolution, tuple) and len(resolution) == 2, "'resolution' argument not tuple of length 2"
  360. self.num_heads = num_heads
  361. self.scale = key_dim**-0.5
  362. self.key_dim = key_dim
  363. self.nh_kd = nh_kd = key_dim * num_heads
  364. self.d = int(attn_ratio * key_dim)
  365. self.dh = int(attn_ratio * key_dim) * num_heads
  366. self.attn_ratio = attn_ratio
  367. h = self.dh + nh_kd * 2
  368. self.norm = nn.LayerNorm(dim)
  369. self.qkv = nn.Linear(dim, h)
  370. self.proj = nn.Linear(self.dh, dim)
  371. points = list(itertools.product(range(resolution[0]), range(resolution[1])))
  372. N = len(points)
  373. attention_offsets = {}
  374. idxs = []
  375. for p1 in points:
  376. for p2 in points:
  377. offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
  378. if offset not in attention_offsets:
  379. attention_offsets[offset] = len(attention_offsets)
  380. idxs.append(attention_offsets[offset])
  381. self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
  382. self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
  383. @torch.no_grad()
  384. def train(self, mode=True):
  385. """Performs multi-head attention with spatial awareness and trainable attention biases."""
  386. super().train(mode)
  387. if mode and hasattr(self, "ab"):
  388. del self.ab
  389. else:
  390. self.ab = self.attention_biases[:, self.attention_bias_idxs]
  391. def forward(self, x): # x
  392. """Applies multi-head attention with spatial awareness and trainable attention biases."""
  393. B, N, _ = x.shape # B, N, C
  394. # Normalization
  395. x = self.norm(x)
  396. qkv = self.qkv(x)
  397. # (B, N, num_heads, d)
  398. q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
  399. # (B, num_heads, N, d)
  400. q = q.permute(0, 2, 1, 3)
  401. k = k.permute(0, 2, 1, 3)
  402. v = v.permute(0, 2, 1, 3)
  403. self.ab = self.ab.to(self.attention_biases.device)
  404. attn = (q @ k.transpose(-2, -1)) * self.scale + (
  405. self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
  406. )
  407. attn = attn.softmax(dim=-1)
  408. x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
  409. return self.proj(x)
  410. class TinyViTBlock(nn.Module):
  411. """
  412. TinyViT Block that applies self-attention and a local convolution to the input.
  413. This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
  414. local convolutions to process input features efficiently.
  415. Attributes:
  416. dim (int): The dimensionality of the input and output.
  417. input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
  418. num_heads (int): Number of attention heads.
  419. window_size (int): Size of the attention window.
  420. mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
  421. drop_path (nn.Module): Stochastic depth layer, identity function during inference.
  422. attn (Attention): Self-attention module.
  423. mlp (Mlp): Multi-layer perceptron module.
  424. local_conv (Conv2d_BN): Depth-wise local convolution layer.
  425. Methods:
  426. forward: Processes the input through the TinyViT block.
  427. extra_repr: Returns a string with extra information about the block's parameters.
  428. Examples:
  429. >>> input_tensor = torch.randn(1, 196, 192)
  430. >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
  431. >>> output = block(input_tensor)
  432. >>> print(output.shape)
  433. torch.Size([1, 196, 192])
  434. """
  435. def __init__(
  436. self,
  437. dim,
  438. input_resolution,
  439. num_heads,
  440. window_size=7,
  441. mlp_ratio=4.0,
  442. drop=0.0,
  443. drop_path=0.0,
  444. local_conv_size=3,
  445. activation=nn.GELU,
  446. ):
  447. """
  448. Initializes a TinyViT block with self-attention and local convolution.
  449. This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
  450. local convolutions to process input features efficiently.
  451. Args:
  452. dim (int): Dimensionality of the input and output features.
  453. input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
  454. num_heads (int): Number of attention heads.
  455. window_size (int): Size of the attention window. Must be greater than 0.
  456. mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
  457. drop (float): Dropout rate.
  458. drop_path (float): Stochastic depth rate.
  459. local_conv_size (int): Kernel size of the local convolution.
  460. activation (torch.nn.Module): Activation function for MLP.
  461. Raises:
  462. AssertionError: If window_size is not greater than 0.
  463. AssertionError: If dim is not divisible by num_heads.
  464. Examples:
  465. >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
  466. >>> input_tensor = torch.randn(1, 196, 192)
  467. >>> output = block(input_tensor)
  468. >>> print(output.shape)
  469. torch.Size([1, 196, 192])
  470. """
  471. super().__init__()
  472. self.dim = dim
  473. self.input_resolution = input_resolution
  474. self.num_heads = num_heads
  475. assert window_size > 0, "window_size must be greater than 0"
  476. self.window_size = window_size
  477. self.mlp_ratio = mlp_ratio
  478. # NOTE: `DropPath` is needed only for training.
  479. # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  480. self.drop_path = nn.Identity()
  481. assert dim % num_heads == 0, "dim must be divisible by num_heads"
  482. head_dim = dim // num_heads
  483. window_resolution = (window_size, window_size)
  484. self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)
  485. mlp_hidden_dim = int(dim * mlp_ratio)
  486. mlp_activation = activation
  487. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop)
  488. pad = local_conv_size // 2
  489. self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
  490. def forward(self, x):
  491. """Applies self-attention, local convolution, and MLP operations to the input tensor."""
  492. h, w = self.input_resolution
  493. b, hw, c = x.shape # batch, height*width, channels
  494. assert hw == h * w, "input feature has wrong size"
  495. res_x = x
  496. if h == self.window_size and w == self.window_size:
  497. x = self.attn(x)
  498. else:
  499. x = x.view(b, h, w, c)
  500. pad_b = (self.window_size - h % self.window_size) % self.window_size
  501. pad_r = (self.window_size - w % self.window_size) % self.window_size
  502. padding = pad_b > 0 or pad_r > 0
  503. if padding:
  504. x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
  505. pH, pW = h + pad_b, w + pad_r
  506. nH = pH // self.window_size
  507. nW = pW // self.window_size
  508. # Window partition
  509. x = (
  510. x.view(b, nH, self.window_size, nW, self.window_size, c)
  511. .transpose(2, 3)
  512. .reshape(b * nH * nW, self.window_size * self.window_size, c)
  513. )
  514. x = self.attn(x)
  515. # Window reverse
  516. x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c)
  517. if padding:
  518. x = x[:, :h, :w].contiguous()
  519. x = x.view(b, hw, c)
  520. x = res_x + self.drop_path(x)
  521. x = x.transpose(1, 2).reshape(b, c, h, w)
  522. x = self.local_conv(x)
  523. x = x.view(b, c, hw).transpose(1, 2)
  524. return x + self.drop_path(self.mlp(x))
  525. def extra_repr(self) -> str:
  526. """
  527. Returns a string representation of the TinyViTBlock's parameters.
  528. This method provides a formatted string containing key information about the TinyViTBlock, including its
  529. dimension, input resolution, number of attention heads, window size, and MLP ratio.
  530. Returns:
  531. (str): A formatted string containing the block's parameters.
  532. Examples:
  533. >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0)
  534. >>> print(block.extra_repr())
  535. dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0
  536. """
  537. return (
  538. f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
  539. f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
  540. )
  541. class BasicLayer(nn.Module):
  542. """
  543. A basic TinyViT layer for one stage in a TinyViT architecture.
  544. This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
  545. and an optional downsampling operation.
  546. Attributes:
  547. dim (int): The dimensionality of the input and output features.
  548. input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
  549. depth (int): Number of TinyViT blocks in this layer.
  550. use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
  551. blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.
  552. downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.
  553. Methods:
  554. forward: Processes the input through the layer's blocks and optional downsampling.
  555. extra_repr: Returns a string with the layer's parameters for printing.
  556. Examples:
  557. >>> input_tensor = torch.randn(1, 3136, 192)
  558. >>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
  559. >>> output = layer(input_tensor)
  560. >>> print(output.shape)
  561. torch.Size([1, 784, 384])
  562. """
  563. def __init__(
  564. self,
  565. dim,
  566. input_resolution,
  567. depth,
  568. num_heads,
  569. window_size,
  570. mlp_ratio=4.0,
  571. drop=0.0,
  572. drop_path=0.0,
  573. downsample=None,
  574. use_checkpoint=False,
  575. local_conv_size=3,
  576. activation=nn.GELU,
  577. out_dim=None,
  578. ):
  579. """
  580. Initializes a BasicLayer in the TinyViT architecture.
  581. This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
  582. process feature maps at a specific resolution and dimensionality within the TinyViT model.
  583. Args:
  584. dim (int): Dimensionality of the input and output features.
  585. input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
  586. depth (int): Number of TinyViT blocks in this layer.
  587. num_heads (int): Number of attention heads in each TinyViT block.
  588. window_size (int): Size of the local window for attention computation.
  589. mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
  590. drop (float): Dropout rate.
  591. drop_path (float | List[float]): Stochastic depth rate. Can be a float or a list of floats for each block.
  592. downsample (nn.Module | None): Downsampling layer at the end of the layer. None to skip downsampling.
  593. use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
  594. local_conv_size (int): Kernel size for the local convolution in each TinyViT block.
  595. activation (nn.Module): Activation function used in the MLP.
  596. out_dim (int | None): Output dimension after downsampling. None means it will be the same as `dim`.
  597. Raises:
  598. ValueError: If `drop_path` is a list and its length doesn't match `depth`.
  599. Examples:
  600. >>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
  601. >>> x = torch.randn(1, 56 * 56, 96)
  602. >>> output = layer(x)
  603. >>> print(output.shape)
  604. """
  605. super().__init__()
  606. self.dim = dim
  607. self.input_resolution = input_resolution
  608. self.depth = depth
  609. self.use_checkpoint = use_checkpoint
  610. # Build blocks
  611. self.blocks = nn.ModuleList(
  612. [
  613. TinyViTBlock(
  614. dim=dim,
  615. input_resolution=input_resolution,
  616. num_heads=num_heads,
  617. window_size=window_size,
  618. mlp_ratio=mlp_ratio,
  619. drop=drop,
  620. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  621. local_conv_size=local_conv_size,
  622. activation=activation,
  623. )
  624. for i in range(depth)
  625. ]
  626. )
  627. # Patch merging layer
  628. self.downsample = (
  629. None
  630. if downsample is None
  631. else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
  632. )
  633. def forward(self, x):
  634. """Processes input through TinyViT blocks and optional downsampling."""
  635. for blk in self.blocks:
  636. x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
  637. return x if self.downsample is None else self.downsample(x)
  638. def extra_repr(self) -> str:
  639. """Returns a string with the layer's parameters for printing."""
  640. return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
  641. class TinyViT(nn.Module):
  642. """
  643. TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
  644. This class implements the TinyViT model, which combines elements of vision transformers and convolutional
  645. neural networks for improved efficiency and performance on vision tasks.
  646. Attributes:
  647. img_size (int): Input image size.
  648. num_classes (int): Number of classification classes.
  649. depths (List[int]): Number of blocks in each stage.
  650. num_layers (int): Total number of layers in the network.
  651. mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
  652. patch_embed (PatchEmbed): Module for patch embedding.
  653. patches_resolution (Tuple[int, int]): Resolution of embedded patches.
  654. layers (nn.ModuleList): List of network layers.
  655. norm_head (nn.LayerNorm): Layer normalization for the classifier head.
  656. head (nn.Linear): Linear layer for final classification.
  657. neck (nn.Sequential): Neck module for feature refinement.
  658. Methods:
  659. set_layer_lr_decay: Sets layer-wise learning rate decay.
  660. _init_weights: Initializes weights for linear and normalization layers.
  661. no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay.
  662. forward_features: Processes input through the feature extraction layers.
  663. forward: Performs a forward pass through the entire network.
  664. Examples:
  665. >>> model = TinyViT(img_size=224, num_classes=1000)
  666. >>> x = torch.randn(1, 3, 224, 224)
  667. >>> features = model.forward_features(x)
  668. >>> print(features.shape)
  669. torch.Size([1, 256, 64, 64])
  670. """
  671. def __init__(
  672. self,
  673. img_size=224,
  674. in_chans=3,
  675. num_classes=1000,
  676. embed_dims=(96, 192, 384, 768),
  677. depths=(2, 2, 6, 2),
  678. num_heads=(3, 6, 12, 24),
  679. window_sizes=(7, 7, 14, 7),
  680. mlp_ratio=4.0,
  681. drop_rate=0.0,
  682. drop_path_rate=0.1,
  683. use_checkpoint=False,
  684. mbconv_expand_ratio=4.0,
  685. local_conv_size=3,
  686. layer_lr_decay=1.0,
  687. ):
  688. """
  689. Initializes the TinyViT model.
  690. This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
  691. attention and convolution blocks, and a classification head.
  692. Args:
  693. img_size (int): Size of the input image. Default is 224.
  694. in_chans (int): Number of input channels. Default is 3.
  695. num_classes (int): Number of classes for classification. Default is 1000.
  696. embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage.
  697. Default is (96, 192, 384, 768).
  698. depths (Tuple[int, int, int, int]): Number of blocks in each stage. Default is (2, 2, 6, 2).
  699. num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage.
  700. Default is (3, 6, 12, 24).
  701. window_sizes (Tuple[int, int, int, int]): Window sizes for each stage. Default is (7, 7, 14, 7).
  702. mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. Default is 4.0.
  703. drop_rate (float): Dropout rate. Default is 0.0.
  704. drop_path_rate (float): Stochastic depth rate. Default is 0.1.
  705. use_checkpoint (bool): Whether to use checkpointing to save memory. Default is False.
  706. mbconv_expand_ratio (float): Expansion ratio for MBConv layer. Default is 4.0.
  707. local_conv_size (int): Kernel size for local convolutions. Default is 3.
  708. layer_lr_decay (float): Layer-wise learning rate decay factor. Default is 1.0.
  709. Examples:
  710. >>> model = TinyViT(img_size=224, num_classes=1000)
  711. >>> x = torch.randn(1, 3, 224, 224)
  712. >>> output = model(x)
  713. >>> print(output.shape)
  714. torch.Size([1, 1000])
  715. """
  716. super().__init__()
  717. self.img_size = img_size
  718. self.num_classes = num_classes
  719. self.depths = depths
  720. self.num_layers = len(depths)
  721. self.mlp_ratio = mlp_ratio
  722. activation = nn.GELU
  723. self.patch_embed = PatchEmbed(
  724. in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation
  725. )
  726. patches_resolution = self.patch_embed.patches_resolution
  727. self.patches_resolution = patches_resolution
  728. # Stochastic depth
  729. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  730. # Build layers
  731. self.layers = nn.ModuleList()
  732. for i_layer in range(self.num_layers):
  733. kwargs = dict(
  734. dim=embed_dims[i_layer],
  735. input_resolution=(
  736. patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
  737. patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
  738. ),
  739. # input_resolution=(patches_resolution[0] // (2 ** i_layer),
  740. # patches_resolution[1] // (2 ** i_layer)),
  741. depth=depths[i_layer],
  742. drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
  743. downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
  744. use_checkpoint=use_checkpoint,
  745. out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
  746. activation=activation,
  747. )
  748. if i_layer == 0:
  749. layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
  750. else:
  751. layer = BasicLayer(
  752. num_heads=num_heads[i_layer],
  753. window_size=window_sizes[i_layer],
  754. mlp_ratio=self.mlp_ratio,
  755. drop=drop_rate,
  756. local_conv_size=local_conv_size,
  757. **kwargs,
  758. )
  759. self.layers.append(layer)
  760. # Classifier head
  761. self.norm_head = nn.LayerNorm(embed_dims[-1])
  762. self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
  763. # Init weights
  764. self.apply(self._init_weights)
  765. self.set_layer_lr_decay(layer_lr_decay)
  766. self.neck = nn.Sequential(
  767. nn.Conv2d(
  768. embed_dims[-1],
  769. 256,
  770. kernel_size=1,
  771. bias=False,
  772. ),
  773. LayerNorm2d(256),
  774. nn.Conv2d(
  775. 256,
  776. 256,
  777. kernel_size=3,
  778. padding=1,
  779. bias=False,
  780. ),
  781. LayerNorm2d(256),
  782. )
  783. def set_layer_lr_decay(self, layer_lr_decay):
  784. """Sets layer-wise learning rate decay for the TinyViT model based on depth."""
  785. decay_rate = layer_lr_decay
  786. # Layers -> blocks (depth)
  787. depth = sum(self.depths)
  788. lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
  789. def _set_lr_scale(m, scale):
  790. """Sets the learning rate scale for each layer in the model based on the layer's depth."""
  791. for p in m.parameters():
  792. p.lr_scale = scale
  793. self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
  794. i = 0
  795. for layer in self.layers:
  796. for block in layer.blocks:
  797. block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
  798. i += 1
  799. if layer.downsample is not None:
  800. layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
  801. assert i == depth
  802. for m in [self.norm_head, self.head]:
  803. m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
  804. for k, p in self.named_parameters():
  805. p.param_name = k
  806. def _check_lr_scale(m):
  807. """Checks if the learning rate scale attribute is present in module's parameters."""
  808. for p in m.parameters():
  809. assert hasattr(p, "lr_scale"), p.param_name
  810. self.apply(_check_lr_scale)
  811. @staticmethod
  812. def _init_weights(m):
  813. """Initializes weights for linear and normalization layers in the TinyViT model."""
  814. if isinstance(m, nn.Linear):
  815. # NOTE: This initialization is needed only for training.
  816. # trunc_normal_(m.weight, std=.02)
  817. if m.bias is not None:
  818. nn.init.constant_(m.bias, 0)
  819. elif isinstance(m, nn.LayerNorm):
  820. nn.init.constant_(m.bias, 0)
  821. nn.init.constant_(m.weight, 1.0)
  822. @torch.jit.ignore
  823. def no_weight_decay_keywords(self):
  824. """Returns a set of keywords for parameters that should not use weight decay."""
  825. return {"attention_biases"}
  826. def forward_features(self, x):
  827. """Processes input through feature extraction layers, returning spatial features."""
  828. x = self.patch_embed(x) # x input is (N, C, H, W)
  829. x = self.layers[0](x)
  830. start_i = 1
  831. for i in range(start_i, len(self.layers)):
  832. layer = self.layers[i]
  833. x = layer(x)
  834. batch, _, channel = x.shape
  835. x = x.view(batch, self.patches_resolution[0] // 4, self.patches_resolution[1] // 4, channel)
  836. x = x.permute(0, 3, 1, 2)
  837. return self.neck(x)
  838. def forward(self, x):
  839. """Performs the forward pass through the TinyViT model, extracting features from the input image."""
  840. return self.forward_features(x)
  841. def set_imgsz(self, imgsz=[1024, 1024]):
  842. """
  843. Set image size to make model compatible with different image sizes.
  844. Args:
  845. imgsz (Tuple[int, int]): The size of the input image.
  846. """
  847. imgsz = [s // 4 for s in imgsz]
  848. self.patches_resolution = imgsz
  849. for i, layer in enumerate(self.layers):
  850. input_resolution = (
  851. imgsz[0] // (2 ** (i - 1 if i == 3 else i)),
  852. imgsz[1] // (2 ** (i - 1 if i == 3 else i)),
  853. )
  854. layer.input_resolution = input_resolution
  855. if layer.downsample is not None:
  856. layer.downsample.input_resolution = input_resolution
  857. if isinstance(layer, BasicLayer):
  858. for b in layer.blocks:
  859. b.input_resolution = input_resolution