12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013 |
- # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- from typing import List
- import torch
- import torch.nn.functional as F
- from torch import nn
- from torch.nn.init import trunc_normal_
- from ultralytics.nn.modules import MLP
- from .blocks import SAM2TwoWayTransformer
- from .decoders import MaskDecoder, SAM2MaskDecoder
- from .encoders import ImageEncoderViT, PromptEncoder
- from .utils import get_1d_sine_pe, select_closest_cond_frames
- # a large negative value as a placeholder score for missing objects
- NO_OBJ_SCORE = -1024.0
- class SAMModel(nn.Module):
- """
- Segment Anything Model (SAM) for object segmentation tasks.
- This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images
- and input prompts.
- Attributes:
- mask_threshold (float): Threshold value for mask prediction.
- image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
- prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
- mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
- Methods:
- __init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
- Examples:
- >>> image_encoder = ImageEncoderViT(...)
- >>> prompt_encoder = PromptEncoder(...)
- >>> mask_decoder = MaskDecoder(...)
- >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
- >>> # Further usage depends on SAMPredictor class
- Notes:
- All forward() operations are implemented in the SAMPredictor class.
- """
- mask_threshold: float = 0.0
- def __init__(
- self,
- image_encoder: ImageEncoderViT,
- prompt_encoder: PromptEncoder,
- mask_decoder: MaskDecoder,
- pixel_mean: List[float] = (123.675, 116.28, 103.53),
- pixel_std: List[float] = (58.395, 57.12, 57.375),
- ) -> None:
- """
- Initialize the SAMModel class to predict object masks from an image and input prompts.
- Args:
- image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
- prompt_encoder (PromptEncoder): Encodes various types of input prompts.
- mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
- pixel_mean (List[float]): Mean values for normalizing pixels in the input image.
- pixel_std (List[float]): Std values for normalizing pixels in the input image.
- Examples:
- >>> image_encoder = ImageEncoderViT(...)
- >>> prompt_encoder = PromptEncoder(...)
- >>> mask_decoder = MaskDecoder(...)
- >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
- >>> # Further usage depends on SAMPredictor class
- Notes:
- All forward() operations moved to SAMPredictor.
- """
- super().__init__()
- self.image_encoder = image_encoder
- self.prompt_encoder = prompt_encoder
- self.mask_decoder = mask_decoder
- self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
- self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
- def set_imgsz(self, imgsz):
- """
- Set image size to make model compatible with different image sizes.
- Args:
- imgsz (Tuple[int, int]): The size of the input image.
- """
- if hasattr(self.image_encoder, "set_imgsz"):
- self.image_encoder.set_imgsz(imgsz)
- self.prompt_encoder.input_image_size = imgsz
- self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # 16 is fixed as patch size of ViT model
- self.image_encoder.img_size = imgsz[0]
- class SAM2Model(torch.nn.Module):
- """
- SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
- This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms
- for temporal consistency and efficient tracking of objects across frames.
- Attributes:
- mask_threshold (float): Threshold value for mask prediction.
- image_encoder (ImageEncoderViT): Visual encoder for extracting image features.
- memory_attention (nn.Module): Module for attending to memory features.
- memory_encoder (nn.Module): Encoder for generating memory representations.
- num_maskmem (int): Number of accessible memory frames.
- image_size (int): Size of input images.
- backbone_stride (int): Stride of the backbone network output.
- sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings.
- sam_image_embedding_size (int): Size of SAM image embeddings.
- sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts.
- sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.
- obj_ptr_proj (nn.Module): Projection layer for object pointers.
- obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.
- Methods:
- forward_image: Processes image batch through encoder to extract multi-level features.
- track_step: Performs a single tracking step, updating object masks and memory features.
- Examples:
- >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
- >>> image_batch = torch.rand(1, 3, 512, 512)
- >>> features = model.forward_image(image_batch)
- >>> track_results = model.track_step(0, True, features, None, None, None, {})
- """
- mask_threshold: float = 0.0
- def __init__(
- self,
- image_encoder,
- memory_attention,
- memory_encoder,
- num_maskmem=7,
- image_size=512,
- backbone_stride=16,
- sigmoid_scale_for_mem_enc=1.0,
- sigmoid_bias_for_mem_enc=0.0,
- binarize_mask_from_pts_for_mem_enc=False,
- use_mask_input_as_output_without_sam=False,
- max_cond_frames_in_attn=-1,
- directly_add_no_mem_embed=False,
- use_high_res_features_in_sam=False,
- multimask_output_in_sam=False,
- multimask_min_pt_num=1,
- multimask_max_pt_num=1,
- multimask_output_for_tracking=False,
- use_multimask_token_for_obj_ptr: bool = False,
- iou_prediction_use_sigmoid=False,
- memory_temporal_stride_for_eval=1,
- non_overlap_masks_for_mem_enc=False,
- use_obj_ptrs_in_encoder=False,
- max_obj_ptrs_in_encoder=16,
- add_tpos_enc_to_obj_ptrs=True,
- proj_tpos_enc_in_obj_ptrs=False,
- use_signed_tpos_enc_to_obj_ptrs=False,
- only_obj_ptrs_in_the_past_for_eval=False,
- pred_obj_scores: bool = False,
- pred_obj_scores_mlp: bool = False,
- fixed_no_obj_ptr: bool = False,
- soft_no_obj_ptr: bool = False,
- use_mlp_for_obj_ptr_proj: bool = False,
- no_obj_embed_spatial: bool = False,
- sam_mask_decoder_extra_args=None,
- compile_image_encoder: bool = False,
- ):
- """
- Initializes the SAM2Model for video object segmentation with memory-based tracking.
- Args:
- image_encoder (nn.Module): Visual encoder for extracting image features.
- memory_attention (nn.Module): Module for attending to memory features.
- memory_encoder (nn.Module): Encoder for generating memory representations.
- num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames).
- image_size (int): Size of input images.
- backbone_stride (int): Stride of the image backbone output.
- sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
- sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
- binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
- with clicks during evaluation.
- use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
- prompt encoder and mask decoder on frames with mask input.
- max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
- -1 means no limit.
- directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
- first frame.
- use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
- multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial
- conditioning frames.
- multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
- multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
- multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
- use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
- iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
- memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
- non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
- memory encoder during evaluation.
- use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
- max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
- cross-attention.
- add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
- the encoder.
- proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
- encoding in object pointers.
- use_signed_tpos_enc_to_obj_ptrs (bool): whether to use signed distance (instead of unsigned absolute distance)
- in the temporal positional encoding in the object pointers, only relevant when both `use_obj_ptrs_in_encoder=True`
- and `add_tpos_enc_to_obj_ptrs=True`.
- only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
- during evaluation.
- pred_obj_scores (bool): Whether to predict if there is an object in the frame.
- pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
- fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
- soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
- use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
- no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.
- sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder.
- compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
- Examples:
- >>> image_encoder = ImageEncoderViT(...)
- >>> memory_attention = SAM2TwoWayTransformer(...)
- >>> memory_encoder = nn.Sequential(...)
- >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
- >>> image_batch = torch.rand(1, 3, 512, 512)
- >>> features = model.forward_image(image_batch)
- >>> track_results = model.track_step(0, True, features, None, None, None, {})
- """
- super().__init__()
- # Part 1: the image backbone
- self.image_encoder = image_encoder
- # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
- self.use_high_res_features_in_sam = use_high_res_features_in_sam
- self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
- self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
- self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
- if use_obj_ptrs_in_encoder:
- # A conv layer to downsample the mask prompt to stride 4 (the same stride as
- # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
- # so that it can be fed into the SAM mask decoder to generate a pointer.
- self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
- self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
- if proj_tpos_enc_in_obj_ptrs:
- assert add_tpos_enc_to_obj_ptrs # these options need to be used together
- self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
- self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
- self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
- # Part 2: memory attention to condition current frame's visual features
- # with memories (and obj ptrs) from past frames
- self.memory_attention = memory_attention
- self.hidden_dim = memory_attention.d_model
- # Part 3: memory encoder for the previous frame's outputs
- self.memory_encoder = memory_encoder
- self.mem_dim = self.hidden_dim
- if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
- # if there is compression of memories along channel dim
- self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
- self.num_maskmem = num_maskmem # Number of memories accessible
- # Temporal encoding of the memories
- self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
- trunc_normal_(self.maskmem_tpos_enc, std=0.02)
- # a single token to indicate no memory embedding from previous frames
- self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
- self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
- trunc_normal_(self.no_mem_embed, std=0.02)
- trunc_normal_(self.no_mem_pos_enc, std=0.02)
- self.directly_add_no_mem_embed = directly_add_no_mem_embed
- # Apply sigmoid to the output raw mask logits (to turn them from
- # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
- self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
- self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
- self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
- self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
- self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
- # On frames with mask input, whether to directly output the input mask without
- # using a SAM prompt encoder + mask decoder
- self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
- self.multimask_output_in_sam = multimask_output_in_sam
- self.multimask_min_pt_num = multimask_min_pt_num
- self.multimask_max_pt_num = multimask_max_pt_num
- self.multimask_output_for_tracking = multimask_output_for_tracking
- self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
- self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
- # Part 4: SAM-style prompt encoder (for both mask and point inputs)
- # and SAM-style mask decoder for the final mask output
- self.image_size = image_size
- self.backbone_stride = backbone_stride
- self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
- self.pred_obj_scores = pred_obj_scores
- self.pred_obj_scores_mlp = pred_obj_scores_mlp
- self.fixed_no_obj_ptr = fixed_no_obj_ptr
- self.soft_no_obj_ptr = soft_no_obj_ptr
- if self.fixed_no_obj_ptr:
- assert self.pred_obj_scores
- assert self.use_obj_ptrs_in_encoder
- if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
- self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
- trunc_normal_(self.no_obj_ptr, std=0.02)
- self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
- self.no_obj_embed_spatial = None
- if no_obj_embed_spatial:
- self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
- trunc_normal_(self.no_obj_embed_spatial, std=0.02)
- self._build_sam_heads()
- self.max_cond_frames_in_attn = max_cond_frames_in_attn
- # Model compilation
- if compile_image_encoder:
- # Compile the forward function (not the full module) to allow loading checkpoints.
- print("Image encoder compilation is enabled. First forward pass will be slow.")
- self.image_encoder.forward = torch.compile(
- self.image_encoder.forward,
- mode="max-autotune",
- fullgraph=True,
- dynamic=False,
- )
- @property
- def device(self):
- """Returns the device on which the model's parameters are stored."""
- return next(self.parameters()).device
- def forward(self, *args, **kwargs):
- """Processes image and prompt inputs to generate object masks and scores in video sequences."""
- raise NotImplementedError(
- "Please use the corresponding methods in SAM2VideoPredictor for inference."
- "See notebooks/video_predictor_example.ipynb for an example."
- )
- def _build_sam_heads(self):
- """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
- self.sam_prompt_embed_dim = self.hidden_dim
- self.sam_image_embedding_size = self.image_size // self.backbone_stride
- # Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code)
- self.sam_prompt_encoder = PromptEncoder(
- embed_dim=self.sam_prompt_embed_dim,
- image_embedding_size=(
- self.sam_image_embedding_size,
- self.sam_image_embedding_size,
- ),
- input_image_size=(self.image_size, self.image_size),
- mask_in_chans=16,
- )
- self.sam_mask_decoder = SAM2MaskDecoder(
- num_multimask_outputs=3,
- transformer=SAM2TwoWayTransformer(
- depth=2,
- embedding_dim=self.sam_prompt_embed_dim,
- mlp_dim=2048,
- num_heads=8,
- ),
- transformer_dim=self.sam_prompt_embed_dim,
- iou_head_depth=3,
- iou_head_hidden_dim=256,
- use_high_res_features=self.use_high_res_features_in_sam,
- iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
- pred_obj_scores=self.pred_obj_scores,
- pred_obj_scores_mlp=self.pred_obj_scores_mlp,
- use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
- **(self.sam_mask_decoder_extra_args or {}),
- )
- if self.use_obj_ptrs_in_encoder:
- # a linear projection on SAM output tokens to turn them into object pointers
- self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
- if self.use_mlp_for_obj_ptr_proj:
- self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
- else:
- self.obj_ptr_proj = torch.nn.Identity()
- if self.proj_tpos_enc_in_obj_ptrs:
- # a linear projection on temporal positional encoding in object pointers to
- # avoid potential interference with spatial positional encoding
- self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
- else:
- self.obj_ptr_tpos_proj = torch.nn.Identity()
- def _forward_sam_heads(
- self,
- backbone_features,
- point_inputs=None,
- mask_inputs=None,
- high_res_features=None,
- multimask_output=False,
- ):
- """
- Forward pass through SAM prompt encoders and mask heads.
- This method processes image features and optional point/mask inputs to generate object masks and scores.
- Args:
- backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
- point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
- 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
- pixel-unit coordinates in (x, y) format for P input points.
- 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
- 0 means negative clicks, and -1 means padding.
- mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
- same spatial size as the image.
- high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
- (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
- for SAM decoder.
- multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
- output only 1 mask and its IoU estimate.
- Returns:
- (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
- low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
- high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
- ious: Tensor of shape (B, M) with estimated IoU for each output mask.
- low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
- high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
- obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
- object_score_logits: Tensor of shape (B) with object score logits.
- Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
- Examples:
- >>> backbone_features = torch.rand(1, 256, 32, 32)
- >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
- >>> mask_inputs = torch.rand(1, 1, 512, 512)
- >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
- >>> (
- ... low_res_multimasks,
- ... high_res_multimasks,
- ... ious,
- ... low_res_masks,
- ... high_res_masks,
- ... obj_ptr,
- ... object_score_logits,
- ... ) = results
- """
- B = backbone_features.size(0)
- device = backbone_features.device
- assert backbone_features.size(1) == self.sam_prompt_embed_dim
- assert backbone_features.size(2) == self.sam_image_embedding_size
- assert backbone_features.size(3) == self.sam_image_embedding_size
- # a) Handle point prompts
- if point_inputs is not None:
- sam_point_coords = point_inputs["point_coords"]
- sam_point_labels = point_inputs["point_labels"]
- assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
- else:
- # If no points are provide, pad with an empty point (with label -1)
- sam_point_coords = torch.zeros(B, 1, 2, device=device)
- sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
- # b) Handle mask prompts
- if mask_inputs is not None:
- # If mask_inputs is provided, downsize it into low-res mask input if needed
- # and feed it as a dense mask prompt into the SAM mask encoder
- assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
- if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
- sam_mask_prompt = F.interpolate(
- mask_inputs.float(),
- size=self.sam_prompt_encoder.mask_input_size,
- align_corners=False,
- mode="bilinear",
- antialias=True, # use antialias for downsampling
- )
- else:
- sam_mask_prompt = mask_inputs
- else:
- # Otherwise, simply feed None (and SAM's prompt encoder will add
- # a learned `no_mask_embed` to indicate no mask input in this case).
- sam_mask_prompt = None
- sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
- points=(sam_point_coords, sam_point_labels),
- boxes=None,
- masks=sam_mask_prompt,
- )
- low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder(
- image_embeddings=backbone_features,
- image_pe=self.sam_prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=multimask_output,
- repeat_image=False, # the image is already batched
- high_res_features=high_res_features,
- )
- if self.pred_obj_scores:
- is_obj_appearing = object_score_logits > 0
- # Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction
- low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE)
- # convert masks from possibly bfloat16 (or float16) to float32
- # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
- low_res_multimasks = low_res_multimasks.float()
- high_res_multimasks = F.interpolate(
- low_res_multimasks,
- size=(self.image_size, self.image_size),
- mode="bilinear",
- align_corners=False,
- )
- sam_output_token = sam_output_tokens[:, 0]
- if multimask_output:
- # take the best mask prediction (with the highest IoU estimation)
- best_iou_inds = torch.argmax(ious, dim=-1)
- batch_inds = torch.arange(B, device=device)
- low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
- high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
- if sam_output_tokens.size(1) > 1:
- sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
- else:
- low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
- # Extract object pointer from the SAM output token (with occlusion handling)
- obj_ptr = self.obj_ptr_proj(sam_output_token)
- if self.pred_obj_scores:
- # Allow *soft* no obj ptr, unlike for masks
- if self.soft_no_obj_ptr:
- lambda_is_obj_appearing = object_score_logits.sigmoid()
- else:
- lambda_is_obj_appearing = is_obj_appearing.float()
- if self.fixed_no_obj_ptr:
- obj_ptr = lambda_is_obj_appearing * obj_ptr
- obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
- return (
- low_res_multimasks,
- high_res_multimasks,
- ious,
- low_res_masks,
- high_res_masks,
- obj_ptr,
- object_score_logits,
- )
- def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
- """Processes mask inputs directly as output, bypassing SAM encoder/decoder."""
- # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
- out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
- mask_inputs_float = mask_inputs.float()
- high_res_masks = mask_inputs_float * out_scale + out_bias
- low_res_masks = F.interpolate(
- high_res_masks,
- size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
- align_corners=False,
- mode="bilinear",
- antialias=True, # use antialias for downsampling
- )
- # a dummy IoU prediction of all 1's under mask input
- ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
- if not self.use_obj_ptrs_in_encoder:
- # all zeros as a dummy object pointer (of shape [B, C])
- obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
- else:
- # produce an object pointer using the SAM decoder from the mask input
- _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
- backbone_features=backbone_features,
- mask_inputs=self.mask_downsample(mask_inputs_float),
- high_res_features=high_res_features,
- )
- # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
- # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
- # on the object_scores from the SAM decoder.
- is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
- is_obj_appearing = is_obj_appearing[..., None]
- lambda_is_obj_appearing = is_obj_appearing.float()
- object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
- if self.pred_obj_scores:
- if self.fixed_no_obj_ptr:
- obj_ptr = lambda_is_obj_appearing * obj_ptr
- obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
- return (
- low_res_masks,
- high_res_masks,
- ious,
- low_res_masks,
- high_res_masks,
- obj_ptr,
- object_score_logits,
- )
- def forward_image(self, img_batch: torch.Tensor):
- """Processes image batch through encoder to extract multi-level features for SAM model."""
- backbone_out = self.image_encoder(img_batch)
- if self.use_high_res_features_in_sam:
- # precompute projected level 0 and level 1 features in SAM decoder
- # to avoid running it again on every SAM click
- backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
- backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
- return backbone_out
- def _prepare_backbone_features(self, backbone_out):
- """Prepares and flattens visual features from the image backbone output for further processing."""
- assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
- assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
- feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
- vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
- feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
- # flatten NxCxHxW to HWxNxC
- vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
- vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
- return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
- def _prepare_memory_conditioned_features(
- self,
- frame_idx,
- is_init_cond_frame,
- current_vision_feats,
- current_vision_pos_embeds,
- feat_sizes,
- output_dict,
- num_frames,
- track_in_reverse=False, # tracking in reverse time order (for demo usage)
- ):
- """Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
- B = current_vision_feats[-1].size(1) # batch size on this frame
- C = self.hidden_dim
- H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
- device = current_vision_feats[-1].device
- # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
- # In this case, we skip the fusion with any memory.
- if self.num_maskmem == 0: # Disable memory and skip fusion
- return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
- num_obj_ptr_tokens = 0
- tpos_sign_mul = -1 if track_in_reverse else 1
- # Step 1: condition the visual features of the current frame on previous memories
- if not is_init_cond_frame:
- # Retrieve the memories encoded with the maskmem backbone
- to_cat_memory, to_cat_memory_pos_embed = [], []
- # Add conditioning frame's output first (all cond frames have t_pos=0 for
- # when getting temporal positional embedding below)
- assert len(output_dict["cond_frame_outputs"]) > 0
- # Select a maximum number of temporally closest cond frames for cross attention
- cond_outputs = output_dict["cond_frame_outputs"]
- selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
- frame_idx, cond_outputs, self.max_cond_frames_in_attn
- )
- t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
- # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
- # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
- # We also allow taking the memory frame non-consecutively (with r>1), in which case
- # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
- r = 1 if self.training else self.memory_temporal_stride_for_eval
- for t_pos in range(1, self.num_maskmem):
- t_rel = self.num_maskmem - t_pos # how many frames before current frame
- if t_rel == 1:
- # for t_rel == 1, we take the last frame (regardless of r)
- prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel
- elif not track_in_reverse:
- # first find the nearest frame among every r-th frames before this frame
- # for r=1, this would be (frame_idx - 2)
- prev_frame_idx = ((frame_idx - 2) // r) * r
- # then seek further among every r-th frames
- prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
- else:
- # first find the nearest frame among every r-th frames after this frame
- # for r=1, this would be (frame_idx + 2)
- prev_frame_idx = -(-(frame_idx + 2) // r) * r
- # then seek further among every r-th frames
- prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
- out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
- if out is None:
- # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
- # frames, we still attend to it as if it's a non-conditioning frame.
- out = unselected_cond_outputs.get(prev_frame_idx, None)
- t_pos_and_prevs.append((t_pos, out))
- for t_pos, prev in t_pos_and_prevs:
- if prev is None:
- continue # skip padding frames
- # "maskmem_features" might have been offloaded to CPU in demo use cases,
- # so we load it back to inference device (it's a no-op if it's already on device).
- feats = prev["maskmem_features"].to(device=device, non_blocking=True)
- to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
- # Spatial positional encoding (it might have been offloaded to CPU in eval)
- maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device)
- maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
- # Temporal positional encoding
- maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
- to_cat_memory_pos_embed.append(maskmem_enc)
- # Construct the list of past object pointers
- if self.use_obj_ptrs_in_encoder:
- max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
- # First add those object pointers from selected conditioning frames
- # (optionally, only include object pointers in the past during evaluation)
- if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
- ptr_cond_outputs = {
- t: out
- for t, out in selected_cond_outputs.items()
- if (t >= frame_idx if track_in_reverse else t <= frame_idx)
- }
- else:
- ptr_cond_outputs = selected_cond_outputs
- pos_and_ptrs = [
- # Temporal pos encoding contains how far away each pointer is from current frame
- (
- (
- (frame_idx - t) * tpos_sign_mul
- if self.use_signed_tpos_enc_to_obj_ptrs
- else abs(frame_idx - t)
- ),
- out["obj_ptr"],
- )
- for t, out in ptr_cond_outputs.items()
- ]
- # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
- for t_diff in range(1, max_obj_ptrs_in_encoder):
- t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
- if t < 0 or (num_frames is not None and t >= num_frames):
- break
- out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
- if out is not None:
- pos_and_ptrs.append((t_diff, out["obj_ptr"]))
- # If we have at least one object pointer, add them to the across attention
- if pos_and_ptrs:
- pos_list, ptrs_list = zip(*pos_and_ptrs)
- # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
- obj_ptrs = torch.stack(ptrs_list, dim=0)
- # a temporal positional embedding based on how far each object pointer is from
- # the current frame (sine embedding normalized by the max pointer num).
- if self.add_tpos_enc_to_obj_ptrs:
- t_diff_max = max_obj_ptrs_in_encoder - 1
- tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
- obj_pos = torch.tensor(pos_list, device=device)
- obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
- obj_pos = self.obj_ptr_tpos_proj(obj_pos)
- obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
- else:
- obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
- if self.mem_dim < C:
- # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
- obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
- obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
- obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
- to_cat_memory.append(obj_ptrs)
- to_cat_memory_pos_embed.append(obj_pos)
- num_obj_ptr_tokens = obj_ptrs.shape[0]
- else:
- num_obj_ptr_tokens = 0
- else:
- # for initial conditioning frames, encode them without using any previous memory
- if self.directly_add_no_mem_embed:
- # directly add no-mem embedding (instead of using the transformer encoder)
- pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
- pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
- return pix_feat_with_mem
- # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
- to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
- to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
- # Step 2: Concatenate the memories and forward through the transformer encoder
- memory = torch.cat(to_cat_memory, dim=0)
- memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
- pix_feat_with_mem = self.memory_attention(
- curr=current_vision_feats,
- curr_pos=current_vision_pos_embeds,
- memory=memory,
- memory_pos=memory_pos_embed,
- num_obj_ptr_tokens=num_obj_ptr_tokens,
- )
- # reshape the output (HW)BC => BCHW
- pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
- return pix_feat_with_mem
- def _encode_new_memory(
- self,
- current_vision_feats,
- feat_sizes,
- pred_masks_high_res,
- object_score_logits,
- is_mask_from_pts,
- ):
- """Encodes frame features and masks into a new memory representation for video segmentation."""
- B = current_vision_feats[-1].size(1) # batch size on this frame
- C = self.hidden_dim
- H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
- # top-level feature, (HW)BC => BCHW
- pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
- if self.non_overlap_masks_for_mem_enc and not self.training:
- # optionally, apply non-overlapping constraints to the masks (it's applied
- # in the batch dimension and should only be used during eval, where all
- # the objects come from the same video under batch size 1).
- pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
- # scale the raw mask logits with a temperature before applying sigmoid
- binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
- if binarize and not self.training:
- mask_for_mem = (pred_masks_high_res > 0).float()
- else:
- # apply sigmoid on the raw mask logits to turn them into range (0, 1)
- mask_for_mem = torch.sigmoid(pred_masks_high_res)
- # apply scale and bias terms to the sigmoid probabilities
- if self.sigmoid_scale_for_mem_enc != 1.0:
- mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
- if self.sigmoid_bias_for_mem_enc != 0.0:
- mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
- maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
- maskmem_features = maskmem_out["vision_features"]
- maskmem_pos_enc = maskmem_out["vision_pos_enc"]
- # add a no-object embedding to the spatial memory to indicate that the frame
- # is predicted to be occluded (i.e. no object is appearing in the frame)
- if self.no_obj_embed_spatial is not None:
- is_obj_appearing = (object_score_logits > 0).float()
- maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[
- ..., None, None
- ].expand(*maskmem_features.shape)
- return maskmem_features, maskmem_pos_enc
- def _track_step(
- self,
- frame_idx,
- is_init_cond_frame,
- current_vision_feats,
- current_vision_pos_embeds,
- feat_sizes,
- point_inputs,
- mask_inputs,
- output_dict,
- num_frames,
- track_in_reverse,
- prev_sam_mask_logits,
- ):
- """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
- current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
- # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
- if len(current_vision_feats) > 1:
- high_res_features = [
- x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
- for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
- ]
- else:
- high_res_features = None
- if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
- # When use_mask_input_as_output_without_sam=True, we directly output the mask input
- # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
- pix_feat = current_vision_feats[-1].permute(1, 2, 0)
- pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
- sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
- else:
- # fused the visual feature with previous memory features in the memory bank
- pix_feat = self._prepare_memory_conditioned_features(
- frame_idx=frame_idx,
- is_init_cond_frame=is_init_cond_frame,
- current_vision_feats=current_vision_feats[-1:],
- current_vision_pos_embeds=current_vision_pos_embeds[-1:],
- feat_sizes=feat_sizes[-1:],
- output_dict=output_dict,
- num_frames=num_frames,
- track_in_reverse=track_in_reverse,
- )
- # apply SAM-style segmentation head
- # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
- # e.g. in demo where such logits come from earlier interaction instead of correction sampling
- # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
- if prev_sam_mask_logits is not None:
- assert point_inputs is not None and mask_inputs is None
- mask_inputs = prev_sam_mask_logits
- multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
- sam_outputs = self._forward_sam_heads(
- backbone_features=pix_feat,
- point_inputs=point_inputs,
- mask_inputs=mask_inputs,
- high_res_features=high_res_features,
- multimask_output=multimask_output,
- )
- return current_out, sam_outputs, high_res_features, pix_feat
- def _encode_memory_in_output(
- self,
- current_vision_feats,
- feat_sizes,
- point_inputs,
- run_mem_encoder,
- high_res_masks,
- object_score_logits,
- current_out,
- ):
- """Finally run the memory encoder on the predicted mask to encode, it into a new memory feature (that can be
- used in future frames).
- """
- if run_mem_encoder and self.num_maskmem > 0:
- high_res_masks_for_mem_enc = high_res_masks
- maskmem_features, maskmem_pos_enc = self._encode_new_memory(
- current_vision_feats=current_vision_feats,
- feat_sizes=feat_sizes,
- pred_masks_high_res=high_res_masks_for_mem_enc,
- object_score_logits=object_score_logits,
- is_mask_from_pts=(point_inputs is not None),
- )
- current_out["maskmem_features"] = maskmem_features
- current_out["maskmem_pos_enc"] = maskmem_pos_enc
- else:
- current_out["maskmem_features"] = None
- current_out["maskmem_pos_enc"] = None
- def track_step(
- self,
- frame_idx,
- is_init_cond_frame,
- current_vision_feats,
- current_vision_pos_embeds,
- feat_sizes,
- point_inputs,
- mask_inputs,
- output_dict,
- num_frames,
- track_in_reverse=False, # tracking in reverse time order (for demo usage)
- # Whether to run the memory encoder on the predicted masks. Sometimes we might want
- # to skip the memory encoder with `run_mem_encoder=False`. For example,
- # in demo we might call `track_step` multiple times for each user click,
- # and only encode the memory when the user finalizes their clicks. And in ablation
- # settings like SAM training on static images, we don't need the memory encoder.
- run_mem_encoder=True,
- # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
- prev_sam_mask_logits=None,
- ):
- """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
- current_out, sam_outputs, _, _ = self._track_step(
- frame_idx,
- is_init_cond_frame,
- current_vision_feats,
- current_vision_pos_embeds,
- feat_sizes,
- point_inputs,
- mask_inputs,
- output_dict,
- num_frames,
- track_in_reverse,
- prev_sam_mask_logits,
- )
- _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs
- current_out["pred_masks"] = low_res_masks
- current_out["pred_masks_high_res"] = high_res_masks
- current_out["obj_ptr"] = obj_ptr
- if not self.training:
- # Only add this in inference (to avoid unused param in activation checkpointing;
- # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
- current_out["object_score_logits"] = object_score_logits
- # Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)
- self._encode_memory_in_output(
- current_vision_feats,
- feat_sizes,
- point_inputs,
- run_mem_encoder,
- high_res_masks,
- object_score_logits,
- current_out,
- )
- return current_out
- def _use_multimask(self, is_init_cond_frame, point_inputs):
- """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
- num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
- return (
- self.multimask_output_in_sam
- and (is_init_cond_frame or self.multimask_output_for_tracking)
- and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
- )
- @staticmethod
- def _apply_non_overlapping_constraints(pred_masks):
- """Applies non-overlapping constraints to masks, keeping the highest scoring object per location."""
- batch_size = pred_masks.size(0)
- if batch_size == 1:
- return pred_masks
- device = pred_masks.device
- # "max_obj_inds": object index of the object with the highest score at each location
- max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
- # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
- batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
- keep = max_obj_inds == batch_obj_inds
- # suppress overlapping regions' scores below -10.0 so that the foreground regions
- # don't overlap (here sigmoid(-10.0)=4.5398e-05)
- pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
- return pred_masks
- def set_binarize(self, binarize=False):
- """Set binarize for VideoPredictor."""
- self.binarize_mask_from_pts_for_mem_enc = binarize
- def set_imgsz(self, imgsz):
- """
- Set image size to make model compatible with different image sizes.
- Args:
- imgsz (Tuple[int, int]): The size of the input image.
- """
- self.image_size = imgsz[0]
- self.sam_prompt_encoder.input_image_size = imgsz
- self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
|