high_reso_swin.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041
  1. import math
  2. from functools import partial
  3. from typing import Any, Callable, List, Optional
  4. import torch
  5. import torch.nn.functional as F
  6. from torch import nn, Tensor
  7. from libs.vision_libs.models._api import register_model, Weights, WeightsEnum
  8. from libs.vision_libs.models._utils import _ovewrite_named_param, handle_legacy_interface
  9. from libs.vision_libs.transforms import InterpolationMode
  10. from libs.vision_libs.transforms._presets import ImageClassification
  11. from libs.vision_libs.utils import _log_api_usage_once
  12. from libs.vision_libs.ops import MLP,Permute
  13. from libs.vision_libs.ops.stochastic_depth import StochasticDepth
  14. from libs.vision_libs.models._meta import _IMAGENET_CATEGORIES
  15. __all__ = [
  16. "SwinTransformer",
  17. "Swin_T_Weights",
  18. "Swin_S_Weights",
  19. "Swin_B_Weights",
  20. "Swin_V2_T_Weights",
  21. "Swin_V2_S_Weights",
  22. "Swin_V2_B_Weights",
  23. "swin_t",
  24. "swin_s",
  25. "swin_b",
  26. "swin_v2_t",
  27. "swin_v2_s",
  28. "swin_v2_b",
  29. ]
  30. def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor:
  31. H, W, _ = x.shape[-3:]
  32. x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  33. x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
  34. x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
  35. x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
  36. x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
  37. x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
  38. return x
  39. torch.fx.wrap("_patch_merging_pad")
  40. def _get_relative_position_bias(
  41. relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int]
  42. ) -> torch.Tensor:
  43. N = window_size[0] * window_size[1]
  44. relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index]
  45. relative_position_bias = relative_position_bias.view(N, N, -1)
  46. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
  47. return relative_position_bias
  48. torch.fx.wrap("_get_relative_position_bias")
  49. class PatchMerging(nn.Module):
  50. """Patch Merging Layer.
  51. Args:
  52. dim (int): Number of input channels.
  53. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  54. """
  55. def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
  56. super().__init__()
  57. _log_api_usage_once(self)
  58. self.dim = dim
  59. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  60. self.norm = norm_layer(4 * dim)
  61. def forward(self, x: Tensor):
  62. """
  63. Args:
  64. x (Tensor): input tensor with expected layout of [..., H, W, C]
  65. Returns:
  66. Tensor with layout of [..., H/2, W/2, 2*C]
  67. """
  68. x = _patch_merging_pad(x)
  69. x = self.norm(x)
  70. x = self.reduction(x) # ... H/2 W/2 2*C
  71. return x
  72. class PatchMergingV2(nn.Module):
  73. """Patch Merging Layer for Swin Transformer V2.
  74. Args:
  75. dim (int): Number of input channels.
  76. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  77. """
  78. def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm):
  79. super().__init__()
  80. _log_api_usage_once(self)
  81. self.dim = dim
  82. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  83. self.norm = norm_layer(2 * dim) # difference
  84. def forward(self, x: Tensor):
  85. """
  86. Args:
  87. x (Tensor): input tensor with expected layout of [..., H, W, C]
  88. Returns:
  89. Tensor with layout of [..., H/2, W/2, 2*C]
  90. """
  91. x = _patch_merging_pad(x)
  92. x = self.reduction(x) # ... H/2 W/2 2*C
  93. x = self.norm(x)
  94. return x
  95. def shifted_window_attention(
  96. input: Tensor,
  97. qkv_weight: Tensor,
  98. proj_weight: Tensor,
  99. relative_position_bias: Tensor,
  100. window_size: List[int],
  101. num_heads: int,
  102. shift_size: List[int],
  103. attention_dropout: float = 0.0,
  104. dropout: float = 0.0,
  105. qkv_bias: Optional[Tensor] = None,
  106. proj_bias: Optional[Tensor] = None,
  107. logit_scale: Optional[torch.Tensor] = None,
  108. training: bool = True,
  109. ) -> Tensor:
  110. """
  111. Window based multi-head self attention (W-MSA) module with relative position bias.
  112. It supports both of shifted and non-shifted window.
  113. Args:
  114. input (Tensor[N, H, W, C]): The input tensor or 4-dimensions.
  115. qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
  116. proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
  117. relative_position_bias (Tensor): The learned relative position bias added to attention.
  118. window_size (List[int]): Window size.
  119. num_heads (int): Number of attention heads.
  120. shift_size (List[int]): Shift size for shifted window attention.
  121. attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
  122. dropout (float): Dropout ratio of output. Default: 0.0.
  123. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
  124. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
  125. logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
  126. training (bool, optional): Training flag used by the dropout parameters. Default: True.
  127. Returns:
  128. Tensor[N, H, W, C]: The output tensor after shifted window attention.
  129. """
  130. B, H, W, C = input.shape
  131. # pad feature maps to multiples of window size
  132. pad_r = (window_size[1] - W % window_size[1]) % window_size[1]
  133. pad_b = (window_size[0] - H % window_size[0]) % window_size[0]
  134. x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b))
  135. _, pad_H, pad_W, _ = x.shape
  136. shift_size = shift_size.copy()
  137. # If window size is larger than feature size, there is no need to shift window
  138. if window_size[0] >= pad_H:
  139. shift_size[0] = 0
  140. if window_size[1] >= pad_W:
  141. shift_size[1] = 0
  142. # cyclic shift
  143. if sum(shift_size) > 0:
  144. x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
  145. # partition windows
  146. num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
  147. x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C)
  148. x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C
  149. # multi-head attention
  150. if logit_scale is not None and qkv_bias is not None:
  151. qkv_bias = qkv_bias.clone()
  152. length = qkv_bias.numel() // 3
  153. qkv_bias[length : 2 * length].zero_()
  154. qkv = F.linear(x, qkv_weight, qkv_bias)
  155. qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
  156. q, k, v = qkv[0], qkv[1], qkv[2]
  157. if logit_scale is not None:
  158. # cosine attention
  159. attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
  160. logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp()
  161. attn = attn * logit_scale
  162. else:
  163. q = q * (C // num_heads) ** -0.5
  164. attn = q.matmul(k.transpose(-2, -1))
  165. # add relative position bias
  166. attn = attn + relative_position_bias
  167. if sum(shift_size) > 0:
  168. # generate attention ins
  169. attn_mask = x.new_zeros((pad_H, pad_W))
  170. h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None))
  171. w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None))
  172. count = 0
  173. for h in h_slices:
  174. for w in w_slices:
  175. attn_mask[h[0] : h[1], w[0] : w[1]] = count
  176. count += 1
  177. attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1])
  178. attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1])
  179. attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
  180. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  181. attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
  182. attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
  183. attn = attn.view(-1, num_heads, x.size(1), x.size(1))
  184. attn = F.softmax(attn, dim=-1)
  185. attn = F.dropout(attn, p=attention_dropout, training=training)
  186. x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
  187. x = F.linear(x, proj_weight, proj_bias)
  188. x = F.dropout(x, p=dropout, training=training)
  189. # reverse windows
  190. x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
  191. x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
  192. # reverse cyclic shift
  193. if sum(shift_size) > 0:
  194. x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
  195. # unpad features
  196. x = x[:, :H, :W, :].contiguous()
  197. return x
  198. torch.fx.wrap("shifted_window_attention")
  199. class ShiftedWindowAttention(nn.Module):
  200. """
  201. See :func:`shifted_window_attention`.
  202. """
  203. def __init__(
  204. self,
  205. dim: int,
  206. window_size: List[int],
  207. shift_size: List[int],
  208. num_heads: int,
  209. qkv_bias: bool = True,
  210. proj_bias: bool = True,
  211. attention_dropout: float = 0.0,
  212. dropout: float = 0.0,
  213. ):
  214. super().__init__()
  215. if len(window_size) != 2 or len(shift_size) != 2:
  216. raise ValueError("window_size and shift_size must be of length 2")
  217. self.window_size = window_size
  218. self.shift_size = shift_size
  219. self.num_heads = num_heads
  220. self.attention_dropout = attention_dropout
  221. self.dropout = dropout
  222. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  223. self.proj = nn.Linear(dim, dim, bias=proj_bias)
  224. self.define_relative_position_bias_table()
  225. self.define_relative_position_index()
  226. def define_relative_position_bias_table(self):
  227. # define a parameter table of relative position bias
  228. self.relative_position_bias_table = nn.Parameter(
  229. torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads)
  230. ) # 2*Wh-1 * 2*Ww-1, nH
  231. nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
  232. def define_relative_position_index(self):
  233. # get pair-wise relative position index for each token inside the window
  234. coords_h = torch.arange(self.window_size[0])
  235. coords_w = torch.arange(self.window_size[1])
  236. coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww
  237. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  238. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  239. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  240. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  241. relative_coords[:, :, 1] += self.window_size[1] - 1
  242. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  243. relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww
  244. self.register_buffer("relative_position_index", relative_position_index)
  245. def get_relative_position_bias(self) -> torch.Tensor:
  246. return _get_relative_position_bias(
  247. self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type]
  248. )
  249. def forward(self, x: Tensor) -> Tensor:
  250. """
  251. Args:
  252. x (Tensor): Tensor with layout of [B, H, W, C]
  253. Returns:
  254. Tensor with same layout as input, i.e. [B, H, W, C]
  255. """
  256. relative_position_bias = self.get_relative_position_bias()
  257. return shifted_window_attention(
  258. x,
  259. self.qkv.weight,
  260. self.proj.weight,
  261. relative_position_bias,
  262. self.window_size,
  263. self.num_heads,
  264. shift_size=self.shift_size,
  265. attention_dropout=self.attention_dropout,
  266. dropout=self.dropout,
  267. qkv_bias=self.qkv.bias,
  268. proj_bias=self.proj.bias,
  269. training=self.training,
  270. )
  271. class ShiftedWindowAttentionV2(ShiftedWindowAttention):
  272. """
  273. See :func:`shifted_window_attention_v2`.
  274. """
  275. def __init__(
  276. self,
  277. dim: int,
  278. window_size: List[int],
  279. shift_size: List[int],
  280. num_heads: int,
  281. qkv_bias: bool = True,
  282. proj_bias: bool = True,
  283. attention_dropout: float = 0.0,
  284. dropout: float = 0.0,
  285. ):
  286. super().__init__(
  287. dim,
  288. window_size,
  289. shift_size,
  290. num_heads,
  291. qkv_bias=qkv_bias,
  292. proj_bias=proj_bias,
  293. attention_dropout=attention_dropout,
  294. dropout=dropout,
  295. )
  296. self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
  297. # mlp to generate continuous relative position bias
  298. self.cpb_mlp = nn.Sequential(
  299. nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
  300. )
  301. if qkv_bias:
  302. length = self.qkv.bias.numel() // 3
  303. self.qkv.bias[length : 2 * length].data.zero_()
  304. def define_relative_position_bias_table(self):
  305. # get relative_coords_table
  306. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
  307. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
  308. relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
  309. relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
  310. relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
  311. relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
  312. relative_coords_table *= 8 # normalize to -8, 8
  313. relative_coords_table = (
  314. torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0
  315. )
  316. self.register_buffer("relative_coords_table", relative_coords_table)
  317. def get_relative_position_bias(self) -> torch.Tensor:
  318. relative_position_bias = _get_relative_position_bias(
  319. self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads),
  320. self.relative_position_index, # type: ignore[arg-type]
  321. self.window_size,
  322. )
  323. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
  324. return relative_position_bias
  325. def forward(self, x: Tensor):
  326. """
  327. Args:
  328. x (Tensor): Tensor with layout of [B, H, W, C]
  329. Returns:
  330. Tensor with same layout as input, i.e. [B, H, W, C]
  331. """
  332. relative_position_bias = self.get_relative_position_bias()
  333. return shifted_window_attention(
  334. x,
  335. self.qkv.weight,
  336. self.proj.weight,
  337. relative_position_bias,
  338. self.window_size,
  339. self.num_heads,
  340. shift_size=self.shift_size,
  341. attention_dropout=self.attention_dropout,
  342. dropout=self.dropout,
  343. qkv_bias=self.qkv.bias,
  344. proj_bias=self.proj.bias,
  345. logit_scale=self.logit_scale,
  346. training=self.training,
  347. )
  348. class SwinTransformerBlock(nn.Module):
  349. """
  350. Swin Transformer Block.
  351. Args:
  352. dim (int): Number of input channels.
  353. num_heads (int): Number of attention heads.
  354. window_size (List[int]): Window size.
  355. shift_size (List[int]): Shift size for shifted window attention.
  356. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  357. dropout (float): Dropout rate. Default: 0.0.
  358. attention_dropout (float): Attention dropout rate. Default: 0.0.
  359. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
  360. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  361. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention
  362. """
  363. def __init__(
  364. self,
  365. dim: int,
  366. num_heads: int,
  367. window_size: List[int],
  368. shift_size: List[int],
  369. mlp_ratio: float = 4.0,
  370. dropout: float = 0.0,
  371. attention_dropout: float = 0.0,
  372. stochastic_depth_prob: float = 0.0,
  373. norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
  374. attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention,
  375. ):
  376. super().__init__()
  377. _log_api_usage_once(self)
  378. self.norm1 = norm_layer(dim)
  379. self.attn = attn_layer(
  380. dim,
  381. window_size,
  382. shift_size,
  383. num_heads,
  384. attention_dropout=attention_dropout,
  385. dropout=dropout,
  386. )
  387. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  388. self.norm2 = norm_layer(dim)
  389. self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
  390. for m in self.mlp.modules():
  391. if isinstance(m, nn.Linear):
  392. nn.init.xavier_uniform_(m.weight)
  393. if m.bias is not None:
  394. nn.init.normal_(m.bias, std=1e-6)
  395. def forward(self, x: Tensor):
  396. x = x + self.stochastic_depth(self.attn(self.norm1(x)))
  397. x = x + self.stochastic_depth(self.mlp(self.norm2(x)))
  398. # x=x.permute(0, 3, 1, 2).contiguous()
  399. return x
  400. class SwinTransformerBlockV2(SwinTransformerBlock):
  401. """
  402. Swin Transformer V2 Block.
  403. Args:
  404. dim (int): Number of input channels.
  405. num_heads (int): Number of attention heads.
  406. window_size (List[int]): Window size.
  407. shift_size (List[int]): Shift size for shifted window attention.
  408. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  409. dropout (float): Dropout rate. Default: 0.0.
  410. attention_dropout (float): Attention dropout rate. Default: 0.0.
  411. stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
  412. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  413. attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2.
  414. """
  415. def __init__(
  416. self,
  417. dim: int,
  418. num_heads: int,
  419. window_size: List[int],
  420. shift_size: List[int],
  421. mlp_ratio: float = 4.0,
  422. dropout: float = 0.0,
  423. attention_dropout: float = 0.0,
  424. stochastic_depth_prob: float = 0.0,
  425. norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
  426. attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2,
  427. ):
  428. super().__init__(
  429. dim,
  430. num_heads,
  431. window_size,
  432. shift_size,
  433. mlp_ratio=mlp_ratio,
  434. dropout=dropout,
  435. attention_dropout=attention_dropout,
  436. stochastic_depth_prob=stochastic_depth_prob,
  437. norm_layer=norm_layer,
  438. attn_layer=attn_layer,
  439. )
  440. def forward(self, x: Tensor):
  441. # Here is the difference, we apply norm after the attention in V2.
  442. # In V1 we applied norm before the attention.
  443. x = x + self.stochastic_depth(self.norm1(self.attn(x)))
  444. x = x + self.stochastic_depth(self.norm2(self.mlp(x)))
  445. return x
  446. class SwinTransformer(nn.Module):
  447. """
  448. Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
  449. Shifted Windows" <https://arxiv.org/abs/2103.14030>`_ paper.
  450. Args:
  451. patch_size (List[int]): Patch size.
  452. embed_dim (int): Patch embedding dimension.
  453. depths (List(int)): Depth of each Swin Transformer layer.
  454. num_heads (List(int)): Number of attention heads in different layers.
  455. window_size (List[int]): Window size.
  456. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  457. dropout (float): Dropout rate. Default: 0.0.
  458. attention_dropout (float): Attention dropout rate. Default: 0.0.
  459. stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1.
  460. num_classes (int): Number of classes for classification head. Default: 1000.
  461. block (nn.Module, optional): SwinTransformer Block. Default: None.
  462. norm_layer (nn.Module, optional): Normalization layer. Default: None.
  463. downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging.
  464. """
  465. def __init__(
  466. self,
  467. patch_size: List[int],
  468. embed_dim: int,
  469. depths: List[int],
  470. num_heads: List[int],
  471. window_size: List[int],
  472. mlp_ratio: float = 4.0,
  473. dropout: float = 0.0,
  474. attention_dropout: float = 0.0,
  475. stochastic_depth_prob: float = 0.1,
  476. num_classes: int = 1000,
  477. norm_layer: Optional[Callable[..., nn.Module]] = None,
  478. block: Optional[Callable[..., nn.Module]] = None,
  479. downsample_layer: Callable[..., nn.Module] = PatchMerging,
  480. ):
  481. super().__init__()
  482. _log_api_usage_once(self)
  483. self.num_classes = num_classes
  484. if block is None:
  485. block = SwinTransformerBlock
  486. if norm_layer is None:
  487. norm_layer = partial(nn.LayerNorm, eps=1e-5)
  488. layers: List[nn.Module] = []
  489. # split image into non-overlapping patches
  490. layers.append(
  491. nn.Sequential(
  492. nn.Conv2d(
  493. 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1])
  494. ),
  495. Permute([0, 2, 3, 1]),
  496. norm_layer(embed_dim),
  497. )
  498. )
  499. total_stage_blocks = sum(depths)
  500. stage_block_id = 0
  501. # build SwinTransformer blocks
  502. for i_stage in range(len(depths)):
  503. stage: List[nn.Module] = []
  504. dim = embed_dim * 2**i_stage
  505. for i_layer in range(depths[i_stage]):
  506. # adjust stochastic depth probability based on the depth of the stage block
  507. sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)
  508. stage.append(
  509. block(
  510. dim,
  511. num_heads[i_stage],
  512. window_size=window_size,
  513. shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size],
  514. mlp_ratio=mlp_ratio,
  515. dropout=dropout,
  516. attention_dropout=attention_dropout,
  517. stochastic_depth_prob=sd_prob,
  518. norm_layer=norm_layer,
  519. )
  520. )
  521. stage_block_id += 1
  522. layers.append(nn.Sequential(*stage))
  523. # add patch merging layer
  524. if i_stage < (len(depths) - 1):
  525. layers.append(downsample_layer(dim, norm_layer))
  526. self.features = nn.Sequential(*layers)
  527. num_features = embed_dim * 2 ** (len(depths) - 1)
  528. self.norm = norm_layer(num_features)
  529. self.permute = Permute([0, 3, 1, 2]) # B H W C -> B C H W
  530. self.avgpool = nn.AdaptiveAvgPool2d(1)
  531. self.flatten = nn.Flatten(1)
  532. self.head = nn.Linear(num_features, num_classes)
  533. for m in self.modules():
  534. if isinstance(m, nn.Linear):
  535. nn.init.trunc_normal_(m.weight, std=0.02)
  536. if m.bias is not None:
  537. nn.init.zeros_(m.bias)
  538. def forward(self, x):
  539. x = self.features(x)
  540. x = self.norm(x)
  541. x = self.permute(x)
  542. x = self.avgpool(x)
  543. x = self.flatten(x)
  544. x = self.head(x)
  545. return x
  546. def _swin_transformer(
  547. patch_size: List[int],
  548. embed_dim: int,
  549. depths: List[int],
  550. num_heads: List[int],
  551. window_size: List[int],
  552. stochastic_depth_prob: float,
  553. weights: Optional[WeightsEnum],
  554. progress: bool,
  555. **kwargs: Any,
  556. ) -> SwinTransformer:
  557. if weights is not None:
  558. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  559. model = SwinTransformer(
  560. patch_size=patch_size,
  561. embed_dim=embed_dim,
  562. depths=depths,
  563. num_heads=num_heads,
  564. window_size=window_size,
  565. stochastic_depth_prob=stochastic_depth_prob,
  566. **kwargs,
  567. )
  568. if weights is not None:
  569. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  570. return model
  571. _COMMON_META = {
  572. "categories": _IMAGENET_CATEGORIES,
  573. }
  574. class Swin_T_Weights(WeightsEnum):
  575. IMAGENET1K_V1 = Weights(
  576. url="https://download.pytorch.org/models/swin_t-704ceda3.pth",
  577. transforms=partial(
  578. ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC
  579. ),
  580. meta={
  581. **_COMMON_META,
  582. "num_params": 28288354,
  583. "min_size": (224, 224),
  584. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
  585. "_metrics": {
  586. "ImageNet-1K": {
  587. "acc@1": 81.474,
  588. "acc@5": 95.776,
  589. }
  590. },
  591. "_ops": 4.491,
  592. "_file_size": 108.19,
  593. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  594. },
  595. )
  596. DEFAULT = IMAGENET1K_V1
  597. class Swin_S_Weights(WeightsEnum):
  598. IMAGENET1K_V1 = Weights(
  599. url="https://download.pytorch.org/models/swin_s-5e29d889.pth",
  600. transforms=partial(
  601. ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC
  602. ),
  603. meta={
  604. **_COMMON_META,
  605. "num_params": 49606258,
  606. "min_size": (224, 224),
  607. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
  608. "_metrics": {
  609. "ImageNet-1K": {
  610. "acc@1": 83.196,
  611. "acc@5": 96.360,
  612. }
  613. },
  614. "_ops": 8.741,
  615. "_file_size": 189.786,
  616. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  617. },
  618. )
  619. DEFAULT = IMAGENET1K_V1
  620. class Swin_B_Weights(WeightsEnum):
  621. IMAGENET1K_V1 = Weights(
  622. url="https://download.pytorch.org/models/swin_b-68c6b09e.pth",
  623. transforms=partial(
  624. ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
  625. ),
  626. meta={
  627. **_COMMON_META,
  628. "num_params": 87768224,
  629. "min_size": (224, 224),
  630. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
  631. "_metrics": {
  632. "ImageNet-1K": {
  633. "acc@1": 83.582,
  634. "acc@5": 96.640,
  635. }
  636. },
  637. "_ops": 15.431,
  638. "_file_size": 335.364,
  639. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  640. },
  641. )
  642. DEFAULT = IMAGENET1K_V1
  643. class Swin_V2_T_Weights(WeightsEnum):
  644. IMAGENET1K_V1 = Weights(
  645. url="https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth",
  646. transforms=partial(
  647. ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC
  648. ),
  649. meta={
  650. **_COMMON_META,
  651. "num_params": 28351570,
  652. "min_size": (256, 256),
  653. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2",
  654. "_metrics": {
  655. "ImageNet-1K": {
  656. "acc@1": 82.072,
  657. "acc@5": 96.132,
  658. }
  659. },
  660. "_ops": 5.94,
  661. "_file_size": 108.626,
  662. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  663. },
  664. )
  665. DEFAULT = IMAGENET1K_V1
  666. class Swin_V2_S_Weights(WeightsEnum):
  667. IMAGENET1K_V1 = Weights(
  668. url="https://download.pytorch.org/models/swin_v2_s-637d8ceb.pth",
  669. transforms=partial(
  670. ImageClassification, crop_size=256, resize_size=260, interpolation=InterpolationMode.BICUBIC
  671. ),
  672. meta={
  673. **_COMMON_META,
  674. "num_params": 49737442,
  675. "min_size": (256, 256),
  676. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2",
  677. "_metrics": {
  678. "ImageNet-1K": {
  679. "acc@1": 83.712,
  680. "acc@5": 96.816,
  681. }
  682. },
  683. "_ops": 11.546,
  684. "_file_size": 190.675,
  685. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  686. },
  687. )
  688. DEFAULT = IMAGENET1K_V1
  689. class Swin_V2_B_Weights(WeightsEnum):
  690. IMAGENET1K_V1 = Weights(
  691. url="https://download.pytorch.org/models/swin_v2_b-781e5279.pth",
  692. transforms=partial(
  693. ImageClassification, crop_size=256, resize_size=272, interpolation=InterpolationMode.BICUBIC
  694. ),
  695. meta={
  696. **_COMMON_META,
  697. "num_params": 87930848,
  698. "min_size": (256, 256),
  699. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer-v2",
  700. "_metrics": {
  701. "ImageNet-1K": {
  702. "acc@1": 84.112,
  703. "acc@5": 96.864,
  704. }
  705. },
  706. "_ops": 20.325,
  707. "_file_size": 336.372,
  708. "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
  709. },
  710. )
  711. DEFAULT = IMAGENET1K_V1
  712. def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
  713. """
  714. Constructs a swin_tiny architecture from
  715. `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_.
  716. Args:
  717. weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The
  718. pretrained weights to use. See
  719. :class:`~torchvision.models.Swin_T_Weights` below for
  720. more details, and possible values. By default, no pre-trained
  721. weights are used.
  722. progress (bool, optional): If True, displays a progress bar of the
  723. download to stderr. Default is True.
  724. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
  725. base class. Please refer to the `source code
  726. <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
  727. for more details about this class.
  728. .. autoclass:: torchvision.models.Swin_T_Weights
  729. :members:
  730. """
  731. weights = Swin_T_Weights.verify(weights)
  732. return _swin_transformer(
  733. patch_size=[1, 1],
  734. embed_dim=96,
  735. depths=[2, 2, 6, 2],
  736. num_heads=[3, 6, 12, 24],
  737. window_size=[7, 7],
  738. stochastic_depth_prob=0.2,
  739. weights=weights,
  740. progress=progress,
  741. **kwargs,
  742. )
  743. def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
  744. """
  745. Constructs a swin_small architecture from
  746. `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_.
  747. Args:
  748. weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The
  749. pretrained weights to use. See
  750. :class:`~torchvision.models.Swin_S_Weights` below for
  751. more details, and possible values. By default, no pre-trained
  752. weights are used.
  753. progress (bool, optional): If True, displays a progress bar of the
  754. download to stderr. Default is True.
  755. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
  756. base class. Please refer to the `source code
  757. <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
  758. for more details about this class.
  759. .. autoclass:: torchvision.models.Swin_S_Weights
  760. :members:
  761. """
  762. weights = Swin_S_Weights.verify(weights)
  763. return _swin_transformer(
  764. patch_size=[4, 4],
  765. embed_dim=96,
  766. depths=[2, 2, 18, 2],
  767. num_heads=[3, 6, 12, 24],
  768. window_size=[7, 7],
  769. stochastic_depth_prob=0.3,
  770. weights=weights,
  771. progress=progress,
  772. **kwargs,
  773. )
  774. def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
  775. """
  776. Constructs a swin_base architecture from
  777. `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_.
  778. Args:
  779. weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
  780. pretrained weights to use. See
  781. :class:`~torchvision.models.Swin_B_Weights` below for
  782. more details, and possible values. By default, no pre-trained
  783. weights are used.
  784. progress (bool, optional): If True, displays a progress bar of the
  785. download to stderr. Default is True.
  786. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
  787. base class. Please refer to the `source code
  788. <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
  789. for more details about this class.
  790. .. autoclass:: torchvision.models.Swin_B_Weights
  791. :members:
  792. """
  793. weights = Swin_B_Weights.verify(weights)
  794. return _swin_transformer(
  795. patch_size=[4, 4],
  796. embed_dim=128,
  797. depths=[2, 2, 18, 2],
  798. num_heads=[4, 8, 16, 32],
  799. window_size=[7, 7],
  800. stochastic_depth_prob=0.5,
  801. weights=weights,
  802. progress=progress,
  803. **kwargs,
  804. )
  805. def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
  806. """
  807. Constructs a swin_v2_tiny architecture from
  808. `Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/abs/2111.09883>`_.
  809. Args:
  810. weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The
  811. pretrained weights to use. See
  812. :class:`~torchvision.models.Swin_V2_T_Weights` below for
  813. more details, and possible values. By default, no pre-trained
  814. weights are used.
  815. progress (bool, optional): If True, displays a progress bar of the
  816. download to stderr. Default is True.
  817. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
  818. base class. Please refer to the `source code
  819. <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
  820. for more details about this class.
  821. .. autoclass:: torchvision.models.Swin_V2_T_Weights
  822. :members:
  823. """
  824. weights = Swin_V2_T_Weights.verify(weights)
  825. return _swin_transformer(
  826. patch_size=[2, 2],
  827. embed_dim=96,
  828. depths=[2, 2, 6, 2],
  829. num_heads=[3, 6, 12, 24],
  830. window_size=[8, 8],
  831. stochastic_depth_prob=0.2,
  832. weights=weights,
  833. progress=progress,
  834. block=SwinTransformerBlockV2,
  835. downsample_layer=PatchMergingV2,
  836. **kwargs,
  837. )
  838. def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
  839. """
  840. Constructs a swin_v2_small architecture from
  841. `Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/abs/2111.09883>`_.
  842. Args:
  843. weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The
  844. pretrained weights to use. See
  845. :class:`~torchvision.models.Swin_V2_S_Weights` below for
  846. more details, and possible values. By default, no pre-trained
  847. weights are used.
  848. progress (bool, optional): If True, displays a progress bar of the
  849. download to stderr. Default is True.
  850. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
  851. base class. Please refer to the `source code
  852. <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
  853. for more details about this class.
  854. .. autoclass:: torchvision.models.Swin_V2_S_Weights
  855. :members:
  856. """
  857. weights = Swin_V2_S_Weights.verify(weights)
  858. return _swin_transformer(
  859. patch_size=[4, 4],
  860. embed_dim=96,
  861. depths=[2, 2, 18, 2],
  862. num_heads=[3, 6, 12, 24],
  863. window_size=[8, 8],
  864. stochastic_depth_prob=0.3,
  865. weights=weights,
  866. progress=progress,
  867. block=SwinTransformerBlockV2,
  868. downsample_layer=PatchMergingV2,
  869. **kwargs,
  870. )
  871. def swin_v2_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
  872. """
  873. Constructs a swin_v2_base architecture from
  874. `Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/abs/2111.09883>`_.
  875. Args:
  876. weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The
  877. pretrained weights to use. See
  878. :class:`~torchvision.models.Swin_V2_B_Weights` below for
  879. more details, and possible values. By default, no pre-trained
  880. weights are used.
  881. progress (bool, optional): If True, displays a progress bar of the
  882. download to stderr. Default is True.
  883. **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
  884. base class. Please refer to the `source code
  885. <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
  886. for more details about this class.
  887. .. autoclass:: torchvision.models.Swin_V2_B_Weights
  888. :members:
  889. """
  890. weights = Swin_V2_B_Weights.verify(weights)
  891. return _swin_transformer(
  892. patch_size=[4, 4],
  893. embed_dim=128,
  894. depths=[2, 2, 18, 2],
  895. num_heads=[4, 8, 16, 32],
  896. window_size=[8, 8],
  897. stochastic_depth_prob=0.5,
  898. weights=weights,
  899. progress=progress,
  900. block=SwinTransformerBlockV2,
  901. downsample_layer=PatchMergingV2,
  902. **kwargs,
  903. )
  904. if __name__ == '__main__':
  905. input=torch.randn(3,3,512,512)
  906. model=swin_v2_t(weights=None)
  907. out=model(input)
  908. print(f'out:{out.shape}')