high_reso_maxvit.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942
  1. import math
  2. from collections import OrderedDict
  3. from functools import partial
  4. from typing import Any, Callable, List, Optional, Sequence, Tuple
  5. import numpy as np
  6. import torch
  7. import torch.nn.functional as F
  8. from torch import nn, Tensor
  9. from torchvision.models._api import register_model, Weights, WeightsEnum
  10. from torchvision.models._meta import _IMAGENET_CATEGORIES
  11. from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface
  12. from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation
  13. from torchvision.ops.stochastic_depth import StochasticDepth
  14. from torchvision.transforms._presets import ImageClassification, InterpolationMode
  15. from torchvision.utils import _log_api_usage_once
  16. __all__ = [
  17. "MaxVit",
  18. "MaxVit_T_Weights",
  19. "maxvit_t",
  20. ]
  21. from libs.vision_libs.models.detection.backbone_utils import BackboneWithFPN
  22. def _get_conv_output_shape(input_size: Tuple[int, int], kernel_size: int, stride: int, padding: int) -> Tuple[int, int]:
  23. return (
  24. (input_size[0] - kernel_size + 2 * padding) // stride + 1,
  25. (input_size[1] - kernel_size + 2 * padding) // stride + 1,
  26. )
  27. def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List[Tuple[int, int]]:
  28. """Util function to check that the input size is correct for a MaxVit configuration."""
  29. shapes = []
  30. block_input_shape = _get_conv_output_shape(input_size, 3, 2, 1)
  31. for _ in range(n_blocks):
  32. block_input_shape = _get_conv_output_shape(block_input_shape, 3, 2, 1)
  33. shapes.append(block_input_shape)
  34. return shapes
  35. def _get_relative_position_index(height: int, width: int) -> torch.Tensor:
  36. coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)]))
  37. coords_flat = torch.flatten(coords, 1)
  38. relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]
  39. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  40. relative_coords[:, :, 0] += height - 1
  41. relative_coords[:, :, 1] += width - 1
  42. relative_coords[:, :, 0] *= 2 * width - 1
  43. return relative_coords.sum(-1)
  44. class MBConv(nn.Module):
  45. """MBConv: Mobile Inverted Residual Bottleneck.
  46. Args:
  47. in_channels (int): Number of input channels.
  48. out_channels (int): Number of output channels.
  49. expansion_ratio (float): Expansion ratio in the bottleneck.
  50. squeeze_ratio (float): Squeeze ratio in the SE Layer.
  51. stride (int): Stride of the depthwise convolution.
  52. activation_layer (Callable[..., nn.Module]): Activation function.
  53. norm_layer (Callable[..., nn.Module]): Normalization function.
  54. p_stochastic_dropout (float): Probability of stochastic depth.
  55. """
  56. def __init__(
  57. self,
  58. in_channels: int,
  59. out_channels: int,
  60. expansion_ratio: float,
  61. squeeze_ratio: float,
  62. stride: int,
  63. activation_layer: Callable[..., nn.Module],
  64. norm_layer: Callable[..., nn.Module],
  65. p_stochastic_dropout: float = 0.0,
  66. ) -> None:
  67. super().__init__()
  68. proj: Sequence[nn.Module]
  69. self.proj: nn.Module
  70. should_proj = stride != 1 or in_channels != out_channels
  71. if should_proj:
  72. proj = [nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=True)]
  73. if stride == 2:
  74. proj = [nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)] + proj # type: ignore
  75. self.proj = nn.Sequential(*proj)
  76. else:
  77. self.proj = nn.Identity() # type: ignore
  78. mid_channels = int(out_channels * expansion_ratio)
  79. sqz_channels = int(out_channels * squeeze_ratio)
  80. if p_stochastic_dropout:
  81. self.stochastic_depth = StochasticDepth(p_stochastic_dropout, mode="row") # type: ignore
  82. else:
  83. self.stochastic_depth = nn.Identity() # type: ignore
  84. _layers = OrderedDict()
  85. _layers["pre_norm"] = norm_layer(in_channels)
  86. _layers["conv_a"] = Conv2dNormActivation(
  87. in_channels,
  88. mid_channels,
  89. kernel_size=1,
  90. stride=1,
  91. padding=0,
  92. activation_layer=activation_layer,
  93. norm_layer=norm_layer,
  94. inplace=None,
  95. )
  96. _layers["conv_b"] = Conv2dNormActivation(
  97. mid_channels,
  98. mid_channels,
  99. kernel_size=3,
  100. stride=stride,
  101. padding=1,
  102. activation_layer=activation_layer,
  103. norm_layer=norm_layer,
  104. groups=mid_channels,
  105. inplace=None,
  106. )
  107. _layers["squeeze_excitation"] = SqueezeExcitation(mid_channels, sqz_channels, activation=nn.SiLU)
  108. _layers["conv_c"] = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, bias=True)
  109. self.layers = nn.Sequential(_layers)
  110. def forward(self, x: Tensor) -> Tensor:
  111. """
  112. Args:
  113. x (Tensor): Input tensor with expected layout of [B, C, H, W].
  114. Returns:
  115. Tensor: Output tensor with expected layout of [B, C, H / stride, W / stride].
  116. """
  117. res = self.proj(x)
  118. x = self.stochastic_depth(self.layers(x))
  119. return res + x
  120. class RelativePositionalMultiHeadAttention(nn.Module):
  121. """Relative Positional Multi-Head Attention.
  122. Args:
  123. feat_dim (int): Number of input features.
  124. head_dim (int): Number of features per head.
  125. max_seq_len (int): Maximum sequence length.
  126. """
  127. def __init__(
  128. self,
  129. feat_dim: int,
  130. head_dim: int,
  131. max_seq_len: int,
  132. ) -> None:
  133. super().__init__()
  134. if feat_dim % head_dim != 0:
  135. raise ValueError(f"feat_dim: {feat_dim} must be divisible by head_dim: {head_dim}")
  136. self.n_heads = feat_dim // head_dim
  137. self.head_dim = head_dim
  138. self.size = int(math.sqrt(max_seq_len))
  139. self.max_seq_len = max_seq_len
  140. self.to_qkv = nn.Linear(feat_dim, self.n_heads * self.head_dim * 3)
  141. self.scale_factor = feat_dim**-0.5
  142. self.merge = nn.Linear(self.head_dim * self.n_heads, feat_dim)
  143. self.relative_position_bias_table = nn.parameter.Parameter(
  144. torch.empty(((2 * self.size - 1) * (2 * self.size - 1), self.n_heads), dtype=torch.float32),
  145. )
  146. self.register_buffer("relative_position_index", _get_relative_position_index(self.size, self.size))
  147. # initialize with truncated normal the bias
  148. torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
  149. def get_relative_positional_bias(self) -> torch.Tensor:
  150. bias_index = self.relative_position_index.view(-1) # type: ignore
  151. relative_bias = self.relative_position_bias_table[bias_index].view(self.max_seq_len, self.max_seq_len, -1) # type: ignore
  152. relative_bias = relative_bias.permute(2, 0, 1).contiguous()
  153. return relative_bias.unsqueeze(0)
  154. def forward(self, x: Tensor) -> Tensor:
  155. """
  156. Args:
  157. x (Tensor): Input tensor with expected layout of [B, G, P, D].
  158. Returns:
  159. Tensor: Output tensor with expected layout of [B, G, P, D].
  160. """
  161. B, G, P, D = x.shape
  162. H, DH = self.n_heads, self.head_dim
  163. qkv = self.to_qkv(x)
  164. q, k, v = torch.chunk(qkv, 3, dim=-1)
  165. q = q.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
  166. k = k.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
  167. v = v.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
  168. k = k * self.scale_factor
  169. dot_prod = torch.einsum("B G H I D, B G H J D -> B G H I J", q, k)
  170. pos_bias = self.get_relative_positional_bias()
  171. dot_prod = F.softmax(dot_prod + pos_bias, dim=-1)
  172. out = torch.einsum("B G H I J, B G H J D -> B G H I D", dot_prod, v)
  173. out = out.permute(0, 1, 3, 2, 4).reshape(B, G, P, D)
  174. out = self.merge(out)
  175. return out
  176. class SwapAxes(nn.Module):
  177. """Permute the axes of a tensor."""
  178. def __init__(self, a: int, b: int) -> None:
  179. super().__init__()
  180. self.a = a
  181. self.b = b
  182. def forward(self, x: torch.Tensor) -> torch.Tensor:
  183. res = torch.swapaxes(x, self.a, self.b)
  184. return res
  185. class WindowPartition(nn.Module):
  186. """
  187. Partition the input tensor into non-overlapping windows.
  188. """
  189. def __init__(self) -> None:
  190. super().__init__()
  191. def forward(self, x: Tensor, p: int) -> Tensor:
  192. """
  193. Args:
  194. x (Tensor): Input tensor with expected layout of [B, C, H, W].
  195. p (int): Number of partitions.
  196. Returns:
  197. Tensor: Output tensor with expected layout of [B, H/P, W/P, P*P, C].
  198. """
  199. B, C, H, W = x.shape
  200. P = p
  201. # chunk up H and W dimensions
  202. x = x.reshape(B, C, H // P, P, W // P, P)
  203. x = x.permute(0, 2, 4, 3, 5, 1)
  204. # colapse P * P dimension
  205. x = x.reshape(B, (H // P) * (W // P), P * P, C)
  206. return x
  207. class WindowDepartition(nn.Module):
  208. """
  209. Departition the input tensor of non-overlapping windows into a feature volume of layout [B, C, H, W].
  210. """
  211. def __init__(self) -> None:
  212. super().__init__()
  213. def forward(self, x: Tensor, p: int, h_partitions: int, w_partitions: int) -> Tensor:
  214. """
  215. Args:
  216. x (Tensor): Input tensor with expected layout of [B, (H/P * W/P), P*P, C].
  217. p (int): Number of partitions.
  218. h_partitions (int): Number of vertical partitions.
  219. w_partitions (int): Number of horizontal partitions.
  220. Returns:
  221. Tensor: Output tensor with expected layout of [B, C, H, W].
  222. """
  223. B, G, PP, C = x.shape
  224. P = p
  225. HP, WP = h_partitions, w_partitions
  226. # split P * P dimension into 2 P tile dimensionsa
  227. x = x.reshape(B, HP, WP, P, P, C)
  228. # permute into B, C, HP, P, WP, P
  229. x = x.permute(0, 5, 1, 3, 2, 4)
  230. # reshape into B, C, H, W
  231. x = x.reshape(B, C, HP * P, WP * P)
  232. return x
  233. class PartitionAttentionLayer(nn.Module):
  234. """
  235. Layer for partitioning the input tensor into non-overlapping windows and applying attention to each window.
  236. Args:
  237. in_channels (int): Number of input channels.
  238. head_dim (int): Dimension of each attention head.
  239. partition_size (int): Size of the partitions.
  240. partition_type (str): Type of partitioning to use. Can be either "grid" or "window".
  241. grid_size (Tuple[int, int]): Size of the grid to partition the input tensor into.
  242. mlp_ratio (int): Ratio of the feature size expansion in the MLP layer.
  243. activation_layer (Callable[..., nn.Module]): Activation function to use.
  244. norm_layer (Callable[..., nn.Module]): Normalization function to use.
  245. attention_dropout (float): Dropout probability for the attention layer.
  246. mlp_dropout (float): Dropout probability for the MLP layer.
  247. p_stochastic_dropout (float): Probability of dropping out a partition.
  248. """
  249. def __init__(
  250. self,
  251. in_channels: int,
  252. head_dim: int,
  253. # partitioning parameters
  254. partition_size: int,
  255. partition_type: str,
  256. # grid size needs to be known at initialization time
  257. # because we need to know hamy relative offsets there are in the grid
  258. grid_size: Tuple[int, int],
  259. mlp_ratio: int,
  260. activation_layer: Callable[..., nn.Module],
  261. norm_layer: Callable[..., nn.Module],
  262. attention_dropout: float,
  263. mlp_dropout: float,
  264. p_stochastic_dropout: float,
  265. ) -> None:
  266. super().__init__()
  267. self.n_heads = in_channels // head_dim
  268. self.head_dim = head_dim
  269. self.n_partitions = grid_size[0] // partition_size
  270. self.partition_type = partition_type
  271. self.grid_size = grid_size
  272. if partition_type not in ["grid", "window"]:
  273. raise ValueError("partition_type must be either 'grid' or 'window'")
  274. if partition_type == "window":
  275. self.p, self.g = partition_size, self.n_partitions
  276. else:
  277. self.p, self.g = self.n_partitions, partition_size
  278. self.partition_op = WindowPartition()
  279. self.departition_op = WindowDepartition()
  280. self.partition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity()
  281. self.departition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity()
  282. self.attn_layer = nn.Sequential(
  283. norm_layer(in_channels),
  284. # it's always going to be partition_size ** 2 because
  285. # of the axis swap in the case of grid partitioning
  286. RelativePositionalMultiHeadAttention(in_channels, head_dim, partition_size**2),
  287. nn.Dropout(attention_dropout),
  288. )
  289. # pre-normalization similar to transformer layers
  290. self.mlp_layer = nn.Sequential(
  291. nn.LayerNorm(in_channels),
  292. nn.Linear(in_channels, in_channels * mlp_ratio),
  293. activation_layer(),
  294. nn.Linear(in_channels * mlp_ratio, in_channels),
  295. nn.Dropout(mlp_dropout),
  296. )
  297. # layer scale factors
  298. self.stochastic_dropout = StochasticDepth(p_stochastic_dropout, mode="row")
  299. def forward(self, x: Tensor) -> Tensor:
  300. """
  301. Args:
  302. x (Tensor): Input tensor with expected layout of [B, C, H, W].
  303. Returns:
  304. Tensor: Output tensor with expected layout of [B, C, H, W].
  305. """
  306. # Undefined behavior if H or W are not divisible by p
  307. # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766
  308. gh, gw = self.grid_size[0] // self.p, self.grid_size[1] // self.p
  309. torch._assert(
  310. self.grid_size[0] % self.p == 0 and self.grid_size[1] % self.p == 0,
  311. "Grid size must be divisible by partition size. Got grid size of {} and partition size of {}".format(
  312. self.grid_size, self.p
  313. ),
  314. )
  315. x = self.partition_op(x, self.p)
  316. x = self.partition_swap(x)
  317. x = x + self.stochastic_dropout(self.attn_layer(x))
  318. x = x + self.stochastic_dropout(self.mlp_layer(x))
  319. x = self.departition_swap(x)
  320. x = self.departition_op(x, self.p, gh, gw)
  321. return x
  322. class MaxVitLayer(nn.Module):
  323. """
  324. MaxVit layer consisting of a MBConv layer followed by a PartitionAttentionLayer with `window` and a PartitionAttentionLayer with `grid`.
  325. Args:
  326. in_channels (int): Number of input channels.
  327. out_channels (int): Number of output channels.
  328. expansion_ratio (float): Expansion ratio in the bottleneck.
  329. squeeze_ratio (float): Squeeze ratio in the SE Layer.
  330. stride (int): Stride of the depthwise convolution.
  331. activation_layer (Callable[..., nn.Module]): Activation function.
  332. norm_layer (Callable[..., nn.Module]): Normalization function.
  333. head_dim (int): Dimension of the attention heads.
  334. mlp_ratio (int): Ratio of the MLP layer.
  335. mlp_dropout (float): Dropout probability for the MLP layer.
  336. attention_dropout (float): Dropout probability for the attention layer.
  337. p_stochastic_dropout (float): Probability of stochastic depth.
  338. partition_size (int): Size of the partitions.
  339. grid_size (Tuple[int, int]): Size of the input feature grid.
  340. """
  341. def __init__(
  342. self,
  343. # conv parameters
  344. in_channels: int,
  345. out_channels: int,
  346. squeeze_ratio: float,
  347. expansion_ratio: float,
  348. stride: int,
  349. # conv + transformer parameters
  350. norm_layer: Callable[..., nn.Module],
  351. activation_layer: Callable[..., nn.Module],
  352. # transformer parameters
  353. head_dim: int,
  354. mlp_ratio: int,
  355. mlp_dropout: float,
  356. attention_dropout: float,
  357. p_stochastic_dropout: float,
  358. # partitioning parameters
  359. partition_size: int,
  360. grid_size: Tuple[int, int],
  361. ) -> None:
  362. super().__init__()
  363. layers: OrderedDict = OrderedDict()
  364. # convolutional layer
  365. layers["MBconv"] = MBConv(
  366. in_channels=in_channels,
  367. out_channels=out_channels,
  368. expansion_ratio=expansion_ratio,
  369. squeeze_ratio=squeeze_ratio,
  370. stride=stride,
  371. activation_layer=activation_layer,
  372. norm_layer=norm_layer,
  373. p_stochastic_dropout=p_stochastic_dropout,
  374. )
  375. # attention layers, block -> grid
  376. layers["window_attention"] = PartitionAttentionLayer(
  377. in_channels=out_channels,
  378. head_dim=head_dim,
  379. partition_size=partition_size,
  380. partition_type="window",
  381. grid_size=grid_size,
  382. mlp_ratio=mlp_ratio,
  383. activation_layer=activation_layer,
  384. norm_layer=nn.LayerNorm,
  385. attention_dropout=attention_dropout,
  386. mlp_dropout=mlp_dropout,
  387. p_stochastic_dropout=p_stochastic_dropout,
  388. )
  389. layers["grid_attention"] = PartitionAttentionLayer(
  390. in_channels=out_channels,
  391. head_dim=head_dim,
  392. partition_size=partition_size,
  393. partition_type="grid",
  394. grid_size=grid_size,
  395. mlp_ratio=mlp_ratio,
  396. activation_layer=activation_layer,
  397. norm_layer=nn.LayerNorm,
  398. attention_dropout=attention_dropout,
  399. mlp_dropout=mlp_dropout,
  400. p_stochastic_dropout=p_stochastic_dropout,
  401. )
  402. self.layers = nn.Sequential(layers)
  403. def forward(self, x: Tensor) -> Tensor:
  404. """
  405. Args:
  406. x (Tensor): Input tensor of shape (B, C, H, W).
  407. Returns:
  408. Tensor: Output tensor of shape (B, C, H, W).
  409. """
  410. x = self.layers(x)
  411. return x
  412. class MaxVitBlock(nn.Module):
  413. """
  414. A MaxVit block consisting of `n_layers` MaxVit layers.
  415. Args:
  416. in_channels (int): Number of input channels.
  417. out_channels (int): Number of output channels.
  418. expansion_ratio (float): Expansion ratio in the bottleneck.
  419. squeeze_ratio (float): Squeeze ratio in the SE Layer.
  420. activation_layer (Callable[..., nn.Module]): Activation function.
  421. norm_layer (Callable[..., nn.Module]): Normalization function.
  422. head_dim (int): Dimension of the attention heads.
  423. mlp_ratio (int): Ratio of the MLP layer.
  424. mlp_dropout (float): Dropout probability for the MLP layer.
  425. attention_dropout (float): Dropout probability for the attention layer.
  426. p_stochastic_dropout (float): Probability of stochastic depth.
  427. partition_size (int): Size of the partitions.
  428. input_grid_size (Tuple[int, int]): Size of the input feature grid.
  429. n_layers (int): Number of layers in the block.
  430. p_stochastic (List[float]): List of probabilities for stochastic depth for each layer.
  431. """
  432. def __init__(
  433. self,
  434. # conv parameters
  435. in_channels: int,
  436. out_channels: int,
  437. squeeze_ratio: float,
  438. expansion_ratio: float,
  439. # conv + transformer parameters
  440. norm_layer: Callable[..., nn.Module],
  441. activation_layer: Callable[..., nn.Module],
  442. # transformer parameters
  443. head_dim: int,
  444. mlp_ratio: int,
  445. mlp_dropout: float,
  446. attention_dropout: float,
  447. # partitioning parameters
  448. partition_size: int,
  449. input_grid_size: Tuple[int, int],
  450. # number of layers
  451. n_layers: int,
  452. p_stochastic: List[float],
  453. ) -> None:
  454. super().__init__()
  455. if not len(p_stochastic) == n_layers:
  456. raise ValueError(f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}.")
  457. self.layers = nn.ModuleList()
  458. # account for the first stride of the first layer
  459. self.grid_size = _get_conv_output_shape(input_grid_size, kernel_size=3, stride=2, padding=1)
  460. for idx, p in enumerate(p_stochastic):
  461. stride = 2 if idx == 0 else 1
  462. self.layers += [
  463. MaxVitLayer(
  464. in_channels=in_channels if idx == 0 else out_channels,
  465. out_channels=out_channels,
  466. squeeze_ratio=squeeze_ratio,
  467. expansion_ratio=expansion_ratio,
  468. stride=stride,
  469. norm_layer=norm_layer,
  470. activation_layer=activation_layer,
  471. head_dim=head_dim,
  472. mlp_ratio=mlp_ratio,
  473. mlp_dropout=mlp_dropout,
  474. attention_dropout=attention_dropout,
  475. partition_size=partition_size,
  476. grid_size=self.grid_size,
  477. p_stochastic_dropout=p,
  478. ),
  479. ]
  480. def forward(self, x: Tensor) -> Tensor:
  481. """
  482. Args:
  483. x (Tensor): Input tensor of shape (B, C, H, W).
  484. Returns:
  485. Tensor: Output tensor of shape (B, C, H, W).
  486. """
  487. for layer in self.layers:
  488. x = layer(x)
  489. return x
  490. class MaxVit(nn.Module):
  491. """
  492. Implements MaxVit Transformer from the `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_ paper.
  493. Args:
  494. input_size (Tuple[int, int]): Size of the input image.
  495. stem_channels (int): Number of channels in the stem.
  496. partition_size (int): Size of the partitions.
  497. block_channels (List[int]): Number of channels in each block.
  498. block_layers (List[int]): Number of layers in each block.
  499. stochastic_depth_prob (float): Probability of stochastic depth. Expands to a list of probabilities for each layer that scales linearly to the specified value.
  500. squeeze_ratio (float): Squeeze ratio in the SE Layer. Default: 0.25.
  501. expansion_ratio (float): Expansion ratio in the MBConv bottleneck. Default: 4.
  502. norm_layer (Callable[..., nn.Module]): Normalization function. Default: None (setting to None will produce a `BatchNorm2d(eps=1e-3, momentum=0.99)`).
  503. activation_layer (Callable[..., nn.Module]): Activation function Default: nn.GELU.
  504. head_dim (int): Dimension of the attention heads.
  505. mlp_ratio (int): Expansion ratio of the MLP layer. Default: 4.
  506. mlp_dropout (float): Dropout probability for the MLP layer. Default: 0.0.
  507. attention_dropout (float): Dropout probability for the attention layer. Default: 0.0.
  508. num_classes (int): Number of classes. Default: 1000.
  509. """
  510. def __init__(
  511. self,
  512. # input size parameters
  513. input_size: Tuple[int, int],
  514. # stem and task parameters
  515. stem_channels: int,
  516. # partitioning parameters
  517. partition_size: int,
  518. # block parameters
  519. block_channels: List[int],
  520. block_layers: List[int],
  521. # attention head dimensions
  522. head_dim: int,
  523. stochastic_depth_prob: float,
  524. # conv + transformer parameters
  525. # norm_layer is applied only to the conv layers
  526. # activation_layer is applied both to conv and transformer layers
  527. norm_layer: Optional[Callable[..., nn.Module]] = None,
  528. activation_layer: Callable[..., nn.Module] = nn.GELU,
  529. # conv parameters
  530. squeeze_ratio: float = 0.25,
  531. expansion_ratio: float = 4,
  532. # transformer parameters
  533. mlp_ratio: int = 4,
  534. mlp_dropout: float = 0.0,
  535. attention_dropout: float = 0.0,
  536. # task parameters
  537. num_classes: int = 1000,
  538. ) -> None:
  539. super().__init__()
  540. _log_api_usage_once(self)
  541. input_channels = 3
  542. # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1029-L1030
  543. # for the exact parameters used in batchnorm
  544. if norm_layer is None:
  545. norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.99)
  546. # Make sure input size will be divisible by the partition size in all blocks
  547. # Undefined behavior if H or W are not divisible by p
  548. # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766
  549. block_input_sizes = _make_block_input_shapes(input_size, len(block_channels))
  550. for idx, block_input_size in enumerate(block_input_sizes):
  551. if block_input_size[0] % partition_size != 0 or block_input_size[1] % partition_size != 0:
  552. raise ValueError(
  553. f"Input size {block_input_size} of block {idx} is not divisible by partition size {partition_size}. "
  554. f"Consider changing the partition size or the input size.\n"
  555. f"Current configuration yields the following block input sizes: {block_input_sizes}."
  556. )
  557. # stem
  558. self.stem0 = nn.Sequential(
  559. Conv2dNormActivation(
  560. input_channels,
  561. stem_channels,
  562. 3,
  563. stride=1,
  564. norm_layer=norm_layer,
  565. activation_layer=activation_layer,
  566. bias=False,
  567. inplace=None,
  568. ),
  569. Conv2dNormActivation(
  570. stem_channels, stem_channels, 3, stride=1, norm_layer=None, activation_layer=None, bias=True
  571. ),
  572. )
  573. self.stem1 = nn.Sequential(
  574. Conv2dNormActivation(
  575. stem_channels,
  576. stem_channels,
  577. 3,
  578. stride=2,
  579. norm_layer=norm_layer,
  580. activation_layer=activation_layer,
  581. bias=False,
  582. inplace=None,
  583. ),
  584. Conv2dNormActivation(
  585. stem_channels, stem_channels, 3, stride=1, norm_layer=None, activation_layer=None, bias=True
  586. ),
  587. )
  588. # account for stem stride
  589. input_size = _get_conv_output_shape(input_size, kernel_size=3, stride=2, padding=1)
  590. self.partition_size = partition_size
  591. # blocks
  592. self.blocks = nn.ModuleList()
  593. in_channels = [stem_channels] + block_channels[:-1]
  594. out_channels = block_channels
  595. # precompute the stochastich depth probabilities from 0 to stochastic_depth_prob
  596. # since we have N blocks with L layers, we will have N * L probabilities uniformly distributed
  597. # over the range [0, stochastic_depth_prob]
  598. p_stochastic = np.linspace(0, stochastic_depth_prob, sum(block_layers)).tolist()
  599. p_idx = 0
  600. for in_channel, out_channel, num_layers in zip(in_channels, out_channels, block_layers):
  601. self.blocks.append(
  602. MaxVitBlock(
  603. in_channels=in_channel,
  604. out_channels=out_channel,
  605. squeeze_ratio=squeeze_ratio,
  606. expansion_ratio=expansion_ratio,
  607. norm_layer=norm_layer,
  608. activation_layer=activation_layer,
  609. head_dim=head_dim,
  610. mlp_ratio=mlp_ratio,
  611. mlp_dropout=mlp_dropout,
  612. attention_dropout=attention_dropout,
  613. partition_size=partition_size,
  614. input_grid_size=input_size,
  615. n_layers=num_layers,
  616. p_stochastic=p_stochastic[p_idx : p_idx + num_layers],
  617. ),
  618. )
  619. input_size = self.blocks[-1].grid_size
  620. p_idx += num_layers
  621. # see https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1137-L1158
  622. # for why there is Linear -> Tanh -> Linear
  623. self.classifier = nn.Sequential(
  624. nn.AdaptiveAvgPool2d(1),
  625. nn.Flatten(),
  626. nn.LayerNorm(block_channels[-1]),
  627. nn.Linear(block_channels[-1], block_channels[-1]),
  628. nn.Tanh(),
  629. nn.Linear(block_channels[-1], num_classes, bias=False),
  630. )
  631. self._init_weights()
  632. def forward(self, x: Tensor) -> Tensor:
  633. x = self.stem0(x)
  634. x=self.stem1(x)
  635. for block in self.blocks:
  636. x = block(x)
  637. x = self.classifier(x)
  638. return x
  639. def _init_weights(self):
  640. for m in self.modules():
  641. if isinstance(m, nn.Conv2d):
  642. nn.init.normal_(m.weight, std=0.02)
  643. if m.bias is not None:
  644. nn.init.zeros_(m.bias)
  645. elif isinstance(m, nn.BatchNorm2d):
  646. nn.init.constant_(m.weight, 1)
  647. nn.init.constant_(m.bias, 0)
  648. elif isinstance(m, nn.Linear):
  649. nn.init.normal_(m.weight, std=0.02)
  650. if m.bias is not None:
  651. nn.init.zeros_(m.bias)
  652. def _maxvit(
  653. # stem parameters
  654. stem_channels: int,
  655. # block parameters
  656. block_channels: List[int],
  657. block_layers: List[int],
  658. stochastic_depth_prob: float,
  659. # partitioning parameters
  660. partition_size: int,
  661. # transformer parameters
  662. head_dim: int,
  663. # Weights API
  664. weights: Optional[WeightsEnum] = None,
  665. progress: bool = False,
  666. # kwargs,
  667. **kwargs: Any,
  668. ) -> MaxVit:
  669. if weights is not None:
  670. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  671. assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
  672. _ovewrite_named_param(kwargs, "input_size", weights.meta["min_size"])
  673. input_size = kwargs.pop("input_size", (224, 224))
  674. model = MaxVit(
  675. stem_channels=stem_channels,
  676. block_channels=block_channels,
  677. block_layers=block_layers,
  678. stochastic_depth_prob=stochastic_depth_prob,
  679. head_dim=head_dim,
  680. partition_size=partition_size,
  681. input_size=input_size,
  682. **kwargs,
  683. )
  684. if weights is not None:
  685. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  686. return model
  687. class MaxVit_T_Weights(WeightsEnum):
  688. IMAGENET1K_V1 = Weights(
  689. # URL empty until official release
  690. url="https://download.pytorch.org/models/maxvit_t-bc5ab103.pth",
  691. transforms=partial(
  692. ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
  693. ),
  694. meta={
  695. "categories": _IMAGENET_CATEGORIES,
  696. "num_params": 30919624,
  697. "min_size": (224, 224),
  698. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#maxvit",
  699. "_metrics": {
  700. "ImageNet-1K": {
  701. "acc@1": 83.700,
  702. "acc@5": 96.722,
  703. }
  704. },
  705. "_ops": 5.558,
  706. "_file_size": 118.769,
  707. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  708. },
  709. )
  710. DEFAULT = IMAGENET1K_V1
  711. @handle_legacy_interface(weights=("pretrained", MaxVit_T_Weights.IMAGENET1K_V1))
  712. def maxvit_t(*, weights: Optional[MaxVit_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MaxVit:
  713. """
  714. Constructs a maxvit_t architecture from
  715. `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_.
  716. Args:
  717. weights (:class:`~torchvision.models.MaxVit_T_Weights`, optional): The
  718. pretrained weights to use. See
  719. :class:`~torchvision.models.MaxVit_T_Weights` below for
  720. more details, and possible values. By default, no pre-trained
  721. weights are used.
  722. progress (bool, optional): If True, displays a progress bar of the
  723. download to stderr. Default is True.
  724. **kwargs: parameters passed to the ``torchvision.models.maxvit.MaxVit``
  725. base class. Please refer to the `source code
  726. <https://github.com/pytorch/vision/blob/main/torchvision/models/maxvit.py>`_
  727. for more details about this class.
  728. .. autoclass:: torchvision.models.MaxVit_T_Weights
  729. :members:
  730. """
  731. weights = MaxVit_T_Weights.verify(weights)
  732. return _maxvit(
  733. stem_channels=64,
  734. block_channels=[64, 128, 256, 512],
  735. block_layers=[2, 2, 5, 2],
  736. head_dim=32,
  737. stochastic_depth_prob=0.2,
  738. partition_size=7,
  739. weights=weights,
  740. progress=progress,
  741. **kwargs,
  742. )
  743. class MaxVitBackbone(torch.nn.Module):
  744. def __init__(self,input_size=(224*2,224*2)):
  745. super(MaxVitBackbone, self).__init__()
  746. # 提取MaxVit的部分层作为特征提取器
  747. maxvit_model = maxvit_t(pretrained=False,input_size=input_size)
  748. self.stem0 = maxvit_model.stem0 # Stem层
  749. self.stem1 = maxvit_model.stem1 # Stem层
  750. self.block0= maxvit_model.blocks[0]
  751. self.block1 = maxvit_model.blocks[1]
  752. self.block2 = maxvit_model.blocks[2]
  753. self.block3 = maxvit_model.blocks[3]
  754. def forward(self, x):
  755. print("Input size:", x.shape)
  756. x = self.stem0(x)
  757. print("After stem0 size:", x.shape)
  758. x=self.stem1(x)
  759. print("After stem1 size:", x.shape)
  760. x = self.block0(x)
  761. print("After block0 size:", x.shape)
  762. x = self.block1(x)
  763. print("After block1 size:", x.shape)
  764. x = self.block2(x)
  765. print("After block2 size:", x.shape)
  766. x = self.block3(x)
  767. print("After block3 size:", x.shape)
  768. return x
  769. def maxvit_with_fpn(size=224):
  770. maxvit = MaxVitBackbone(input_size=(size, size))
  771. in_channels_list = [64, 64, 64, 128, 256, 512]
  772. featmap_names = ['0', '1', '2', '3', '4','5', 'pool']
  773. # print(f'featmap_names:{featmap_names}')
  774. # roi_pooler = MultiScaleRoIAlign(
  775. # featmap_names=featmap_names,
  776. # output_size=7,
  777. # sampling_ratio=2
  778. # )
  779. backbone_with_fpn = BackboneWithFPN(
  780. maxvit,
  781. return_layers={'stem0': '0', 'stem1': '1', 'block0': '2', 'block1': '3', 'block2': '4', 'block3': '5'},
  782. # 确保这些键对应到实际的层
  783. in_channels_list=in_channels_list,
  784. out_channels=256
  785. )
  786. test_input = torch.randn(1, 3, size, size)
  787. return backbone_with_fpn
  788. if __name__ == '__main__':
  789. maxvit = MaxVitBackbone(input_size=(224 * 3, 224 * 3))
  790. in_channels_list = [64,64, 64, 128, 256, 512]
  791. featmap_names = ['0', '1', '2', '3', '4', 'pool']
  792. # print(f'featmap_names:{featmap_names}')
  793. # roi_pooler = MultiScaleRoIAlign(
  794. # featmap_names=featmap_names,
  795. # output_size=7,
  796. # sampling_ratio=2
  797. # )
  798. backbone_with_fpn = BackboneWithFPN(
  799. maxvit,
  800. return_layers={'stem0': '0','stem1': '1', 'block0': '2', 'block1': '3', 'block2': '4', 'block3': '5'}, # 确保这些键对应到实际的层
  801. in_channels_list=in_channels_list,
  802. out_channels=256
  803. )
  804. test_input = torch.randn(1, 3, 224 * 3, 224* 3 )
  805. # model = FasterRCNN(
  806. # backbone=backbone_with_fpn,
  807. # min_size=224 * 5,
  808. # max_size=224 * 5,
  809. # num_classes=91, # COCO 数据集有 91 类
  810. # rpn_anchor_generator=get_anchor_generator(backbone_with_fpn, test_input=test_input),
  811. # box_roi_pool=roi_pooler
  812. # )
  813. out = maxvit(test_input)
  814. with torch.no_grad():
  815. output = backbone_with_fpn(test_input)
  816. #
  817. print("Output feature maps:")
  818. for k, v in output.items():
  819. print(f"{k}: {v.shape}")
  820. # model.eval()
  821. # output = model(test_input)
  822. # print(f'fasterrcnn output:{output}')