encoders.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from typing import List, Optional, Tuple, Type
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from ultralytics.nn.modules import LayerNorm2d
  7. from .blocks import (
  8. Block,
  9. CXBlock,
  10. Fuser,
  11. MaskDownSampler,
  12. MultiScaleBlock,
  13. PatchEmbed,
  14. PositionEmbeddingRandom,
  15. PositionEmbeddingSine,
  16. )
  17. class ImageEncoderViT(nn.Module):
  18. """
  19. An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
  20. This class processes images by splitting them into patches, applying transformer blocks, and generating a final
  21. encoded representation through a neck module.
  22. Attributes:
  23. img_size (int): Dimension of input images, assumed to be square.
  24. patch_embed (PatchEmbed): Module for patch embedding.
  25. pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
  26. blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
  27. neck (nn.Sequential): Neck module to further process the output.
  28. Methods:
  29. forward: Processes input through patch embedding, positional embedding, blocks, and neck.
  30. Examples:
  31. >>> import torch
  32. >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
  33. >>> input_image = torch.randn(1, 3, 224, 224)
  34. >>> output = encoder(input_image)
  35. >>> print(output.shape)
  36. """
  37. def __init__(
  38. self,
  39. img_size: int = 1024,
  40. patch_size: int = 16,
  41. in_chans: int = 3,
  42. embed_dim: int = 768,
  43. depth: int = 12,
  44. num_heads: int = 12,
  45. mlp_ratio: float = 4.0,
  46. out_chans: int = 256,
  47. qkv_bias: bool = True,
  48. norm_layer: Type[nn.Module] = nn.LayerNorm,
  49. act_layer: Type[nn.Module] = nn.GELU,
  50. use_abs_pos: bool = True,
  51. use_rel_pos: bool = False,
  52. rel_pos_zero_init: bool = True,
  53. window_size: int = 0,
  54. global_attn_indexes: Tuple[int, ...] = (),
  55. ) -> None:
  56. """
  57. Initializes an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
  58. Args:
  59. img_size (int): Input image size, assumed to be square.
  60. patch_size (int): Size of image patches.
  61. in_chans (int): Number of input image channels.
  62. embed_dim (int): Dimension of patch embeddings.
  63. depth (int): Number of transformer blocks.
  64. num_heads (int): Number of attention heads in each block.
  65. mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
  66. out_chans (int): Number of output channels from the neck module.
  67. qkv_bias (bool): If True, adds learnable bias to query, key, value projections.
  68. norm_layer (Type[nn.Module]): Type of normalization layer to use.
  69. act_layer (Type[nn.Module]): Type of activation layer to use.
  70. use_abs_pos (bool): If True, uses absolute positional embeddings.
  71. use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.
  72. rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
  73. window_size (int): Size of attention window for windowed attention blocks.
  74. global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.
  75. Attributes:
  76. img_size (int): Dimension of input images.
  77. patch_embed (PatchEmbed): Module for patch embedding.
  78. pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
  79. blocks (nn.ModuleList): List of transformer blocks.
  80. neck (nn.Sequential): Neck module for final processing.
  81. Examples:
  82. >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
  83. >>> input_image = torch.randn(1, 3, 224, 224)
  84. >>> output = encoder(input_image)
  85. >>> print(output.shape)
  86. """
  87. super().__init__()
  88. self.img_size = img_size
  89. self.patch_embed = PatchEmbed(
  90. kernel_size=(patch_size, patch_size),
  91. stride=(patch_size, patch_size),
  92. in_chans=in_chans,
  93. embed_dim=embed_dim,
  94. )
  95. self.pos_embed: Optional[nn.Parameter] = None
  96. if use_abs_pos:
  97. # Initialize absolute positional embedding with pretrain image size.
  98. self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
  99. self.blocks = nn.ModuleList()
  100. for i in range(depth):
  101. block = Block(
  102. dim=embed_dim,
  103. num_heads=num_heads,
  104. mlp_ratio=mlp_ratio,
  105. qkv_bias=qkv_bias,
  106. norm_layer=norm_layer,
  107. act_layer=act_layer,
  108. use_rel_pos=use_rel_pos,
  109. rel_pos_zero_init=rel_pos_zero_init,
  110. window_size=window_size if i not in global_attn_indexes else 0,
  111. input_size=(img_size // patch_size, img_size // patch_size),
  112. )
  113. self.blocks.append(block)
  114. self.neck = nn.Sequential(
  115. nn.Conv2d(
  116. embed_dim,
  117. out_chans,
  118. kernel_size=1,
  119. bias=False,
  120. ),
  121. LayerNorm2d(out_chans),
  122. nn.Conv2d(
  123. out_chans,
  124. out_chans,
  125. kernel_size=3,
  126. padding=1,
  127. bias=False,
  128. ),
  129. LayerNorm2d(out_chans),
  130. )
  131. def forward(self, x: torch.Tensor) -> torch.Tensor:
  132. """Processes input through patch embedding, positional embedding, transformer blocks, and neck module."""
  133. x = self.patch_embed(x)
  134. if self.pos_embed is not None:
  135. pos_embed = (
  136. F.interpolate(self.pos_embed.permute(0, 3, 1, 2), scale_factor=self.img_size / 1024).permute(0, 2, 3, 1)
  137. if self.img_size != 1024
  138. else self.pos_embed
  139. )
  140. x = x + pos_embed
  141. for blk in self.blocks:
  142. x = blk(x)
  143. return self.neck(x.permute(0, 3, 1, 2))
  144. class PromptEncoder(nn.Module):
  145. """
  146. Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
  147. Attributes:
  148. embed_dim (int): Dimension of the embeddings.
  149. input_image_size (Tuple[int, int]): Size of the input image as (H, W).
  150. image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
  151. pe_layer (PositionEmbeddingRandom): Module for random position embedding.
  152. num_point_embeddings (int): Number of point embeddings for different types of points.
  153. point_embeddings (nn.ModuleList): List of point embeddings.
  154. not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
  155. mask_input_size (Tuple[int, int]): Size of the input mask.
  156. mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
  157. no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
  158. Methods:
  159. get_dense_pe: Returns the positional encoding used to encode point prompts.
  160. forward: Embeds different types of prompts, returning both sparse and dense embeddings.
  161. Examples:
  162. >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
  163. >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
  164. >>> boxes = torch.rand(1, 2, 2)
  165. >>> masks = torch.rand(1, 1, 256, 256)
  166. >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
  167. >>> print(sparse_embeddings.shape, dense_embeddings.shape)
  168. torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
  169. """
  170. def __init__(
  171. self,
  172. embed_dim: int,
  173. image_embedding_size: Tuple[int, int],
  174. input_image_size: Tuple[int, int],
  175. mask_in_chans: int,
  176. activation: Type[nn.Module] = nn.GELU,
  177. ) -> None:
  178. """
  179. Initializes the PromptEncoder module for encoding various types of prompts.
  180. This module encodes different types of prompts (points, boxes, masks) for input to SAM's mask decoder,
  181. producing both sparse and dense embeddings.
  182. Args:
  183. embed_dim (int): The dimension of the embeddings.
  184. image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W).
  185. input_image_size (Tuple[int, int]): The padded size of the input image as (H, W).
  186. mask_in_chans (int): The number of hidden channels used for encoding input masks.
  187. activation (Type[nn.Module]): The activation function to use when encoding input masks.
  188. Attributes:
  189. embed_dim (int): Dimension of the embeddings.
  190. input_image_size (Tuple[int, int]): Size of the input image as (H, W).
  191. image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
  192. pe_layer (PositionEmbeddingRandom): Module for random position embedding.
  193. num_point_embeddings (int): Number of point embeddings for different types of points.
  194. point_embeddings (nn.ModuleList): List of point embeddings.
  195. not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
  196. mask_input_size (Tuple[int, int]): Size of the input mask.
  197. mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
  198. Examples:
  199. >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
  200. >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
  201. >>> boxes = torch.rand(1, 2, 2)
  202. >>> masks = torch.rand(1, 1, 256, 256)
  203. >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
  204. >>> print(sparse_embeddings.shape, dense_embeddings.shape)
  205. torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
  206. """
  207. super().__init__()
  208. self.embed_dim = embed_dim
  209. self.input_image_size = input_image_size
  210. self.image_embedding_size = image_embedding_size
  211. self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
  212. self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
  213. point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]
  214. self.point_embeddings = nn.ModuleList(point_embeddings)
  215. self.not_a_point_embed = nn.Embedding(1, embed_dim)
  216. self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
  217. self.mask_downscaling = nn.Sequential(
  218. nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
  219. LayerNorm2d(mask_in_chans // 4),
  220. activation(),
  221. nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
  222. LayerNorm2d(mask_in_chans),
  223. activation(),
  224. nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
  225. )
  226. self.no_mask_embed = nn.Embedding(1, embed_dim)
  227. def get_dense_pe(self) -> torch.Tensor:
  228. """
  229. Returns the dense positional encoding used for encoding point prompts.
  230. This method generates a positional encoding for a dense set of points matching the shape of the image
  231. encoding. The encoding is used to provide spatial information to the model when processing point prompts.
  232. Returns:
  233. (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
  234. height and width of the image embedding size, respectively.
  235. Examples:
  236. >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
  237. >>> dense_pe = prompt_encoder.get_dense_pe()
  238. >>> print(dense_pe.shape)
  239. torch.Size([1, 256, 64, 64])
  240. """
  241. return self.pe_layer(self.image_embedding_size).unsqueeze(0)
  242. def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
  243. """Embeds point prompts by applying positional encoding and label-specific embeddings."""
  244. points = points + 0.5 # Shift to center of pixel
  245. if pad:
  246. padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
  247. padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
  248. points = torch.cat([points, padding_point], dim=1)
  249. labels = torch.cat([labels, padding_label], dim=1)
  250. point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
  251. point_embedding[labels == -1] = 0.0
  252. point_embedding[labels == -1] += self.not_a_point_embed.weight
  253. point_embedding[labels == 0] += self.point_embeddings[0].weight
  254. point_embedding[labels == 1] += self.point_embeddings[1].weight
  255. point_embedding[labels == 2] += self.point_embeddings[2].weight
  256. point_embedding[labels == 3] += self.point_embeddings[3].weight
  257. return point_embedding
  258. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  259. """Embeds box prompts by applying positional encoding and adding corner embeddings."""
  260. boxes = boxes + 0.5 # Shift to center of pixel
  261. coords = boxes.reshape(-1, 2, 2)
  262. corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
  263. corner_embedding[:, 0, :] += self.point_embeddings[2].weight
  264. corner_embedding[:, 1, :] += self.point_embeddings[3].weight
  265. return corner_embedding
  266. def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
  267. """Embeds mask inputs by downscaling and processing through convolutional layers."""
  268. return self.mask_downscaling(masks)
  269. @staticmethod
  270. def _get_batch_size(
  271. points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  272. boxes: Optional[torch.Tensor],
  273. masks: Optional[torch.Tensor],
  274. ) -> int:
  275. """Gets the batch size of the output given the batch size of the input prompts."""
  276. if points is not None:
  277. return points[0].shape[0]
  278. elif boxes is not None:
  279. return boxes.shape[0]
  280. elif masks is not None:
  281. return masks.shape[0]
  282. else:
  283. return 1
  284. def _get_device(self) -> torch.device:
  285. """Returns the device of the first point embedding's weight tensor."""
  286. return self.point_embeddings[0].weight.device
  287. def forward(
  288. self,
  289. points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  290. boxes: Optional[torch.Tensor],
  291. masks: Optional[torch.Tensor],
  292. ) -> Tuple[torch.Tensor, torch.Tensor]:
  293. """
  294. Embeds different types of prompts, returning both sparse and dense embeddings.
  295. Args:
  296. points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
  297. tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with
  298. shape (B, N).
  299. boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
  300. masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
  301. Returns:
  302. (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
  303. - sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
  304. - dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
  305. Examples:
  306. >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
  307. >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
  308. >>> boxes = torch.rand(1, 2, 2, 2)
  309. >>> masks = torch.rand(1, 1, 256, 256)
  310. >>> sparse_emb, dense_emb = encoder(points, boxes, masks)
  311. >>> print(sparse_emb.shape, dense_emb.shape)
  312. torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
  313. """
  314. bs = self._get_batch_size(points, boxes, masks)
  315. sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
  316. if points is not None:
  317. coords, labels = points
  318. point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
  319. sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
  320. if boxes is not None:
  321. box_embeddings = self._embed_boxes(boxes)
  322. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
  323. if masks is not None:
  324. dense_embeddings = self._embed_masks(masks)
  325. else:
  326. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
  327. bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
  328. )
  329. return sparse_embeddings, dense_embeddings
  330. class MemoryEncoder(nn.Module):
  331. """
  332. Encodes pixel features and masks into a memory representation for efficient image segmentation.
  333. This class processes pixel-level features and masks, fusing them to generate encoded memory representations
  334. suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
  335. Attributes:
  336. mask_downsampler (MaskDownSampler): Module for downsampling input masks.
  337. pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.
  338. fuser (Fuser): Module for fusing pixel features and masks.
  339. position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.
  340. out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.
  341. Methods:
  342. forward: Processes input pixel features and masks to generate encoded memory representations.
  343. Examples:
  344. >>> import torch
  345. >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
  346. >>> pix_feat = torch.randn(1, 256, 64, 64)
  347. >>> masks = torch.randn(1, 1, 64, 64)
  348. >>> encoded_feat, pos = encoder(pix_feat, masks)
  349. >>> print(encoded_feat.shape, pos.shape)
  350. torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
  351. """
  352. def __init__(
  353. self,
  354. out_dim,
  355. in_dim=256, # in_dim of pix_feats
  356. ):
  357. """Initializes the MemoryEncoder for encoding pixel features and masks into memory representations."""
  358. super().__init__()
  359. self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
  360. self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
  361. self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
  362. self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
  363. self.out_proj = nn.Identity()
  364. if out_dim != in_dim:
  365. self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
  366. def forward(
  367. self,
  368. pix_feat: torch.Tensor,
  369. masks: torch.Tensor,
  370. skip_mask_sigmoid: bool = False,
  371. ) -> Tuple[torch.Tensor, torch.Tensor]:
  372. """Processes pixel features and masks to generate encoded memory representations for segmentation."""
  373. if not skip_mask_sigmoid:
  374. masks = F.sigmoid(masks)
  375. masks = self.mask_downsampler(masks)
  376. # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
  377. pix_feat = pix_feat.to(masks.device)
  378. x = self.pix_feat_proj(pix_feat)
  379. x = x + masks
  380. x = self.fuser(x)
  381. x = self.out_proj(x)
  382. pos = self.position_encoding(x).to(x.dtype)
  383. return {"vision_features": x, "vision_pos_enc": [pos]}
  384. class ImageEncoder(nn.Module):
  385. """
  386. Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings.
  387. This class combines a trunk network for feature extraction with a neck network for feature refinement
  388. and positional encoding generation. It can optionally discard the lowest resolution features.
  389. Attributes:
  390. trunk (nn.Module): The trunk network for initial feature extraction.
  391. neck (nn.Module): The neck network for feature refinement and positional encoding generation.
  392. scalp (int): Number of lowest resolution feature levels to discard.
  393. Methods:
  394. forward: Processes the input image through the trunk and neck networks.
  395. Examples:
  396. >>> trunk = SomeTrunkNetwork()
  397. >>> neck = SomeNeckNetwork()
  398. >>> encoder = ImageEncoder(trunk, neck, scalp=1)
  399. >>> image = torch.randn(1, 3, 224, 224)
  400. >>> output = encoder(image)
  401. >>> print(output.keys())
  402. dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
  403. """
  404. def __init__(
  405. self,
  406. trunk: nn.Module,
  407. neck: nn.Module,
  408. scalp: int = 0,
  409. ):
  410. """Initializes the ImageEncoder with trunk and neck networks for feature extraction and refinement."""
  411. super().__init__()
  412. self.trunk = trunk
  413. self.neck = neck
  414. self.scalp = scalp
  415. assert self.trunk.channel_list == self.neck.backbone_channel_list, (
  416. f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
  417. )
  418. def forward(self, sample: torch.Tensor):
  419. """Encodes input through patch embedding, positional embedding, transformer blocks, and neck module."""
  420. features, pos = self.neck(self.trunk(sample))
  421. if self.scalp > 0:
  422. # Discard the lowest resolution features
  423. features, pos = features[: -self.scalp], pos[: -self.scalp]
  424. src = features[-1]
  425. return {
  426. "vision_features": src,
  427. "vision_pos_enc": pos,
  428. "backbone_fpn": features,
  429. }
  430. class FpnNeck(nn.Module):
  431. """
  432. A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
  433. This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
  434. similar to ViT positional embedding interpolation.
  435. Attributes:
  436. position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.
  437. convs (nn.ModuleList): List of convolutional layers for each backbone level.
  438. backbone_channel_list (List[int]): List of channel dimensions from the backbone.
  439. fpn_interp_model (str): Interpolation mode for FPN feature resizing.
  440. fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
  441. fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.
  442. Methods:
  443. forward: Performs forward pass through the FPN neck.
  444. Examples:
  445. >>> backbone_channels = [64, 128, 256, 512]
  446. >>> fpn_neck = FpnNeck(256, backbone_channels)
  447. >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]
  448. >>> outputs, positions = fpn_neck(inputs)
  449. >>> print(len(outputs), len(positions))
  450. 4 4
  451. """
  452. def __init__(
  453. self,
  454. d_model: int,
  455. backbone_channel_list: List[int],
  456. kernel_size: int = 1,
  457. stride: int = 1,
  458. padding: int = 0,
  459. fpn_interp_model: str = "bilinear",
  460. fuse_type: str = "sum",
  461. fpn_top_down_levels: Optional[List[int]] = None,
  462. ):
  463. """
  464. Initializes a modified Feature Pyramid Network (FPN) neck.
  465. This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
  466. similar to ViT positional embedding interpolation.
  467. Args:
  468. d_model (int): Dimension of the model.
  469. backbone_channel_list (List[int]): List of channel dimensions from the backbone.
  470. kernel_size (int): Kernel size for the convolutional layers.
  471. stride (int): Stride for the convolutional layers.
  472. padding (int): Padding for the convolutional layers.
  473. fpn_interp_model (str): Interpolation mode for FPN feature resizing.
  474. fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
  475. fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
  476. Examples:
  477. >>> backbone_channels = [64, 128, 256, 512]
  478. >>> fpn_neck = FpnNeck(256, backbone_channels)
  479. >>> print(fpn_neck)
  480. """
  481. super().__init__()
  482. self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
  483. self.convs = nn.ModuleList()
  484. self.backbone_channel_list = backbone_channel_list
  485. for dim in backbone_channel_list:
  486. current = nn.Sequential()
  487. current.add_module(
  488. "conv",
  489. nn.Conv2d(
  490. in_channels=dim,
  491. out_channels=d_model,
  492. kernel_size=kernel_size,
  493. stride=stride,
  494. padding=padding,
  495. ),
  496. )
  497. self.convs.append(current)
  498. self.fpn_interp_model = fpn_interp_model
  499. assert fuse_type in {"sum", "avg"}
  500. self.fuse_type = fuse_type
  501. # levels to have top-down features in its outputs
  502. # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
  503. # have top-down propagation, while outputs of level 0 and level 1 have only
  504. # lateral features from the same backbone level.
  505. if fpn_top_down_levels is None:
  506. # default is to have top-down features on all levels
  507. fpn_top_down_levels = range(len(self.convs))
  508. self.fpn_top_down_levels = list(fpn_top_down_levels)
  509. def forward(self, xs: List[torch.Tensor]):
  510. """
  511. Performs forward pass through the Feature Pyramid Network (FPN) neck.
  512. This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
  513. and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
  514. Args:
  515. xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
  516. Returns:
  517. (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing:
  518. - out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape
  519. (B, d_model, H, W).
  520. - pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.
  521. Examples:
  522. >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
  523. >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
  524. >>> outputs, positions = fpn_neck(inputs)
  525. >>> print(len(outputs), len(positions))
  526. 4 4
  527. """
  528. out = [None] * len(self.convs)
  529. pos = [None] * len(self.convs)
  530. assert len(xs) == len(self.convs)
  531. # fpn forward pass
  532. # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
  533. prev_features = None
  534. # forward in top-down order (from low to high resolution)
  535. n = len(self.convs) - 1
  536. for i in range(n, -1, -1):
  537. x = xs[i]
  538. lateral_features = self.convs[n - i](x)
  539. if i in self.fpn_top_down_levels and prev_features is not None:
  540. top_down_features = F.interpolate(
  541. prev_features.to(dtype=torch.float32),
  542. scale_factor=2.0,
  543. mode=self.fpn_interp_model,
  544. align_corners=(None if self.fpn_interp_model == "nearest" else False),
  545. antialias=False,
  546. )
  547. prev_features = lateral_features + top_down_features
  548. if self.fuse_type == "avg":
  549. prev_features /= 2
  550. else:
  551. prev_features = lateral_features
  552. x_out = prev_features
  553. out[i] = x_out
  554. pos[i] = self.position_encoding(x_out).to(x_out.dtype)
  555. return out, pos
  556. class Hiera(nn.Module):
  557. """
  558. Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
  559. This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for
  560. efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,
  561. with optional pooling and global attention mechanisms.
  562. Attributes:
  563. window_spec (Tuple[int, ...]): Window sizes for each stage.
  564. q_stride (Tuple[int, int]): Downsampling stride between stages.
  565. stage_ends (List[int]): Indices of the last block in each stage.
  566. q_pool_blocks (List[int]): Indices of blocks where pooling is applied.
  567. return_interm_layers (bool): Whether to return intermediate layer outputs.
  568. patch_embed (PatchEmbed): Module for patch embedding.
  569. global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.
  570. window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.
  571. pos_embed (nn.Parameter): Positional embedding for the background.
  572. pos_embed_window (nn.Parameter): Positional embedding for the window.
  573. blocks (nn.ModuleList): List of MultiScaleBlock modules.
  574. channel_list (List[int]): List of output channel dimensions for each stage.
  575. Methods:
  576. _get_pos_embed: Generates positional embeddings by interpolating and combining window and background embeddings.
  577. forward: Performs the forward pass through the Hiera model.
  578. Examples:
  579. >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
  580. >>> input_tensor = torch.randn(1, 3, 224, 224)
  581. >>> output_features = model(input_tensor)
  582. >>> for feat in output_features:
  583. ... print(feat.shape)
  584. """
  585. def __init__(
  586. self,
  587. embed_dim: int = 96, # initial embed dim
  588. num_heads: int = 1, # initial number of heads
  589. drop_path_rate: float = 0.0, # stochastic depth
  590. q_pool: int = 3, # number of q_pool stages
  591. q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
  592. stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
  593. dim_mul: float = 2.0, # dim_mul factor at stage shift
  594. head_mul: float = 2.0, # head_mul factor at stage shift
  595. window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
  596. # window size per stage, when not using global att.
  597. window_spec: Tuple[int, ...] = (
  598. 8,
  599. 4,
  600. 14,
  601. 7,
  602. ),
  603. # global attn in these blocks
  604. global_att_blocks: Tuple[int, ...] = (
  605. 12,
  606. 16,
  607. 20,
  608. ),
  609. return_interm_layers=True, # return feats from every stage
  610. ):
  611. """Initializes the Hiera model, configuring its hierarchical vision transformer architecture."""
  612. super().__init__()
  613. assert len(stages) == len(window_spec)
  614. self.window_spec = window_spec
  615. depth = sum(stages)
  616. self.q_stride = q_stride
  617. self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
  618. assert 0 <= q_pool <= len(self.stage_ends[:-1])
  619. self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
  620. self.return_interm_layers = return_interm_layers
  621. self.patch_embed = PatchEmbed(
  622. embed_dim=embed_dim,
  623. kernel_size=(7, 7),
  624. stride=(4, 4),
  625. padding=(3, 3),
  626. )
  627. # Which blocks have global att?
  628. self.global_att_blocks = global_att_blocks
  629. # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
  630. self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
  631. self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
  632. self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
  633. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
  634. cur_stage = 1
  635. self.blocks = nn.ModuleList()
  636. for i in range(depth):
  637. dim_out = embed_dim
  638. # lags by a block, so first block of
  639. # next stage uses an initial window size
  640. # of previous stage and final window size of current stage
  641. window_size = self.window_spec[cur_stage - 1]
  642. if self.global_att_blocks is not None:
  643. window_size = 0 if i in self.global_att_blocks else window_size
  644. if i - 1 in self.stage_ends:
  645. dim_out = int(embed_dim * dim_mul)
  646. num_heads = int(num_heads * head_mul)
  647. cur_stage += 1
  648. block = MultiScaleBlock(
  649. dim=embed_dim,
  650. dim_out=dim_out,
  651. num_heads=num_heads,
  652. drop_path=dpr[i],
  653. q_stride=self.q_stride if i in self.q_pool_blocks else None,
  654. window_size=window_size,
  655. )
  656. embed_dim = dim_out
  657. self.blocks.append(block)
  658. self.channel_list = (
  659. [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
  660. if return_interm_layers
  661. else [self.blocks[-1].dim_out]
  662. )
  663. def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
  664. """Generates positional embeddings by interpolating and combining window and background embeddings."""
  665. h, w = hw
  666. window_embed = self.pos_embed_window
  667. pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
  668. pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
  669. pos_embed = pos_embed.permute(0, 2, 3, 1)
  670. return pos_embed
  671. def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
  672. """Performs forward pass through Hiera model, extracting multiscale features from input images."""
  673. x = self.patch_embed(x)
  674. # x: (B, H, W, C)
  675. # Add pos embed
  676. x = x + self._get_pos_embed(x.shape[1:3])
  677. outputs = []
  678. for i, blk in enumerate(self.blocks):
  679. x = blk(x)
  680. if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
  681. feats = x.permute(0, 3, 1, 2)
  682. outputs.append(feats)
  683. return outputs