sam.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. # Copyright (c) Meta Platforms, Inc. and affiliates.
  3. # All rights reserved.
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from typing import List
  7. import torch
  8. import torch.nn.functional as F
  9. from torch import nn
  10. from torch.nn.init import trunc_normal_
  11. from ultralytics.nn.modules import MLP
  12. from .blocks import SAM2TwoWayTransformer
  13. from .decoders import MaskDecoder, SAM2MaskDecoder
  14. from .encoders import ImageEncoderViT, PromptEncoder
  15. from .utils import get_1d_sine_pe, select_closest_cond_frames
  16. # a large negative value as a placeholder score for missing objects
  17. NO_OBJ_SCORE = -1024.0
  18. class SAMModel(nn.Module):
  19. """
  20. Segment Anything Model (SAM) for object segmentation tasks.
  21. This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images
  22. and input prompts.
  23. Attributes:
  24. mask_threshold (float): Threshold value for mask prediction.
  25. image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
  26. prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
  27. mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
  28. Methods:
  29. __init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
  30. Examples:
  31. >>> image_encoder = ImageEncoderViT(...)
  32. >>> prompt_encoder = PromptEncoder(...)
  33. >>> mask_decoder = MaskDecoder(...)
  34. >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
  35. >>> # Further usage depends on SAMPredictor class
  36. Notes:
  37. All forward() operations are implemented in the SAMPredictor class.
  38. """
  39. mask_threshold: float = 0.0
  40. def __init__(
  41. self,
  42. image_encoder: ImageEncoderViT,
  43. prompt_encoder: PromptEncoder,
  44. mask_decoder: MaskDecoder,
  45. pixel_mean: List[float] = (123.675, 116.28, 103.53),
  46. pixel_std: List[float] = (58.395, 57.12, 57.375),
  47. ) -> None:
  48. """
  49. Initialize the SAMModel class to predict object masks from an image and input prompts.
  50. Args:
  51. image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
  52. prompt_encoder (PromptEncoder): Encodes various types of input prompts.
  53. mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
  54. pixel_mean (List[float]): Mean values for normalizing pixels in the input image.
  55. pixel_std (List[float]): Std values for normalizing pixels in the input image.
  56. Examples:
  57. >>> image_encoder = ImageEncoderViT(...)
  58. >>> prompt_encoder = PromptEncoder(...)
  59. >>> mask_decoder = MaskDecoder(...)
  60. >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
  61. >>> # Further usage depends on SAMPredictor class
  62. Notes:
  63. All forward() operations moved to SAMPredictor.
  64. """
  65. super().__init__()
  66. self.image_encoder = image_encoder
  67. self.prompt_encoder = prompt_encoder
  68. self.mask_decoder = mask_decoder
  69. self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
  70. self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
  71. def set_imgsz(self, imgsz):
  72. """
  73. Set image size to make model compatible with different image sizes.
  74. Args:
  75. imgsz (Tuple[int, int]): The size of the input image.
  76. """
  77. if hasattr(self.image_encoder, "set_imgsz"):
  78. self.image_encoder.set_imgsz(imgsz)
  79. self.prompt_encoder.input_image_size = imgsz
  80. self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # 16 is fixed as patch size of ViT model
  81. self.image_encoder.img_size = imgsz[0]
  82. class SAM2Model(torch.nn.Module):
  83. """
  84. SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
  85. This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms
  86. for temporal consistency and efficient tracking of objects across frames.
  87. Attributes:
  88. mask_threshold (float): Threshold value for mask prediction.
  89. image_encoder (ImageEncoderViT): Visual encoder for extracting image features.
  90. memory_attention (nn.Module): Module for attending to memory features.
  91. memory_encoder (nn.Module): Encoder for generating memory representations.
  92. num_maskmem (int): Number of accessible memory frames.
  93. image_size (int): Size of input images.
  94. backbone_stride (int): Stride of the backbone network output.
  95. sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings.
  96. sam_image_embedding_size (int): Size of SAM image embeddings.
  97. sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts.
  98. sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.
  99. obj_ptr_proj (nn.Module): Projection layer for object pointers.
  100. obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.
  101. Methods:
  102. forward_image: Processes image batch through encoder to extract multi-level features.
  103. track_step: Performs a single tracking step, updating object masks and memory features.
  104. Examples:
  105. >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
  106. >>> image_batch = torch.rand(1, 3, 512, 512)
  107. >>> features = model.forward_image(image_batch)
  108. >>> track_results = model.track_step(0, True, features, None, None, None, {})
  109. """
  110. mask_threshold: float = 0.0
  111. def __init__(
  112. self,
  113. image_encoder,
  114. memory_attention,
  115. memory_encoder,
  116. num_maskmem=7,
  117. image_size=512,
  118. backbone_stride=16,
  119. sigmoid_scale_for_mem_enc=1.0,
  120. sigmoid_bias_for_mem_enc=0.0,
  121. binarize_mask_from_pts_for_mem_enc=False,
  122. use_mask_input_as_output_without_sam=False,
  123. max_cond_frames_in_attn=-1,
  124. directly_add_no_mem_embed=False,
  125. use_high_res_features_in_sam=False,
  126. multimask_output_in_sam=False,
  127. multimask_min_pt_num=1,
  128. multimask_max_pt_num=1,
  129. multimask_output_for_tracking=False,
  130. use_multimask_token_for_obj_ptr: bool = False,
  131. iou_prediction_use_sigmoid=False,
  132. memory_temporal_stride_for_eval=1,
  133. non_overlap_masks_for_mem_enc=False,
  134. use_obj_ptrs_in_encoder=False,
  135. max_obj_ptrs_in_encoder=16,
  136. add_tpos_enc_to_obj_ptrs=True,
  137. proj_tpos_enc_in_obj_ptrs=False,
  138. use_signed_tpos_enc_to_obj_ptrs=False,
  139. only_obj_ptrs_in_the_past_for_eval=False,
  140. pred_obj_scores: bool = False,
  141. pred_obj_scores_mlp: bool = False,
  142. fixed_no_obj_ptr: bool = False,
  143. soft_no_obj_ptr: bool = False,
  144. use_mlp_for_obj_ptr_proj: bool = False,
  145. no_obj_embed_spatial: bool = False,
  146. sam_mask_decoder_extra_args=None,
  147. compile_image_encoder: bool = False,
  148. ):
  149. """
  150. Initializes the SAM2Model for video object segmentation with memory-based tracking.
  151. Args:
  152. image_encoder (nn.Module): Visual encoder for extracting image features.
  153. memory_attention (nn.Module): Module for attending to memory features.
  154. memory_encoder (nn.Module): Encoder for generating memory representations.
  155. num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames).
  156. image_size (int): Size of input images.
  157. backbone_stride (int): Stride of the image backbone output.
  158. sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
  159. sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
  160. binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
  161. with clicks during evaluation.
  162. use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
  163. prompt encoder and mask decoder on frames with mask input.
  164. max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
  165. -1 means no limit.
  166. directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
  167. first frame.
  168. use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
  169. multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial
  170. conditioning frames.
  171. multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
  172. multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
  173. multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
  174. use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
  175. iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
  176. memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
  177. non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
  178. memory encoder during evaluation.
  179. use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
  180. max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
  181. cross-attention.
  182. add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
  183. the encoder.
  184. proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
  185. encoding in object pointers.
  186. use_signed_tpos_enc_to_obj_ptrs (bool): whether to use signed distance (instead of unsigned absolute distance)
  187. in the temporal positional encoding in the object pointers, only relevant when both `use_obj_ptrs_in_encoder=True`
  188. and `add_tpos_enc_to_obj_ptrs=True`.
  189. only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
  190. during evaluation.
  191. pred_obj_scores (bool): Whether to predict if there is an object in the frame.
  192. pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
  193. fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
  194. soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
  195. use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
  196. no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.
  197. sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder.
  198. compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
  199. Examples:
  200. >>> image_encoder = ImageEncoderViT(...)
  201. >>> memory_attention = SAM2TwoWayTransformer(...)
  202. >>> memory_encoder = nn.Sequential(...)
  203. >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
  204. >>> image_batch = torch.rand(1, 3, 512, 512)
  205. >>> features = model.forward_image(image_batch)
  206. >>> track_results = model.track_step(0, True, features, None, None, None, {})
  207. """
  208. super().__init__()
  209. # Part 1: the image backbone
  210. self.image_encoder = image_encoder
  211. # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
  212. self.use_high_res_features_in_sam = use_high_res_features_in_sam
  213. self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
  214. self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
  215. self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
  216. if use_obj_ptrs_in_encoder:
  217. # A conv layer to downsample the mask prompt to stride 4 (the same stride as
  218. # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
  219. # so that it can be fed into the SAM mask decoder to generate a pointer.
  220. self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
  221. self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
  222. if proj_tpos_enc_in_obj_ptrs:
  223. assert add_tpos_enc_to_obj_ptrs # these options need to be used together
  224. self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
  225. self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
  226. self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
  227. # Part 2: memory attention to condition current frame's visual features
  228. # with memories (and obj ptrs) from past frames
  229. self.memory_attention = memory_attention
  230. self.hidden_dim = memory_attention.d_model
  231. # Part 3: memory encoder for the previous frame's outputs
  232. self.memory_encoder = memory_encoder
  233. self.mem_dim = self.hidden_dim
  234. if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
  235. # if there is compression of memories along channel dim
  236. self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
  237. self.num_maskmem = num_maskmem # Number of memories accessible
  238. # Temporal encoding of the memories
  239. self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
  240. trunc_normal_(self.maskmem_tpos_enc, std=0.02)
  241. # a single token to indicate no memory embedding from previous frames
  242. self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
  243. self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
  244. trunc_normal_(self.no_mem_embed, std=0.02)
  245. trunc_normal_(self.no_mem_pos_enc, std=0.02)
  246. self.directly_add_no_mem_embed = directly_add_no_mem_embed
  247. # Apply sigmoid to the output raw mask logits (to turn them from
  248. # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
  249. self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
  250. self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
  251. self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
  252. self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
  253. self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
  254. # On frames with mask input, whether to directly output the input mask without
  255. # using a SAM prompt encoder + mask decoder
  256. self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
  257. self.multimask_output_in_sam = multimask_output_in_sam
  258. self.multimask_min_pt_num = multimask_min_pt_num
  259. self.multimask_max_pt_num = multimask_max_pt_num
  260. self.multimask_output_for_tracking = multimask_output_for_tracking
  261. self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
  262. self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
  263. # Part 4: SAM-style prompt encoder (for both mask and point inputs)
  264. # and SAM-style mask decoder for the final mask output
  265. self.image_size = image_size
  266. self.backbone_stride = backbone_stride
  267. self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
  268. self.pred_obj_scores = pred_obj_scores
  269. self.pred_obj_scores_mlp = pred_obj_scores_mlp
  270. self.fixed_no_obj_ptr = fixed_no_obj_ptr
  271. self.soft_no_obj_ptr = soft_no_obj_ptr
  272. if self.fixed_no_obj_ptr:
  273. assert self.pred_obj_scores
  274. assert self.use_obj_ptrs_in_encoder
  275. if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
  276. self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
  277. trunc_normal_(self.no_obj_ptr, std=0.02)
  278. self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
  279. self.no_obj_embed_spatial = None
  280. if no_obj_embed_spatial:
  281. self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
  282. trunc_normal_(self.no_obj_embed_spatial, std=0.02)
  283. self._build_sam_heads()
  284. self.max_cond_frames_in_attn = max_cond_frames_in_attn
  285. # Model compilation
  286. if compile_image_encoder:
  287. # Compile the forward function (not the full module) to allow loading checkpoints.
  288. print("Image encoder compilation is enabled. First forward pass will be slow.")
  289. self.image_encoder.forward = torch.compile(
  290. self.image_encoder.forward,
  291. mode="max-autotune",
  292. fullgraph=True,
  293. dynamic=False,
  294. )
  295. @property
  296. def device(self):
  297. """Returns the device on which the model's parameters are stored."""
  298. return next(self.parameters()).device
  299. def forward(self, *args, **kwargs):
  300. """Processes image and prompt inputs to generate object masks and scores in video sequences."""
  301. raise NotImplementedError(
  302. "Please use the corresponding methods in SAM2VideoPredictor for inference."
  303. "See notebooks/video_predictor_example.ipynb for an example."
  304. )
  305. def _build_sam_heads(self):
  306. """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
  307. self.sam_prompt_embed_dim = self.hidden_dim
  308. self.sam_image_embedding_size = self.image_size // self.backbone_stride
  309. # Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code)
  310. self.sam_prompt_encoder = PromptEncoder(
  311. embed_dim=self.sam_prompt_embed_dim,
  312. image_embedding_size=(
  313. self.sam_image_embedding_size,
  314. self.sam_image_embedding_size,
  315. ),
  316. input_image_size=(self.image_size, self.image_size),
  317. mask_in_chans=16,
  318. )
  319. self.sam_mask_decoder = SAM2MaskDecoder(
  320. num_multimask_outputs=3,
  321. transformer=SAM2TwoWayTransformer(
  322. depth=2,
  323. embedding_dim=self.sam_prompt_embed_dim,
  324. mlp_dim=2048,
  325. num_heads=8,
  326. ),
  327. transformer_dim=self.sam_prompt_embed_dim,
  328. iou_head_depth=3,
  329. iou_head_hidden_dim=256,
  330. use_high_res_features=self.use_high_res_features_in_sam,
  331. iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
  332. pred_obj_scores=self.pred_obj_scores,
  333. pred_obj_scores_mlp=self.pred_obj_scores_mlp,
  334. use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
  335. **(self.sam_mask_decoder_extra_args or {}),
  336. )
  337. if self.use_obj_ptrs_in_encoder:
  338. # a linear projection on SAM output tokens to turn them into object pointers
  339. self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
  340. if self.use_mlp_for_obj_ptr_proj:
  341. self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
  342. else:
  343. self.obj_ptr_proj = torch.nn.Identity()
  344. if self.proj_tpos_enc_in_obj_ptrs:
  345. # a linear projection on temporal positional encoding in object pointers to
  346. # avoid potential interference with spatial positional encoding
  347. self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
  348. else:
  349. self.obj_ptr_tpos_proj = torch.nn.Identity()
  350. def _forward_sam_heads(
  351. self,
  352. backbone_features,
  353. point_inputs=None,
  354. mask_inputs=None,
  355. high_res_features=None,
  356. multimask_output=False,
  357. ):
  358. """
  359. Forward pass through SAM prompt encoders and mask heads.
  360. This method processes image features and optional point/mask inputs to generate object masks and scores.
  361. Args:
  362. backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
  363. point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
  364. 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
  365. pixel-unit coordinates in (x, y) format for P input points.
  366. 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
  367. 0 means negative clicks, and -1 means padding.
  368. mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
  369. same spatial size as the image.
  370. high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
  371. (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
  372. for SAM decoder.
  373. multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
  374. output only 1 mask and its IoU estimate.
  375. Returns:
  376. (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
  377. low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
  378. high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
  379. ious: Tensor of shape (B, M) with estimated IoU for each output mask.
  380. low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
  381. high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
  382. obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
  383. object_score_logits: Tensor of shape (B) with object score logits.
  384. Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
  385. Examples:
  386. >>> backbone_features = torch.rand(1, 256, 32, 32)
  387. >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
  388. >>> mask_inputs = torch.rand(1, 1, 512, 512)
  389. >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
  390. >>> (
  391. ... low_res_multimasks,
  392. ... high_res_multimasks,
  393. ... ious,
  394. ... low_res_masks,
  395. ... high_res_masks,
  396. ... obj_ptr,
  397. ... object_score_logits,
  398. ... ) = results
  399. """
  400. B = backbone_features.size(0)
  401. device = backbone_features.device
  402. assert backbone_features.size(1) == self.sam_prompt_embed_dim
  403. assert backbone_features.size(2) == self.sam_image_embedding_size
  404. assert backbone_features.size(3) == self.sam_image_embedding_size
  405. # a) Handle point prompts
  406. if point_inputs is not None:
  407. sam_point_coords = point_inputs["point_coords"]
  408. sam_point_labels = point_inputs["point_labels"]
  409. assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
  410. else:
  411. # If no points are provide, pad with an empty point (with label -1)
  412. sam_point_coords = torch.zeros(B, 1, 2, device=device)
  413. sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
  414. # b) Handle mask prompts
  415. if mask_inputs is not None:
  416. # If mask_inputs is provided, downsize it into low-res mask input if needed
  417. # and feed it as a dense mask prompt into the SAM mask encoder
  418. assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
  419. if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
  420. sam_mask_prompt = F.interpolate(
  421. mask_inputs.float(),
  422. size=self.sam_prompt_encoder.mask_input_size,
  423. align_corners=False,
  424. mode="bilinear",
  425. antialias=True, # use antialias for downsampling
  426. )
  427. else:
  428. sam_mask_prompt = mask_inputs
  429. else:
  430. # Otherwise, simply feed None (and SAM's prompt encoder will add
  431. # a learned `no_mask_embed` to indicate no mask input in this case).
  432. sam_mask_prompt = None
  433. sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
  434. points=(sam_point_coords, sam_point_labels),
  435. boxes=None,
  436. masks=sam_mask_prompt,
  437. )
  438. low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder(
  439. image_embeddings=backbone_features,
  440. image_pe=self.sam_prompt_encoder.get_dense_pe(),
  441. sparse_prompt_embeddings=sparse_embeddings,
  442. dense_prompt_embeddings=dense_embeddings,
  443. multimask_output=multimask_output,
  444. repeat_image=False, # the image is already batched
  445. high_res_features=high_res_features,
  446. )
  447. if self.pred_obj_scores:
  448. is_obj_appearing = object_score_logits > 0
  449. # Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction
  450. low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE)
  451. # convert masks from possibly bfloat16 (or float16) to float32
  452. # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
  453. low_res_multimasks = low_res_multimasks.float()
  454. high_res_multimasks = F.interpolate(
  455. low_res_multimasks,
  456. size=(self.image_size, self.image_size),
  457. mode="bilinear",
  458. align_corners=False,
  459. )
  460. sam_output_token = sam_output_tokens[:, 0]
  461. if multimask_output:
  462. # take the best mask prediction (with the highest IoU estimation)
  463. best_iou_inds = torch.argmax(ious, dim=-1)
  464. batch_inds = torch.arange(B, device=device)
  465. low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
  466. high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
  467. if sam_output_tokens.size(1) > 1:
  468. sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
  469. else:
  470. low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
  471. # Extract object pointer from the SAM output token (with occlusion handling)
  472. obj_ptr = self.obj_ptr_proj(sam_output_token)
  473. if self.pred_obj_scores:
  474. # Allow *soft* no obj ptr, unlike for masks
  475. if self.soft_no_obj_ptr:
  476. lambda_is_obj_appearing = object_score_logits.sigmoid()
  477. else:
  478. lambda_is_obj_appearing = is_obj_appearing.float()
  479. if self.fixed_no_obj_ptr:
  480. obj_ptr = lambda_is_obj_appearing * obj_ptr
  481. obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
  482. return (
  483. low_res_multimasks,
  484. high_res_multimasks,
  485. ious,
  486. low_res_masks,
  487. high_res_masks,
  488. obj_ptr,
  489. object_score_logits,
  490. )
  491. def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
  492. """Processes mask inputs directly as output, bypassing SAM encoder/decoder."""
  493. # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
  494. out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
  495. mask_inputs_float = mask_inputs.float()
  496. high_res_masks = mask_inputs_float * out_scale + out_bias
  497. low_res_masks = F.interpolate(
  498. high_res_masks,
  499. size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
  500. align_corners=False,
  501. mode="bilinear",
  502. antialias=True, # use antialias for downsampling
  503. )
  504. # a dummy IoU prediction of all 1's under mask input
  505. ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
  506. if not self.use_obj_ptrs_in_encoder:
  507. # all zeros as a dummy object pointer (of shape [B, C])
  508. obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
  509. else:
  510. # produce an object pointer using the SAM decoder from the mask input
  511. _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
  512. backbone_features=backbone_features,
  513. mask_inputs=self.mask_downsample(mask_inputs_float),
  514. high_res_features=high_res_features,
  515. )
  516. # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
  517. # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
  518. # on the object_scores from the SAM decoder.
  519. is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
  520. is_obj_appearing = is_obj_appearing[..., None]
  521. lambda_is_obj_appearing = is_obj_appearing.float()
  522. object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
  523. if self.pred_obj_scores:
  524. if self.fixed_no_obj_ptr:
  525. obj_ptr = lambda_is_obj_appearing * obj_ptr
  526. obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
  527. return (
  528. low_res_masks,
  529. high_res_masks,
  530. ious,
  531. low_res_masks,
  532. high_res_masks,
  533. obj_ptr,
  534. object_score_logits,
  535. )
  536. def forward_image(self, img_batch: torch.Tensor):
  537. """Processes image batch through encoder to extract multi-level features for SAM model."""
  538. backbone_out = self.image_encoder(img_batch)
  539. if self.use_high_res_features_in_sam:
  540. # precompute projected level 0 and level 1 features in SAM decoder
  541. # to avoid running it again on every SAM click
  542. backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
  543. backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
  544. return backbone_out
  545. def _prepare_backbone_features(self, backbone_out):
  546. """Prepares and flattens visual features from the image backbone output for further processing."""
  547. assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
  548. assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
  549. feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
  550. vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
  551. feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
  552. # flatten NxCxHxW to HWxNxC
  553. vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
  554. vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
  555. return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
  556. def _prepare_memory_conditioned_features(
  557. self,
  558. frame_idx,
  559. is_init_cond_frame,
  560. current_vision_feats,
  561. current_vision_pos_embeds,
  562. feat_sizes,
  563. output_dict,
  564. num_frames,
  565. track_in_reverse=False, # tracking in reverse time order (for demo usage)
  566. ):
  567. """Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
  568. B = current_vision_feats[-1].size(1) # batch size on this frame
  569. C = self.hidden_dim
  570. H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
  571. device = current_vision_feats[-1].device
  572. # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
  573. # In this case, we skip the fusion with any memory.
  574. if self.num_maskmem == 0: # Disable memory and skip fusion
  575. return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
  576. num_obj_ptr_tokens = 0
  577. tpos_sign_mul = -1 if track_in_reverse else 1
  578. # Step 1: condition the visual features of the current frame on previous memories
  579. if not is_init_cond_frame:
  580. # Retrieve the memories encoded with the maskmem backbone
  581. to_cat_memory, to_cat_memory_pos_embed = [], []
  582. # Add conditioning frame's output first (all cond frames have t_pos=0 for
  583. # when getting temporal positional embedding below)
  584. assert len(output_dict["cond_frame_outputs"]) > 0
  585. # Select a maximum number of temporally closest cond frames for cross attention
  586. cond_outputs = output_dict["cond_frame_outputs"]
  587. selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
  588. frame_idx, cond_outputs, self.max_cond_frames_in_attn
  589. )
  590. t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
  591. # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
  592. # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
  593. # We also allow taking the memory frame non-consecutively (with r>1), in which case
  594. # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
  595. r = 1 if self.training else self.memory_temporal_stride_for_eval
  596. for t_pos in range(1, self.num_maskmem):
  597. t_rel = self.num_maskmem - t_pos # how many frames before current frame
  598. if t_rel == 1:
  599. # for t_rel == 1, we take the last frame (regardless of r)
  600. prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel
  601. elif not track_in_reverse:
  602. # first find the nearest frame among every r-th frames before this frame
  603. # for r=1, this would be (frame_idx - 2)
  604. prev_frame_idx = ((frame_idx - 2) // r) * r
  605. # then seek further among every r-th frames
  606. prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
  607. else:
  608. # first find the nearest frame among every r-th frames after this frame
  609. # for r=1, this would be (frame_idx + 2)
  610. prev_frame_idx = -(-(frame_idx + 2) // r) * r
  611. # then seek further among every r-th frames
  612. prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
  613. out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
  614. if out is None:
  615. # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
  616. # frames, we still attend to it as if it's a non-conditioning frame.
  617. out = unselected_cond_outputs.get(prev_frame_idx, None)
  618. t_pos_and_prevs.append((t_pos, out))
  619. for t_pos, prev in t_pos_and_prevs:
  620. if prev is None:
  621. continue # skip padding frames
  622. # "maskmem_features" might have been offloaded to CPU in demo use cases,
  623. # so we load it back to inference device (it's a no-op if it's already on device).
  624. feats = prev["maskmem_features"].to(device=device, non_blocking=True)
  625. to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
  626. # Spatial positional encoding (it might have been offloaded to CPU in eval)
  627. maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device)
  628. maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
  629. # Temporal positional encoding
  630. maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
  631. to_cat_memory_pos_embed.append(maskmem_enc)
  632. # Construct the list of past object pointers
  633. if self.use_obj_ptrs_in_encoder:
  634. max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
  635. # First add those object pointers from selected conditioning frames
  636. # (optionally, only include object pointers in the past during evaluation)
  637. if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
  638. ptr_cond_outputs = {
  639. t: out
  640. for t, out in selected_cond_outputs.items()
  641. if (t >= frame_idx if track_in_reverse else t <= frame_idx)
  642. }
  643. else:
  644. ptr_cond_outputs = selected_cond_outputs
  645. pos_and_ptrs = [
  646. # Temporal pos encoding contains how far away each pointer is from current frame
  647. (
  648. (
  649. (frame_idx - t) * tpos_sign_mul
  650. if self.use_signed_tpos_enc_to_obj_ptrs
  651. else abs(frame_idx - t)
  652. ),
  653. out["obj_ptr"],
  654. )
  655. for t, out in ptr_cond_outputs.items()
  656. ]
  657. # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
  658. for t_diff in range(1, max_obj_ptrs_in_encoder):
  659. t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
  660. if t < 0 or (num_frames is not None and t >= num_frames):
  661. break
  662. out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
  663. if out is not None:
  664. pos_and_ptrs.append((t_diff, out["obj_ptr"]))
  665. # If we have at least one object pointer, add them to the across attention
  666. if pos_and_ptrs:
  667. pos_list, ptrs_list = zip(*pos_and_ptrs)
  668. # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
  669. obj_ptrs = torch.stack(ptrs_list, dim=0)
  670. # a temporal positional embedding based on how far each object pointer is from
  671. # the current frame (sine embedding normalized by the max pointer num).
  672. if self.add_tpos_enc_to_obj_ptrs:
  673. t_diff_max = max_obj_ptrs_in_encoder - 1
  674. tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
  675. obj_pos = torch.tensor(pos_list, device=device)
  676. obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
  677. obj_pos = self.obj_ptr_tpos_proj(obj_pos)
  678. obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
  679. else:
  680. obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
  681. if self.mem_dim < C:
  682. # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
  683. obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
  684. obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
  685. obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
  686. to_cat_memory.append(obj_ptrs)
  687. to_cat_memory_pos_embed.append(obj_pos)
  688. num_obj_ptr_tokens = obj_ptrs.shape[0]
  689. else:
  690. num_obj_ptr_tokens = 0
  691. else:
  692. # for initial conditioning frames, encode them without using any previous memory
  693. if self.directly_add_no_mem_embed:
  694. # directly add no-mem embedding (instead of using the transformer encoder)
  695. pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
  696. pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
  697. return pix_feat_with_mem
  698. # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
  699. to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
  700. to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
  701. # Step 2: Concatenate the memories and forward through the transformer encoder
  702. memory = torch.cat(to_cat_memory, dim=0)
  703. memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
  704. pix_feat_with_mem = self.memory_attention(
  705. curr=current_vision_feats,
  706. curr_pos=current_vision_pos_embeds,
  707. memory=memory,
  708. memory_pos=memory_pos_embed,
  709. num_obj_ptr_tokens=num_obj_ptr_tokens,
  710. )
  711. # reshape the output (HW)BC => BCHW
  712. pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
  713. return pix_feat_with_mem
  714. def _encode_new_memory(
  715. self,
  716. current_vision_feats,
  717. feat_sizes,
  718. pred_masks_high_res,
  719. object_score_logits,
  720. is_mask_from_pts,
  721. ):
  722. """Encodes frame features and masks into a new memory representation for video segmentation."""
  723. B = current_vision_feats[-1].size(1) # batch size on this frame
  724. C = self.hidden_dim
  725. H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
  726. # top-level feature, (HW)BC => BCHW
  727. pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
  728. if self.non_overlap_masks_for_mem_enc and not self.training:
  729. # optionally, apply non-overlapping constraints to the masks (it's applied
  730. # in the batch dimension and should only be used during eval, where all
  731. # the objects come from the same video under batch size 1).
  732. pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
  733. # scale the raw mask logits with a temperature before applying sigmoid
  734. binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
  735. if binarize and not self.training:
  736. mask_for_mem = (pred_masks_high_res > 0).float()
  737. else:
  738. # apply sigmoid on the raw mask logits to turn them into range (0, 1)
  739. mask_for_mem = torch.sigmoid(pred_masks_high_res)
  740. # apply scale and bias terms to the sigmoid probabilities
  741. if self.sigmoid_scale_for_mem_enc != 1.0:
  742. mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
  743. if self.sigmoid_bias_for_mem_enc != 0.0:
  744. mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
  745. maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
  746. maskmem_features = maskmem_out["vision_features"]
  747. maskmem_pos_enc = maskmem_out["vision_pos_enc"]
  748. # add a no-object embedding to the spatial memory to indicate that the frame
  749. # is predicted to be occluded (i.e. no object is appearing in the frame)
  750. if self.no_obj_embed_spatial is not None:
  751. is_obj_appearing = (object_score_logits > 0).float()
  752. maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[
  753. ..., None, None
  754. ].expand(*maskmem_features.shape)
  755. return maskmem_features, maskmem_pos_enc
  756. def _track_step(
  757. self,
  758. frame_idx,
  759. is_init_cond_frame,
  760. current_vision_feats,
  761. current_vision_pos_embeds,
  762. feat_sizes,
  763. point_inputs,
  764. mask_inputs,
  765. output_dict,
  766. num_frames,
  767. track_in_reverse,
  768. prev_sam_mask_logits,
  769. ):
  770. """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
  771. current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
  772. # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
  773. if len(current_vision_feats) > 1:
  774. high_res_features = [
  775. x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
  776. for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
  777. ]
  778. else:
  779. high_res_features = None
  780. if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
  781. # When use_mask_input_as_output_without_sam=True, we directly output the mask input
  782. # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
  783. pix_feat = current_vision_feats[-1].permute(1, 2, 0)
  784. pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
  785. sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
  786. else:
  787. # fused the visual feature with previous memory features in the memory bank
  788. pix_feat = self._prepare_memory_conditioned_features(
  789. frame_idx=frame_idx,
  790. is_init_cond_frame=is_init_cond_frame,
  791. current_vision_feats=current_vision_feats[-1:],
  792. current_vision_pos_embeds=current_vision_pos_embeds[-1:],
  793. feat_sizes=feat_sizes[-1:],
  794. output_dict=output_dict,
  795. num_frames=num_frames,
  796. track_in_reverse=track_in_reverse,
  797. )
  798. # apply SAM-style segmentation head
  799. # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
  800. # e.g. in demo where such logits come from earlier interaction instead of correction sampling
  801. # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
  802. if prev_sam_mask_logits is not None:
  803. assert point_inputs is not None and mask_inputs is None
  804. mask_inputs = prev_sam_mask_logits
  805. multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
  806. sam_outputs = self._forward_sam_heads(
  807. backbone_features=pix_feat,
  808. point_inputs=point_inputs,
  809. mask_inputs=mask_inputs,
  810. high_res_features=high_res_features,
  811. multimask_output=multimask_output,
  812. )
  813. return current_out, sam_outputs, high_res_features, pix_feat
  814. def _encode_memory_in_output(
  815. self,
  816. current_vision_feats,
  817. feat_sizes,
  818. point_inputs,
  819. run_mem_encoder,
  820. high_res_masks,
  821. object_score_logits,
  822. current_out,
  823. ):
  824. """Finally run the memory encoder on the predicted mask to encode, it into a new memory feature (that can be
  825. used in future frames).
  826. """
  827. if run_mem_encoder and self.num_maskmem > 0:
  828. high_res_masks_for_mem_enc = high_res_masks
  829. maskmem_features, maskmem_pos_enc = self._encode_new_memory(
  830. current_vision_feats=current_vision_feats,
  831. feat_sizes=feat_sizes,
  832. pred_masks_high_res=high_res_masks_for_mem_enc,
  833. object_score_logits=object_score_logits,
  834. is_mask_from_pts=(point_inputs is not None),
  835. )
  836. current_out["maskmem_features"] = maskmem_features
  837. current_out["maskmem_pos_enc"] = maskmem_pos_enc
  838. else:
  839. current_out["maskmem_features"] = None
  840. current_out["maskmem_pos_enc"] = None
  841. def track_step(
  842. self,
  843. frame_idx,
  844. is_init_cond_frame,
  845. current_vision_feats,
  846. current_vision_pos_embeds,
  847. feat_sizes,
  848. point_inputs,
  849. mask_inputs,
  850. output_dict,
  851. num_frames,
  852. track_in_reverse=False, # tracking in reverse time order (for demo usage)
  853. # Whether to run the memory encoder on the predicted masks. Sometimes we might want
  854. # to skip the memory encoder with `run_mem_encoder=False`. For example,
  855. # in demo we might call `track_step` multiple times for each user click,
  856. # and only encode the memory when the user finalizes their clicks. And in ablation
  857. # settings like SAM training on static images, we don't need the memory encoder.
  858. run_mem_encoder=True,
  859. # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
  860. prev_sam_mask_logits=None,
  861. ):
  862. """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
  863. current_out, sam_outputs, _, _ = self._track_step(
  864. frame_idx,
  865. is_init_cond_frame,
  866. current_vision_feats,
  867. current_vision_pos_embeds,
  868. feat_sizes,
  869. point_inputs,
  870. mask_inputs,
  871. output_dict,
  872. num_frames,
  873. track_in_reverse,
  874. prev_sam_mask_logits,
  875. )
  876. _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs
  877. current_out["pred_masks"] = low_res_masks
  878. current_out["pred_masks_high_res"] = high_res_masks
  879. current_out["obj_ptr"] = obj_ptr
  880. if not self.training:
  881. # Only add this in inference (to avoid unused param in activation checkpointing;
  882. # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
  883. current_out["object_score_logits"] = object_score_logits
  884. # Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)
  885. self._encode_memory_in_output(
  886. current_vision_feats,
  887. feat_sizes,
  888. point_inputs,
  889. run_mem_encoder,
  890. high_res_masks,
  891. object_score_logits,
  892. current_out,
  893. )
  894. return current_out
  895. def _use_multimask(self, is_init_cond_frame, point_inputs):
  896. """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
  897. num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
  898. return (
  899. self.multimask_output_in_sam
  900. and (is_init_cond_frame or self.multimask_output_for_tracking)
  901. and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
  902. )
  903. @staticmethod
  904. def _apply_non_overlapping_constraints(pred_masks):
  905. """Applies non-overlapping constraints to masks, keeping the highest scoring object per location."""
  906. batch_size = pred_masks.size(0)
  907. if batch_size == 1:
  908. return pred_masks
  909. device = pred_masks.device
  910. # "max_obj_inds": object index of the object with the highest score at each location
  911. max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
  912. # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
  913. batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
  914. keep = max_obj_inds == batch_obj_inds
  915. # suppress overlapping regions' scores below -10.0 so that the foreground regions
  916. # don't overlap (here sigmoid(-10.0)=4.5398e-05)
  917. pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
  918. return pred_masks
  919. def set_binarize(self, binarize=False):
  920. """Set binarize for VideoPredictor."""
  921. self.binarize_mask_from_pts_for_mem_enc = binarize
  922. def set_imgsz(self, imgsz):
  923. """
  924. Set image size to make model compatible with different image sizes.
  925. Args:
  926. imgsz (Tuple[int, int]): The size of the input image.
  927. """
  928. self.image_size = imgsz[0]
  929. self.sam_prompt_encoder.input_image_size = imgsz
  930. self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16