123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605 |
- # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
- """
- Generate predictions using the Segment Anything Model (SAM).
- SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance.
- This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation
- using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image
- segmentation tasks.
- """
- from collections import OrderedDict
- import numpy as np
- import torch
- import torch.nn.functional as F
- from ultralytics.data.augment import LetterBox
- from ultralytics.engine.predictor import BasePredictor
- from ultralytics.engine.results import Results
- from ultralytics.utils import DEFAULT_CFG, ops
- from ultralytics.utils.torch_utils import select_device, smart_inference_mode
- from .amg import (
- batch_iterator,
- batched_mask_to_box,
- build_all_layer_point_grids,
- calculate_stability_score,
- generate_crop_boxes,
- is_box_near_crop_edge,
- remove_small_regions,
- uncrop_boxes_xyxy,
- uncrop_masks,
- )
- from .build import build_sam
- class Predictor(BasePredictor):
- """
- Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
- This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image
- segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for
- fine-grained control over segmentation results.
- Attributes:
- args (SimpleNamespace): Configuration arguments for the predictor.
- model (torch.nn.Module): The loaded SAM model.
- device (torch.device): The device (CPU or GPU) on which the model is loaded.
- im (torch.Tensor): The preprocessed input image.
- features (torch.Tensor): Extracted image features.
- prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
- segment_all (bool): Flag to indicate if full image segmentation should be performed.
- mean (torch.Tensor): Mean values for image normalization.
- std (torch.Tensor): Standard deviation values for image normalization.
- Methods:
- preprocess: Prepares input images for model inference.
- pre_transform: Performs initial transformations on the input image.
- inference: Performs segmentation inference based on input prompts.
- prompt_inference: Internal function for prompt-based segmentation inference.
- generate: Generates segmentation masks for an entire image.
- setup_model: Initializes the SAM model for inference.
- get_model: Builds and returns a SAM model.
- postprocess: Post-processes model outputs to generate final results.
- setup_source: Sets up the data source for inference.
- set_image: Sets and preprocesses a single image for inference.
- get_im_features: Extracts image features using the SAM image encoder.
- set_prompts: Sets prompts for subsequent inference.
- reset_image: Resets the current image and its features.
- remove_small_regions: Removes small disconnected regions and holes from masks.
- Examples:
- >>> predictor = Predictor()
- >>> predictor.setup_model(model_path="sam_model.pt")
- >>> predictor.set_image("image.jpg")
- >>> bboxes = [[100, 100, 200, 200]]
- >>> results = predictor(bboxes=bboxes)
- """
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
- """
- Initialize the Predictor with configuration, overrides, and callbacks.
- Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or
- callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True
- for optimal results.
- Args:
- cfg (Dict): Configuration dictionary containing default settings.
- overrides (Dict | None): Dictionary of values to override default configuration.
- _callbacks (Dict | None): Dictionary of callback functions to customize behavior.
- Examples:
- >>> predictor_example = Predictor(cfg=DEFAULT_CFG)
- >>> predictor_example_with_imgsz = Predictor(overrides={"imgsz": 640})
- >>> predictor_example_with_callback = Predictor(_callbacks={"on_predict_start": custom_callback})
- """
- if overrides is None:
- overrides = {}
- overrides.update(dict(task="segment", mode="predict", batch=1))
- super().__init__(cfg, overrides, _callbacks)
- self.args.retina_masks = True
- self.im = None
- self.features = None
- self.prompts = {}
- self.segment_all = False
- def preprocess(self, im):
- """
- Preprocess the input image for model inference.
- This method prepares the input image by applying transformations and normalization. It supports both
- torch.Tensor and list of np.ndarray as input formats.
- Args:
- im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
- Returns:
- im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
- Examples:
- >>> predictor = Predictor()
- >>> image = torch.rand(1, 3, 640, 640)
- >>> preprocessed_image = predictor.preprocess(image)
- """
- if self.im is not None:
- return self.im
- not_tensor = not isinstance(im, torch.Tensor)
- if not_tensor:
- im = np.stack(self.pre_transform(im))
- im = im[..., ::-1].transpose((0, 3, 1, 2))
- im = np.ascontiguousarray(im)
- im = torch.from_numpy(im)
- im = im.to(self.device)
- im = im.half() if self.model.fp16 else im.float()
- if not_tensor:
- im = (im - self.mean) / self.std
- return im
- def pre_transform(self, im):
- """
- Perform initial transformations on the input image for preprocessing.
- This method applies transformations such as resizing to prepare the image for further preprocessing.
- Currently, batched inference is not supported; hence the list length should be 1.
- Args:
- im (List[np.ndarray]): List containing a single image in HWC numpy array format.
- Returns:
- (List[np.ndarray]): List containing the transformed image.
- Raises:
- AssertionError: If the input list contains more than one image.
- Examples:
- >>> predictor = Predictor()
- >>> image = np.random.rand(480, 640, 3) # Single HWC image
- >>> transformed = predictor.pre_transform([image])
- >>> print(len(transformed))
- 1
- """
- assert len(im) == 1, "SAM model does not currently support batched inference"
- letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
- return [letterbox(image=x) for x in im]
- def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
- """
- Perform image segmentation inference based on the given input cues, using the currently loaded image.
- This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
- encoder, and mask decoder for real-time and promptable segmentation tasks.
- Args:
- im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
- bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format.
- points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
- labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.
- masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256.
- multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts.
- *args (Any): Additional positional arguments.
- **kwargs (Any): Additional keyword arguments.
- Returns:
- (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
- (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
- (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
- Examples:
- >>> predictor = Predictor()
- >>> predictor.setup_model(model_path="sam_model.pt")
- >>> predictor.set_image("image.jpg")
- >>> results = predictor(bboxes=[[0, 0, 100, 100]])
- """
- # Override prompts if any stored in self.prompts
- bboxes = self.prompts.pop("bboxes", bboxes)
- points = self.prompts.pop("points", points)
- masks = self.prompts.pop("masks", masks)
- labels = self.prompts.pop("labels", labels)
- if all(i is None for i in [bboxes, points, masks]):
- return self.generate(im, *args, **kwargs)
- return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
- def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
- """
- Performs image segmentation inference based on input cues using SAM's specialized architecture.
- This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.
- It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
- Args:
- im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
- bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
- points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
- labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
- masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
- multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
- Raises:
- AssertionError: If the number of points don't match the number of labels, in case labels were passed.
- Returns:
- (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
- (np.ndarray): Quality scores predicted by the model for each mask, with length C.
- Examples:
- >>> predictor = Predictor()
- >>> im = torch.rand(1, 3, 1024, 1024)
- >>> bboxes = [[100, 100, 200, 200]]
- >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes)
- """
- features = self.get_im_features(im) if self.features is None else self.features
- bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
- points = (points, labels) if points is not None else None
- # Embed prompts
- sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
- # Predict masks
- pred_masks, pred_scores = self.model.mask_decoder(
- image_embeddings=features,
- image_pe=self.model.prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=multimask_output,
- )
- # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
- # `d` could be 1 or 3 depends on `multimask_output`.
- return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
- def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
- """
- Prepares and transforms the input prompts for processing based on the destination shape.
- Args:
- dst_shape (tuple): The target shape (height, width) for the prompts.
- bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
- points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
- labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
- masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
- Raises:
- AssertionError: If the number of points don't match the number of labels, in case labels were passed.
- Returns:
- (tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
- """
- src_shape = self.batch[1][0].shape[:2]
- r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
- # Transform input prompts
- if points is not None:
- points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
- points = points[None] if points.ndim == 1 else points
- # Assuming labels are all positive if users don't pass labels.
- if labels is None:
- labels = np.ones(points.shape[:-1])
- labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
- assert points.shape[-2] == labels.shape[-1], (
- f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}."
- )
- points *= r
- if points.ndim == 2:
- # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
- points, labels = points[:, None, :], labels[:, None]
- if bboxes is not None:
- bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
- bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
- bboxes *= r
- if masks is not None:
- masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
- return bboxes, points, labels, masks
- def generate(
- self,
- im,
- crop_n_layers=0,
- crop_overlap_ratio=512 / 1500,
- crop_downscale_factor=1,
- point_grids=None,
- points_stride=32,
- points_batch_size=64,
- conf_thres=0.88,
- stability_score_thresh=0.95,
- stability_score_offset=0.95,
- crop_nms_thresh=0.7,
- ):
- """
- Perform image segmentation using the Segment Anything Model (SAM).
- This method segments an entire image into constituent parts by leveraging SAM's advanced architecture
- and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
- Args:
- im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W).
- crop_n_layers (int): Number of layers for additional mask predictions on image crops.
- crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers.
- crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer.
- point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].
- points_stride (int): Number of points to sample along each side of the image.
- points_batch_size (int): Batch size for the number of points processed simultaneously.
- conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction.
- stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability.
- stability_score_offset (float): Offset value for calculating stability score.
- crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
- Returns:
- pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
- pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
- pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
- Examples:
- >>> predictor = Predictor()
- >>> im = torch.rand(1, 3, 1024, 1024) # Example input image
- >>> masks, scores, boxes = predictor.generate(im)
- """
- import torchvision # scope for faster 'import ultralytics'
- self.segment_all = True
- ih, iw = im.shape[2:]
- crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
- if point_grids is None:
- point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor)
- pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
- for crop_region, layer_idx in zip(crop_regions, layer_idxs):
- x1, y1, x2, y2 = crop_region
- w, h = x2 - x1, y2 - y1
- area = torch.tensor(w * h, device=im.device)
- points_scale = np.array([[w, h]]) # w, h
- # Crop image and interpolate to input size
- crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False)
- # (num_points, 2)
- points_for_image = point_grids[layer_idx] * points_scale
- crop_masks, crop_scores, crop_bboxes = [], [], []
- for (points,) in batch_iterator(points_batch_size, points_for_image):
- pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
- # Interpolate predicted masks to input size
- pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0]
- idx = pred_score > conf_thres
- pred_mask, pred_score = pred_mask[idx], pred_score[idx]
- stability_score = calculate_stability_score(
- pred_mask, self.model.mask_threshold, stability_score_offset
- )
- idx = stability_score > stability_score_thresh
- pred_mask, pred_score = pred_mask[idx], pred_score[idx]
- # Bool type is much more memory-efficient.
- pred_mask = pred_mask > self.model.mask_threshold
- # (N, 4)
- pred_bbox = batched_mask_to_box(pred_mask).float()
- keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])
- if not torch.all(keep_mask):
- pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask]
- crop_masks.append(pred_mask)
- crop_bboxes.append(pred_bbox)
- crop_scores.append(pred_score)
- # Do nms within this crop
- crop_masks = torch.cat(crop_masks)
- crop_bboxes = torch.cat(crop_bboxes)
- crop_scores = torch.cat(crop_scores)
- keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS
- crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)
- crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)
- crop_scores = crop_scores[keep]
- pred_masks.append(crop_masks)
- pred_bboxes.append(crop_bboxes)
- pred_scores.append(crop_scores)
- region_areas.append(area.expand(len(crop_masks)))
- pred_masks = torch.cat(pred_masks)
- pred_bboxes = torch.cat(pred_bboxes)
- pred_scores = torch.cat(pred_scores)
- region_areas = torch.cat(region_areas)
- # Remove duplicate masks between crops
- if len(crop_regions) > 1:
- scores = 1 / region_areas
- keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)
- pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep]
- return pred_masks, pred_scores, pred_bboxes
- def setup_model(self, model=None, verbose=True):
- """
- Initializes the Segment Anything Model (SAM) for inference.
- This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
- parameters for image normalization and other Ultralytics compatibility settings.
- Args:
- model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config.
- verbose (bool): If True, prints selected device information.
- Examples:
- >>> predictor = Predictor()
- >>> predictor.setup_model(model=sam_model, verbose=True)
- """
- device = select_device(self.args.device, verbose=verbose)
- if model is None:
- model = self.get_model()
- model.eval()
- self.model = model.to(device)
- self.device = device
- self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
- self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
- # Ultralytics compatibility settings
- self.model.pt = False
- self.model.triton = False
- self.model.stride = 32
- self.model.fp16 = False
- self.done_warmup = True
- def get_model(self):
- """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks."""
- return build_sam(self.args.model)
- def postprocess(self, preds, img, orig_imgs):
- """
- Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
- This method scales masks and boxes to the original image size and applies a threshold to the mask
- predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
- Args:
- preds (Tuple[torch.Tensor]): The output from SAM model inference, containing:
- - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).
- - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).
- - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.
- img (torch.Tensor): The processed input image tensor with shape (C, H, W).
- orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
- Returns:
- results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
- metadata for each processed image.
- Examples:
- >>> predictor = Predictor()
- >>> preds = predictor.inference(img)
- >>> results = predictor.postprocess(preds, img, orig_imgs)
- """
- # (N, 1, H, W), (N, 1)
- pred_masks, pred_scores = preds[:2]
- pred_bboxes = preds[2] if self.segment_all else None
- names = dict(enumerate(str(i) for i in range(len(pred_masks))))
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
- results = []
- for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]):
- if len(masks) == 0:
- masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device)
- else:
- masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
- masks = masks > self.model.mask_threshold # to bool
- if pred_bboxes is not None:
- pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
- else:
- pred_bboxes = batched_mask_to_box(masks)
- # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
- cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
- pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
- results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
- # Reset segment-all mode.
- self.segment_all = False
- return results
- def setup_source(self, source):
- """
- Sets up the data source for inference.
- This method configures the data source from which images will be fetched for inference. It supports
- various input types such as image files, directories, video files, and other compatible data sources.
- Args:
- source (str | Path | None): The path or identifier for the image data source. Can be a file path,
- directory path, URL, or other supported source types.
- Examples:
- >>> predictor = Predictor()
- >>> predictor.setup_source("path/to/images")
- >>> predictor.setup_source("video.mp4")
- >>> predictor.setup_source(None) # Uses default source if available
- Notes:
- - If source is None, the method may use a default source if configured.
- - The method adapts to different source types and prepares them for subsequent inference steps.
- - Supported source types may include local files, directories, URLs, and video streams.
- """
- if source is not None:
- super().setup_source(source)
- def set_image(self, image):
- """
- Preprocesses and sets a single image for inference.
- This method prepares the model for inference on a single image by setting up the model if not already
- initialized, configuring the data source, and preprocessing the image for feature extraction. It
- ensures that only one image is set at a time and extracts image features for subsequent use.
- Args:
- image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
- an image read by cv2.
- Raises:
- AssertionError: If more than one image is attempted to be set.
- Examples:
- >>> predictor = Predictor()
- >>> predictor.set_image("path/to/image.jpg")
- >>> predictor.set_image(cv2.imread("path/to/image.jpg"))
- Notes:
- - This method should be called before performing inference on a new image.
- - The extracted features are stored in the `self.features` attribute for later use.
- """
- if self.model is None:
- self.setup_model(model=None)
- self.setup_source(image)
- assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
- for batch in self.dataset:
- im = self.preprocess(batch[1])
- self.features = self.get_im_features(im)
- break
- def get_im_features(self, im):
- """Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
- assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
- f"SAM models only support square image size, but got {self.imgsz}."
- )
- self.model.set_imgsz(self.imgsz)
- return self.model.image_encoder(im)
- def set_prompts(self, prompts):
- """Sets prompts for subsequent inference operations."""
- self.prompts = prompts
- def reset_image(self):
- """Resets the current image and its features, clearing them for subsequent inference."""
- self.im = None
- self.features = None
- @staticmethod
- def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
- """
- Remove small disconnected regions and holes from segmentation masks.
- This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM).
- It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
- Suppression (NMS) to eliminate any newly created duplicate boxes.
- Args:
- masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of
- masks, H is height, and W is width.
- min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than
- this will be removed.
- nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
- Returns:
- new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
- keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
- Examples:
- >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
- >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7)
- >>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}")
- >>> print(f"Indices of kept masks: {keep}")
- """
- import torchvision # scope for faster 'import ultralytics'
- if len(masks) == 0:
- return masks
- # Filter small disconnected regions and holes
- new_masks = []
- scores = []
- for mask in masks:
- mask = mask.cpu().numpy().astype(np.uint8)
- mask, changed = remove_small_regions(mask, min_area, mode="holes")
- unchanged = not changed
- mask, changed = remove_small_regions(mask, min_area, mode="islands")
- unchanged = unchanged and not changed
- new_masks.append(torch.as_tensor(mask).unsqueeze(0))
- # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing
- scores.append(float(unchanged))
- # Recalculate boxes and remove any new duplicates
- new_masks = torch.cat(new_masks, dim=0)
- boxes = batched_mask_to_box(new_masks)
- keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
- return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
- class SAM2Predictor(Predictor):
- """
- SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
- This class extends the base Predictor class to implement SAM2-specific functionality for image
- segmentation tasks. It provides methods for model initialization, feature extraction, and
- prompt-based inference.
- Attributes:
- _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels.
- model (torch.nn.Module): The loaded SAM2 model.
- device (torch.device): The device (CPU or GPU) on which the model is loaded.
- features (Dict[str, torch.Tensor]): Cached image features for efficient inference.
- segment_all (bool): Flag to indicate if all segments should be predicted.
- prompts (Dict): Dictionary to store various types of prompts for inference.
- Methods:
- get_model: Retrieves and initializes the SAM2 model.
- prompt_inference: Performs image segmentation inference based on various prompts.
- set_image: Preprocesses and sets a single image for inference.
- get_im_features: Extracts and processes image features using SAM2's image encoder.
- Examples:
- >>> predictor = SAM2Predictor(cfg)
- >>> predictor.set_image("path/to/image.jpg")
- >>> bboxes = [[100, 100, 200, 200]]
- >>> result = predictor(bboxes=bboxes)[0]
- >>> print(f"Predicted {len(result.masks)} masks with average score {result.boxes.conf.mean():.2f}")
- """
- _bb_feat_sizes = [
- (256, 256),
- (128, 128),
- (64, 64),
- ]
- def get_model(self):
- """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
- return build_sam(self.args.model)
- def prompt_inference(
- self,
- im,
- bboxes=None,
- points=None,
- labels=None,
- masks=None,
- multimask_output=False,
- img_idx=-1,
- ):
- """
- Performs image segmentation inference based on various prompts using SAM2 architecture.
- This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
- based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
- multi-object prediction scenarios.
- Args:
- im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
- bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
- points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
- labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
- masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
- multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
- img_idx (int): Index of the image in the batch to process.
- Returns:
- (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
- (np.ndarray): Quality scores for each mask, with length C.
- Examples:
- >>> predictor = SAM2Predictor(cfg)
- >>> image = torch.rand(1, 3, 640, 640)
- >>> bboxes = [[100, 100, 200, 200]]
- >>> result = predictor(image, bboxes=bboxes)[0]
- >>> print(f"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}")
- Notes:
- - The method supports batched inference for multiple objects when points or bboxes are provided.
- - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
- - When both bboxes and points are provided, they are merged into a single 'points' input for the model.
- References:
- - SAM2 Paper: [Add link to SAM2 paper when available]
- """
- features = self.get_im_features(im) if self.features is None else self.features
- points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
- points = (points, labels) if points is not None else None
- sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
- points=points,
- boxes=None,
- masks=masks,
- )
- # Predict masks
- batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
- high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
- pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
- image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
- image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=multimask_output,
- repeat_image=batched_mode,
- high_res_features=high_res_features,
- )
- # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
- # `d` could be 1 or 3 depends on `multimask_output`.
- return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
- def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
- """
- Prepares and transforms the input prompts for processing based on the destination shape.
- Args:
- dst_shape (tuple): The target shape (height, width) for the prompts.
- bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
- points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
- labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
- masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
- Raises:
- AssertionError: If the number of points don't match the number of labels, in case labels were passed.
- Returns:
- (tuple): A tuple containing transformed points, labels, and masks.
- """
- bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
- if bboxes is not None:
- bboxes = bboxes.view(-1, 2, 2)
- bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
- # NOTE: merge "boxes" and "points" into a single "points" input
- # (where boxes are added at the beginning) to model.sam_prompt_encoder
- if points is not None:
- points = torch.cat([bboxes, points], dim=1)
- labels = torch.cat([bbox_labels, labels], dim=1)
- else:
- points, labels = bboxes, bbox_labels
- return points, labels, masks
- def set_image(self, image):
- """
- Preprocesses and sets a single image for inference using the SAM2 model.
- This method initializes the model if not already done, configures the data source to the specified image,
- and preprocesses the image for feature extraction. It supports setting only one image at a time.
- Args:
- image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
- Raises:
- AssertionError: If more than one image is attempted to be set.
- Examples:
- >>> predictor = SAM2Predictor()
- >>> predictor.set_image("path/to/image.jpg")
- >>> predictor.set_image(np.array([...])) # Using a numpy array
- Notes:
- - This method must be called before performing any inference on a new image.
- - The method caches the extracted features for efficient subsequent inferences on the same image.
- - Only one image can be set at a time. To process multiple images, call this method for each new image.
- """
- if self.model is None:
- self.setup_model(model=None)
- self.setup_source(image)
- assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
- for batch in self.dataset:
- im = self.preprocess(batch[1])
- self.features = self.get_im_features(im)
- break
- def get_im_features(self, im):
- """Extracts image features from the SAM image encoder for subsequent processing."""
- assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
- f"SAM 2 models only support square image size, but got {self.imgsz}."
- )
- self.model.set_imgsz(self.imgsz)
- self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]]
- backbone_out = self.model.forward_image(im)
- _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
- if self.model.directly_add_no_mem_embed:
- vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
- feats = [
- feat.permute(1, 2, 0).view(1, -1, *feat_size)
- for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
- ][::-1]
- return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
- class SAM2VideoPredictor(SAM2Predictor):
- """
- SAM2VideoPredictor to handle user interactions with videos and manage inference states.
- This class extends the functionality of SAM2Predictor to support video processing and maintains
- the state of inference operations. It includes configurations for managing non-overlapping masks,
- clearing memory for non-conditional inputs, and setting up callbacks for prediction events.
- Attributes:
- inference_state (Dict): A dictionary to store the current state of inference operations.
- non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.
- clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.
- clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.
- callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events.
- Args:
- cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG.
- overrides (Dict, Optional): Additional configuration overrides. Defaults to None.
- _callbacks (List, Optional): Custom callbacks to be added. Defaults to None.
- Note:
- The `fill_hole_area` attribute is defined but not used in the current implementation.
- """
- # fill_hole_area = 8 # not used
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
- """
- Initialize the predictor with configuration and optional overrides.
- This constructor initializes the SAM2VideoPredictor with a given configuration, applies any
- specified overrides, and sets up the inference state along with certain flags
- that control the behavior of the predictor.
- Args:
- cfg (Dict): Configuration dictionary containing default settings.
- overrides (Dict | None): Dictionary of values to override default configuration.
- _callbacks (Dict | None): Dictionary of callback functions to customize behavior.
- Examples:
- >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
- >>> predictor_example_with_imgsz = SAM2VideoPredictor(overrides={"imgsz": 640})
- >>> predictor_example_with_callback = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback})
- """
- super().__init__(cfg, overrides, _callbacks)
- self.inference_state = {}
- self.non_overlap_masks = True
- self.clear_non_cond_mem_around_input = False
- self.clear_non_cond_mem_for_multi_obj = False
- self.callbacks["on_predict_start"].append(self.init_state)
- def get_model(self):
- """
- Retrieves and configures the model with binarization enabled.
- Note:
- This method overrides the base class implementation to set the binarize flag to True.
- """
- model = super().get_model()
- model.set_binarize(True)
- return model
- def inference(self, im, bboxes=None, points=None, labels=None, masks=None):
- """
- Perform image segmentation inference based on the given input cues, using the currently loaded image. This
- method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
- mask decoder for real-time and promptable segmentation tasks.
- Args:
- im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
- bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
- points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
- labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
- masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
- Returns:
- (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
- (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
- """
- # Override prompts if any stored in self.prompts
- bboxes = self.prompts.pop("bboxes", bboxes)
- points = self.prompts.pop("points", points)
- masks = self.prompts.pop("masks", masks)
- frame = self.dataset.frame
- self.inference_state["im"] = im
- output_dict = self.inference_state["output_dict"]
- if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts
- points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
- if points is not None:
- for i in range(len(points)):
- self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame)
- elif masks is not None:
- for i in range(len(masks)):
- self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame)
- self.propagate_in_video_preflight()
- consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
- batch_size = len(self.inference_state["obj_idx_to_id"])
- if len(output_dict["cond_frame_outputs"]) == 0:
- raise RuntimeError("No points are provided; please add points first")
- if frame in consolidated_frame_inds["cond_frame_outputs"]:
- storage_key = "cond_frame_outputs"
- current_out = output_dict[storage_key][frame]
- if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
- # clear non-conditioning memory of the surrounding frames
- self._clear_non_cond_mem_around_input(frame)
- elif frame in consolidated_frame_inds["non_cond_frame_outputs"]:
- storage_key = "non_cond_frame_outputs"
- current_out = output_dict[storage_key][frame]
- else:
- storage_key = "non_cond_frame_outputs"
- current_out = self._run_single_frame_inference(
- output_dict=output_dict,
- frame_idx=frame,
- batch_size=batch_size,
- is_init_cond_frame=False,
- point_inputs=None,
- mask_inputs=None,
- reverse=False,
- run_mem_encoder=True,
- )
- output_dict[storage_key][frame] = current_out
- # Create slices of per-object outputs for subsequent interaction with each
- # individual object after tracking.
- self._add_output_per_object(frame, current_out, storage_key)
- self.inference_state["frames_already_tracked"].append(frame)
- pred_masks = current_out["pred_masks"].flatten(0, 1)
- pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks
- return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device)
- def postprocess(self, preds, img, orig_imgs):
- """
- Post-processes the predictions to apply non-overlapping constraints if required.
- This method extends the post-processing functionality by applying non-overlapping constraints
- to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that
- the masks do not overlap, which can be useful for certain applications.
- Args:
- preds (Tuple[torch.Tensor]): The predictions from the model.
- img (torch.Tensor): The processed image tensor.
- orig_imgs (List[np.ndarray]): The original images before processing.
- Returns:
- results (list): The post-processed predictions.
- Note:
- If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.
- """
- results = super().postprocess(preds, img, orig_imgs)
- if self.non_overlap_masks:
- for result in results:
- if result.masks is None or len(result.masks) == 0:
- continue
- result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0]
- return results
- @smart_inference_mode()
- def add_new_prompts(
- self,
- obj_id,
- points=None,
- labels=None,
- masks=None,
- frame_idx=0,
- ):
- """
- Adds new points or masks to a specific frame for a given object ID.
- This method updates the inference state with new prompts (points or masks) for a specified
- object and frame index. It ensures that the prompts are either points or masks, but not both,
- and updates the internal state accordingly. It also handles the generation of new segmentations
- based on the provided prompts and the existing state.
- Args:
- obj_id (int): The ID of the object to which the prompts are associated.
- points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None.
- labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None.
- masks (torch.Tensor, optional): Binary masks for the object. Defaults to None.
- frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0.
- Returns:
- (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects.
- Raises:
- AssertionError: If both `masks` and `points` are provided, or neither is provided.
- Note:
- - Only one type of prompt (either points or masks) can be added per call.
- - If the frame is being tracked for the first time, it is treated as an initial conditioning frame.
- - The method handles the consolidation of outputs and resizing of masks to the original video resolution.
- """
- assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other."
- obj_idx = self._obj_id_to_idx(obj_id)
- point_inputs = None
- pop_key = "point_inputs_per_obj"
- if points is not None:
- point_inputs = {"point_coords": points, "point_labels": labels}
- self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs
- pop_key = "mask_inputs_per_obj"
- self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks
- self.inference_state[pop_key][obj_idx].pop(frame_idx, None)
- # If this frame hasn't been tracked before, we treat it as an initial conditioning
- # frame, meaning that the inputs points are to generate segments on this frame without
- # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
- # the input points will be used to correct the already tracked masks.
- is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"]
- obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
- obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
- # Add a frame to conditioning output if it's an initial conditioning frame or
- # if the model sees all frames receiving clicks/mask as conditioning frames.
- is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond
- storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
- # Get any previously predicted mask logits on this object and feed it along with
- # the new clicks into the SAM mask decoder.
- prev_sam_mask_logits = None
- # lookup temporary output dict first, which contains the most recent output
- # (if not found, then lookup conditioning and non-conditioning frame output)
- if point_inputs is not None:
- prev_out = (
- obj_temp_output_dict[storage_key].get(frame_idx)
- or obj_output_dict["cond_frame_outputs"].get(frame_idx)
- or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
- )
- if prev_out is not None and prev_out.get("pred_masks") is not None:
- prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True)
- # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
- prev_sam_mask_logits.clamp_(-32.0, 32.0)
- current_out = self._run_single_frame_inference(
- output_dict=obj_output_dict, # run on the slice of a single object
- frame_idx=frame_idx,
- batch_size=1, # run on the slice of a single object
- is_init_cond_frame=is_init_cond_frame,
- point_inputs=point_inputs,
- mask_inputs=masks,
- reverse=False,
- # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
- # at the beginning of `propagate_in_video` (after user finalize their clicks). This
- # allows us to enforce non-overlapping constraints on all objects before encoding
- # them into memory.
- run_mem_encoder=False,
- prev_sam_mask_logits=prev_sam_mask_logits,
- )
- # Add the output to the output dict (to be used as future memory)
- obj_temp_output_dict[storage_key][frame_idx] = current_out
- # Resize the output mask to the original video resolution
- consolidated_out = self._consolidate_temp_output_across_obj(
- frame_idx,
- is_cond=is_cond,
- run_mem_encoder=False,
- )
- pred_masks = consolidated_out["pred_masks"].flatten(0, 1)
- return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device)
- @smart_inference_mode()
- def propagate_in_video_preflight(self):
- """
- Prepare inference_state and consolidate temporary outputs before tracking.
- This method marks the start of tracking, disallowing the addition of new objects until the session is reset.
- It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`.
- Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent
- with the provided inputs.
- """
- # Tracking has started and we don't allow adding new objects until session is reset.
- self.inference_state["tracking_has_started"] = True
- batch_size = len(self.inference_state["obj_idx_to_id"])
- # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
- # add them into "output_dict".
- temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"]
- output_dict = self.inference_state["output_dict"]
- # "consolidated_frame_inds" contains indices of those frames where consolidated
- # temporary outputs have been added (either in this call or any previous calls
- # to `propagate_in_video_preflight`).
- consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
- for is_cond in {False, True}:
- # Separately consolidate conditioning and non-conditioning temp outputs
- storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
- # Find all the frames that contain temporary outputs for any objects
- # (these should be the frames that have just received clicks for mask inputs
- # via `add_new_points` or `add_new_mask`)
- temp_frame_inds = set()
- for obj_temp_output_dict in temp_output_dict_per_obj.values():
- temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
- consolidated_frame_inds[storage_key].update(temp_frame_inds)
- # consolidate the temporary output across all objects on this frame
- for frame_idx in temp_frame_inds:
- consolidated_out = self._consolidate_temp_output_across_obj(
- frame_idx, is_cond=is_cond, run_mem_encoder=True
- )
- # merge them into "output_dict" and also create per-object slices
- output_dict[storage_key][frame_idx] = consolidated_out
- self._add_output_per_object(frame_idx, consolidated_out, storage_key)
- if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
- # clear non-conditioning memory of the surrounding frames
- self._clear_non_cond_mem_around_input(frame_idx)
- # clear temporary outputs in `temp_output_dict_per_obj`
- for obj_temp_output_dict in temp_output_dict_per_obj.values():
- obj_temp_output_dict[storage_key].clear()
- # edge case: if an output is added to "cond_frame_outputs", we remove any prior
- # output on the same frame in "non_cond_frame_outputs"
- for frame_idx in output_dict["cond_frame_outputs"]:
- output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
- for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
- for frame_idx in obj_output_dict["cond_frame_outputs"]:
- obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
- for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
- assert frame_idx in output_dict["cond_frame_outputs"]
- consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
- # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
- # with either points or mask inputs (which should be true under a correct workflow).
- all_consolidated_frame_inds = (
- consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"]
- )
- input_frames_inds = set()
- for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values():
- input_frames_inds.update(point_inputs_per_frame.keys())
- for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values():
- input_frames_inds.update(mask_inputs_per_frame.keys())
- assert all_consolidated_frame_inds == input_frames_inds
- @staticmethod
- def init_state(predictor):
- """
- Initialize an inference state for the predictor.
- This function sets up the initial state required for performing inference on video data.
- It includes initializing various dictionaries and ordered dictionaries that will store
- inputs, outputs, and other metadata relevant to the tracking process.
- Args:
- predictor (SAM2VideoPredictor): The predictor object for which to initialize the state.
- """
- if len(predictor.inference_state) > 0: # means initialized
- return
- assert predictor.dataset is not None
- assert predictor.dataset.mode == "video"
- inference_state = {
- "num_frames": predictor.dataset.frames,
- "point_inputs_per_obj": {}, # inputs points on each frame
- "mask_inputs_per_obj": {}, # inputs mask on each frame
- "constants": {}, # values that don't change across frames (so we only need to hold one copy of them)
- # mapping between client-side object id and model-side object index
- "obj_id_to_idx": OrderedDict(),
- "obj_idx_to_id": OrderedDict(),
- "obj_ids": [],
- # A storage to hold the model's tracking results and states on each frame
- "output_dict": {
- "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
- "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
- },
- # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
- "output_dict_per_obj": {},
- # A temporary storage to hold new outputs when user interact with a frame
- # to add clicks or mask (it's merged into "output_dict" before propagation starts)
- "temp_output_dict_per_obj": {},
- # Frames that already holds consolidated outputs from click or mask inputs
- # (we directly use their consolidated outputs during tracking)
- "consolidated_frame_inds": {
- "cond_frame_outputs": set(), # set containing frame indices
- "non_cond_frame_outputs": set(), # set containing frame indices
- },
- # metadata for each tracking frame (e.g. which direction it's tracked)
- "tracking_has_started": False,
- "frames_already_tracked": [],
- }
- predictor.inference_state = inference_state
- def get_im_features(self, im, batch=1):
- """
- Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.
- Args:
- im (torch.Tensor): The input image tensor.
- batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1.
- Returns:
- vis_feats (torch.Tensor): The visual features extracted from the image.
- vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.
- feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features.
- Note:
- - If `batch` is greater than 1, the features are expanded to fit the batch size.
- - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features.
- """
- backbone_out = self.model.forward_image(im)
- if batch > 1: # expand features if there's more than one prompt
- for i, feat in enumerate(backbone_out["backbone_fpn"]):
- backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1)
- for i, pos in enumerate(backbone_out["vision_pos_enc"]):
- pos = pos.expand(batch, -1, -1, -1)
- backbone_out["vision_pos_enc"][i] = pos
- _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out)
- return vis_feats, vis_pos_embed, feat_sizes
- def _obj_id_to_idx(self, obj_id):
- """
- Map client-side object id to model-side object index.
- Args:
- obj_id (int): The unique identifier of the object provided by the client side.
- Returns:
- obj_idx (int): The index of the object on the model side.
- Raises:
- RuntimeError: If an attempt is made to add a new object after tracking has started.
- Note:
- - The method updates or retrieves mappings between object IDs and indices stored in
- `inference_state`.
- - It ensures that new objects can only be added before tracking commences.
- - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`).
- - Additional data structures are initialized for the new object to store inputs and outputs.
- """
- obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None)
- if obj_idx is not None:
- return obj_idx
- # This is a new object id not sent to the server before. We only allow adding
- # new objects *before* the tracking starts.
- allow_new_object = not self.inference_state["tracking_has_started"]
- if allow_new_object:
- # get the next object slot
- obj_idx = len(self.inference_state["obj_id_to_idx"])
- self.inference_state["obj_id_to_idx"][obj_id] = obj_idx
- self.inference_state["obj_idx_to_id"][obj_idx] = obj_id
- self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"])
- # set up input and output structures for this object
- self.inference_state["point_inputs_per_obj"][obj_idx] = {}
- self.inference_state["mask_inputs_per_obj"][obj_idx] = {}
- self.inference_state["output_dict_per_obj"][obj_idx] = {
- "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
- "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
- }
- self.inference_state["temp_output_dict_per_obj"][obj_idx] = {
- "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
- "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
- }
- return obj_idx
- else:
- raise RuntimeError(
- f"Cannot add new object id {obj_id} after tracking starts. "
- f"All existing object ids: {self.inference_state['obj_ids']}. "
- f"Please call 'reset_state' to restart from scratch."
- )
- def _run_single_frame_inference(
- self,
- output_dict,
- frame_idx,
- batch_size,
- is_init_cond_frame,
- point_inputs,
- mask_inputs,
- reverse,
- run_mem_encoder,
- prev_sam_mask_logits=None,
- ):
- """
- Run tracking on a single frame based on current inputs and previous memory.
- Args:
- output_dict (Dict): The dictionary containing the output states of the tracking process.
- frame_idx (int): The index of the current frame.
- batch_size (int): The batch size for processing the frame.
- is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.
- point_inputs (Dict, Optional): Input points and their labels. Defaults to None.
- mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None.
- reverse (bool): Indicates if the tracking should be performed in reverse order.
- run_mem_encoder (bool): Indicates if the memory encoder should be executed.
- prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None.
- Returns:
- current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions.
- Raises:
- AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.
- Note:
- - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive.
- - The method retrieves image features using the `get_im_features` method.
- - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored.
- - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements.
- """
- # Retrieve correct image features
- current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(
- self.inference_state["im"], batch_size
- )
- # point and mask should not appear as input simultaneously on the same frame
- assert point_inputs is None or mask_inputs is None
- current_out = self.model.track_step(
- frame_idx=frame_idx,
- is_init_cond_frame=is_init_cond_frame,
- current_vision_feats=current_vision_feats,
- current_vision_pos_embeds=current_vision_pos_embeds,
- feat_sizes=feat_sizes,
- point_inputs=point_inputs,
- mask_inputs=mask_inputs,
- output_dict=output_dict,
- num_frames=self.inference_state["num_frames"],
- track_in_reverse=reverse,
- run_mem_encoder=run_mem_encoder,
- prev_sam_mask_logits=prev_sam_mask_logits,
- )
- maskmem_features = current_out["maskmem_features"]
- if maskmem_features is not None:
- current_out["maskmem_features"] = maskmem_features.to(
- dtype=torch.float16, device=self.device, non_blocking=True
- )
- # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions
- # potentially fill holes in the predicted masks
- # if self.fill_hole_area > 0:
- # pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True)
- # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area)
- # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
- current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"])
- return current_out
- def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
- """
- Caches and manages the positional encoding for mask memory across frames and objects.
- This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for
- mask memory, which is constant across frames and objects, thus reducing the amount of
- redundant information stored during an inference session. It checks if the positional
- encoding has already been cached; if not, it caches a slice of the provided encoding.
- If the batch size is greater than one, it expands the cached positional encoding to match
- the current batch size.
- Args:
- out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory.
- Should be a list of tensors or None.
- Returns:
- out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
- Note:
- - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.
- - Only a single object's slice is cached since the encoding is the same across objects.
- - The method checks if the positional encoding has already been cached in the session's constants.
- - If the batch size is greater than one, the cached encoding is expanded to fit the batch size.
- """
- model_constants = self.inference_state["constants"]
- # "out_maskmem_pos_enc" should be either a list of tensors or None
- if out_maskmem_pos_enc is not None:
- if "maskmem_pos_enc" not in model_constants:
- assert isinstance(out_maskmem_pos_enc, list)
- # only take the slice for one object, since it's same across objects
- maskmem_pos_enc = [x[:1].clone() for x in out_maskmem_pos_enc]
- model_constants["maskmem_pos_enc"] = maskmem_pos_enc
- else:
- maskmem_pos_enc = model_constants["maskmem_pos_enc"]
- # expand the cached maskmem_pos_enc to the actual batch size
- batch_size = out_maskmem_pos_enc[0].size(0)
- if batch_size > 1:
- out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]
- return out_maskmem_pos_enc
- def _consolidate_temp_output_across_obj(
- self,
- frame_idx,
- is_cond=False,
- run_mem_encoder=False,
- ):
- """
- Consolidates per-object temporary outputs into a single output for all objects.
- This method combines the temporary outputs for each object on a given frame into a unified
- output. It fills in any missing objects either from the main output dictionary or leaves
- placeholders if they do not exist in the main output. Optionally, it can re-run the memory
- encoder after applying non-overlapping constraints to the object scores.
- Args:
- frame_idx (int): The index of the frame for which to consolidate outputs.
- is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame.
- Defaults to False.
- run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after
- consolidating the outputs. Defaults to False.
- Returns:
- consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects.
- Note:
- - The method initializes the consolidated output with placeholder values for missing objects.
- - It searches for outputs in both the temporary and main output dictionaries.
- - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder.
- - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True.
- """
- batch_size = len(self.inference_state["obj_idx_to_id"])
- storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
- # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
- # will be added when rerunning the memory encoder after applying non-overlapping
- # constraints to object scores. Its "pred_masks" are prefilled with a large
- # negative value (NO_OBJ_SCORE) to represent missing objects.
- consolidated_out = {
- "maskmem_features": None,
- "maskmem_pos_enc": None,
- "pred_masks": torch.full(
- size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
- fill_value=-1024.0,
- dtype=torch.float32,
- device=self.device,
- ),
- "obj_ptr": torch.full(
- size=(batch_size, self.model.hidden_dim),
- fill_value=-1024.0,
- dtype=torch.float32,
- device=self.device,
- ),
- "object_score_logits": torch.full(
- size=(batch_size, 1),
- # default to 10.0 for object_score_logits, i.e. assuming the object is
- # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
- fill_value=10.0,
- dtype=torch.float32,
- device=self.device,
- ),
- }
- for obj_idx in range(batch_size):
- obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
- obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
- out = (
- obj_temp_output_dict[storage_key].get(frame_idx)
- # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
- # we fall back and look up its previous output in "output_dict_per_obj".
- # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
- # "output_dict_per_obj" to find a previous output for this object.
- or obj_output_dict["cond_frame_outputs"].get(frame_idx)
- or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
- )
- # If the object doesn't appear in "output_dict_per_obj" either, we skip it
- # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
- # placeholder above) and set its object pointer to be a dummy pointer.
- if out is None:
- # Fill in dummy object pointers for those objects without any inputs or
- # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
- # i.e. when we need to build the memory for tracking).
- if run_mem_encoder:
- # fill object pointer with a dummy pointer (based on an empty mask)
- consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx)
- continue
- # Add the temporary object output mask to consolidated output mask
- consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"]
- consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
- # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder
- if run_mem_encoder:
- high_res_masks = F.interpolate(
- consolidated_out["pred_masks"],
- size=self.imgsz,
- mode="bilinear",
- align_corners=False,
- )
- if self.model.non_overlap_masks_for_mem_enc:
- high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)
- consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder(
- batch_size=batch_size,
- high_res_masks=high_res_masks,
- is_mask_from_pts=True, # these frames are what the user interacted with
- object_score_logits=consolidated_out["object_score_logits"],
- )
- return consolidated_out
- def _get_empty_mask_ptr(self, frame_idx):
- """
- Get a dummy object pointer based on an empty mask on the current frame.
- Args:
- frame_idx (int): The index of the current frame for which to generate the dummy object pointer.
- Returns:
- (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask.
- """
- # Retrieve correct image features
- current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"])
- # Feed the empty mask and image feature above to get a dummy object pointer
- current_out = self.model.track_step(
- frame_idx=frame_idx,
- is_init_cond_frame=True,
- current_vision_feats=current_vision_feats,
- current_vision_pos_embeds=current_vision_pos_embeds,
- feat_sizes=feat_sizes,
- point_inputs=None,
- # A dummy (empty) mask with a single object
- mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device),
- output_dict={},
- num_frames=self.inference_state["num_frames"],
- track_in_reverse=False,
- run_mem_encoder=False,
- prev_sam_mask_logits=None,
- )
- return current_out["obj_ptr"]
- def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):
- """
- Run the memory encoder on masks.
- This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their
- memory also needs to be computed again with the memory encoder.
- Args:
- batch_size (int): The batch size for processing the frame.
- high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory.
- object_score_logits (torch.Tensor): Logits representing the object scores.
- is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.
- Returns:
- (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding.
- """
- # Retrieve correct image features
- current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size)
- maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
- current_vision_feats=current_vision_feats,
- feat_sizes=feat_sizes,
- pred_masks_high_res=high_res_masks,
- is_mask_from_pts=is_mask_from_pts,
- object_score_logits=object_score_logits,
- )
- # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
- maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc)
- return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc
- def _add_output_per_object(self, frame_idx, current_out, storage_key):
- """
- Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
- The resulting slices share the same tensor storage.
- Args:
- frame_idx (int): The index of the current frame.
- current_out (Dict): The current output dictionary containing multi-object outputs.
- storage_key (str): The key used to store the output in the per-object output dictionary.
- """
- maskmem_features = current_out["maskmem_features"]
- assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
- maskmem_pos_enc = current_out["maskmem_pos_enc"]
- assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
- for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items():
- obj_slice = slice(obj_idx, obj_idx + 1)
- obj_out = {
- "maskmem_features": None,
- "maskmem_pos_enc": None,
- "pred_masks": current_out["pred_masks"][obj_slice],
- "obj_ptr": current_out["obj_ptr"][obj_slice],
- }
- if maskmem_features is not None:
- obj_out["maskmem_features"] = maskmem_features[obj_slice]
- if maskmem_pos_enc is not None:
- obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
- obj_output_dict[storage_key][frame_idx] = obj_out
- def _clear_non_cond_mem_around_input(self, frame_idx):
- """
- Remove the non-conditioning memory around the input frame.
- When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated
- object appearance information and could confuse the model. This method clears those non-conditioning memories
- surrounding the interacted frame to avoid giving the model both old and new information about the object.
- Args:
- frame_idx (int): The index of the current frame where user interaction occurred.
- """
- r = self.model.memory_temporal_stride_for_eval
- frame_idx_begin = frame_idx - r * self.model.num_maskmem
- frame_idx_end = frame_idx + r * self.model.num_maskmem
- for t in range(frame_idx_begin, frame_idx_end + 1):
- self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
- for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
- obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|