decoders.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from typing import List, Optional, Tuple, Type
  3. import torch
  4. from torch import nn
  5. from ultralytics.nn.modules import MLP, LayerNorm2d
  6. class MaskDecoder(nn.Module):
  7. """
  8. Decoder module for generating masks and their associated quality scores using a transformer architecture.
  9. This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
  10. generate mask predictions along with their quality scores.
  11. Attributes:
  12. transformer_dim (int): Channel dimension for the transformer module.
  13. transformer (nn.Module): Transformer module used for mask prediction.
  14. num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
  15. iou_token (nn.Embedding): Embedding for the IoU token.
  16. num_mask_tokens (int): Number of mask tokens.
  17. mask_tokens (nn.Embedding): Embedding for the mask tokens.
  18. output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.
  19. output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.
  20. iou_prediction_head (nn.Module): MLP for predicting mask quality.
  21. Methods:
  22. forward: Predicts masks given image and prompt embeddings.
  23. predict_masks: Internal method for mask prediction.
  24. Examples:
  25. >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
  26. >>> masks, iou_pred = decoder(
  27. ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True
  28. ... )
  29. >>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
  30. """
  31. def __init__(
  32. self,
  33. transformer_dim: int,
  34. transformer: nn.Module,
  35. num_multimask_outputs: int = 3,
  36. activation: Type[nn.Module] = nn.GELU,
  37. iou_head_depth: int = 3,
  38. iou_head_hidden_dim: int = 256,
  39. ) -> None:
  40. """
  41. Initializes the MaskDecoder module for generating masks and their quality scores.
  42. Args:
  43. transformer_dim (int): Channel dimension for the transformer module.
  44. transformer (nn.Module): Transformer module used for mask prediction.
  45. num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
  46. activation (Type[nn.Module]): Type of activation to use when upscaling masks.
  47. iou_head_depth (int): Depth of the MLP used to predict mask quality.
  48. iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
  49. Examples:
  50. >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
  51. >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)
  52. >>> print(decoder)
  53. """
  54. super().__init__()
  55. self.transformer_dim = transformer_dim
  56. self.transformer = transformer
  57. self.num_multimask_outputs = num_multimask_outputs
  58. self.iou_token = nn.Embedding(1, transformer_dim)
  59. self.num_mask_tokens = num_multimask_outputs + 1
  60. self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
  61. self.output_upscaling = nn.Sequential(
  62. nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
  63. LayerNorm2d(transformer_dim // 4),
  64. activation(),
  65. nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
  66. activation(),
  67. )
  68. self.output_hypernetworks_mlps = nn.ModuleList(
  69. [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
  70. )
  71. self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
  72. def forward(
  73. self,
  74. image_embeddings: torch.Tensor,
  75. image_pe: torch.Tensor,
  76. sparse_prompt_embeddings: torch.Tensor,
  77. dense_prompt_embeddings: torch.Tensor,
  78. multimask_output: bool,
  79. ) -> Tuple[torch.Tensor, torch.Tensor]:
  80. """
  81. Predicts masks given image and prompt embeddings.
  82. Args:
  83. image_embeddings (torch.Tensor): Embeddings from the image encoder.
  84. image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
  85. sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
  86. dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
  87. multimask_output (bool): Whether to return multiple masks or a single mask.
  88. Returns:
  89. (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
  90. - masks (torch.Tensor): Batched predicted masks.
  91. - iou_pred (torch.Tensor): Batched predictions of mask quality.
  92. Examples:
  93. >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
  94. >>> image_emb = torch.rand(1, 256, 64, 64)
  95. >>> image_pe = torch.rand(1, 256, 64, 64)
  96. >>> sparse_emb = torch.rand(1, 2, 256)
  97. >>> dense_emb = torch.rand(1, 256, 64, 64)
  98. >>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True)
  99. >>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
  100. """
  101. masks, iou_pred = self.predict_masks(
  102. image_embeddings=image_embeddings,
  103. image_pe=image_pe,
  104. sparse_prompt_embeddings=sparse_prompt_embeddings,
  105. dense_prompt_embeddings=dense_prompt_embeddings,
  106. )
  107. # Select the correct mask or masks for output
  108. mask_slice = slice(1, None) if multimask_output else slice(0, 1)
  109. masks = masks[:, mask_slice, :, :]
  110. iou_pred = iou_pred[:, mask_slice]
  111. # Prepare output
  112. return masks, iou_pred
  113. def predict_masks(
  114. self,
  115. image_embeddings: torch.Tensor,
  116. image_pe: torch.Tensor,
  117. sparse_prompt_embeddings: torch.Tensor,
  118. dense_prompt_embeddings: torch.Tensor,
  119. ) -> Tuple[torch.Tensor, torch.Tensor]:
  120. """Predicts masks and quality scores using image and prompt embeddings via transformer architecture."""
  121. # Concatenate output tokens
  122. output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
  123. output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
  124. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
  125. # Expand per-image data in batch direction to be per-mask
  126. src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
  127. src = src + dense_prompt_embeddings
  128. pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
  129. b, c, h, w = src.shape
  130. # Run the transformer
  131. hs, src = self.transformer(src, pos_src, tokens)
  132. iou_token_out = hs[:, 0, :]
  133. mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
  134. # Upscale mask embeddings and predict masks using the mask tokens
  135. src = src.transpose(1, 2).view(b, c, h, w)
  136. upscaled_embedding = self.output_upscaling(src)
  137. hyper_in_list: List[torch.Tensor] = [
  138. self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
  139. ]
  140. hyper_in = torch.stack(hyper_in_list, dim=1)
  141. b, c, h, w = upscaled_embedding.shape
  142. masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
  143. # Generate mask quality predictions
  144. iou_pred = self.iou_prediction_head(iou_token_out)
  145. return masks, iou_pred
  146. class SAM2MaskDecoder(nn.Module):
  147. """
  148. Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
  149. This class extends the functionality of the MaskDecoder, incorporating additional features such as
  150. high-resolution feature processing, dynamic multimask output, and object score prediction.
  151. Attributes:
  152. transformer_dim (int): Channel dimension of the transformer.
  153. transformer (nn.Module): Transformer used to predict masks.
  154. num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
  155. iou_token (nn.Embedding): Embedding for IOU token.
  156. num_mask_tokens (int): Total number of mask tokens.
  157. mask_tokens (nn.Embedding): Embedding for mask tokens.
  158. pred_obj_scores (bool): Whether to predict object scores.
  159. obj_score_token (nn.Embedding): Embedding for object score token.
  160. use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
  161. output_upscaling (nn.Sequential): Upscaling layers for output.
  162. use_high_res_features (bool): Whether to use high-resolution features.
  163. conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).
  164. conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).
  165. output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.
  166. iou_prediction_head (MLP): MLP for IOU prediction.
  167. pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.
  168. dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
  169. dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
  170. dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
  171. Methods:
  172. forward: Predicts masks given image and prompt embeddings.
  173. predict_masks: Predicts instance segmentation masks from image and prompt embeddings.
  174. _get_stability_scores: Computes mask stability scores based on IoU between thresholds.
  175. _dynamic_multimask_via_stability: Dynamically selects the most stable mask output.
  176. Examples:
  177. >>> image_embeddings = torch.rand(1, 256, 64, 64)
  178. >>> image_pe = torch.rand(1, 256, 64, 64)
  179. >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
  180. >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
  181. >>> decoder = SAM2MaskDecoder(256, transformer)
  182. >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
  183. ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
  184. ... )
  185. """
  186. def __init__(
  187. self,
  188. transformer_dim: int,
  189. transformer: nn.Module,
  190. num_multimask_outputs: int = 3,
  191. activation: Type[nn.Module] = nn.GELU,
  192. iou_head_depth: int = 3,
  193. iou_head_hidden_dim: int = 256,
  194. use_high_res_features: bool = False,
  195. iou_prediction_use_sigmoid=False,
  196. dynamic_multimask_via_stability=False,
  197. dynamic_multimask_stability_delta=0.05,
  198. dynamic_multimask_stability_thresh=0.98,
  199. pred_obj_scores: bool = False,
  200. pred_obj_scores_mlp: bool = False,
  201. use_multimask_token_for_obj_ptr: bool = False,
  202. ) -> None:
  203. """
  204. Initializes the SAM2MaskDecoder module for predicting instance segmentation masks.
  205. This decoder extends the functionality of MaskDecoder, incorporating additional features such as
  206. high-resolution feature processing, dynamic multimask output, and object score prediction.
  207. Args:
  208. transformer_dim (int): Channel dimension of the transformer.
  209. transformer (nn.Module): Transformer used to predict masks.
  210. num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
  211. activation (Type[nn.Module]): Type of activation to use when upscaling masks.
  212. iou_head_depth (int): Depth of the MLP used to predict mask quality.
  213. iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
  214. use_high_res_features (bool): Whether to use high-resolution features.
  215. iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
  216. dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
  217. dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
  218. dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
  219. pred_obj_scores (bool): Whether to predict object scores.
  220. pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
  221. use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
  222. Examples:
  223. >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
  224. >>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)
  225. >>> print(decoder)
  226. """
  227. super().__init__()
  228. self.transformer_dim = transformer_dim
  229. self.transformer = transformer
  230. self.num_multimask_outputs = num_multimask_outputs
  231. self.iou_token = nn.Embedding(1, transformer_dim)
  232. self.num_mask_tokens = num_multimask_outputs + 1
  233. self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
  234. self.pred_obj_scores = pred_obj_scores
  235. if self.pred_obj_scores:
  236. self.obj_score_token = nn.Embedding(1, transformer_dim)
  237. self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
  238. self.output_upscaling = nn.Sequential(
  239. nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
  240. LayerNorm2d(transformer_dim // 4),
  241. activation(),
  242. nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
  243. activation(),
  244. )
  245. self.use_high_res_features = use_high_res_features
  246. if use_high_res_features:
  247. self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
  248. self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
  249. self.output_hypernetworks_mlps = nn.ModuleList(
  250. [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
  251. )
  252. self.iou_prediction_head = MLP(
  253. transformer_dim,
  254. iou_head_hidden_dim,
  255. self.num_mask_tokens,
  256. iou_head_depth,
  257. sigmoid=iou_prediction_use_sigmoid,
  258. )
  259. if self.pred_obj_scores:
  260. self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
  261. if pred_obj_scores_mlp:
  262. self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
  263. # When outputting a single mask, optionally we can dynamically fall back to the best
  264. # multimask output token if the single mask output token gives low stability scores.
  265. self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
  266. self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
  267. self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
  268. def forward(
  269. self,
  270. image_embeddings: torch.Tensor,
  271. image_pe: torch.Tensor,
  272. sparse_prompt_embeddings: torch.Tensor,
  273. dense_prompt_embeddings: torch.Tensor,
  274. multimask_output: bool,
  275. repeat_image: bool,
  276. high_res_features: Optional[List[torch.Tensor]] = None,
  277. ) -> Tuple[torch.Tensor, torch.Tensor]:
  278. """
  279. Predicts masks given image and prompt embeddings.
  280. Args:
  281. image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
  282. image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W).
  283. sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C).
  284. dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).
  285. multimask_output (bool): Whether to return multiple masks or a single mask.
  286. repeat_image (bool): Flag to repeat the image embeddings.
  287. high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
  288. Returns:
  289. (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
  290. - masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
  291. - iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).
  292. - sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).
  293. - object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).
  294. Examples:
  295. >>> image_embeddings = torch.rand(1, 256, 64, 64)
  296. >>> image_pe = torch.rand(1, 256, 64, 64)
  297. >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
  298. >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
  299. >>> decoder = SAM2MaskDecoder(256, transformer)
  300. >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
  301. ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
  302. ... )
  303. """
  304. masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
  305. image_embeddings=image_embeddings,
  306. image_pe=image_pe,
  307. sparse_prompt_embeddings=sparse_prompt_embeddings,
  308. dense_prompt_embeddings=dense_prompt_embeddings,
  309. repeat_image=repeat_image,
  310. high_res_features=high_res_features,
  311. )
  312. # Select the correct mask or masks for output
  313. if multimask_output:
  314. masks = masks[:, 1:, :, :]
  315. iou_pred = iou_pred[:, 1:]
  316. elif self.dynamic_multimask_via_stability and not self.training:
  317. masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
  318. else:
  319. masks = masks[:, 0:1, :, :]
  320. iou_pred = iou_pred[:, 0:1]
  321. if multimask_output and self.use_multimask_token_for_obj_ptr:
  322. sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
  323. else:
  324. # Take the mask output token. Here we *always* use the token for single mask output.
  325. # At test time, even if we track after 1-click (and using multimask_output=True),
  326. # we still take the single mask token here. The rationale is that we always track
  327. # after multiple clicks during training, so the past tokens seen during training
  328. # are always the single mask token (and we'll let it be the object-memory token).
  329. sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
  330. # Prepare output
  331. return masks, iou_pred, sam_tokens_out, object_score_logits
  332. def predict_masks(
  333. self,
  334. image_embeddings: torch.Tensor,
  335. image_pe: torch.Tensor,
  336. sparse_prompt_embeddings: torch.Tensor,
  337. dense_prompt_embeddings: torch.Tensor,
  338. repeat_image: bool,
  339. high_res_features: Optional[List[torch.Tensor]] = None,
  340. ) -> Tuple[torch.Tensor, torch.Tensor]:
  341. """Predicts instance segmentation masks from image and prompt embeddings using a transformer."""
  342. # Concatenate output tokens
  343. s = 0
  344. if self.pred_obj_scores:
  345. output_tokens = torch.cat(
  346. [
  347. self.obj_score_token.weight,
  348. self.iou_token.weight,
  349. self.mask_tokens.weight,
  350. ],
  351. dim=0,
  352. )
  353. s = 1
  354. else:
  355. output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
  356. output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
  357. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
  358. # Expand per-image data in batch direction to be per-mask
  359. if repeat_image:
  360. src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
  361. else:
  362. assert image_embeddings.shape[0] == tokens.shape[0]
  363. src = image_embeddings
  364. src = src + dense_prompt_embeddings
  365. assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
  366. pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
  367. b, c, h, w = src.shape
  368. # Run the transformer
  369. hs, src = self.transformer(src, pos_src, tokens)
  370. iou_token_out = hs[:, s, :]
  371. mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
  372. # Upscale mask embeddings and predict masks using the mask tokens
  373. src = src.transpose(1, 2).view(b, c, h, w)
  374. if not self.use_high_res_features:
  375. upscaled_embedding = self.output_upscaling(src)
  376. else:
  377. dc1, ln1, act1, dc2, act2 = self.output_upscaling
  378. feat_s0, feat_s1 = high_res_features
  379. upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
  380. upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
  381. hyper_in_list: List[torch.Tensor] = [
  382. self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
  383. ]
  384. hyper_in = torch.stack(hyper_in_list, dim=1)
  385. b, c, h, w = upscaled_embedding.shape
  386. masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
  387. # Generate mask quality predictions
  388. iou_pred = self.iou_prediction_head(iou_token_out)
  389. if self.pred_obj_scores:
  390. assert s == 1
  391. object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
  392. else:
  393. # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
  394. object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
  395. return masks, iou_pred, mask_tokens_out, object_score_logits
  396. def _get_stability_scores(self, mask_logits):
  397. """Computes mask stability scores based on IoU between upper and lower thresholds."""
  398. mask_logits = mask_logits.flatten(-2)
  399. stability_delta = self.dynamic_multimask_stability_delta
  400. area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
  401. area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
  402. return torch.where(area_u > 0, area_i / area_u, 1.0)
  403. def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
  404. """
  405. Dynamically selects the most stable mask output based on stability scores and IoU predictions.
  406. This method is used when outputting a single mask. If the stability score from the current single-mask
  407. output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs
  408. (based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask
  409. for both clicking and tracking scenarios.
  410. Args:
  411. all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
  412. batch size, N is number of masks (typically 4), and H, W are mask dimensions.
  413. all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
  414. Returns:
  415. (Tuple[torch.Tensor, torch.Tensor]):
  416. - mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).
  417. - iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).
  418. Examples:
  419. >>> decoder = SAM2MaskDecoder(...)
  420. >>> all_mask_logits = torch.rand(2, 4, 256, 256) # 2 images, 4 masks each
  421. >>> all_iou_scores = torch.rand(2, 4)
  422. >>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores)
  423. >>> print(mask_logits.shape, iou_scores.shape)
  424. torch.Size([2, 1, 256, 256]) torch.Size([2, 1])
  425. """
  426. # The best mask from multimask output tokens (1~3)
  427. multimask_logits = all_mask_logits[:, 1:, :, :]
  428. multimask_iou_scores = all_iou_scores[:, 1:]
  429. best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
  430. batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
  431. best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
  432. best_multimask_logits = best_multimask_logits.unsqueeze(1)
  433. best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
  434. best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
  435. # The mask from singlemask output token 0 and its stability score
  436. singlemask_logits = all_mask_logits[:, 0:1, :, :]
  437. singlemask_iou_scores = all_iou_scores[:, 0:1]
  438. stability_scores = self._get_stability_scores(singlemask_logits)
  439. is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
  440. # Dynamically fall back to best multimask output upon low stability scores.
  441. mask_logits_out = torch.where(
  442. is_stable[..., None, None].expand_as(singlemask_logits),
  443. singlemask_logits,
  444. best_multimask_logits,
  445. )
  446. iou_scores_out = torch.where(
  447. is_stable.expand_as(singlemask_iou_scores),
  448. singlemask_iou_scores,
  449. best_multimask_iou_scores,
  450. )
  451. return mask_logits_out, iou_scores_out