@@ -0,0 +1,9 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from .fastsam import FastSAM
+from .nas import NAS
+from .rtdetr import RTDETR
+from .sam import SAM
+from .yolo import YOLO, YOLOWorld
+__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import
@@ -0,0 +1,7 @@
+from .model import FastSAM
+from .predict import FastSAMPredictor
+from .val import FastSAMValidator
+__all__ = "FastSAMPredictor", "FastSAM", "FastSAMValidator"
@@ -0,0 +1,55 @@
+from pathlib import Path
+from ultralytics.engine.model import Model
+class FastSAM(Model):
+ """
+ FastSAM model interface.
+ Example:
+ ```python
+ from ultralytics import FastSAM
+ model = FastSAM("last.pt")
+ results = model.predict("ultralytics/assets/bus.jpg")
+ ```
+ def __init__(self, model="FastSAM-x.pt"):
+ """Call the __init__ method of the parent class (YOLO) with the updated default model."""
+ if str(model) == "FastSAM.pt":
+ model = "FastSAM-x.pt"
+ assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
+ super().__init__(model=model, task="segment")
+ def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs):
+ Perform segmentation prediction on image or video source.
+ Supports prompted segmentation with bounding boxes, points, labels, and texts.
+ Args:
+ source (str | PIL.Image | numpy.ndarray): Input source.
+ stream (bool): Enable real-time streaming.
+ bboxes (list): Bounding box coordinates for prompted segmentation.
+ points (list): Points for prompted segmentation.
+ labels (list): Labels for prompted segmentation.
+ texts (list): Texts for prompted segmentation.
+ **kwargs (Any): Additional keyword arguments.
+ Returns:
+ (list): Model predictions.
+ prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
+ return super().predict(source, stream, prompts=prompts, **kwargs)
+ @property
+ def task_map(self):
+ """Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
+ return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}
@@ -0,0 +1,150 @@
+import torch
+from PIL import Image
+from ultralytics.models.yolo.segment import SegmentationPredictor
+from ultralytics.utils import DEFAULT_CFG, checks
+from ultralytics.utils.metrics import box_iou
+from ultralytics.utils.ops import scale_masks
+from .utils import adjust_bboxes_to_image_border
+class FastSAMPredictor(SegmentationPredictor):
+ FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
+ YOLO framework.
+ This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
+ adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single-
+ class segmentation.
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+ """Initializes a FastSAMPredictor for fast SAM segmentation tasks in Ultralytics YOLO framework."""
+ super().__init__(cfg, overrides, _callbacks)
+ self.prompts = {}
+ def postprocess(self, preds, img, orig_imgs):
+ """Applies box postprocess for FastSAM predictions."""
+ bboxes = self.prompts.pop("bboxes", None)
+ points = self.prompts.pop("points", None)
+ labels = self.prompts.pop("labels", None)
+ texts = self.prompts.pop("texts", None)
+ results = super().postprocess(preds, img, orig_imgs)
+ for result in results:
+ full_box = torch.tensor(
+ [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
+ )
+ boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
+ idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
+ if idx.numel() != 0:
+ result.boxes.xyxy[idx] = full_box
+ return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
+ def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
+ Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
+ Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
+ results (Results | List[Results]): The original inference results from FastSAM models without any prompts.
+ bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
+ points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
+ labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
+ texts (str | List[str], optional): Textual prompts, a list contains string objects.
+ (List[Results]): The output results determined by prompts.
+ if bboxes is None and points is None and texts is None:
+ return results
+ prompt_results = []
+ if not isinstance(results, list):
+ results = [results]
+ if len(result) == 0:
+ prompt_results.append(result)
+ continue
+ masks = result.masks.data
+ if masks.shape[1:] != result.orig_shape:
+ masks = scale_masks(masks[None], result.orig_shape)[0]
+ # bboxes prompt
+ idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
+ if bboxes is not None:
+ bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
+ bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
+ bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
+ mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
+ full_mask_areas = torch.sum(masks, dim=(1, 2))
+ union = bbox_areas[:, None] + full_mask_areas - mask_areas
+ idx[torch.argmax(mask_areas / union, dim=1)] = True
+ if points is not None:
+ points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
+ points = points[None] if points.ndim == 1 else points
+ if labels is None:
+ labels = torch.ones(points.shape[0])
+ labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
+ assert len(labels) == len(points), (
+ f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
+ point_idx = (
+ torch.ones(len(result), dtype=torch.bool, device=self.device)
+ if labels.sum() == 0 # all negative points
+ else torch.zeros(len(result), dtype=torch.bool, device=self.device)
+ for point, label in zip(points, labels):
+ point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
+ idx |= point_idx
+ if texts is not None:
+ if isinstance(texts, str):
+ texts = [texts]
+ crop_ims, filter_idx = [], []
+ for i, b in enumerate(result.boxes.xyxy.tolist()):
+ x1, y1, x2, y2 = (int(x) for x in b)
+ if masks[i].sum() <= 100:
+ filter_idx.append(i)
+ crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
+ similarity = self._clip_inference(crop_ims, texts)
+ text_idx = torch.argmax(similarity, dim=-1) # (M, )
+ if len(filter_idx):
+ text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
+ idx[text_idx] = True
+ prompt_results.append(result[idx])
+ return prompt_results
+ def _clip_inference(self, images, texts):
+ CLIP Inference process.
+ images (List[PIL.Image]): A list of source images and each of them should be PIL.Image type with RGB channel order.
+ texts (List[str]): A list of prompt texts and each of them should be string object.
+ (torch.Tensor): The similarity between given images and texts.
+ try:
+ import clip
+ except ImportError:
+ checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
+ if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
+ self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
+ images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
+ tokenized_text = clip.tokenize(texts).to(self.device)
+ image_features = self.clip_model.encode_image(images)
+ text_features = self.clip_model.encode_text(tokenized_text)
+ image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
+ text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
+ return (image_features * text_features[:, None]).sum(-1) # (M, N)
+ def set_prompts(self, prompts):
+ """Set prompts in advance."""
+ self.prompts = prompts
@@ -0,0 +1,24 @@
+def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
+ Adjust bounding boxes to stick to image border if they are within a certain threshold.
+ boxes (torch.Tensor): (n, 4)
+ image_shape (tuple): (height, width)
+ threshold (int): pixel threshold
+ adjusted_boxes (torch.Tensor): adjusted bounding boxes
+ # Image dimensions
+ h, w = image_shape
+ # Adjust boxes
+ boxes[boxes[:, 0] < threshold, 0] = 0 # x1
+ boxes[boxes[:, 1] < threshold, 1] = 0 # y1
+ boxes[boxes[:, 2] > w - threshold, 2] = w # x2
+ boxes[boxes[:, 3] > h - threshold, 3] = h # y2
+ return boxes
@@ -0,0 +1,40 @@
+from ultralytics.models.yolo.segment import SegmentationValidator
+from ultralytics.utils.metrics import SegmentMetrics
+class FastSAMValidator(SegmentationValidator):
+ Custom validation class for fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
+ Extends the SegmentationValidator class, customizing the validation process specifically for fast SAM. This class
+ sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
+ to avoid errors during validation.
+ Attributes:
+ dataloader: The data loader object used for validation.
+ save_dir (str): The directory where validation results will be saved.
+ pbar: A progress bar object.
+ args: Additional arguments for customization.
+ _callbacks: List of callback functions to be invoked during validation.
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
+ Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
+ dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
+ save_dir (Path, optional): Directory to save results.
+ pbar (tqdm.tqdm): Progress bar for displaying progress.
+ args (SimpleNamespace): Configuration for the validator.
+ _callbacks (dict): Dictionary to store various callback functions.
+ Notes:
+ Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
+ super().__init__(dataloader, save_dir, pbar, args, _callbacks)
+ self.args.task = "segment"
+ self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
+ self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
+from .model import NAS
+from .predict import NASPredictor
+from .val import NASValidator
+__all__ = "NASPredictor", "NASValidator", "NAS"
@@ -0,0 +1,94 @@
+"""
+YOLO-NAS model interface.
+Example:
+ from ultralytics import NAS
+ model = NAS("yolo_nas_s")
+from ultralytics.utils.downloads import attempt_download_asset
+from ultralytics.utils.torch_utils import model_info
+class NAS(Model):
+ YOLO NAS model for object detection.
+ This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
+ It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
+ model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'.
+ Note:
+ YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
+ def __init__(self, model="yolo_nas_s.pt") -> None:
+ """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
+ assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
+ super().__init__(model, task="detect")
+ def _load(self, weights: str, task=None) -> None:
+ """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
+ import super_gradients
+ suffix = Path(weights).suffix
+ if suffix == ".pt":
+ self.model = torch.load(attempt_download_asset(weights))
+ elif suffix == "":
+ self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
+ # Override the forward method to ignore additional arguments
+ def new_forward(x, *args, **kwargs):
+ """Ignore additional __call__ arguments."""
+ return self.model._original_forward(x)
+ self.model._original_forward = self.model.forward
+ self.model.forward = new_forward
+ # Standardize model
+ self.model.fuse = lambda verbose=True: self.model
+ self.model.stride = torch.tensor([32])
+ self.model.names = dict(enumerate(self.model._class_names))
+ self.model.is_fused = lambda: False # for info()
+ self.model.yaml = {} # for info()
+ self.model.pt_path = weights # for export()
+ self.model.task = "detect" # for export()
+ def info(self, detailed=False, verbose=True):
+ Logs model info.
+ detailed (bool): Show detailed information about model.
+ verbose (bool): Controls verbosity.
+ return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
+ """Returns a dictionary mapping tasks to respective predictor and validator classes."""
+ return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
@@ -0,0 +1,57 @@
+from ultralytics.engine.predictor import BasePredictor
+from ultralytics.engine.results import Results
+from ultralytics.utils import ops
+class NASPredictor(BasePredictor):
+ Ultralytics YOLO NAS Predictor for object detection.
+ This class extends the `BasePredictor` from Ultralytics engine and is responsible for post-processing the
+ raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
+ scaling the bounding boxes to fit the original image dimensions.
+ args (Namespace): Namespace containing various configurations for post-processing.
+ predictor = model.predictor
+ # Assumes that raw_preds, img, orig_imgs are available
+ results = predictor.postprocess(raw_preds, img, orig_imgs)
+ Typically, this class is not instantiated directly. It is used internally within the `NAS` class.
+ def postprocess(self, preds_in, img, orig_imgs):
+ """Postprocess predictions and returns a list of Results objects."""
+ # Cat boxes and class scores
+ boxes = ops.xyxy2xywh(preds_in[0][0])
+ preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
+ preds = ops.non_max_suppression(
+ preds,
+ self.args.conf,
+ self.args.iou,
+ agnostic=self.args.agnostic_nms,
+ max_det=self.args.max_det,
+ classes=self.args.classes,
+ if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
+ orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+ results = []
+ for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
+ results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
@@ -0,0 +1,50 @@
+from ultralytics.models.yolo.detect import DetectionValidator
+__all__ = ["NASValidator"]
+class NASValidator(DetectionValidator):
+ Ultralytics YOLO NAS Validator for object detection.
+ Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions
+ generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
+ ultimately producing the final detections.
+ args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU.
+ lb (torch.Tensor): Optional tensor for multilabel NMS.
+ validator = model.validator
+ # Assumes that raw_preds are available
+ final_preds = validator.postprocess(raw_preds)
+ This class is generally not instantiated directly but is used internally within the `NAS` class.
+ def postprocess(self, preds_in):
+ """Apply Non-maximum suppression to prediction outputs."""
+ return ops.non_max_suppression(
+ labels=self.lb,
+ multi_label=False,
+ agnostic=self.args.single_cls or self.args.agnostic_nms,
+ max_time_img=0.5,
+from .model import RTDETR
+from .predict import RTDETRPredictor
+from .val import RTDETRValidator
+__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR"
@@ -0,0 +1,54 @@
+Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time
+performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient
+hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
+For more information on RT-DETR, visit: https://arxiv.org/pdf/2304.08069.pdf
+from ultralytics.nn.tasks import RTDETRDetectionModel
+from .train import RTDETRTrainer
+class RTDETR(Model):
+ Interface for Baidu's RT-DETR model. This Vision Transformer-based object detector provides real-time performance
+ with high accuracy. It supports efficient hybrid encoding, IoU-aware query selection, and adaptable inference speed.
+ model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
+ def __init__(self, model="rtdetr-l.pt") -> None:
+ Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.
+ Raises:
+ NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
+ super().__init__(model=model, task="detect")
+ def task_map(self) -> dict:
+ Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
+ dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
+ return {
+ "detect": {
+ "predictor": RTDETRPredictor,
+ "validator": RTDETRValidator,
+ "trainer": RTDETRTrainer,
+ "model": RTDETRDetectionModel,
+ }
@@ -0,0 +1,84 @@
+from ultralytics.data.augment import LetterBox
+class RTDETRPredictor(BasePredictor):
+ RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions using
+ Baidu's RT-DETR model.
+ This class leverages the power of Vision Transformers to provide real-time object detection while maintaining
+ high accuracy. It supports key features like efficient hybrid encoding and IoU-aware query selection.
+ from ultralytics.utils import ASSETS
+ from ultralytics.models.rtdetr import RTDETRPredictor
+ args = dict(model="rtdetr-l.pt", source=ASSETS)
+ predictor = RTDETRPredictor(overrides=args)
+ predictor.predict_cli()
+ imgsz (int): Image size for inference (must be square and scale-filled).
+ args (dict): Argument overrides for the predictor.
+ Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
+ The method filters detections based on confidence and class if specified in `self.args`.
+ preds (list): List of [predictions, extra] from the model.
+ img (torch.Tensor): Processed input images.
+ orig_imgs (list or torch.Tensor): Original, unprocessed images.
+ (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
+ and class labels.
+ if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
+ preds = [preds, None]
+ nd = preds[0].shape[-1]
+ bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
+ for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4)
+ bbox = ops.xywh2xyxy(bbox)
+ max_score, cls = score.max(-1, keepdim=True) # (300, 1)
+ idx = max_score.squeeze(-1) > self.args.conf # (300, )
+ if self.args.classes is not None:
+ idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
+ pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
+ oh, ow = orig_img.shape[:2]
+ pred[..., [0, 2]] *= ow
+ pred[..., [1, 3]] *= oh
+ def pre_transform(self, im):
+ Pre-transforms the input images before feeding them into the model for inference. The input images are
+ letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scaleFilled.
+ im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list.
+ (list): List of pre-transformed images ready for model inference.
+ letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True)
+ return [letterbox(image=x) for x in im]
@@ -0,0 +1,105 @@
+from copy import copy
+from ultralytics.models.yolo.detect import DetectionTrainer
+from ultralytics.utils import RANK, colorstr
+from .val import RTDETRDataset, RTDETRValidator
+class RTDETRTrainer(DetectionTrainer):
+ Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
+ class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
+ Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.
+ - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
+ - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
+ from ultralytics.models.rtdetr.train import RTDETRTrainer
+ args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
+ trainer = RTDETRTrainer(overrides=args)
+ trainer.train()
+ def get_model(self, cfg=None, weights=None, verbose=True):
+ Initialize and return an RT-DETR model for object detection tasks.
+ cfg (dict, optional): Model configuration. Defaults to None.
+ weights (str, optional): Path to pre-trained model weights. Defaults to None.
+ verbose (bool): Verbose logging if True. Defaults to True.
+ (RTDETRDetectionModel): Initialized model.
+ model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
+ if weights:
+ model.load(weights)
+ return model
+ def build_dataset(self, img_path, mode="val", batch=None):
+ Build and return an RT-DETR dataset for training or validation.
+ img_path (str): Path to the folder containing images.
+ mode (str): Dataset mode, either 'train' or 'val'.
+ batch (int, optional): Batch size for rectangle training. Defaults to None.
+ (RTDETRDataset): Dataset object for the specific mode.
+ return RTDETRDataset(
+ img_path=img_path,
+ imgsz=self.args.imgsz,
+ batch_size=batch,
+ augment=mode == "train",
+ hyp=self.args,
+ rect=False,
+ cache=self.args.cache or None,
+ single_cls=self.args.single_cls or False,
+ prefix=colorstr(f"{mode}: "),
+ data=self.data,
+ fraction=self.args.fraction if mode == "train" else 1.0,
+ def get_validator(self):
+ Returns a DetectionValidator suitable for RT-DETR model validation.
+ (RTDETRValidator): Validator object for model validation.
+ self.loss_names = "giou_loss", "cls_loss", "l1_loss"
+ return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
+ def preprocess_batch(self, batch):
+ Preprocess a batch of images. Scales and converts the images to float format.
+ batch (dict): Dictionary containing a batch of images, bboxes, and labels.
+ (dict): Preprocessed batch.
+ batch = super().preprocess_batch(batch)
+ bs = len(batch["img"])
+ batch_idx = batch["batch_idx"]
+ gt_bbox, gt_class = [], []
+ for i in range(bs):
+ gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
+ gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
+ return batch
@@ -0,0 +1,135 @@
+from ultralytics.data import YOLODataset
+from ultralytics.data.augment import Compose, Format, v8_transforms
+from ultralytics.utils import colorstr, ops
+__all__ = ("RTDETRValidator",) # tuple or list
+class RTDETRDataset(YOLODataset):
+ Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
+ This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
+ real-time detection and tracking tasks.
+ def __init__(self, *args, data=None, **kwargs):
+ """Initialize the RTDETRDataset class by inheriting from the YOLODataset class."""
+ super().__init__(*args, data=data, **kwargs)
+ # NOTE: add stretch version load_image for RTDETR mosaic
+ def load_image(self, i, rect_mode=False):
+ """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
+ return super().load_image(i=i, rect_mode=rect_mode)
+ def build_transforms(self, hyp=None):
+ """Temporary, only for evaluation."""
+ if self.augment:
+ hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
+ hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
+ transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
+ else:
+ # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
+ transforms = Compose([])
+ transforms.append(
+ Format(
+ bbox_format="xywh",
+ normalize=True,
+ return_mask=self.use_segments,
+ return_keypoint=self.use_keypoints,
+ batch_idx=True,
+ mask_ratio=hyp.mask_ratio,
+ mask_overlap=hyp.overlap_mask,
+ return transforms
+class RTDETRValidator(DetectionValidator):
+ RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
+ the RT-DETR (Real-Time DETR) object detection model.
+ The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
+ post-processing, and updates evaluation metrics accordingly.
+ from ultralytics.models.rtdetr import RTDETRValidator
+ args = dict(model="rtdetr-l.pt", data="coco8.yaml")
+ validator = RTDETRValidator(args=args)
+ validator()
+ For further details on the attributes and methods, refer to the parent DetectionValidator class.
+ Build an RTDETR Dataset.
+ mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
+ batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
+ augment=False, # no augmentation
+ rect=False, # no rect
+ def postprocess(self, preds):
+ bs, _, nd = preds[0].shape
+ bboxes *= self.args.imgsz
+ outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
+ for i, bbox in enumerate(bboxes): # (300, 4)
+ score, cls = scores[i].max(-1) # (300, )
+ # Do not need threshold for evaluation as only got 300 boxes here
+ # idx = score > self.args.conf
+ pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
+ # Sort by confidence to correctly get internal metrics
+ pred = pred[score.argsort(descending=True)]
+ outputs[i] = pred # [idx]
+ return outputs
+ def _prepare_batch(self, si, batch):
+ """Prepares a batch for training or inference by applying transformations."""
+ idx = batch["batch_idx"] == si
+ cls = batch["cls"][idx].squeeze(-1)
+ bbox = batch["bboxes"][idx]
+ ori_shape = batch["ori_shape"][si]
+ imgsz = batch["img"].shape[2:]
+ ratio_pad = batch["ratio_pad"][si]
+ if len(cls):
+ bbox = ops.xywh2xyxy(bbox) # target boxes
+ bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
+ bbox[..., [1, 3]] *= ori_shape[0] # native-space pred
+ return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
+ def _prepare_pred(self, pred, pbatch):
+ """Prepares and returns a batch with transformed bounding boxes and class labels."""
+ predn = pred.clone()
+ predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
+ predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
+ return predn.float()
@@ -0,0 +1,6 @@
+from .model import SAM
+from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
+__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list
@@ -0,0 +1,193 @@
+import math
+from itertools import product
+from typing import Any, Generator, List, Tuple
+import numpy as np
+def is_box_near_crop_edge(
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+ """Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance."""
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+ """Yields batches of data from input arguments with specified batch size for efficient processing."""
+ assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+ for b in range(n_batches):
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
+ Computes the stability score for a batch of masks.
+ The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
+ high and low values.
+ masks (torch.Tensor): Batch of predicted mask logits.
+ mask_threshold (float): Threshold value for creating binary masks.
+ threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
+ (torch.Tensor): Stability scores for each mask in the batch.
+ - One mask is always contained inside the other.
+ - Memory is saved by preventing unnecessary cast to torch.int64.
+ Examples:
+ >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
+ >>> mask_threshold = 0.5
+ >>> threshold_offset = 0.1
+ >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
+ intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
+ unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
+ return intersections / unions
+def build_point_grid(n_per_side: int) -> np.ndarray:
+ """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+ return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
+ """Generates point grids for multiple crop layers with varying scales and densities."""
+ return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
+def generate_crop_boxes(
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+ """Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions."""
+ crop_boxes, layer_idxs = [], []
+ im_h, im_w = im_size
+ short_side = min(im_h, im_w)
+ # Original image
+ crop_boxes.append([0, 0, im_w, im_h])
+ layer_idxs.append(0)
+ def crop_len(orig_len, n_crops, overlap):
+ """Crops bounding boxes to the size of the input image."""
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+ for i_layer in range(n_layers):
+ n_crops_per_side = 2 ** (i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+ # Crops in XYWH format
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+ return crop_boxes, layer_idxs
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ """Uncrop bounding boxes by adding the crop box offset to their coordinates."""
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return boxes + offset
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ """Uncrop points by adding the crop box offset to their coordinates."""
+ offset = torch.tensor([[x0, y0]], device=points.device)
+ # Check if points has a channel dimension
+ if len(points.shape) == 3:
+ return points + offset
+def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
+ """Uncrop masks by padding them to the original image size, handling coordinate transformations."""
+ x0, y0, x1, y1 = crop_box
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
+ return torch.nn.functional.pad(masks, pad, value=0)
+def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
+ """Removes small disconnected regions or holes in a mask based on area threshold and mode."""
+ import cv2 # type: ignore
+ assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
+ correct_holes = mode == "holes"
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+ sizes = stats[:, -1][1:] # Row 0 is background label
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+ if not small_regions:
+ return mask, False
+ fill_labels = [0] + small_regions
+ if not correct_holes:
+ # If every region is below threshold, keep largest
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
+ mask = np.isin(regions, fill_labels)
+ return mask, True
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+ """Calculates bounding boxes in XYXY format around binary masks, handling empty masks and various input shapes."""
+ # torch.max below raises an error on empty inputs, just skip in this case
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+ # Normalize shape to CxHxW
+ shape = masks.shape
+ h, w = shape[-2:]
+ masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + h * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + w * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+ # Return to original shape
+ return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]
@@ -0,0 +1,358 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from functools import partial
+from .modules.decoders import MaskDecoder
+from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
+from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
+from .modules.sam import SAM2Model, SAMModel
+from .modules.tiny_encoder import TinyViT
+from .modules.transformer import TwoWayTransformer
+def build_sam_vit_h(checkpoint=None):
+ """Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
+ return _build_sam(
+ encoder_embed_dim=1280,
+ encoder_depth=32,
+ encoder_num_heads=16,
+ encoder_global_attn_indexes=[7, 15, 23, 31],
+ checkpoint=checkpoint,
+def build_sam_vit_l(checkpoint=None):
+ """Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
+ encoder_embed_dim=1024,
+ encoder_depth=24,
+ encoder_global_attn_indexes=[5, 11, 17, 23],
+def build_sam_vit_b(checkpoint=None):
+ """Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint."""
+ encoder_embed_dim=768,
+ encoder_depth=12,
+ encoder_num_heads=12,
+ encoder_global_attn_indexes=[2, 5, 8, 11],
+def build_mobile_sam(checkpoint=None):
+ """Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
+ encoder_embed_dim=[64, 128, 160, 320],
+ encoder_depth=[2, 2, 6, 2],
+ encoder_num_heads=[2, 4, 5, 10],
+ encoder_global_attn_indexes=None,
+ mobile_sam=True,
+def build_sam2_t(checkpoint=None):
+ """Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
+ return _build_sam2(
+ encoder_embed_dim=96,
+ encoder_stages=[1, 2, 7, 2],
+ encoder_num_heads=1,
+ encoder_global_att_blocks=[5, 7, 9],
+ encoder_window_spec=[8, 4, 14, 7],
+ encoder_backbone_channel_list=[768, 384, 192, 96],
+def build_sam2_s(checkpoint=None):
+ """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
+ encoder_stages=[1, 2, 11, 2],
+ encoder_global_att_blocks=[7, 10, 13],
+def build_sam2_b(checkpoint=None):
+ """Builds and returns a SAM2 base-size model with specified architecture parameters."""
+ encoder_embed_dim=112,
+ encoder_stages=[2, 3, 16, 3],
+ encoder_num_heads=2,
+ encoder_global_att_blocks=[12, 16, 20],
+ encoder_window_spatial_size=[14, 14],
+ encoder_backbone_channel_list=[896, 448, 224, 112],
+def build_sam2_l(checkpoint=None):
+ """Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters."""
+ encoder_embed_dim=144,
+ encoder_stages=[2, 6, 36, 4],
+ encoder_global_att_blocks=[23, 33, 43],
+ encoder_window_spec=[8, 4, 16, 8],
+ encoder_backbone_channel_list=[1152, 576, 288, 144],
+def _build_sam(
+ encoder_embed_dim,
+ encoder_depth,
+ encoder_num_heads,
+ encoder_global_attn_indexes,
+ checkpoint=None,
+ mobile_sam=False,
+):
+ Builds a Segment Anything Model (SAM) with specified encoder parameters.
+ encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
+ encoder_depth (int | List[int]): Depth of the encoder.
+ encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
+ encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
+ checkpoint (str | None): Path to the model checkpoint file.
+ mobile_sam (bool): Whether to build a Mobile-SAM model.
+ (SAMModel): A Segment Anything Model instance with the specified architecture.
+ >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
+ >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
+ prompt_embed_dim = 256
+ image_size = 1024
+ vit_patch_size = 16
+ image_embedding_size = image_size // vit_patch_size
+ image_encoder = (
+ TinyViT(
+ img_size=1024,
+ in_chans=3,
+ num_classes=1000,
+ embed_dims=encoder_embed_dim,
+ depths=encoder_depth,
+ num_heads=encoder_num_heads,
+ window_sizes=[7, 7, 14, 7],
+ mlp_ratio=4.0,
+ drop_rate=0.0,
+ drop_path_rate=0.0,
+ use_checkpoint=False,
+ mbconv_expand_ratio=4.0,
+ local_conv_size=3,
+ layer_lr_decay=0.8,
+ if mobile_sam
+ else ImageEncoderViT(
+ depth=encoder_depth,
+ embed_dim=encoder_embed_dim,
+ img_size=image_size,
+ mlp_ratio=4,
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
+ patch_size=vit_patch_size,
+ qkv_bias=True,
+ use_rel_pos=True,
+ global_attn_indexes=encoder_global_attn_indexes,
+ window_size=14,
+ out_chans=prompt_embed_dim,
+ sam = SAMModel(
+ image_encoder=image_encoder,
+ prompt_encoder=PromptEncoder(
+ embed_dim=prompt_embed_dim,
+ image_embedding_size=(image_embedding_size, image_embedding_size),
+ input_image_size=(image_size, image_size),
+ mask_in_chans=16,
+ ),
+ mask_decoder=MaskDecoder(
+ num_multimask_outputs=3,
+ transformer=TwoWayTransformer(
+ depth=2,
+ embedding_dim=prompt_embed_dim,
+ mlp_dim=2048,
+ num_heads=8,
+ transformer_dim=prompt_embed_dim,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ pixel_mean=[123.675, 116.28, 103.53],
+ pixel_std=[58.395, 57.12, 57.375],
+ if checkpoint is not None:
+ checkpoint = attempt_download_asset(checkpoint)
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ sam.load_state_dict(state_dict)
+ sam.eval()
+ return sam
+def _build_sam2(
+ encoder_global_att_blocks=[7, 15, 23, 31],
+ encoder_window_spatial_size=[7, 7],
+ Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
+ encoder_embed_dim (int): Embedding dimension for the encoder.
+ encoder_stages (List[int]): Number of blocks in each stage of the encoder.
+ encoder_num_heads (int): Number of attention heads in the encoder.
+ encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
+ encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
+ encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
+ encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
+ checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
+ (SAM2Model): A configured and initialized SAM2 model.
+ >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
+ >>> sam2_model.eval()
+ image_encoder = ImageEncoder(
+ trunk=Hiera(
+ stages=encoder_stages,
+ global_att_blocks=encoder_global_att_blocks,
+ window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
+ window_spec=encoder_window_spec,
+ neck=FpnNeck(
+ d_model=256,
+ backbone_channel_list=encoder_backbone_channel_list,
+ fpn_top_down_levels=[2, 3],
+ fpn_interp_model="nearest",
+ scalp=1,
+ memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
+ memory_encoder = MemoryEncoder(out_dim=64)
+ is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint
+ sam2 = SAM2Model(
+ memory_attention=memory_attention,
+ memory_encoder=memory_encoder,
+ num_maskmem=7,
+ image_size=1024,
+ sigmoid_scale_for_mem_enc=20.0,
+ sigmoid_bias_for_mem_enc=-10.0,
+ use_mask_input_as_output_without_sam=True,
+ directly_add_no_mem_embed=True,
+ use_high_res_features_in_sam=True,
+ multimask_output_in_sam=True,
+ iou_prediction_use_sigmoid=True,
+ use_obj_ptrs_in_encoder=True,
+ add_tpos_enc_to_obj_ptrs=True,
+ only_obj_ptrs_in_the_past_for_eval=True,
+ pred_obj_scores=True,
+ pred_obj_scores_mlp=True,
+ fixed_no_obj_ptr=True,
+ multimask_output_for_tracking=True,
+ use_multimask_token_for_obj_ptr=True,
+ multimask_min_pt_num=0,
+ multimask_max_pt_num=1,
+ use_mlp_for_obj_ptr_proj=True,
+ compile_image_encoder=False,
+ no_obj_embed_spatial=is_sam2_1,
+ proj_tpos_enc_in_obj_ptrs=is_sam2_1,
+ use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,
+ sam_mask_decoder_extra_args=dict(
+ dynamic_multimask_via_stability=True,
+ dynamic_multimask_stability_delta=0.05,
+ dynamic_multimask_stability_thresh=0.98,
+ state_dict = torch.load(f)["model"]
+ sam2.load_state_dict(state_dict)
+ sam2.eval()
+ return sam2
+sam_model_map = {
+ "sam_h.pt": build_sam_vit_h,
+ "sam_l.pt": build_sam_vit_l,
+ "sam_b.pt": build_sam_vit_b,
+ "mobile_sam.pt": build_mobile_sam,
+ "sam2_t.pt": build_sam2_t,
+ "sam2_s.pt": build_sam2_s,
+ "sam2_b.pt": build_sam2_b,
+ "sam2_l.pt": build_sam2_l,
+ "sam2.1_t.pt": build_sam2_t,
+ "sam2.1_s.pt": build_sam2_s,
+ "sam2.1_b.pt": build_sam2_b,
+ "sam2.1_l.pt": build_sam2_l,
+}
+def build_sam(ckpt="sam_b.pt"):
+ Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
+ ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
+ (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
+ FileNotFoundError: If the provided checkpoint is not a supported SAM model.
+ >>> sam_model = build_sam("sam_b.pt")
+ >>> sam_model = build_sam("path/to/custom_checkpoint.pt")
+ Supported pre-defined models include:
+ - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
+ - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
+ model_builder = None
+ ckpt = str(ckpt) # to allow Path ckpt types
+ for k in sam_model_map.keys():
+ if ckpt.endswith(k):
+ model_builder = sam_model_map.get(k)
+ if not model_builder:
+ raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
+ return model_builder(ckpt)
@@ -0,0 +1,175 @@
+SAM model interface.
+This module provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for real-time image
+segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis,
+and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
+image distributions and tasks without prior knowledge.
+Key Features:
+ - Promptable segmentation
+ - Real-time performance
+ - Zero-shot transfer capabilities
+ - Trained on SA-1B dataset
+from .build import build_sam
+from .predict import Predictor, SAM2Predictor
+class SAM(Model):
+ SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
+ This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
+ promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
+ boxes, points, or labels, and features zero-shot performance capabilities.
+ model (torch.nn.Module): The loaded SAM model.
+ is_sam2 (bool): Indicates whether the model is SAM2 variant.
+ task (str): The task type, set to "segment" for SAM models.
+ Methods:
+ predict: Performs segmentation prediction on the given image or video source.
+ info: Logs information about the SAM model.
+ >>> sam = SAM("sam_b.pt")
+ >>> results = sam.predict("image.jpg", points=[[500, 375]])
+ >>> for r in results:
+ >>> print(f"Detected {len(r.masks)} masks")
+ def __init__(self, model="sam_b.pt") -> None:
+ Initializes the SAM (Segment Anything Model) instance.
+ model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
+ NotImplementedError: If the model file extension is not .pt or .pth.
+ >>> print(sam.is_sam2)
+ if model and Path(model).suffix not in {".pt", ".pth"}:
+ raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
+ self.is_sam2 = "sam2" in Path(model).stem
+ def _load(self, weights: str, task=None):
+ Loads the specified weights into the SAM model.
+ This method initializes the SAM model with the provided weights file, setting up the model architecture
+ and loading the pre-trained parameters.
+ weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
+ task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
+ >>> sam._load("path/to/custom_weights.pt")
+ self.model = build_sam(weights)
+ def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
+ Performs segmentation prediction on the given image or video source.
+ source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
+ a numpy.ndarray object.
+ stream (bool): If True, enables real-time streaming.
+ bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
+ points (List[List[float]] | None): List of points for prompted segmentation.
+ labels (List[int] | None): List of labels for prompted segmentation.
+ **kwargs (Any): Additional keyword arguments for prediction.
+ (List): The model predictions.
+ ... print(f"Detected {len(r.masks)} masks")
+ overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
+ kwargs = {**overrides, **kwargs}
+ prompts = dict(bboxes=bboxes, points=points, labels=labels)
+ def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
+ This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
+ for segmentation tasks.
+ source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image
+ object, or a numpy.ndarray object.
+ **kwargs (Any): Additional keyword arguments to be passed to the predict method.
+ (List): The model predictions, typically containing segmentation masks and other relevant information.
+ >>> results = sam("image.jpg", points=[[500, 375]])
+ >>> print(f"Detected {len(results[0].masks)} masks")
+ return self.predict(source, stream, bboxes, points, labels, **kwargs)
+ Logs information about the SAM model.
+ This method provides details about the Segment Anything Model (SAM), including its architecture,
+ parameters, and computational requirements.
+ detailed (bool): If True, displays detailed information about the model layers and operations.
+ verbose (bool): If True, prints the information to the console.
+ (tuple): A tuple containing the model's information (string representations of the model).
+ >>> info = sam.info()
+ >>> print(info[0]) # Print summary information
+ return model_info(self.model, detailed=detailed, verbose=verbose)
+ Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
+ (Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor
+ class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
+ >>> task_map = sam.task_map
+ >>> print(task_map)
+ {'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
+ return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
@@ -0,0 +1 @@
@@ -0,0 +1,1129 @@
+import copy
+from typing import Any, Optional, Tuple, Type, Union
+import torch.nn.functional as F
+from torch import Tensor, nn
+from ultralytics.nn.modules import MLP, LayerNorm2d, MLPBlock
+from .transformer import Attention, TwoWayAttentionBlock, TwoWayTransformer
+from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition
+class DropPath(nn.Module):
+ Implements stochastic depth regularization for neural networks during training.
+ drop_prob (float): Probability of dropping a path during training.
+ scale_by_keep (bool): Whether to scale the output by the keep probability.
+ forward: Applies stochastic depth to input tensor during training, with optional scaling.
+ >>> drop_path = DropPath(drop_prob=0.2, scale_by_keep=True)
+ >>> x = torch.randn(32, 64, 224, 224)
+ >>> output = drop_path(x)
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
+ """Initialize DropPath module for stochastic depth regularization during training."""
+ super().__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+ def forward(self, x):
+ """Applies stochastic depth to input tensor during training, with optional scaling."""
+ if self.drop_prob == 0.0 or not self.training:
+ return x
+ keep_prob = 1 - self.drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and self.scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+class MaskDownSampler(nn.Module):
+ A mask downsampling and embedding module for efficient processing of input masks.
+ This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks
+ while expanding their channel dimensions using convolutional layers, layer normalization, and activation
+ functions.
+ encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and
+ activation functions for downsampling and embedding masks.
+ forward: Downsamples and encodes input mask to embed_dim channels.
+ >>> mask_downsampler = MaskDownSampler(embed_dim=256, kernel_size=4, stride=4, padding=0, total_stride=16)
+ >>> input_mask = torch.randn(1, 1, 256, 256)
+ >>> output = mask_downsampler(input_mask)
+ >>> print(output.shape)
+ torch.Size([1, 256, 16, 16])
+ def __init__(
+ self,
+ embed_dim=256,
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ total_stride=16,
+ activation=nn.GELU,
+ ):
+ """Initializes a mask downsampler module for progressive downsampling and channel expansion."""
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
+ assert stride**num_layers == total_stride
+ self.encoder = nn.Sequential()
+ mask_in_chans, mask_out_chans = 1, 1
+ for _ in range(num_layers):
+ mask_out_chans = mask_in_chans * (stride**2)
+ self.encoder.append(
+ nn.Conv2d(
+ mask_in_chans,
+ mask_out_chans,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ self.encoder.append(LayerNorm2d(mask_out_chans))
+ self.encoder.append(activation())
+ mask_in_chans = mask_out_chans
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
+ """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
+ return self.encoder(x)
+class CXBlock(nn.Module):
+ ConvNeXt Block for efficient feature extraction in convolutional neural networks.
+ This block implements a modified version of the ConvNeXt architecture, offering improved performance and
+ flexibility in feature extraction.
+ dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
+ norm (LayerNorm2d): Layer normalization applied to channels.
+ pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
+ act (nn.GELU): GELU activation function.
+ pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
+ gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.
+ drop_path (nn.Module): DropPath layer for stochastic depth regularization.
+ forward: Processes the input tensor through the ConvNeXt block.
+ >>> import torch
+ >>> x = torch.randn(1, 64, 56, 56)
+ >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
+ >>> output = block(x)
+ torch.Size([1, 64, 56, 56])
+ dim,
+ kernel_size=7,
+ padding=3,
+ drop_path=0.0,
+ layer_scale_init_value=1e-6,
+ use_dwconv=True,
+ Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
+ dim (int): Number of input channels.
+ kernel_size (int): Size of the convolutional kernel.
+ padding (int): Padding size for the convolution.
+ drop_path (float): Stochastic depth rate.
+ layer_scale_init_value (float): Initial value for Layer Scale.
+ use_dwconv (bool): Whether to use depthwise convolution.
+ >>> x = torch.randn(1, 64, 32, 32)
+ torch.Size([1, 64, 32, 32])
+ self.dwconv = nn.Conv2d(
+ groups=dim if use_dwconv else 1,
+ ) # depthwise conv
+ self.norm = LayerNorm2d(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection."""
+ input = x
+ x = self.dwconv(x)
+ x = self.norm(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+ x = input + self.drop_path(x)
+class Fuser(nn.Module):
+ A module for fusing features through multiple layers of a neural network.
+ This class applies a series of identical layers to an input tensor, optionally projecting the input first.
+ proj (nn.Module): An optional input projection layer. Identity if no projection is needed.
+ layers (nn.ModuleList): A list of identical layers to be applied sequentially.
+ forward: Applies the fuser to an input tensor.
+ >>> layer = CXBlock(dim=256)
+ >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)
+ >>> x = torch.randn(1, 256, 32, 32)
+ >>> output = fuser(x)
+ torch.Size([1, 256, 32, 32])
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
+ Initializes the Fuser module for feature fusion through multiple layers.
+ This module creates a sequence of identical layers and optionally applies an input projection.
+ layer (nn.Module): The layer to be replicated in the fuser.
+ num_layers (int): The number of times to replicate the layer.
+ dim (int | None): The dimension for input projection, if used.
+ input_projection (bool): Whether to use input projection.
+ >>> layer = nn.Linear(64, 64)
+ >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)
+ >>> input_tensor = torch.randn(1, 64)
+ >>> output = fuser(input_tensor)
+ self.proj = nn.Identity()
+ self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
+ if input_projection:
+ assert dim is not None
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
+ """Applies a series of layers to the input tensor, optionally projecting it first."""
+ x = self.proj(x)
+ for layer in self.layers:
+ x = layer(x)
+class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
+ A two-way attention block for performing self-attention and cross-attention in both directions.
+ This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on
+ sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and
+ cross-attention from dense to sparse inputs.
+ self_attn (Attention): Self-attention layer for queries.
+ norm1 (nn.LayerNorm): Layer normalization after the first attention block.
+ cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
+ norm2 (nn.LayerNorm): Layer normalization after the second attention block.
+ mlp (MLP): MLP block for transforming query embeddings.
+ norm3 (nn.LayerNorm): Layer normalization after the MLP block.
+ norm4 (nn.LayerNorm): Layer normalization after the third attention block.
+ cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
+ skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.
+ forward: Processes input through the attention blocks and MLP.
+ >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8)
+ >>> sparse_input = torch.randn(1, 100, 256)
+ >>> dense_input = torch.randn(1, 256, 16, 16)
+ >>> sparse_output, dense_output = block(sparse_input, dense_input)
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ Initializes a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
+ This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse
+ inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention
+ from dense to sparse inputs.
+ embedding_dim (int): The channel dimension of the embeddings.
+ num_heads (int): The number of heads in the attention layers.
+ mlp_dim (int): The hidden dimension of the MLP block.
+ activation (Type[nn.Module]): The activation function of the MLP block.
+ attention_downsample_rate (int): The downsample rate for attention computations.
+ skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
+ >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)
+ >>> sparse_inputs = torch.randn(1, 100, 256)
+ >>> dense_inputs = torch.randn(1, 256, 32, 32)
+ >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)
+ super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)
+ self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)
+class SAM2TwoWayTransformer(TwoWayTransformer):
+ A Two-Way Transformer module for simultaneous attention to image and query points.
+ This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an
+ input image using queries with supplied positional embeddings. It is particularly useful for tasks like
+ object detection, image segmentation, and point cloud processing.
+ depth (int): Number of layers in the transformer.
+ embedding_dim (int): Channel dimension for input embeddings.
+ num_heads (int): Number of heads for multihead attention.
+ mlp_dim (int): Internal channel dimension for the MLP block.
+ layers (nn.ModuleList): List of SAM2TwoWayAttentionBlock layers comprising the transformer.
+ final_attn_token_to_image (Attention): Final attention layer from queries to image.
+ norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
+ forward: Processes input image embeddings and query embeddings through the transformer.
+ >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
+ >>> image_embedding = torch.randn(1, 256, 64, 64)
+ >>> query_embedding = torch.randn(1, 100, 256)
+ >>> output = transformer(image_embedding, query_embedding)
+ >>> print(output[0].shape, output[1].shape)
+ torch.Size([1, 100, 256]) torch.Size([1, 256, 64, 64])
+ depth: int,
+ mlp_dim: int,
+ Initializes a SAM2TwoWayTransformer instance.
+ This transformer decoder attends to an input image using queries with supplied positional embeddings.
+ It is designed for tasks like object detection, image segmentation, and point cloud processing.
+ embedding_dim (int): Channel dimension for the input embeddings.
+ num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
+ mlp_dim (int): Channel dimension internal to the MLP block.
+ activation (Type[nn.Module]): Activation function to use in the MLP block.
+ attention_downsample_rate (int): Downsampling rate for attention computations.
+ >>> transformer
+ SAM2TwoWayTransformer(
+ (layers): ModuleList(
+ (0-4): 5 x SAM2TwoWayAttentionBlock(...)
+ (final_attn_token_to_image): Attention(...)
+ (norm_final_attn): LayerNorm(...)
+ super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)
+ self.layers = nn.ModuleList()
+ for i in range(depth):
+ self.layers.append(
+ SAM2TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+class RoPEAttention(Attention):
+ Implements rotary position encoding for attention mechanisms in transformer architectures.
+ This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance
+ the positional awareness of the attention mechanism.
+ compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
+ freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding.
+ rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.
+ forward: Applies rotary position encoding and computes attention between query, key, and value tensors.
+ >>> rope_attn = RoPEAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))
+ >>> q = torch.randn(1, 1024, 256)
+ >>> k = torch.randn(1, 1024, 256)
+ >>> v = torch.randn(1, 1024, 256)
+ >>> output = rope_attn(q, k, v)
+ torch.Size([1, 1024, 256])
+ *args,
+ rope_theta=10000.0,
+ rope_k_repeat=False,
+ feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
+ **kwargs,
+ """Initializes RoPEAttention with rotary position encoding for enhanced positional awareness."""
+ super().__init__(*args, **kwargs)
+ self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
+ self.freqs_cis = freqs_cis
+ self.rope_k_repeat = rope_k_repeat # repeat q rope to match k length, needed for cross-attention to memories
+ def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
+ """Applies rotary position encoding and computes attention between query, key, and value tensors."""
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+ # Apply rotary position encoding
+ w = h = math.sqrt(q.shape[-2])
+ self.freqs_cis = self.freqs_cis.to(q.device)
+ if self.freqs_cis.shape[0] != q.shape[-2]:
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
+ if q.shape[-2] != k.shape[-2]:
+ assert self.rope_k_repeat
+ num_k_rope = k.size(-2) - num_k_exclude_rope
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
+ q,
+ k[:, :, :num_k_rope],
+ freqs_cis=self.freqs_cis,
+ repeat_freqs_k=self.rope_k_repeat,
+ # Attention
+ _, _, _, c_per_head = q.shape
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
+ attn = attn / math.sqrt(c_per_head)
+ attn = torch.softmax(attn, dim=-1)
+ # Get output
+ out = attn @ v
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+ return out
+def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
+ """Applies pooling and optional normalization to a tensor, handling spatial dimension permutations."""
+ if pool is None:
+ # (B, H, W, C) -> (B, C, H, W)
+ x = x.permute(0, 3, 1, 2)
+ x = pool(x)
+ # (B, C, H', W') -> (B, H', W', C)
+ x = x.permute(0, 2, 3, 1)
+ if norm:
+ x = norm(x)
+class MultiScaleAttention(nn.Module):
+ Implements multiscale self-attention with optional query pooling for efficient feature extraction.
+ This class provides a flexible implementation of multiscale attention, allowing for optional
+ downsampling of query features through pooling. It's designed to enhance the model's ability to
+ capture multiscale information in visual tasks.
+ dim (int): Input dimension of the feature map.
+ dim_out (int): Output dimension of the attention module.
+ num_heads (int): Number of attention heads.
+ scale (float): Scaling factor for dot-product attention.
+ q_pool (nn.Module | None): Optional pooling module for query features.
+ qkv (nn.Linear): Linear projection for query, key, and value.
+ proj (nn.Linear): Output projection.
+ forward: Applies multiscale attention to the input tensor.
+ >>> from torch import nn
+ >>> x = torch.randn(1, 64, 64, 256)
+ >>> msa = MultiScaleAttention(dim=256, dim_out=256, num_heads=8)
+ >>> output = msa(x)
+ torch.Size([1, 64, 64, 256])
+ dim: int,
+ dim_out: int,
+ q_pool: nn.Module = None,
+ """Initializes multiscale attention with optional query pooling for efficient feature extraction."""
+ self.dim = dim
+ self.dim_out = dim_out
+ self.num_heads = num_heads
+ head_dim = dim_out // num_heads
+ self.scale = head_dim**-0.5
+ self.q_pool = q_pool
+ self.qkv = nn.Linear(dim, dim_out * 3)
+ self.proj = nn.Linear(dim_out, dim_out)
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Applies multiscale attention with optional query pooling to extract multiscale features."""
+ B, H, W, _ = x.shape
+ # qkv with shape (B, H * W, 3, nHead, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
+ # q, k, v with shape (B, H * W, nheads, C)
+ q, k, v = torch.unbind(qkv, 2)
+ # Q pooling (for downsample at stage changes)
+ if self.q_pool:
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
+ H, W = q.shape[1:3] # downsampled shape
+ q = q.reshape(B, H * W, self.num_heads, -1)
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
+ x = F.scaled_dot_product_attention(
+ q.transpose(1, 2),
+ k.transpose(1, 2),
+ v.transpose(1, 2),
+ # Transpose back
+ x = x.transpose(1, 2)
+ x = x.reshape(B, H, W, -1)
+class MultiScaleBlock(nn.Module):
+ A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
+ This class implements a multiscale attention mechanism with optional window partitioning and downsampling,
+ designed for use in vision transformer architectures.
+ dim (int): Input dimension of the block.
+ dim_out (int): Output dimension of the block.
+ norm1 (nn.Module): First normalization layer.
+ window_size (int): Size of the window for partitioning.
+ pool (nn.Module | None): Pooling layer for query downsampling.
+ q_stride (Tuple[int, int] | None): Stride for query pooling.
+ attn (MultiScaleAttention): Multi-scale attention module.
+ drop_path (nn.Module): Drop path layer for regularization.
+ norm2 (nn.Module): Second normalization layer.
+ mlp (MLP): Multi-layer perceptron module.
+ proj (nn.Linear | None): Projection layer for dimension mismatch.
+ forward: Processes input tensor through the multiscale block.
+ >>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7)
+ >>> x = torch.randn(1, 56, 56, 256)
+ torch.Size([1, 28, 28, 512])
+ mlp_ratio: float = 4.0,
+ drop_path: float = 0.0,
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
+ q_stride: Tuple[int, int] = None,
+ act_layer: nn.Module = nn.GELU,
+ window_size: int = 0,
+ """Initializes a multiscale attention block with window partitioning and optional query pooling."""
+ if isinstance(norm_layer, str):
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
+ self.norm1 = norm_layer(dim)
+ self.window_size = window_size
+ self.pool, self.q_stride = None, q_stride
+ if self.q_stride:
+ self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)
+ self.attn = MultiScaleAttention(
+ dim_out,
+ q_pool=self.pool,
+ self.norm2 = norm_layer(dim_out)
+ self.mlp = MLP(
+ int(dim_out * mlp_ratio),
+ num_layers=2,
+ act=act_layer,
+ if dim != dim_out:
+ self.proj = nn.Linear(dim, dim_out)
+ """Processes input through multiscale attention and MLP, with optional windowing and downsampling."""
+ shortcut = x # B, H, W, C
+ x = self.norm1(x)
+ # Skip connection
+ if self.dim != self.dim_out:
+ shortcut = do_pool(self.proj(x), self.pool)
+ # Window partition
+ window_size = self.window_size
+ if window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, window_size)
+ # Window Attention + Q Pooling (if stage change)
+ x = self.attn(x)
+ # Shapes have changed due to Q pooling
+ window_size = self.window_size // self.q_stride[0]
+ H, W = shortcut.shape[1:3]
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ pad_hw = (H + pad_h, W + pad_w)
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
+ x = shortcut + self.drop_path(x)
+ # MLP
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+class PositionEmbeddingSine(nn.Module):
+ A module for generating sinusoidal positional embeddings for 2D inputs like images.
+ This class implements sinusoidal position encoding for 2D spatial positions, which can be used in
+ transformer-based models for computer vision tasks.
+ num_pos_feats (int): Number of positional features (half of the embedding dimension).
+ temperature (int): Temperature parameter for the sinusoidal functions.
+ normalize (bool): Whether to normalize the positional embeddings.
+ scale (float): Scaling factor for the embeddings when normalize is True.
+ cache (Dict): Cache for storing precomputed embeddings.
+ _encode_xy: Encodes 2D positions using sine and cosine functions.
+ encode_boxes: Encodes box coordinates and dimensions into positional embeddings.
+ encode_points: Encodes 2D point coordinates with sinusoidal positional embeddings.
+ forward: Generates sinusoidal position embeddings for 2D inputs.
+ >>> pos_emb = PositionEmbeddingSine(num_pos_feats=128)
+ >>> x = torch.randn(1, 3, 224, 224)
+ >>> embeddings = pos_emb(x)
+ >>> print(embeddings.shape)
+ torch.Size([1, 256, 224, 224])
+ num_pos_feats,
+ temperature: int = 10000,
+ normalize: bool = True,
+ scale: Optional[float] = None,
+ """Initializes sinusoidal position embeddings for 2D image inputs."""
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
+ self.num_pos_feats = num_pos_feats // 2
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and not normalize:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+ self.cache = {}
+ def _encode_xy(self, x, y):
+ """Encodes 2D positions using sine/cosine functions for transformer positional embeddings."""
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
+ x_embed = x * self.scale
+ y_embed = y * self.scale
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+ pos_x = x_embed[:, None] / dim_t
+ pos_y = y_embed[:, None] / dim_t
+ pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
+ pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
+ return pos_x, pos_y
+ @torch.no_grad()
+ def encode_boxes(self, x, y, w, h):
+ """Encodes box coordinates and dimensions into positional embeddings for detection."""
+ pos_x, pos_y = self._encode_xy(x, y)
+ return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
+ encode = encode_boxes # Backwards compatibility
+ def encode_points(self, x, y, labels):
+ """Encodes 2D points with sinusoidal embeddings and appends labels."""
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
+ assert bx == by and nx == ny and bx == bl and nx == nl
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
+ return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
+ def forward(self, x: torch.Tensor):
+ """Generates sinusoidal position embeddings for 2D inputs like images."""
+ cache_key = (x.shape[-2], x.shape[-1])
+ if cache_key in self.cache:
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
+ y_embed = (
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
+ .view(1, -1, 1)
+ .repeat(x.shape[0], 1, x.shape[-1])
+ x_embed = (
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
+ .view(1, 1, -1)
+ .repeat(x.shape[0], x.shape[-2], 1)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ self.cache[cache_key] = pos[0]
+ return pos
+class PositionEmbeddingRandom(nn.Module):
+ Positional encoding using random spatial frequencies.
+ This class generates positional embeddings for input coordinates using random spatial frequencies. It is
+ particularly useful for transformer-based models that require position information.
+ positional_encoding_gaussian_matrix (torch.Tensor): A buffer containing random values for encoding.
+ _pe_encoding: Positionally encodes points that are normalized to [0,1].
+ forward: Generates positional encoding for a grid of the specified size.
+ forward_with_coords: Positionally encodes points that are not normalized to [0,1].
+ >>> pe = PositionEmbeddingRandom(num_pos_feats=64)
+ >>> size = (32, 32)
+ >>> encoding = pe(size)
+ >>> print(encoding.shape)
+ torch.Size([128, 32, 32])
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+ """Initializes random spatial frequency position embedding for transformers."""
+ if scale is None or scale <= 0.0:
+ scale = 1.0
+ self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
+ # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
+ torch.use_deterministic_algorithms(False)
+ torch.backends.cudnn.deterministic = False
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+ """Encodes normalized [0,1] coordinates using random spatial frequencies."""
+ # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coords = 2 * coords - 1
+ coords = coords @ self.positional_encoding_gaussian_matrix
+ coords = 2 * np.pi * coords
+ # Outputs d_1 x ... x d_n x C shape
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+ """Generates positional encoding for a grid using random spatial frequencies."""
+ h, w = size
+ device: Any = self.positional_encoding_gaussian_matrix.device
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / h
+ x_embed = x_embed / w
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+ return pe.permute(2, 0, 1) # C x H x W
+ def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
+ """Positionally encodes input coordinates, normalizing them to [0,1] based on the given image size."""
+ coords = coords_input.clone()
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
+class Block(nn.Module):
+ Transformer block with support for window attention and residual propagation.
+ This class implements a transformer block that can use either global or windowed self-attention,
+ followed by a feed-forward network. It supports relative positional embeddings and is designed
+ for use in vision transformer architectures.
+ attn (REAttention): Self-attention layer with optional relative positional encoding.
+ mlp (MLPBlock): Multi-layer perceptron block.
+ window_size (int): Size of attention window. If 0, global attention is used.
+ forward: Processes input through the transformer block.
+ >>> block = Block(dim=256, num_heads=8, window_size=7)
+ torch.Size([1, 56, 56, 256])
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ input_size: Optional[Tuple[int, int]] = None,
+ Initializes a transformer block with optional window attention and relative positional embeddings.
+ This constructor sets up a transformer block that can use either global or windowed self-attention,
+ num_heads (int): Number of attention heads in the self-attention layer.
+ mlp_ratio (float): Ratio of mlp hidden dimension to embedding dimension.
+ qkv_bias (bool): If True, adds a learnable bias to query, key, value projections.
+ norm_layer (Type[nn.Module]): Type of normalization layer to use.
+ act_layer (Type[nn.Module]): Type of activation function to use in the MLP block.
+ use_rel_pos (bool): If True, uses relative positional embeddings in attention.
+ rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
+ window_size (int): Size of attention window. If 0, uses global attention.
+ input_size (Optional[Tuple[int, int]]): Input resolution for calculating relative positional parameter size.
+ self.attn = REAttention(
+ qkv_bias=qkv_bias,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ input_size=input_size if window_size == 0 else (window_size, window_size),
+ self.norm2 = norm_layer(dim)
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+ """Processes input through transformer block with optional windowed self-attention and residual connection."""
+ shortcut = x
+ x, pad_hw = window_partition(x, self.window_size)
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+ x = shortcut + x
+ return x + self.mlp(self.norm2(x))
+class REAttention(nn.Module):
+ Rotary Embedding Attention module for efficient self-attention in transformer architectures.
+ This class implements a multi-head attention mechanism with rotary positional embeddings, designed
+ for use in vision transformer models. It supports optional query pooling and window partitioning
+ for efficient processing of large inputs.
+ q_proj (nn.Linear): Linear projection for query.
+ k_proj (nn.Linear): Linear projection for key.
+ v_proj (nn.Linear): Linear projection for value.
+ out_proj (nn.Linear): Output projection.
+ internal_dim (int): Internal dimension for attention computation.
+ >>> rope_attn = REAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))
+ num_heads: int = 8,
+ Initializes a Relative Position Attention module for transformer-based architectures.
+ This module implements multi-head attention with optional relative positional encodings, designed
+ specifically for vision tasks in transformer models.
+ num_heads (int): Number of attention heads. Default is 8.
+ qkv_bias (bool): If True, adds a learnable bias to query, key, value projections. Default is True.
+ use_rel_pos (bool): If True, uses relative positional encodings. Default is False.
+ rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. Default is True.
+ input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
+ Required if use_rel_pos is True. Default is None.
+ >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))
+ >>> x = torch.randn(1, 32, 32, 256)
+ >>> output = attention(x)
+ torch.Size([1, 32, 32, 256])
+ head_dim = dim // num_heads
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+ self.use_rel_pos = use_rel_pos
+ if self.use_rel_pos:
+ assert input_size is not None, "Input size must be provided if using relative positional encoding."
+ # Initialize relative positional embeddings
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+ """Applies multi-head attention with optional relative positional encoding to input tensor."""
+ # qkv with shape (3, B, nHead, H * W, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ # q, k, v with shape (B * nHead, H * W, C)
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+ attn = attn.softmax(dim=-1)
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+ return self.proj(x)
+class PatchEmbed(nn.Module):
+ Image to Patch Embedding module for vision transformer architectures.
+ This module converts an input image into a sequence of patch embeddings using a convolutional layer.
+ It is commonly used as the first layer in vision transformer architectures to transform image data
+ into a suitable format for subsequent transformer blocks.
+ proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings.
+ forward: Applies patch embedding to the input tensor.
+ >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)
+ >>> output = patch_embed(x)
+ torch.Size([1, 768, 14, 14])
+ kernel_size: Tuple[int, int] = (16, 16),
+ stride: Tuple[int, int] = (16, 16),
+ padding: Tuple[int, int] = (0, 0),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ Initializes the PatchEmbed module for converting image patches to embeddings.
+ This module is typically used as the first layer in vision transformer architectures to transform
+ image data into a suitable format for subsequent transformer blocks.
+ kernel_size (Tuple[int, int]): Size of the convolutional kernel for patch extraction.
+ stride (Tuple[int, int]): Stride of the convolutional operation.
+ padding (Tuple[int, int]): Padding applied to the input before convolution.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Dimensionality of the output patch embeddings.
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+ """Computes patch embedding by applying convolution and transposing resulting tensor."""
+ return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
@@ -0,0 +1,518 @@
+from typing import List, Optional, Tuple, Type
+from torch import nn
+from ultralytics.nn.modules import MLP, LayerNorm2d
+class MaskDecoder(nn.Module):
+ Decoder module for generating masks and their associated quality scores using a transformer architecture.
+ This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
+ generate mask predictions along with their quality scores.
+ transformer_dim (int): Channel dimension for the transformer module.
+ transformer (nn.Module): Transformer module used for mask prediction.
+ num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
+ iou_token (nn.Embedding): Embedding for the IoU token.
+ num_mask_tokens (int): Number of mask tokens.
+ mask_tokens (nn.Embedding): Embedding for the mask tokens.
+ output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.
+ output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.
+ iou_prediction_head (nn.Module): MLP for predicting mask quality.
+ forward: Predicts masks given image and prompt embeddings.
+ predict_masks: Internal method for mask prediction.
+ >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
+ >>> masks, iou_pred = decoder(
+ ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True
+ ... )
+ >>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
+ transformer_dim: int,
+ transformer: nn.Module,
+ num_multimask_outputs: int = 3,
+ activation: Type[nn.Module] = nn.GELU,
+ iou_head_depth: int = 3,
+ iou_head_hidden_dim: int = 256,
+ Initializes the MaskDecoder module for generating masks and their quality scores.
+ activation (Type[nn.Module]): Type of activation to use when upscaling masks.
+ iou_head_depth (int): Depth of the MLP used to predict mask quality.
+ iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
+ >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
+ >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)
+ >>> print(decoder)
+ self.transformer_dim = transformer_dim
+ self.transformer = transformer
+ self.num_multimask_outputs = num_multimask_outputs
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ self.num_mask_tokens = num_multimask_outputs + 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+ LayerNorm2d(transformer_dim // 4),
+ activation(),
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
+ self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
+ def forward(
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ Predicts masks given image and prompt embeddings.
+ image_embeddings (torch.Tensor): Embeddings from the image encoder.
+ image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
+ sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
+ dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
+ multimask_output (bool): Whether to return multiple masks or a single mask.
+ (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
+ - masks (torch.Tensor): Batched predicted masks.
+ - iou_pred (torch.Tensor): Batched predictions of mask quality.
+ >>> image_emb = torch.rand(1, 256, 64, 64)
+ >>> image_pe = torch.rand(1, 256, 64, 64)
+ >>> sparse_emb = torch.rand(1, 2, 256)
+ >>> dense_emb = torch.rand(1, 256, 64, 64)
+ >>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True)
+ >>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
+ masks, iou_pred = self.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ dense_prompt_embeddings=dense_prompt_embeddings,
+ # Select the correct mask or masks for output
+ mask_slice = slice(1, None) if multimask_output else slice(0, 1)
+ masks = masks[:, mask_slice, :, :]
+ iou_pred = iou_pred[:, mask_slice]
+ # Prepare output
+ return masks, iou_pred
+ def predict_masks(
+ """Predicts masks and quality scores using image and prompt embeddings via transformer architecture."""
+ # Concatenate output tokens
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+ # Expand per-image data in batch direction to be per-mask
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ src = src + dense_prompt_embeddings
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = src.shape
+ # Run the transformer
+ hs, src = self.transformer(src, pos_src, tokens)
+ iou_token_out = hs[:, 0, :]
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+ # Upscale mask embeddings and predict masks using the mask tokens
+ src = src.transpose(1, 2).view(b, c, h, w)
+ upscaled_embedding = self.output_upscaling(src)
+ hyper_in_list: List[torch.Tensor] = [
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
+ ]
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding.shape
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+class SAM2MaskDecoder(nn.Module):
+ Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
+ This class extends the functionality of the MaskDecoder, incorporating additional features such as
+ high-resolution feature processing, dynamic multimask output, and object score prediction.
+ transformer_dim (int): Channel dimension of the transformer.
+ transformer (nn.Module): Transformer used to predict masks.
+ num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
+ iou_token (nn.Embedding): Embedding for IOU token.
+ num_mask_tokens (int): Total number of mask tokens.
+ mask_tokens (nn.Embedding): Embedding for mask tokens.
+ pred_obj_scores (bool): Whether to predict object scores.
+ obj_score_token (nn.Embedding): Embedding for object score token.
+ use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
+ output_upscaling (nn.Sequential): Upscaling layers for output.
+ use_high_res_features (bool): Whether to use high-resolution features.
+ conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).
+ conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).
+ output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.
+ iou_prediction_head (MLP): MLP for IOU prediction.
+ pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.
+ dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
+ dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
+ dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
+ predict_masks: Predicts instance segmentation masks from image and prompt embeddings.
+ _get_stability_scores: Computes mask stability scores based on IoU between thresholds.
+ _dynamic_multimask_via_stability: Dynamically selects the most stable mask output.
+ >>> image_embeddings = torch.rand(1, 256, 64, 64)
+ >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
+ >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
+ >>> decoder = SAM2MaskDecoder(256, transformer)
+ >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
+ ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
+ use_high_res_features: bool = False,
+ iou_prediction_use_sigmoid=False,
+ dynamic_multimask_via_stability=False,
+ pred_obj_scores: bool = False,
+ pred_obj_scores_mlp: bool = False,
+ use_multimask_token_for_obj_ptr: bool = False,
+ Initializes the SAM2MaskDecoder module for predicting instance segmentation masks.
+ This decoder extends the functionality of MaskDecoder, incorporating additional features such as
+ iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
+ pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
+ >>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)
+ self.pred_obj_scores = pred_obj_scores
+ if self.pred_obj_scores:
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+ self.use_high_res_features = use_high_res_features
+ if use_high_res_features:
+ self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
+ self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
+ self.iou_prediction_head = MLP(
+ transformer_dim,
+ iou_head_hidden_dim,
+ self.num_mask_tokens,
+ iou_head_depth,
+ sigmoid=iou_prediction_use_sigmoid,
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
+ if pred_obj_scores_mlp:
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
+ # When outputting a single mask, optionally we can dynamically fall back to the best
+ # multimask output token if the single mask output token gives low stability scores.
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
+ repeat_image: bool,
+ high_res_features: Optional[List[torch.Tensor]] = None,
+ image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
+ image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W).
+ sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C).
+ dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).
+ repeat_image (bool): Flag to repeat the image embeddings.
+ high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
+ (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
+ - masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
+ - iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).
+ - sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).
+ - object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
+ repeat_image=repeat_image,
+ high_res_features=high_res_features,
+ if multimask_output:
+ masks = masks[:, 1:, :, :]
+ iou_pred = iou_pred[:, 1:]
+ elif self.dynamic_multimask_via_stability and not self.training:
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+ masks = masks[:, 0:1, :, :]
+ iou_pred = iou_pred[:, 0:1]
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
+ # Take the mask output token. Here we *always* use the token for single mask output.
+ # At test time, even if we track after 1-click (and using multimask_output=True),
+ # we still take the single mask token here. The rationale is that we always track
+ # after multiple clicks during training, so the past tokens seen during training
+ # are always the single mask token (and we'll let it be the object-memory token).
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
+ return masks, iou_pred, sam_tokens_out, object_score_logits
+ """Predicts instance segmentation masks from image and prompt embeddings using a transformer."""
+ s = 0
+ output_tokens = torch.cat(
+ [
+ self.obj_score_token.weight,
+ self.iou_token.weight,
+ self.mask_tokens.weight,
+ ],
+ dim=0,
+ s = 1
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+ if repeat_image:
+ assert image_embeddings.shape[0] == tokens.shape[0]
+ src = image_embeddings
+ assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
+ iou_token_out = hs[:, s, :]
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
+ if not self.use_high_res_features:
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
+ feat_s0, feat_s1 = high_res_features
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
+ assert s == 1
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
+ return masks, iou_pred, mask_tokens_out, object_score_logits
+ def _get_stability_scores(self, mask_logits):
+ """Computes mask stability scores based on IoU between upper and lower thresholds."""
+ mask_logits = mask_logits.flatten(-2)
+ stability_delta = self.dynamic_multimask_stability_delta
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+ return torch.where(area_u > 0, area_i / area_u, 1.0)
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+ Dynamically selects the most stable mask output based on stability scores and IoU predictions.
+ This method is used when outputting a single mask. If the stability score from the current single-mask
+ output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs
+ (based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask
+ for both clicking and tracking scenarios.
+ all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
+ batch size, N is number of masks (typically 4), and H, W are mask dimensions.
+ all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
+ (Tuple[torch.Tensor, torch.Tensor]):
+ - mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).
+ - iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).
+ >>> decoder = SAM2MaskDecoder(...)
+ >>> all_mask_logits = torch.rand(2, 4, 256, 256) # 2 images, 4 masks each
+ >>> all_iou_scores = torch.rand(2, 4)
+ >>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores)
+ >>> print(mask_logits.shape, iou_scores.shape)
+ torch.Size([2, 1, 256, 256]) torch.Size([2, 1])
+ # The best mask from multimask output tokens (1~3)
+ multimask_logits = all_mask_logits[:, 1:, :, :]
+ multimask_iou_scores = all_iou_scores[:, 1:]
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
+ batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
+ # The mask from singlemask output token 0 and its stability score
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
+ stability_scores = self._get_stability_scores(singlemask_logits)
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+ # Dynamically fall back to best multimask output upon low stability scores.
+ mask_logits_out = torch.where(
+ is_stable[..., None, None].expand_as(singlemask_logits),
+ singlemask_logits,
+ best_multimask_logits,
+ iou_scores_out = torch.where(
+ is_stable.expand_as(singlemask_iou_scores),
+ singlemask_iou_scores,
+ best_multimask_iou_scores,
+ return mask_logits_out, iou_scores_out
@@ -0,0 +1,794 @@
+import torch.nn as nn
+from ultralytics.nn.modules import LayerNorm2d
+from .blocks import (
+ Block,
+ CXBlock,
+ Fuser,
+ MaskDownSampler,
+ MultiScaleBlock,
+ PatchEmbed,
+ PositionEmbeddingRandom,
+ PositionEmbeddingSine,
+)
+class ImageEncoderViT(nn.Module):
+ An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
+ This class processes images by splitting them into patches, applying transformer blocks, and generating a final
+ encoded representation through a neck module.
+ img_size (int): Dimension of input images, assumed to be square.
+ patch_embed (PatchEmbed): Module for patch embedding.
+ pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
+ blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
+ neck (nn.Sequential): Neck module to further process the output.
+ forward: Processes input through patch embedding, positional embedding, blocks, and neck.
+ >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
+ >>> input_image = torch.randn(1, 3, 224, 224)
+ >>> output = encoder(input_image)
+ img_size: int = 1024,
+ patch_size: int = 16,
+ depth: int = 12,
+ num_heads: int = 12,
+ out_chans: int = 256,
+ use_abs_pos: bool = True,
+ global_attn_indexes: Tuple[int, ...] = (),
+ Initializes an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
+ img_size (int): Input image size, assumed to be square.
+ patch_size (int): Size of image patches.
+ embed_dim (int): Dimension of patch embeddings.
+ depth (int): Number of transformer blocks.
+ num_heads (int): Number of attention heads in each block.
+ mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
+ out_chans (int): Number of output channels from the neck module.
+ qkv_bias (bool): If True, adds learnable bias to query, key, value projections.
+ act_layer (Type[nn.Module]): Type of activation layer to use.
+ use_abs_pos (bool): If True, uses absolute positional embeddings.
+ use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.
+ window_size (int): Size of attention window for windowed attention blocks.
+ global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.
+ img_size (int): Dimension of input images.
+ blocks (nn.ModuleList): List of transformer blocks.
+ neck (nn.Sequential): Neck module for final processing.
+ self.img_size = img_size
+ self.patch_embed = PatchEmbed(
+ kernel_size=(patch_size, patch_size),
+ stride=(patch_size, patch_size),
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ self.pos_embed: Optional[nn.Parameter] = None
+ if use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
+ self.blocks = nn.ModuleList()
+ block = Block(
+ dim=embed_dim,
+ mlp_ratio=mlp_ratio,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ window_size=window_size if i not in global_attn_indexes else 0,
+ input_size=(img_size // patch_size, img_size // patch_size),
+ self.blocks.append(block)
+ self.neck = nn.Sequential(
+ embed_dim,
+ out_chans,
+ kernel_size=1,
+ bias=False,
+ LayerNorm2d(out_chans),
+ kernel_size=3,
+ padding=1,
+ """Processes input through patch embedding, positional embedding, transformer blocks, and neck module."""
+ x = self.patch_embed(x)
+ if self.pos_embed is not None:
+ pos_embed = (
+ F.interpolate(self.pos_embed.permute(0, 3, 1, 2), scale_factor=self.img_size / 1024).permute(0, 2, 3, 1)
+ if self.img_size != 1024
+ else self.pos_embed
+ x = x + pos_embed
+ for blk in self.blocks:
+ x = blk(x)
+ return self.neck(x.permute(0, 3, 1, 2))
+class PromptEncoder(nn.Module):
+ Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
+ embed_dim (int): Dimension of the embeddings.
+ input_image_size (Tuple[int, int]): Size of the input image as (H, W).
+ image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
+ pe_layer (PositionEmbeddingRandom): Module for random position embedding.
+ num_point_embeddings (int): Number of point embeddings for different types of points.
+ point_embeddings (nn.ModuleList): List of point embeddings.
+ not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
+ mask_input_size (Tuple[int, int]): Size of the input mask.
+ mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
+ no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
+ get_dense_pe: Returns the positional encoding used to encode point prompts.
+ forward: Embeds different types of prompts, returning both sparse and dense embeddings.
+ >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
+ >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
+ >>> boxes = torch.rand(1, 2, 2)
+ >>> masks = torch.rand(1, 1, 256, 256)
+ >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
+ >>> print(sparse_embeddings.shape, dense_embeddings.shape)
+ torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
+ embed_dim: int,
+ image_embedding_size: Tuple[int, int],
+ input_image_size: Tuple[int, int],
+ mask_in_chans: int,
+ Initializes the PromptEncoder module for encoding various types of prompts.
+ This module encodes different types of prompts (points, boxes, masks) for input to SAM's mask decoder,
+ producing both sparse and dense embeddings.
+ embed_dim (int): The dimension of the embeddings.
+ image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W).
+ input_image_size (Tuple[int, int]): The padded size of the input image as (H, W).
+ mask_in_chans (int): The number of hidden channels used for encoding input masks.
+ activation (Type[nn.Module]): The activation function to use when encoding input masks.
+ self.embed_dim = embed_dim
+ self.input_image_size = input_image_size
+ self.image_embedding_size = image_embedding_size
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
+ point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]
+ self.point_embeddings = nn.ModuleList(point_embeddings)
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
+ self.mask_downscaling = nn.Sequential(
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans // 4),
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans),
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
+ def get_dense_pe(self) -> torch.Tensor:
+ Returns the dense positional encoding used for encoding point prompts.
+ This method generates a positional encoding for a dense set of points matching the shape of the image
+ encoding. The encoding is used to provide spatial information to the model when processing point prompts.
+ (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
+ height and width of the image embedding size, respectively.
+ >>> dense_pe = prompt_encoder.get_dense_pe()
+ >>> print(dense_pe.shape)
+ torch.Size([1, 256, 64, 64])
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
+ """Embeds point prompts by applying positional encoding and label-specific embeddings."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
+ points = torch.cat([points, padding_point], dim=1)
+ labels = torch.cat([labels, padding_label], dim=1)
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
+ point_embedding[labels == -1] = 0.0
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
+ point_embedding[labels == 2] += self.point_embeddings[2].weight
+ point_embedding[labels == 3] += self.point_embeddings[3].weight
+ return point_embedding
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts by applying positional encoding and adding corner embeddings."""
+ boxes = boxes + 0.5 # Shift to center of pixel
+ coords = boxes.reshape(-1, 2, 2)
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+ return corner_embedding
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+ """Embeds mask inputs by downscaling and processing through convolutional layers."""
+ return self.mask_downscaling(masks)
+ @staticmethod
+ def _get_batch_size(
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> int:
+ """Gets the batch size of the output given the batch size of the input prompts."""
+ return points[0].shape[0]
+ elif boxes is not None:
+ return boxes.shape[0]
+ elif masks is not None:
+ return masks.shape[0]
+ return 1
+ def _get_device(self) -> torch.device:
+ """Returns the device of the first point embedding's weight tensor."""
+ return self.point_embeddings[0].weight.device
+ Embeds different types of prompts, returning both sparse and dense embeddings.
+ points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
+ tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with
+ shape (B, N).
+ boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
+ masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
+ - sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
+ - dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
+ >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
+ >>> boxes = torch.rand(1, 2, 2, 2)
+ >>> sparse_emb, dense_emb = encoder(points, boxes, masks)
+ >>> print(sparse_emb.shape, dense_emb.shape)
+ bs = self._get_batch_size(points, boxes, masks)
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
+ coords, labels = points
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+ if boxes is not None:
+ box_embeddings = self._embed_boxes(boxes)
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
+ if masks is not None:
+ dense_embeddings = self._embed_masks(masks)
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ return sparse_embeddings, dense_embeddings
+class MemoryEncoder(nn.Module):
+ Encodes pixel features and masks into a memory representation for efficient image segmentation.
+ This class processes pixel-level features and masks, fusing them to generate encoded memory representations
+ suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
+ mask_downsampler (MaskDownSampler): Module for downsampling input masks.
+ pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.
+ fuser (Fuser): Module for fusing pixel features and masks.
+ position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.
+ out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.
+ forward: Processes input pixel features and masks to generate encoded memory representations.
+ >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
+ >>> pix_feat = torch.randn(1, 256, 64, 64)
+ >>> masks = torch.randn(1, 1, 64, 64)
+ >>> encoded_feat, pos = encoder(pix_feat, masks)
+ >>> print(encoded_feat.shape, pos.shape)
+ torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
+ out_dim,
+ in_dim=256, # in_dim of pix_feats
+ """Initializes the MemoryEncoder for encoding pixel features and masks into memory representations."""
+ self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
+ self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
+ self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
+ self.out_proj = nn.Identity()
+ if out_dim != in_dim:
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
+ pix_feat: torch.Tensor,
+ masks: torch.Tensor,
+ skip_mask_sigmoid: bool = False,
+ """Processes pixel features and masks to generate encoded memory representations for segmentation."""
+ if not skip_mask_sigmoid:
+ masks = F.sigmoid(masks)
+ masks = self.mask_downsampler(masks)
+ # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
+ pix_feat = pix_feat.to(masks.device)
+ x = self.pix_feat_proj(pix_feat)
+ x = x + masks
+ x = self.fuser(x)
+ x = self.out_proj(x)
+ pos = self.position_encoding(x).to(x.dtype)
+ return {"vision_features": x, "vision_pos_enc": [pos]}
+class ImageEncoder(nn.Module):
+ Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings.
+ This class combines a trunk network for feature extraction with a neck network for feature refinement
+ and positional encoding generation. It can optionally discard the lowest resolution features.
+ trunk (nn.Module): The trunk network for initial feature extraction.
+ neck (nn.Module): The neck network for feature refinement and positional encoding generation.
+ scalp (int): Number of lowest resolution feature levels to discard.
+ forward: Processes the input image through the trunk and neck networks.
+ >>> trunk = SomeTrunkNetwork()
+ >>> neck = SomeNeckNetwork()
+ >>> encoder = ImageEncoder(trunk, neck, scalp=1)
+ >>> image = torch.randn(1, 3, 224, 224)
+ >>> output = encoder(image)
+ >>> print(output.keys())
+ dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
+ trunk: nn.Module,
+ neck: nn.Module,
+ scalp: int = 0,
+ """Initializes the ImageEncoder with trunk and neck networks for feature extraction and refinement."""
+ self.trunk = trunk
+ self.neck = neck
+ self.scalp = scalp
+ assert self.trunk.channel_list == self.neck.backbone_channel_list, (
+ f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
+ def forward(self, sample: torch.Tensor):
+ """Encodes input through patch embedding, positional embedding, transformer blocks, and neck module."""
+ features, pos = self.neck(self.trunk(sample))
+ if self.scalp > 0:
+ # Discard the lowest resolution features
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
+ src = features[-1]
+ "vision_features": src,
+ "vision_pos_enc": pos,
+ "backbone_fpn": features,
+class FpnNeck(nn.Module):
+ A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
+ This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
+ similar to ViT positional embedding interpolation.
+ position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.
+ convs (nn.ModuleList): List of convolutional layers for each backbone level.
+ backbone_channel_list (List[int]): List of channel dimensions from the backbone.
+ fpn_interp_model (str): Interpolation mode for FPN feature resizing.
+ fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
+ fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.
+ forward: Performs forward pass through the FPN neck.
+ >>> backbone_channels = [64, 128, 256, 512]
+ >>> fpn_neck = FpnNeck(256, backbone_channels)
+ >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]
+ >>> outputs, positions = fpn_neck(inputs)
+ >>> print(len(outputs), len(positions))
+ 4 4
+ d_model: int,
+ backbone_channel_list: List[int],
+ kernel_size: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ fpn_interp_model: str = "bilinear",
+ fuse_type: str = "sum",
+ fpn_top_down_levels: Optional[List[int]] = None,
+ Initializes a modified Feature Pyramid Network (FPN) neck.
+ d_model (int): Dimension of the model.
+ kernel_size (int): Kernel size for the convolutional layers.
+ stride (int): Stride for the convolutional layers.
+ padding (int): Padding for the convolutional layers.
+ fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
+ >>> print(fpn_neck)
+ self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
+ self.convs = nn.ModuleList()
+ self.backbone_channel_list = backbone_channel_list
+ for dim in backbone_channel_list:
+ current = nn.Sequential()
+ current.add_module(
+ "conv",
+ in_channels=dim,
+ out_channels=d_model,
+ self.convs.append(current)
+ self.fpn_interp_model = fpn_interp_model
+ assert fuse_type in {"sum", "avg"}
+ self.fuse_type = fuse_type
+ # levels to have top-down features in its outputs
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
+ # have top-down propagation, while outputs of level 0 and level 1 have only
+ # lateral features from the same backbone level.
+ if fpn_top_down_levels is None:
+ # default is to have top-down features on all levels
+ fpn_top_down_levels = range(len(self.convs))
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
+ def forward(self, xs: List[torch.Tensor]):
+ Performs forward pass through the Feature Pyramid Network (FPN) neck.
+ This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
+ and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
+ xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
+ (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing:
+ - out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape
+ (B, d_model, H, W).
+ - pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.
+ >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
+ >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
+ out = [None] * len(self.convs)
+ pos = [None] * len(self.convs)
+ assert len(xs) == len(self.convs)
+ # fpn forward pass
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
+ prev_features = None
+ # forward in top-down order (from low to high resolution)
+ n = len(self.convs) - 1
+ for i in range(n, -1, -1):
+ x = xs[i]
+ lateral_features = self.convs[n - i](x)
+ if i in self.fpn_top_down_levels and prev_features is not None:
+ top_down_features = F.interpolate(
+ prev_features.to(dtype=torch.float32),
+ scale_factor=2.0,
+ mode=self.fpn_interp_model,
+ align_corners=(None if self.fpn_interp_model == "nearest" else False),
+ antialias=False,
+ prev_features = lateral_features + top_down_features
+ if self.fuse_type == "avg":
+ prev_features /= 2
+ prev_features = lateral_features
+ x_out = prev_features
+ out[i] = x_out
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
+ return out, pos
+class Hiera(nn.Module):
+ Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
+ This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for
+ efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,
+ with optional pooling and global attention mechanisms.
+ window_spec (Tuple[int, ...]): Window sizes for each stage.
+ q_stride (Tuple[int, int]): Downsampling stride between stages.
+ stage_ends (List[int]): Indices of the last block in each stage.
+ q_pool_blocks (List[int]): Indices of blocks where pooling is applied.
+ return_interm_layers (bool): Whether to return intermediate layer outputs.
+ global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.
+ window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.
+ pos_embed (nn.Parameter): Positional embedding for the background.
+ pos_embed_window (nn.Parameter): Positional embedding for the window.
+ blocks (nn.ModuleList): List of MultiScaleBlock modules.
+ channel_list (List[int]): List of output channel dimensions for each stage.
+ _get_pos_embed: Generates positional embeddings by interpolating and combining window and background embeddings.
+ forward: Performs the forward pass through the Hiera model.
+ >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
+ >>> input_tensor = torch.randn(1, 3, 224, 224)
+ >>> output_features = model(input_tensor)
+ >>> for feat in output_features:
+ ... print(feat.shape)
+ embed_dim: int = 96, # initial embed dim
+ num_heads: int = 1, # initial number of heads
+ drop_path_rate: float = 0.0, # stochastic depth
+ q_pool: int = 3, # number of q_pool stages
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
+ head_mul: float = 2.0, # head_mul factor at stage shift
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
+ # window size per stage, when not using global att.
+ window_spec: Tuple[int, ...] = (
+ 8,
+ 4,
+ 14,
+ 7,
+ # global attn in these blocks
+ global_att_blocks: Tuple[int, ...] = (
+ 12,
+ 16,
+ 20,
+ return_interm_layers=True, # return feats from every stage
+ """Initializes the Hiera model, configuring its hierarchical vision transformer architecture."""
+ assert len(stages) == len(window_spec)
+ self.window_spec = window_spec
+ depth = sum(stages)
+ self.q_stride = q_stride
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
+ self.return_interm_layers = return_interm_layers
+ kernel_size=(7, 7),
+ stride=(4, 4),
+ padding=(3, 3),
+ # Which blocks have global att?
+ self.global_att_blocks = global_att_blocks
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
+ self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ cur_stage = 1
+ dim_out = embed_dim
+ # lags by a block, so first block of
+ # next stage uses an initial window size
+ # of previous stage and final window size of current stage
+ window_size = self.window_spec[cur_stage - 1]
+ if self.global_att_blocks is not None:
+ window_size = 0 if i in self.global_att_blocks else window_size
+ if i - 1 in self.stage_ends:
+ dim_out = int(embed_dim * dim_mul)
+ num_heads = int(num_heads * head_mul)
+ cur_stage += 1
+ block = MultiScaleBlock(
+ dim_out=dim_out,
+ drop_path=dpr[i],
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
+ window_size=window_size,
+ embed_dim = dim_out
+ self.channel_list = (
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
+ if return_interm_layers
+ else [self.blocks[-1].dim_out]
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
+ """Generates positional embeddings by interpolating and combining window and background embeddings."""
+ h, w = hw
+ window_embed = self.pos_embed_window
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
+ pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
+ return pos_embed
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ """Performs forward pass through Hiera model, extracting multiscale features from input images."""
+ # x: (B, H, W, C)
+ # Add pos embed
+ x = x + self._get_pos_embed(x.shape[1:3])
+ outputs = []
+ for i, blk in enumerate(self.blocks):
+ if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
+ feats = x.permute(0, 3, 1, 2)
+ outputs.append(feats)
@@ -0,0 +1,237 @@
+from typing import Optional
+from .blocks import RoPEAttention
+class MemoryAttentionLayer(nn.Module):
+ Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
+ This class combines self-attention, cross-attention, and feedforward components to process input tensors and
+ generate memory-based attention outputs.
+ d_model (int): Dimensionality of the model.
+ dim_feedforward (int): Dimensionality of the feedforward network.
+ dropout_value (float): Dropout rate for regularization.
+ self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).
+ cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.
+ linear1 (nn.Linear): First linear layer of the feedforward network.
+ linear2 (nn.Linear): Second linear layer of the feedforward network.
+ norm1 (nn.LayerNorm): Layer normalization for self-attention output.
+ norm2 (nn.LayerNorm): Layer normalization for cross-attention output.
+ norm3 (nn.LayerNorm): Layer normalization for feedforward network output.
+ dropout1 (nn.Dropout): Dropout layer after self-attention.
+ dropout2 (nn.Dropout): Dropout layer after cross-attention.
+ dropout3 (nn.Dropout): Dropout layer after feedforward network.
+ activation (nn.ReLU): Activation function for the feedforward network.
+ pos_enc_at_attn (bool): Flag to add positional encoding at attention.
+ pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.
+ pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.
+ forward: Performs the full memory attention operation on input tensors.
+ _forward_sa: Performs self-attention on input tensor.
+ _forward_ca: Performs cross-attention between target and memory tensors.
+ >>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)
+ >>> tgt = torch.randn(1, 100, 256)
+ >>> memory = torch.randn(1, 100, 64)
+ >>> pos = torch.randn(1, 100, 256)
+ >>> query_pos = torch.randn(1, 100, 256)
+ >>> output = layer(tgt, memory, pos, query_pos)
+ torch.Size([1, 100, 256])
+ d_model: int = 256,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ pos_enc_at_attn: bool = False,
+ pos_enc_at_cross_attn_keys: bool = True,
+ pos_enc_at_cross_attn_queries: bool = False,
+ """Initializes a memory attention layer with self-attention, cross-attention, and feedforward components."""
+ self.d_model = d_model
+ self.dim_feedforward = dim_feedforward
+ self.dropout_value = dropout
+ self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
+ self.cross_attn_image = RoPEAttention(
+ rope_k_repeat=True,
+ embedding_dim=256,
+ num_heads=1,
+ downsample_rate=1,
+ kv_in_dim=64,
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+ self.activation = nn.ReLU()
+ # Where to add pos enc
+ self.pos_enc_at_attn = pos_enc_at_attn
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
+ def _forward_sa(self, tgt, query_pos):
+ """Performs self-attention on input tensor using positional encoding and RoPE attention mechanism."""
+ tgt2 = self.norm1(tgt)
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
+ tgt2 = self.self_attn(q, k, v=tgt2)
+ tgt = tgt + self.dropout1(tgt2)
+ return tgt
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
+ """Performs cross-attention between target and memory tensors using RoPEAttention mechanism."""
+ kwds = {}
+ if num_k_exclude_rope > 0:
+ assert isinstance(self.cross_attn_image, RoPEAttention)
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
+ # Cross-Attention
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.cross_attn_image(
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
+ v=memory,
+ **kwds,
+ tgt = tgt + self.dropout2(tgt2)
+ tgt,
+ memory,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ num_k_exclude_rope: int = 0,
+ ) -> torch.Tensor:
+ """Processes input tensors using self-attention, cross-attention, and MLP for memory-based attention."""
+ tgt = self._forward_sa(tgt, query_pos)
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+class MemoryAttention(nn.Module):
+ Memory attention module for processing sequential data with self and cross-attention mechanisms.
+ This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
+ for processing sequential data, particularly useful in transformer-like architectures.
+ d_model (int): The dimension of the model's hidden state.
+ layers (nn.ModuleList): A list of MemoryAttentionLayer modules.
+ num_layers (int): The number of attention layers.
+ norm (nn.LayerNorm): Layer normalization applied to the output.
+ pos_enc_at_input (bool): Whether to apply positional encoding at the input.
+ batch_first (bool): Whether the input tensors are in batch-first format.
+ forward: Processes input tensors through the attention layers.
+ >>> d_model = 256
+ >>> layer = MemoryAttentionLayer(d_model)
+ >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
+ >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
+ >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
+ >>> curr_pos = torch.randn(10, 32, d_model)
+ >>> memory_pos = torch.randn(20, 32, d_model)
+ >>> output = attention(curr, memory, curr_pos, memory_pos)
+ torch.Size([10, 32, 256])
+ pos_enc_at_input: bool,
+ layer: nn.Module,
+ num_layers: int,
+ batch_first: bool = True, # Do layers expect batch first input?
+ """Initializes MemoryAttention module with layers and normalization for attention processing."""
+ self.num_layers = num_layers
+ self.norm = nn.LayerNorm(d_model)
+ self.pos_enc_at_input = pos_enc_at_input
+ self.batch_first = batch_first
+ curr: torch.Tensor, # self-attention inputs
+ memory: torch.Tensor, # cross-attention inputs
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
+ """Processes input tensors through multiple attention layers, applying self and cross-attention mechanisms."""
+ if isinstance(curr, list):
+ assert isinstance(curr_pos, list)
+ assert len(curr) == len(curr_pos) == 1
+ curr, curr_pos = (
+ curr[0],
+ curr_pos[0],
+ assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"
+ output = curr
+ if self.pos_enc_at_input and curr_pos is not None:
+ output = output + 0.1 * curr_pos
+ if self.batch_first:
+ # Convert to batch first
+ output = output.transpose(0, 1)
+ curr_pos = curr_pos.transpose(0, 1)
+ memory = memory.transpose(0, 1)
+ memory_pos = memory_pos.transpose(0, 1)
+ if isinstance(layer.cross_attn_image, RoPEAttention):
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
+ output = layer(
+ tgt=output,
+ memory=memory,
+ pos=memory_pos,
+ query_pos=curr_pos,
+ normed_output = self.norm(output)
+ # Convert back to seq first
+ normed_output = normed_output.transpose(0, 1)
+ return normed_output
@@ -0,0 +1,1013 @@
+from typing import List
+from torch.nn.init import trunc_normal_
+from ultralytics.nn.modules import MLP
+from .blocks import SAM2TwoWayTransformer
+from .decoders import MaskDecoder, SAM2MaskDecoder
+from .encoders import ImageEncoderViT, PromptEncoder
+from .utils import get_1d_sine_pe, select_closest_cond_frames
+# a large negative value as a placeholder score for missing objects
+NO_OBJ_SCORE = -1024.0
+class SAMModel(nn.Module):
+ Segment Anything Model (SAM) for object segmentation tasks.
+ This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images
+ and input prompts.
+ mask_threshold (float): Threshold value for mask prediction.
+ image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
+ prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
+ mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
+ __init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
+ >>> image_encoder = ImageEncoderViT(...)
+ >>> prompt_encoder = PromptEncoder(...)
+ >>> mask_decoder = MaskDecoder(...)
+ >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
+ >>> # Further usage depends on SAMPredictor class
+ All forward() operations are implemented in the SAMPredictor class.
+ mask_threshold: float = 0.0
+ image_encoder: ImageEncoderViT,
+ prompt_encoder: PromptEncoder,
+ mask_decoder: MaskDecoder,
+ pixel_mean: List[float] = (123.675, 116.28, 103.53),
+ pixel_std: List[float] = (58.395, 57.12, 57.375),
+ Initialize the SAMModel class to predict object masks from an image and input prompts.
+ image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
+ pixel_mean (List[float]): Mean values for normalizing pixels in the input image.
+ pixel_std (List[float]): Std values for normalizing pixels in the input image.
+ All forward() operations moved to SAMPredictor.
+ self.image_encoder = image_encoder
+ self.prompt_encoder = prompt_encoder
+ self.mask_decoder = mask_decoder
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+ def set_imgsz(self, imgsz):
+ Set image size to make model compatible with different image sizes.
+ imgsz (Tuple[int, int]): The size of the input image.
+ if hasattr(self.image_encoder, "set_imgsz"):
+ self.image_encoder.set_imgsz(imgsz)
+ self.prompt_encoder.input_image_size = imgsz
+ self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # 16 is fixed as patch size of ViT model
+ self.image_encoder.img_size = imgsz[0]
+class SAM2Model(torch.nn.Module):
+ SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
+ This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms
+ for temporal consistency and efficient tracking of objects across frames.
+ image_encoder (ImageEncoderViT): Visual encoder for extracting image features.
+ memory_attention (nn.Module): Module for attending to memory features.
+ memory_encoder (nn.Module): Encoder for generating memory representations.
+ num_maskmem (int): Number of accessible memory frames.
+ image_size (int): Size of input images.
+ backbone_stride (int): Stride of the backbone network output.
+ sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings.
+ sam_image_embedding_size (int): Size of SAM image embeddings.
+ sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts.
+ sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.
+ obj_ptr_proj (nn.Module): Projection layer for object pointers.
+ obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.
+ forward_image: Processes image batch through encoder to extract multi-level features.
+ track_step: Performs a single tracking step, updating object masks and memory features.
+ >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
+ >>> image_batch = torch.rand(1, 3, 512, 512)
+ >>> features = model.forward_image(image_batch)
+ >>> track_results = model.track_step(0, True, features, None, None, None, {})
+ image_encoder,
+ memory_attention,
+ memory_encoder,
+ image_size=512,
+ backbone_stride=16,
+ sigmoid_scale_for_mem_enc=1.0,
+ sigmoid_bias_for_mem_enc=0.0,
+ binarize_mask_from_pts_for_mem_enc=False,
+ use_mask_input_as_output_without_sam=False,
+ max_cond_frames_in_attn=-1,
+ directly_add_no_mem_embed=False,
+ use_high_res_features_in_sam=False,
+ multimask_output_in_sam=False,
+ multimask_min_pt_num=1,
+ multimask_output_for_tracking=False,
+ memory_temporal_stride_for_eval=1,
+ non_overlap_masks_for_mem_enc=False,
+ use_obj_ptrs_in_encoder=False,
+ max_obj_ptrs_in_encoder=16,
+ proj_tpos_enc_in_obj_ptrs=False,
+ use_signed_tpos_enc_to_obj_ptrs=False,
+ only_obj_ptrs_in_the_past_for_eval=False,
+ fixed_no_obj_ptr: bool = False,
+ soft_no_obj_ptr: bool = False,
+ use_mlp_for_obj_ptr_proj: bool = False,
+ no_obj_embed_spatial: bool = False,
+ sam_mask_decoder_extra_args=None,
+ compile_image_encoder: bool = False,
+ Initializes the SAM2Model for video object segmentation with memory-based tracking.
+ image_encoder (nn.Module): Visual encoder for extracting image features.
+ num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames).
+ backbone_stride (int): Stride of the image backbone output.
+ sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
+ sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
+ binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
+ with clicks during evaluation.
+ use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
+ prompt encoder and mask decoder on frames with mask input.
+ max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
+ -1 means no limit.
+ directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
+ first frame.
+ use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
+ multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial
+ conditioning frames.
+ multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
+ multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
+ multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
+ use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
+ iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
+ memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
+ non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
+ memory encoder during evaluation.
+ use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
+ max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
+ cross-attention.
+ add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
+ the encoder.
+ proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
+ encoding in object pointers.
+ use_signed_tpos_enc_to_obj_ptrs (bool): whether to use signed distance (instead of unsigned absolute distance)
+ in the temporal positional encoding in the object pointers, only relevant when both `use_obj_ptrs_in_encoder=True`
+ and `add_tpos_enc_to_obj_ptrs=True`.
+ only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
+ during evaluation.
+ pred_obj_scores (bool): Whether to predict if there is an object in the frame.
+ pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
+ fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
+ soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
+ use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
+ no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.
+ sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder.
+ compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
+ >>> memory_attention = SAM2TwoWayTransformer(...)
+ >>> memory_encoder = nn.Sequential(...)
+ # Part 1: the image backbone
+ # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
+ self.use_high_res_features_in_sam = use_high_res_features_in_sam
+ self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
+ self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
+ self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
+ if use_obj_ptrs_in_encoder:
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
+ self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
+ if proj_tpos_enc_in_obj_ptrs:
+ assert add_tpos_enc_to_obj_ptrs # these options need to be used together
+ self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
+ self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
+ self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
+ # Part 2: memory attention to condition current frame's visual features
+ # with memories (and obj ptrs) from past frames
+ self.memory_attention = memory_attention
+ self.hidden_dim = memory_attention.d_model
+ # Part 3: memory encoder for the previous frame's outputs
+ self.memory_encoder = memory_encoder
+ self.mem_dim = self.hidden_dim
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
+ # if there is compression of memories along channel dim
+ self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
+ self.num_maskmem = num_maskmem # Number of memories accessible
+ # Temporal encoding of the memories
+ self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
+ trunc_normal_(self.maskmem_tpos_enc, std=0.02)
+ # a single token to indicate no memory embedding from previous frames
+ self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ trunc_normal_(self.no_mem_embed, std=0.02)
+ trunc_normal_(self.no_mem_pos_enc, std=0.02)
+ self.directly_add_no_mem_embed = directly_add_no_mem_embed
+ # Apply sigmoid to the output raw mask logits (to turn them from
+ # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
+ self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
+ self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
+ self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
+ # On frames with mask input, whether to directly output the input mask without
+ # using a SAM prompt encoder + mask decoder
+ self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
+ self.multimask_output_in_sam = multimask_output_in_sam
+ self.multimask_min_pt_num = multimask_min_pt_num
+ self.multimask_max_pt_num = multimask_max_pt_num
+ self.multimask_output_for_tracking = multimask_output_for_tracking
+ self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
+ # Part 4: SAM-style prompt encoder (for both mask and point inputs)
+ # and SAM-style mask decoder for the final mask output
+ self.image_size = image_size
+ self.backbone_stride = backbone_stride
+ self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
+ self.pred_obj_scores_mlp = pred_obj_scores_mlp
+ self.fixed_no_obj_ptr = fixed_no_obj_ptr
+ self.soft_no_obj_ptr = soft_no_obj_ptr
+ if self.fixed_no_obj_ptr:
+ assert self.pred_obj_scores
+ assert self.use_obj_ptrs_in_encoder
+ if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
+ self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
+ trunc_normal_(self.no_obj_ptr, std=0.02)
+ self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
+ self.no_obj_embed_spatial = None
+ if no_obj_embed_spatial:
+ self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
+ trunc_normal_(self.no_obj_embed_spatial, std=0.02)
+ self._build_sam_heads()
+ self.max_cond_frames_in_attn = max_cond_frames_in_attn
+ # Model compilation
+ if compile_image_encoder:
+ # Compile the forward function (not the full module) to allow loading checkpoints.
+ print("Image encoder compilation is enabled. First forward pass will be slow.")
+ self.image_encoder.forward = torch.compile(
+ self.image_encoder.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=False,
+ def device(self):
+ """Returns the device on which the model's parameters are stored."""
+ return next(self.parameters()).device
+ def forward(self, *args, **kwargs):
+ """Processes image and prompt inputs to generate object masks and scores in video sequences."""
+ raise NotImplementedError(
+ "Please use the corresponding methods in SAM2VideoPredictor for inference."
+ "See notebooks/video_predictor_example.ipynb for an example."
+ def _build_sam_heads(self):
+ """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
+ self.sam_prompt_embed_dim = self.hidden_dim
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride
+ # Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code)
+ self.sam_prompt_encoder = PromptEncoder(
+ embed_dim=self.sam_prompt_embed_dim,
+ image_embedding_size=(
+ self.sam_image_embedding_size,
+ input_image_size=(self.image_size, self.image_size),
+ self.sam_mask_decoder = SAM2MaskDecoder(
+ transformer=SAM2TwoWayTransformer(
+ embedding_dim=self.sam_prompt_embed_dim,
+ transformer_dim=self.sam_prompt_embed_dim,
+ use_high_res_features=self.use_high_res_features_in_sam,
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
+ pred_obj_scores=self.pred_obj_scores,
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
+ **(self.sam_mask_decoder_extra_args or {}),
+ if self.use_obj_ptrs_in_encoder:
+ # a linear projection on SAM output tokens to turn them into object pointers
+ self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
+ if self.use_mlp_for_obj_ptr_proj:
+ self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
+ self.obj_ptr_proj = torch.nn.Identity()
+ if self.proj_tpos_enc_in_obj_ptrs:
+ # a linear projection on temporal positional encoding in object pointers to
+ # avoid potential interference with spatial positional encoding
+ self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
+ self.obj_ptr_tpos_proj = torch.nn.Identity()
+ def _forward_sam_heads(
+ backbone_features,
+ point_inputs=None,
+ mask_inputs=None,
+ high_res_features=None,
+ multimask_output=False,
+ Forward pass through SAM prompt encoders and mask heads.
+ This method processes image features and optional point/mask inputs to generate object masks and scores.
+ backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
+ point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
+ 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
+ pixel-unit coordinates in (x, y) format for P input points.
+ 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
+ 0 means negative clicks, and -1 means padding.
+ mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
+ same spatial size as the image.
+ high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
+ (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
+ for SAM decoder.
+ multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
+ output only 1 mask and its IoU estimate.
+ (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
+ low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
+ high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
+ ious: Tensor of shape (B, M) with estimated IoU for each output mask.
+ low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
+ high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
+ obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
+ object_score_logits: Tensor of shape (B) with object score logits.
+ Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
+ >>> backbone_features = torch.rand(1, 256, 32, 32)
+ >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
+ >>> mask_inputs = torch.rand(1, 1, 512, 512)
+ >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
+ >>> (
+ ... low_res_multimasks,
+ ... high_res_multimasks,
+ ... ious,
+ ... low_res_masks,
+ ... high_res_masks,
+ ... obj_ptr,
+ ... object_score_logits,
+ ... ) = results
+ B = backbone_features.size(0)
+ device = backbone_features.device
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
+ assert backbone_features.size(2) == self.sam_image_embedding_size
+ assert backbone_features.size(3) == self.sam_image_embedding_size
+ # a) Handle point prompts
+ if point_inputs is not None:
+ sam_point_coords = point_inputs["point_coords"]
+ sam_point_labels = point_inputs["point_labels"]
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
+ # If no points are provide, pad with an empty point (with label -1)
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
+ # b) Handle mask prompts
+ if mask_inputs is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
+ sam_mask_prompt = F.interpolate(
+ mask_inputs.float(),
+ size=self.sam_prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ sam_mask_prompt = mask_inputs
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
+ # a learned `no_mask_embed` to indicate no mask input in this case).
+ sam_mask_prompt = None
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
+ points=(sam_point_coords, sam_point_labels),
+ boxes=None,
+ masks=sam_mask_prompt,
+ low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder(
+ image_embeddings=backbone_features,
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=False, # the image is already batched
+ is_obj_appearing = object_score_logits > 0
+ # Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction
+ low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE)
+ # convert masks from possibly bfloat16 (or float16) to float32
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+ low_res_multimasks = low_res_multimasks.float()
+ high_res_multimasks = F.interpolate(
+ low_res_multimasks,
+ size=(self.image_size, self.image_size),
+ sam_output_token = sam_output_tokens[:, 0]
+ # take the best mask prediction (with the highest IoU estimation)
+ best_iou_inds = torch.argmax(ious, dim=-1)
+ batch_inds = torch.arange(B, device=device)
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ if sam_output_tokens.size(1) > 1:
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
+ # Extract object pointer from the SAM output token (with occlusion handling)
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
+ # Allow *soft* no obj ptr, unlike for masks
+ if self.soft_no_obj_ptr:
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
+ lambda_is_obj_appearing = is_obj_appearing.float()
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+ return (
+ high_res_multimasks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
+ """Processes mask inputs directly as output, bypassing SAM encoder/decoder."""
+ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
+ mask_inputs_float = mask_inputs.float()
+ high_res_masks = mask_inputs_float * out_scale + out_bias
+ low_res_masks = F.interpolate(
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
+ # a dummy IoU prediction of all 1's under mask input
+ ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
+ if not self.use_obj_ptrs_in_encoder:
+ # all zeros as a dummy object pointer (of shape [B, C])
+ obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
+ # produce an object pointer using the SAM decoder from the mask input
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
+ backbone_features=backbone_features,
+ mask_inputs=self.mask_downsample(mask_inputs_float),
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
+ # on the object_scores from the SAM decoder.
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
+ is_obj_appearing = is_obj_appearing[..., None]
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
+ def forward_image(self, img_batch: torch.Tensor):
+ """Processes image batch through encoder to extract multi-level features for SAM model."""
+ backbone_out = self.image_encoder(img_batch)
+ if self.use_high_res_features_in_sam:
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
+ return backbone_out
+ def _prepare_backbone_features(self, backbone_out):
+ """Prepares and flattens visual features from the image backbone output for further processing."""
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
+ # flatten NxCxHxW to HWxNxC
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
+ def _prepare_memory_conditioned_features(
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ output_dict,
+ num_frames,
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
+ """Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ device = current_vision_feats[-1].device
+ # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
+ # In this case, we skip the fusion with any memory.
+ if self.num_maskmem == 0: # Disable memory and skip fusion
+ return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ num_obj_ptr_tokens = 0
+ tpos_sign_mul = -1 if track_in_reverse else 1
+ # Step 1: condition the visual features of the current frame on previous memories
+ if not is_init_cond_frame:
+ # Retrieve the memories encoded with the maskmem backbone
+ to_cat_memory, to_cat_memory_pos_embed = [], []
+ # Add conditioning frame's output first (all cond frames have t_pos=0 for
+ # when getting temporal positional embedding below)
+ assert len(output_dict["cond_frame_outputs"]) > 0
+ # Select a maximum number of temporally closest cond frames for cross attention
+ cond_outputs = output_dict["cond_frame_outputs"]
+ selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
+ frame_idx, cond_outputs, self.max_cond_frames_in_attn
+ t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
+ # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
+ # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
+ # We also allow taking the memory frame non-consecutively (with r>1), in which case
+ # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
+ r = 1 if self.training else self.memory_temporal_stride_for_eval
+ for t_pos in range(1, self.num_maskmem):
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
+ if t_rel == 1:
+ # for t_rel == 1, we take the last frame (regardless of r)
+ prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel
+ elif not track_in_reverse:
+ # first find the nearest frame among every r-th frames before this frame
+ # for r=1, this would be (frame_idx - 2)
+ prev_frame_idx = ((frame_idx - 2) // r) * r
+ # then seek further among every r-th frames
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
+ # first find the nearest frame among every r-th frames after this frame
+ # for r=1, this would be (frame_idx + 2)
+ prev_frame_idx = -(-(frame_idx + 2) // r) * r
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
+ if out is None:
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
+ # frames, we still attend to it as if it's a non-conditioning frame.
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
+ t_pos_and_prevs.append((t_pos, out))
+ for t_pos, prev in t_pos_and_prevs:
+ if prev is None:
+ continue # skip padding frames
+ # "maskmem_features" might have been offloaded to CPU in demo use cases,
+ # so we load it back to inference device (it's a no-op if it's already on device).
+ feats = prev["maskmem_features"].to(device=device, non_blocking=True)
+ to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
+ # Spatial positional encoding (it might have been offloaded to CPU in eval)
+ maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device)
+ maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
+ # Temporal positional encoding
+ maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
+ to_cat_memory_pos_embed.append(maskmem_enc)
+ # Construct the list of past object pointers
+ max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
+ # First add those object pointers from selected conditioning frames
+ # (optionally, only include object pointers in the past during evaluation)
+ if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
+ ptr_cond_outputs = {
+ t: out
+ for t, out in selected_cond_outputs.items()
+ if (t >= frame_idx if track_in_reverse else t <= frame_idx)
+ ptr_cond_outputs = selected_cond_outputs
+ pos_and_ptrs = [
+ # Temporal pos encoding contains how far away each pointer is from current frame
+ (
+ (frame_idx - t) * tpos_sign_mul
+ if self.use_signed_tpos_enc_to_obj_ptrs
+ else abs(frame_idx - t)
+ out["obj_ptr"],
+ for t, out in ptr_cond_outputs.items()
+ # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
+ for t_diff in range(1, max_obj_ptrs_in_encoder):
+ t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
+ if t < 0 or (num_frames is not None and t >= num_frames):
+ break
+ out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
+ if out is not None:
+ pos_and_ptrs.append((t_diff, out["obj_ptr"]))
+ # If we have at least one object pointer, add them to the across attention
+ if pos_and_ptrs:
+ pos_list, ptrs_list = zip(*pos_and_ptrs)
+ # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
+ obj_ptrs = torch.stack(ptrs_list, dim=0)
+ # a temporal positional embedding based on how far each object pointer is from
+ # the current frame (sine embedding normalized by the max pointer num).
+ if self.add_tpos_enc_to_obj_ptrs:
+ t_diff_max = max_obj_ptrs_in_encoder - 1
+ tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
+ obj_pos = torch.tensor(pos_list, device=device)
+ obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
+ obj_pos = self.obj_ptr_tpos_proj(obj_pos)
+ obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
+ obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
+ if self.mem_dim < C:
+ # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
+ obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
+ obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
+ obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
+ to_cat_memory.append(obj_ptrs)
+ to_cat_memory_pos_embed.append(obj_pos)
+ num_obj_ptr_tokens = obj_ptrs.shape[0]
+ # for initial conditioning frames, encode them without using any previous memory
+ if self.directly_add_no_mem_embed:
+ # directly add no-mem embedding (instead of using the transformer encoder)
+ pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat_with_mem
+ # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
+ to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
+ to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
+ # Step 2: Concatenate the memories and forward through the transformer encoder
+ memory = torch.cat(to_cat_memory, dim=0)
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
+ pix_feat_with_mem = self.memory_attention(
+ curr=current_vision_feats,
+ curr_pos=current_vision_pos_embeds,
+ memory_pos=memory_pos_embed,
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
+ # reshape the output (HW)BC => BCHW
+ def _encode_new_memory(
+ pred_masks_high_res,
+ is_mask_from_pts,
+ """Encodes frame features and masks into a new memory representation for video segmentation."""
+ # top-level feature, (HW)BC => BCHW
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ if self.non_overlap_masks_for_mem_enc and not self.training:
+ # optionally, apply non-overlapping constraints to the masks (it's applied
+ # in the batch dimension and should only be used during eval, where all
+ # the objects come from the same video under batch size 1).
+ pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
+ # scale the raw mask logits with a temperature before applying sigmoid
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
+ if binarize and not self.training:
+ mask_for_mem = (pred_masks_high_res > 0).float()
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
+ # apply scale and bias terms to the sigmoid probabilities
+ if self.sigmoid_scale_for_mem_enc != 1.0:
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
+ if self.sigmoid_bias_for_mem_enc != 0.0:
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
+ maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
+ maskmem_features = maskmem_out["vision_features"]
+ maskmem_pos_enc = maskmem_out["vision_pos_enc"]
+ # add a no-object embedding to the spatial memory to indicate that the frame
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
+ if self.no_obj_embed_spatial is not None:
+ is_obj_appearing = (object_score_logits > 0).float()
+ maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[
+ ..., None, None
+ ].expand(*maskmem_features.shape)
+ return maskmem_features, maskmem_pos_enc
+ def _track_step(
+ point_inputs,
+ mask_inputs,
+ track_in_reverse,
+ prev_sam_mask_logits,
+ """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
+ if len(current_vision_feats) > 1:
+ high_res_features = [
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
+ high_res_features = None
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
+ sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
+ # fused the visual feature with previous memory features in the memory bank
+ pix_feat = self._prepare_memory_conditioned_features(
+ frame_idx=frame_idx,
+ is_init_cond_frame=is_init_cond_frame,
+ current_vision_feats=current_vision_feats[-1:],
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
+ feat_sizes=feat_sizes[-1:],
+ output_dict=output_dict,
+ num_frames=num_frames,
+ track_in_reverse=track_in_reverse,
+ # apply SAM-style segmentation head
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
+ if prev_sam_mask_logits is not None:
+ assert point_inputs is not None and mask_inputs is None
+ mask_inputs = prev_sam_mask_logits
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
+ sam_outputs = self._forward_sam_heads(
+ backbone_features=pix_feat,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ return current_out, sam_outputs, high_res_features, pix_feat
+ def _encode_memory_in_output(
+ run_mem_encoder,
+ current_out,
+ """Finally run the memory encoder on the predicted mask to encode, it into a new memory feature (that can be
+ used in future frames).
+ if run_mem_encoder and self.num_maskmem > 0:
+ high_res_masks_for_mem_enc = high_res_masks
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats,
+ feat_sizes=feat_sizes,
+ pred_masks_high_res=high_res_masks_for_mem_enc,
+ object_score_logits=object_score_logits,
+ is_mask_from_pts=(point_inputs is not None),
+ current_out["maskmem_features"] = maskmem_features
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
+ current_out["maskmem_features"] = None
+ current_out["maskmem_pos_enc"] = None
+ def track_step(
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
+ # in demo we might call `track_step` multiple times for each user click,
+ # and only encode the memory when the user finalizes their clicks. And in ablation
+ # settings like SAM training on static images, we don't need the memory encoder.
+ run_mem_encoder=True,
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
+ prev_sam_mask_logits=None,
+ current_out, sam_outputs, _, _ = self._track_step(
+ _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs
+ current_out["pred_masks"] = low_res_masks
+ current_out["pred_masks_high_res"] = high_res_masks
+ current_out["obj_ptr"] = obj_ptr
+ if not self.training:
+ # Only add this in inference (to avoid unused param in activation checkpointing;
+ # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
+ current_out["object_score_logits"] = object_score_logits
+ # Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)
+ self._encode_memory_in_output(
+ return current_out
+ def _use_multimask(self, is_init_cond_frame, point_inputs):
+ """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
+ self.multimask_output_in_sam
+ and (is_init_cond_frame or self.multimask_output_for_tracking)
+ and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
+ def _apply_non_overlapping_constraints(pred_masks):
+ """Applies non-overlapping constraints to masks, keeping the highest scoring object per location."""
+ batch_size = pred_masks.size(0)
+ if batch_size == 1:
+ return pred_masks
+ device = pred_masks.device
+ # "max_obj_inds": object index of the object with the highest score at each location
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
+ keep = max_obj_inds == batch_obj_inds
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
+ def set_binarize(self, binarize=False):
+ """Set binarize for VideoPredictor."""
+ self.binarize_mask_from_pts_for_mem_enc = binarize
+ self.image_size = imgsz[0]
+ self.sam_prompt_encoder.input_image_size = imgsz
+ self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16
+# --------------------------------------------------------
+# TinyViT Model Architecture
+# Copyright (c) 2022 Microsoft
+# Adapted from LeViT and Swin Transformer
+# LeViT: (https://github.com/facebookresearch/levit)
+# Swin: (https://github.com/microsoft/swin-transformer)
+# Build the TinyViT Model
+import itertools
+from typing import Tuple
+import torch.utils.checkpoint as checkpoint
+from ultralytics.utils.instance import to_2tuple
+class Conv2d_BN(torch.nn.Sequential):
+ A sequential container that performs 2D convolution followed by batch normalization.
+ c (torch.nn.Conv2d): 2D convolution layer.
+ 1 (torch.nn.BatchNorm2d): Batch normalization layer.
+ __init__: Initializes the Conv2d_BN with specified parameters.
+ a (int): Number of input channels.
+ b (int): Number of output channels.
+ ks (int): Kernel size for the convolution. Defaults to 1.
+ stride (int): Stride for the convolution. Defaults to 1.
+ pad (int): Padding for the convolution. Defaults to 0.
+ dilation (int): Dilation factor for the convolution. Defaults to 1.
+ groups (int): Number of groups for the convolution. Defaults to 1.
+ bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1.
+ >>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)
+ >>> output = conv_bn(input_tensor)
+ def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
+ """Initializes a sequential container with 2D convolution followed by batch normalization."""
+ self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
+ bn = torch.nn.BatchNorm2d(b)
+ torch.nn.init.constant_(bn.weight, bn_weight_init)
+ torch.nn.init.constant_(bn.bias, 0)
+ self.add_module("bn", bn)
+ Embeds images into patches and projects them into a specified embedding dimension.
+ patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.
+ num_patches (int): Total number of patches.
+ in_chans (int): Number of input channels.
+ embed_dim (int): Dimension of the embedding.
+ seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.
+ forward: Processes the input tensor through the patch embedding sequence.
+ >>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)
+ def __init__(self, in_chans, embed_dim, resolution, activation):
+ """Initializes patch embedding with convolutional layers for image-to-patch conversion and projection."""
+ img_size: Tuple[int, int] = to_2tuple(resolution)
+ self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
+ self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
+ self.in_chans = in_chans
+ n = embed_dim
+ self.seq = nn.Sequential(
+ Conv2d_BN(in_chans, n // 2, 3, 2, 1),
+ Conv2d_BN(n // 2, n, 3, 2, 1),
+ """Processes input tensor through patch embedding sequence, converting images to patch embeddings."""
+ return self.seq(x)
+class MBConv(nn.Module):
+ Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
+ hidden_chans (int): Number of hidden channels.
+ out_chans (int): Number of output channels.
+ conv1 (Conv2d_BN): First convolutional layer.
+ act1 (nn.Module): First activation function.
+ conv2 (Conv2d_BN): Depthwise convolutional layer.
+ act2 (nn.Module): Second activation function.
+ conv3 (Conv2d_BN): Final convolutional layer.
+ act3 (nn.Module): Third activation function.
+ drop_path (nn.Module): Drop path layer (Identity for inference).
+ forward: Performs the forward pass through the MBConv layer.
+ >>> in_chans, out_chans = 32, 64
+ >>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)
+ >>> x = torch.randn(1, in_chans, 56, 56)
+ >>> output = mbconv(x)
+ def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
+ """Initializes the MBConv layer with specified input/output channels, expansion ratio, and activation."""
+ self.hidden_chans = int(in_chans * expand_ratio)
+ self.out_chans = out_chans
+ self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
+ self.act1 = activation()
+ self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)
+ self.act2 = activation()
+ self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
+ self.act3 = activation()
+ # NOTE: `DropPath` is needed only for training.
+ # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.drop_path = nn.Identity()
+ """Implements the forward pass of MBConv, applying convolutions and skip connection."""
+ x = self.conv1(x)
+ x = self.act1(x)
+ x = self.conv2(x)
+ x = self.act2(x)
+ x = self.conv3(x)
+ x = self.drop_path(x)
+ x += shortcut
+ return self.act3(x)
+class PatchMerging(nn.Module):
+ Merges neighboring patches in the feature map and projects to a new dimension.
+ This class implements a patch merging operation that combines spatial information and adjusts the feature
+ dimension. It uses a series of convolutional layers with batch normalization to achieve this.
+ input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.
+ dim (int): The input dimension of the feature map.
+ out_dim (int): The output dimension after merging and projection.
+ act (nn.Module): The activation function used between convolutions.
+ conv1 (Conv2d_BN): The first convolutional layer for dimension projection.
+ conv2 (Conv2d_BN): The second convolutional layer for spatial merging.
+ conv3 (Conv2d_BN): The third convolutional layer for final projection.
+ forward: Applies the patch merging operation to the input tensor.
+ >>> input_resolution = (56, 56)
+ >>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)
+ >>> x = torch.randn(4, 64, 56, 56)
+ >>> output = patch_merging(x)
+ def __init__(self, input_resolution, dim, out_dim, activation):
+ """Initializes the PatchMerging module for merging and projecting neighboring patches in feature maps."""
+ self.input_resolution = input_resolution
+ self.out_dim = out_dim
+ self.act = activation()
+ self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
+ stride_c = 1 if out_dim in {320, 448, 576} else 2
+ self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
+ self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
+ """Applies patch merging and dimension projection to the input feature map."""
+ if x.ndim == 3:
+ H, W = self.input_resolution
+ B = len(x)
+ # (B, C, H, W)
+ x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
+ return x.flatten(2).transpose(1, 2)
+class ConvLayer(nn.Module):
+ Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
+ This layer optionally applies downsample operations to the output and supports gradient checkpointing.
+ dim (int): Dimensionality of the input and output.
+ input_resolution (Tuple[int, int]): Resolution of the input image.
+ depth (int): Number of MBConv layers in the block.
+ use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
+ blocks (nn.ModuleList): List of MBConv layers.
+ downsample (Optional[Callable]): Function for downsampling the output.
+ forward: Processes the input through the convolutional layers.
+ >>> input_tensor = torch.randn(1, 64, 56, 56)
+ >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
+ >>> output = conv_layer(input_tensor)
+ input_resolution,
+ depth,
+ activation,
+ downsample=None,
+ out_dim=None,
+ conv_expand_ratio=4.0,
+ Initializes the ConvLayer with the given dimensions and settings.
+ This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
+ optionally applies downsampling to the output.
+ dim (int): The dimensionality of the input and output.
+ input_resolution (Tuple[int, int]): The resolution of the input image.
+ depth (int): The number of MBConv layers in the block.
+ activation (Callable): Activation function applied after each convolution.
+ drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv.
+ downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
+ out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
+ conv_expand_ratio (float): Expansion ratio for the MBConv layers.
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+ # Build blocks
+ self.blocks = nn.ModuleList(
+ MBConv(
+ conv_expand_ratio,
+ drop_path[i] if isinstance(drop_path, list) else drop_path,
+ for i in range(depth)
+ # Patch merging layer
+ self.downsample = (
+ None
+ if downsample is None
+ else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
+ """Processes input through convolutional layers, applying MBConv blocks and optional downsampling."""
+ x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
+ return x if self.downsample is None else self.downsample(x)
+class Mlp(nn.Module):
+ Multi-layer Perceptron (MLP) module for transformer architectures.
+ This module applies layer normalization, two fully-connected layers with an activation function in between,
+ and dropout. It is commonly used in transformer-based architectures.
+ norm (nn.LayerNorm): Layer normalization applied to the input.
+ fc1 (nn.Linear): First fully-connected layer.
+ fc2 (nn.Linear): Second fully-connected layer.
+ act (nn.Module): Activation function applied after the first fully-connected layer.
+ drop (nn.Dropout): Dropout layer applied after the activation function.
+ forward: Applies the MLP operations on the input tensor.
+ >>> mlp = Mlp(in_features=256, hidden_features=512, out_features=256, act_layer=nn.GELU, drop=0.1)
+ >>> x = torch.randn(32, 100, 256)
+ >>> output = mlp(x)
+ torch.Size([32, 100, 256])
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
+ """Initializes a multi-layer perceptron with configurable input, hidden, and output dimensions."""
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.norm = nn.LayerNorm(in_features)
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.act = act_layer()
+ self.drop = nn.Dropout(drop)
+ """Applies MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
+ x = self.fc1(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ return self.drop(x)
+class Attention(torch.nn.Module):
+ Multi-head attention module with spatial awareness and trainable attention biases.
+ This module implements a multi-head attention mechanism with support for spatial awareness, applying
+ attention biases based on spatial resolution. It includes trainable attention biases for each unique
+ offset between spatial positions in the resolution grid.
+ scale (float): Scaling factor for attention scores.
+ key_dim (int): Dimensionality of the keys and queries.
+ nh_kd (int): Product of num_heads and key_dim.
+ d (int): Dimensionality of the value vectors.
+ dh (int): Product of d and num_heads.
+ attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.
+ norm (nn.LayerNorm): Layer normalization applied to input.
+ qkv (nn.Linear): Linear layer for computing query, key, and value projections.
+ proj (nn.Linear): Linear layer for final projection.
+ attention_biases (nn.Parameter): Learnable attention biases.
+ attention_bias_idxs (Tensor): Indices for attention biases.
+ ab (Tensor): Cached attention biases for inference, deleted during training.
+ train: Sets the module in training mode and handles the 'ab' attribute.
+ forward: Performs the forward pass of the attention mechanism.
+ >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
+ >>> x = torch.randn(1, 196, 256)
+ >>> output = attn(x)
+ torch.Size([1, 196, 256])
+ key_dim,
+ attn_ratio=4,
+ resolution=(14, 14),
+ Initializes the Attention module for multi-head attention with spatial awareness.
+ key_dim (int): The dimensionality of the keys and queries.
+ attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
+ resolution (Tuple[int, int]): Spatial resolution of the input feature map. Default is (14, 14).
+ AssertionError: If 'resolution' is not a tuple of length 2.
+ assert isinstance(resolution, tuple) and len(resolution) == 2, "'resolution' argument not tuple of length 2"
+ self.scale = key_dim**-0.5
+ self.key_dim = key_dim
+ self.nh_kd = nh_kd = key_dim * num_heads
+ self.d = int(attn_ratio * key_dim)
+ self.dh = int(attn_ratio * key_dim) * num_heads
+ self.attn_ratio = attn_ratio
+ h = self.dh + nh_kd * 2
+ self.norm = nn.LayerNorm(dim)
+ self.qkv = nn.Linear(dim, h)
+ self.proj = nn.Linear(self.dh, dim)
+ points = list(itertools.product(range(resolution[0]), range(resolution[1])))
+ N = len(points)
+ attention_offsets = {}
+ idxs = []
+ for p1 in points:
+ for p2 in points:
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+ self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
+ self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
+ def train(self, mode=True):
+ """Performs multi-head attention with spatial awareness and trainable attention biases."""
+ super().train(mode)
+ if mode and hasattr(self, "ab"):
+ del self.ab
+ self.ab = self.attention_biases[:, self.attention_bias_idxs]
+ def forward(self, x): # x
+ """Applies multi-head attention with spatial awareness and trainable attention biases."""
+ B, N, _ = x.shape # B, N, C
+ # Normalization
+ qkv = self.qkv(x)
+ # (B, N, num_heads, d)
+ q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
+ # (B, num_heads, N, d)
+ q = q.permute(0, 2, 1, 3)
+ k = k.permute(0, 2, 1, 3)
+ v = v.permute(0, 2, 1, 3)
+ self.ab = self.ab.to(self.attention_biases.device)
+ attn = (q @ k.transpose(-2, -1)) * self.scale + (
+ self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
+class TinyViTBlock(nn.Module):
+ TinyViT Block that applies self-attention and a local convolution to the input.
+ This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
+ local convolutions to process input features efficiently.
+ input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
+ window_size (int): Size of the attention window.
+ drop_path (nn.Module): Stochastic depth layer, identity function during inference.
+ attn (Attention): Self-attention module.
+ mlp (Mlp): Multi-layer perceptron module.
+ local_conv (Conv2d_BN): Depth-wise local convolution layer.
+ forward: Processes the input through the TinyViT block.
+ extra_repr: Returns a string with extra information about the block's parameters.
+ >>> input_tensor = torch.randn(1, 196, 192)
+ >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
+ >>> output = block(input_tensor)
+ torch.Size([1, 196, 192])
+ num_heads,
+ window_size=7,
+ drop=0.0,
+ Initializes a TinyViT block with self-attention and local convolution.
+ dim (int): Dimensionality of the input and output features.
+ input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
+ window_size (int): Size of the attention window. Must be greater than 0.
+ drop (float): Dropout rate.
+ local_conv_size (int): Kernel size of the local convolution.
+ activation (torch.nn.Module): Activation function for MLP.
+ AssertionError: If window_size is not greater than 0.
+ AssertionError: If dim is not divisible by num_heads.
+ assert window_size > 0, "window_size must be greater than 0"
+ self.mlp_ratio = mlp_ratio
+ assert dim % num_heads == 0, "dim must be divisible by num_heads"
+ window_resolution = (window_size, window_size)
+ self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ mlp_activation = activation
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop)
+ pad = local_conv_size // 2
+ self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
+ """Applies self-attention, local convolution, and MLP operations to the input tensor."""
+ h, w = self.input_resolution
+ b, hw, c = x.shape # batch, height*width, channels
+ assert hw == h * w, "input feature has wrong size"
+ res_x = x
+ if h == self.window_size and w == self.window_size:
+ x = x.view(b, h, w, c)
+ pad_b = (self.window_size - h % self.window_size) % self.window_size
+ pad_r = (self.window_size - w % self.window_size) % self.window_size
+ padding = pad_b > 0 or pad_r > 0
+ if padding:
+ x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
+ pH, pW = h + pad_b, w + pad_r
+ nH = pH // self.window_size
+ nW = pW // self.window_size
+ x = (
+ x.view(b, nH, self.window_size, nW, self.window_size, c)
+ .transpose(2, 3)
+ .reshape(b * nH * nW, self.window_size * self.window_size, c)
+ # Window reverse
+ x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c)
+ x = x[:, :h, :w].contiguous()
+ x = x.view(b, hw, c)
+ x = res_x + self.drop_path(x)
+ x = x.transpose(1, 2).reshape(b, c, h, w)
+ x = self.local_conv(x)
+ x = x.view(b, c, hw).transpose(1, 2)
+ return x + self.drop_path(self.mlp(x))
+ def extra_repr(self) -> str:
+ Returns a string representation of the TinyViTBlock's parameters.
+ This method provides a formatted string containing key information about the TinyViTBlock, including its
+ dimension, input resolution, number of attention heads, window size, and MLP ratio.
+ (str): A formatted string containing the block's parameters.
+ >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0)
+ >>> print(block.extra_repr())
+ dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
+ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
+class BasicLayer(nn.Module):
+ A basic TinyViT layer for one stage in a TinyViT architecture.
+ This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
+ and an optional downsampling operation.
+ dim (int): The dimensionality of the input and output features.
+ depth (int): Number of TinyViT blocks in this layer.
+ blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.
+ downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.
+ forward: Processes the input through the layer's blocks and optional downsampling.
+ extra_repr: Returns a string with the layer's parameters for printing.
+ >>> input_tensor = torch.randn(1, 3136, 192)
+ >>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
+ >>> output = layer(input_tensor)
+ torch.Size([1, 784, 384])
+ window_size,
+ Initializes a BasicLayer in the TinyViT architecture.
+ This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
+ process feature maps at a specific resolution and dimensionality within the TinyViT model.
+ num_heads (int): Number of attention heads in each TinyViT block.
+ window_size (int): Size of the local window for attention computation.
+ drop_path (float | List[float]): Stochastic depth rate. Can be a float or a list of floats for each block.
+ downsample (nn.Module | None): Downsampling layer at the end of the layer. None to skip downsampling.
+ local_conv_size (int): Kernel size for the local convolution in each TinyViT block.
+ activation (nn.Module): Activation function used in the MLP.
+ out_dim (int | None): Output dimension after downsampling. None means it will be the same as `dim`.
+ ValueError: If `drop_path` is a list and its length doesn't match `depth`.
+ >>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
+ >>> x = torch.randn(1, 56 * 56, 96)
+ >>> output = layer(x)
+ TinyViTBlock(
+ dim=dim,
+ input_resolution=input_resolution,
+ drop=drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ local_conv_size=local_conv_size,
+ """Processes input through TinyViT blocks and optional downsampling."""
+ """Returns a string with the layer's parameters for printing."""
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+class TinyViT(nn.Module):
+ TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
+ This class implements the TinyViT model, which combines elements of vision transformers and convolutional
+ neural networks for improved efficiency and performance on vision tasks.
+ img_size (int): Input image size.
+ num_classes (int): Number of classification classes.
+ depths (List[int]): Number of blocks in each stage.
+ num_layers (int): Total number of layers in the network.
+ patches_resolution (Tuple[int, int]): Resolution of embedded patches.
+ layers (nn.ModuleList): List of network layers.
+ norm_head (nn.LayerNorm): Layer normalization for the classifier head.
+ head (nn.Linear): Linear layer for final classification.
+ neck (nn.Sequential): Neck module for feature refinement.
+ set_layer_lr_decay: Sets layer-wise learning rate decay.
+ _init_weights: Initializes weights for linear and normalization layers.
+ no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay.
+ forward_features: Processes input through the feature extraction layers.
+ forward: Performs a forward pass through the entire network.
+ >>> model = TinyViT(img_size=224, num_classes=1000)
+ >>> features = model.forward_features(x)
+ >>> print(features.shape)
+ img_size=224,
+ embed_dims=(96, 192, 384, 768),
+ depths=(2, 2, 6, 2),
+ num_heads=(3, 6, 12, 24),
+ window_sizes=(7, 7, 14, 7),
+ drop_path_rate=0.1,
+ layer_lr_decay=1.0,
+ Initializes the TinyViT model.
+ This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
+ attention and convolution blocks, and a classification head.
+ img_size (int): Size of the input image. Default is 224.
+ in_chans (int): Number of input channels. Default is 3.
+ num_classes (int): Number of classes for classification. Default is 1000.
+ embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage.
+ Default is (96, 192, 384, 768).
+ depths (Tuple[int, int, int, int]): Number of blocks in each stage. Default is (2, 2, 6, 2).
+ num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage.
+ Default is (3, 6, 12, 24).
+ window_sizes (Tuple[int, int, int, int]): Window sizes for each stage. Default is (7, 7, 14, 7).
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. Default is 4.0.
+ drop_rate (float): Dropout rate. Default is 0.0.
+ drop_path_rate (float): Stochastic depth rate. Default is 0.1.
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default is False.
+ mbconv_expand_ratio (float): Expansion ratio for MBConv layer. Default is 4.0.
+ local_conv_size (int): Kernel size for local convolutions. Default is 3.
+ layer_lr_decay (float): Layer-wise learning rate decay factor. Default is 1.0.
+ >>> output = model(x)
+ torch.Size([1, 1000])
+ self.num_classes = num_classes
+ self.depths = depths
+ self.num_layers = len(depths)
+ activation = nn.GELU
+ in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+ # Stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+ # Build layers
+ for i_layer in range(self.num_layers):
+ kwargs = dict(
+ dim=embed_dims[i_layer],
+ input_resolution=(
+ patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
+ patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
+ # input_resolution=(patches_resolution[0] // (2 ** i_layer),
+ # patches_resolution[1] // (2 ** i_layer)),
+ depth=depths[i_layer],
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint,
+ out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
+ if i_layer == 0:
+ layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
+ layer = BasicLayer(
+ num_heads=num_heads[i_layer],
+ window_size=window_sizes[i_layer],
+ mlp_ratio=self.mlp_ratio,
+ drop=drop_rate,
+ self.layers.append(layer)
+ # Classifier head
+ self.norm_head = nn.LayerNorm(embed_dims[-1])
+ self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
+ # Init weights
+ self.apply(self._init_weights)
+ self.set_layer_lr_decay(layer_lr_decay)
+ embed_dims[-1],
+ 256,
+ LayerNorm2d(256),
+ def set_layer_lr_decay(self, layer_lr_decay):
+ """Sets layer-wise learning rate decay for the TinyViT model based on depth."""
+ decay_rate = layer_lr_decay
+ # Layers -> blocks (depth)
+ depth = sum(self.depths)
+ lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
+ def _set_lr_scale(m, scale):
+ """Sets the learning rate scale for each layer in the model based on the layer's depth."""
+ for p in m.parameters():
+ p.lr_scale = scale
+ self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
+ i = 0
+ for block in layer.blocks:
+ block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
+ i += 1
+ if layer.downsample is not None:
+ layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
+ assert i == depth
+ for m in [self.norm_head, self.head]:
+ m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
+ for k, p in self.named_parameters():
+ p.param_name = k
+ def _check_lr_scale(m):
+ """Checks if the learning rate scale attribute is present in module's parameters."""
+ assert hasattr(p, "lr_scale"), p.param_name
+ self.apply(_check_lr_scale)
+ def _init_weights(m):
+ """Initializes weights for linear and normalization layers in the TinyViT model."""
+ if isinstance(m, nn.Linear):
+ # NOTE: This initialization is needed only for training.
+ # trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.weight, 1.0)
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ """Returns a set of keywords for parameters that should not use weight decay."""
+ return {"attention_biases"}
+ def forward_features(self, x):
+ """Processes input through feature extraction layers, returning spatial features."""
+ x = self.patch_embed(x) # x input is (N, C, H, W)
+ x = self.layers[0](x)
+ start_i = 1
+ for i in range(start_i, len(self.layers)):
+ layer = self.layers[i]
+ batch, _, channel = x.shape
+ x = x.view(batch, self.patches_resolution[0] // 4, self.patches_resolution[1] // 4, channel)
+ return self.neck(x)
+ """Performs the forward pass through the TinyViT model, extracting features from the input image."""
+ return self.forward_features(x)
+ def set_imgsz(self, imgsz=[1024, 1024]):
+ imgsz = [s // 4 for s in imgsz]
+ self.patches_resolution = imgsz
+ for i, layer in enumerate(self.layers):
+ input_resolution = (
+ imgsz[0] // (2 ** (i - 1 if i == 3 else i)),
+ imgsz[1] // (2 ** (i - 1 if i == 3 else i)),
+ layer.input_resolution = input_resolution
+ layer.downsample.input_resolution = input_resolution
+ if isinstance(layer, BasicLayer):
+ for b in layer.blocks:
+ b.input_resolution = input_resolution
@@ -0,0 +1,373 @@
+from typing import Tuple, Type
+from ultralytics.nn.modules import MLPBlock
+class TwoWayTransformer(nn.Module):
+ This class implements a specialized transformer decoder that attends to an input image using queries with
+ supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
+ cloud processing.
+ layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
+ forward: Processes image and point embeddings through the transformer.
+ >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
+ >>> image_embedding = torch.randn(1, 256, 32, 32)
+ >>> image_pe = torch.randn(1, 256, 32, 32)
+ >>> point_embedding = torch.randn(1, 100, 256)
+ >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
+ >>> print(output_queries.shape, output_image.shape)
+ Initialize a Two-Way Transformer for simultaneous attention to image and query points.
+ attention_downsample_rate (int): Downsampling rate for attention mechanism.
+ layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
+ self.embedding_dim = embedding_dim
+ self.mlp_dim = mlp_dim
+ TwoWayAttentionBlock(
+ self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
+ image_embedding: Tensor,
+ image_pe: Tensor,
+ point_embedding: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ Processes image and point embeddings through the Two-Way Transformer.
+ image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
+ image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
+ point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
+ (Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
+ # Prepare queries
+ queries = point_embedding
+ keys = image_embedding
+ # Apply transformer blocks and final layernorm
+ queries, keys = layer(
+ queries=queries,
+ keys=keys,
+ query_pe=point_embedding,
+ key_pe=image_pe,
+ # Apply the final attention layer from the points to the image
+ q = queries + point_embedding
+ k = keys + image_pe
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm_final_attn(queries)
+ return queries, keys
+class TwoWayAttentionBlock(nn.Module):
+ A two-way attention block for simultaneous attention to image and query points.
+ This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
+ cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
+ inputs to sparse inputs.
+ norm1 (nn.LayerNorm): Layer normalization after self-attention.
+ norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
+ mlp (MLPBlock): MLP block for transforming query embeddings.
+ norm3 (nn.LayerNorm): Layer normalization after MLP block.
+ norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
+ skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
+ forward: Applies self-attention and cross-attention to queries and keys.
+ >>> embedding_dim, num_heads = 256, 8
+ >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
+ >>> queries = torch.randn(1, 100, embedding_dim)
+ >>> keys = torch.randn(1, 1000, embedding_dim)
+ >>> query_pe = torch.randn(1, 100, embedding_dim)
+ >>> key_pe = torch.randn(1, 1000, embedding_dim)
+ >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
+ Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
+ This block implements a specialized transformer layer with four main components: self-attention on sparse
+ inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
+ of dense inputs to sparse inputs.
+ embedding_dim (int): Channel dimension of the embeddings.
+ num_heads (int): Number of attention heads in the attention layers.
+ mlp_dim (int): Hidden dimension of the MLP block.
+ activation (Type[nn.Module]): Activation function for the MLP block.
+ attention_downsample_rate (int): Downsampling rate for the attention mechanism.
+ self.self_attn = Attention(embedding_dim, num_heads)
+ self.norm1 = nn.LayerNorm(embedding_dim)
+ self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
+ self.norm2 = nn.LayerNorm(embedding_dim)
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
+ self.norm3 = nn.LayerNorm(embedding_dim)
+ self.norm4 = nn.LayerNorm(embedding_dim)
+ self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
+ self.skip_first_layer_pe = skip_first_layer_pe
+ def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
+ """Applies two-way attention to process query and key embeddings in a transformer block."""
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries = self.self_attn(q=queries, k=queries, v=queries)
+ q = queries + query_pe
+ attn_out = self.self_attn(q=q, k=q, v=queries)
+ queries = self.norm1(queries)
+ # Cross attention block, tokens attending to image embedding
+ k = keys + key_pe
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+ queries = self.norm2(queries)
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.norm3(queries)
+ # Cross attention block, image embedding attending to tokens
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+ keys = keys + attn_out
+ keys = self.norm4(keys)
+class Attention(nn.Module):
+ An attention layer with downscaling capability for embedding size after projection.
+ This class implements a multi-head attention mechanism with the option to downsample the internal
+ dimension of queries, keys, and values.
+ embedding_dim (int): Dimensionality of input embeddings.
+ kv_in_dim (int): Dimensionality of key and value inputs.
+ internal_dim (int): Internal dimension after downsampling.
+ q_proj (nn.Linear): Linear projection for queries.
+ k_proj (nn.Linear): Linear projection for keys.
+ v_proj (nn.Linear): Linear projection for values.
+ out_proj (nn.Linear): Linear projection for output.
+ _separate_heads: Separates input tensor into attention heads.
+ _recombine_heads: Recombines separated attention heads.
+ forward: Computes attention output for given query, key, and value tensors.
+ >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
+ >>> q = torch.randn(1, 100, 256)
+ >>> k = v = torch.randn(1, 50, 256)
+ >>> output = attn(q, k, v)
+ downsample_rate: int = 1,
+ kv_in_dim: int = None,
+ Initializes the Attention module with specified dimensions and settings.
+ This class implements a multi-head attention mechanism with optional downsampling of the internal
+ dimension for queries, keys, and values.
+ downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
+ kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
+ AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
+ self.internal_dim = embedding_dim // downsample_rate
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+ def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
+ """Separates the input tensor into the specified number of attention heads."""
+ b, n, c = x.shape
+ x = x.reshape(b, n, num_heads, c // num_heads)
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
+ def _recombine_heads(x: Tensor) -> Tensor:
+ """Recombines separated attention heads into a single tensor."""
+ b, n_heads, n_tokens, c_per_head = x.shape
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ """Applies multi-head attention to query, key, and value tensors with optional downsampling."""
+ # Input projections
+ return self.out_proj(out)
@@ -0,0 +1,293 @@
+def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
+ Selects the closest conditioning frames to a given frame index.
+ frame_idx (int): Current frame index.
+ cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.
+ max_cond_frame_num (int): Maximum number of conditioning frames to select.
+ (Tuple[Dict[int, Any], Dict[int, Any]]): A tuple containing two dictionaries:
+ - selected_outputs: Selected items from cond_frame_outputs.
+ - unselected_outputs: Items not selected from cond_frame_outputs.
+ >>> frame_idx = 5
+ >>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
+ >>> max_cond_frame_num = 2
+ >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
+ >>> print(selected)
+ {3: 'b', 7: 'c'}
+ >>> print(unselected)
+ {1: 'a', 9: 'd'}
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
+ selected_outputs = cond_frame_outputs
+ unselected_outputs = {}
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
+ selected_outputs = {}
+ # the closest conditioning frame before `frame_idx` (if any)
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
+ if idx_before is not None:
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
+ # the closest conditioning frame after `frame_idx` (if any)
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
+ if idx_after is not None:
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
+ # add other temporally closest conditioning frames until reaching a total
+ # of `max_cond_frame_num` conditioning frames.
+ num_remain = max_cond_frame_num - len(selected_outputs)
+ inds_remain = sorted(
+ (t for t in cond_frame_outputs if t not in selected_outputs),
+ key=lambda x: abs(x - frame_idx),
+ )[:num_remain]
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
+ unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}
+ return selected_outputs, unselected_outputs
+def get_1d_sine_pe(pos_inds, dim, temperature=10000):
+ """Generates 1D sinusoidal positional embeddings for given positions and dimensions."""
+ pe_dim = dim // 2
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
+def init_t_xy(end_x: int, end_y: int):
+ """Initializes 1D and 2D coordinate tensors for a grid of specified dimensions."""
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
+ t_x = (t % end_x).float()
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
+ return t_x, t_y
+def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
+ """Computes axial complex exponential positional encodings for 2D spatial positions in a grid."""
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+ t_x, t_y = init_t_xy(end_x, end_y)
+ freqs_x = torch.outer(t_x, freqs_x)
+ freqs_y = torch.outer(t_y, freqs_y)
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+ """Reshapes frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+def apply_rotary_enc(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ repeat_freqs_k: bool = False,
+ """Applies rotary positional encoding to query and key tensors using complex-valued frequency components."""
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+ if xk_ is None:
+ # no keys to rotate, due to dropout
+ return xq_out.type_as(xq).to(xq.device), xk
+ # repeat freqs along seq_len dim to match k seq_len
+ if repeat_freqs_k:
+ r = xk_.shape[-2] // xq_.shape[-2]
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
+def window_partition(x, window_size):
+ Partitions input tensor into non-overlapping windows with padding if needed.
+ x (torch.Tensor): Input tensor with shape (B, H, W, C).
+ window_size (int): Size of each window.
+ (Tuple[torch.Tensor, Tuple[int, int]]): A tuple containing:
+ - windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).
+ - (Hp, Wp) (Tuple[int, int]): Padded height and width before partition.
+ >>> x = torch.randn(1, 16, 16, 3)
+ >>> windows, (Hp, Wp) = window_partition(x, window_size=4)
+ >>> print(windows.shape, Hp, Wp)
+ torch.Size([16, 4, 4, 3]) 16 16
+ B, H, W, C = x.shape
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+ Hp, Wp = H + pad_h, W + pad_w
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows, (Hp, Wp)
+def window_unpartition(windows, window_size, pad_hw, hw):
+ Unpartitions windowed sequences into original sequences and removes padding.
+ This function reverses the windowing process, reconstructing the original input from windowed segments
+ and removing any padding that was added during the windowing process.
+ windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
+ window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
+ the size of each window, and C is the number of channels.
+ pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
+ hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
+ (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
+ are the original height and width, and C is the number of channels.
+ >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
+ >>> pad_hw = (16, 16) # Padded height and width
+ >>> hw = (15, 14) # Original height and width
+ >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)
+ >>> print(x.shape)
+ torch.Size([1, 15, 14, 64])
+ Hp, Wp = pad_hw
+ H, W = hw
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+ if Hp > H or Wp > W:
+ x = x[:, :H, :W, :].contiguous()
+def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ Extracts relative positional embeddings based on query and key sizes.
+ q_size (int): Size of the query.
+ k_size (int): Size of the key.
+ rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
+ distance and C is the embedding dimension.
+ (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
+ k_size, C).
+ >>> q_size, k_size = 8, 16
+ >>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1
+ >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
+ >>> print(extracted_pos.shape)
+ torch.Size([8, 16, 64])
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+ rel_pos_resized = rel_pos
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+ return rel_pos_resized[relative_coords.long()]
+def add_decomposed_rel_pos(
+ attn: torch.Tensor,
+ q: torch.Tensor,
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ q_size: Tuple[int, int],
+ k_size: Tuple[int, int],
+ Adds decomposed Relative Positional Embeddings to the attention map.
+ This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
+ paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
+ positions.
+ attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
+ q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
+ rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
+ rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
+ q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
+ k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
+ (torch.Tensor): Updated attention map with added relative positional embeddings, shape
+ (B, q_h * q_w, k_h * k_w).
+ >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
+ >>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
+ >>> q = torch.rand(B, q_h * q_w, C)
+ >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
+ >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
+ >>> q_size, k_size = (q_h, q_w), (k_h, k_w)
+ >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
+ >>> print(updated_attn.shape)
+ torch.Size([1, 64, 64])
+ References:
+ https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
+ q_h, q_w = q_size
+ k_h, k_w = k_size
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+ B, _, dim = q.shape
+ r_q = q.reshape(B, q_h, q_w, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+ attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
+ B, q_h * q_w, k_h * k_w
+ return attn
@@ -0,0 +1,1605 @@
+Generate predictions using the Segment Anything Model (SAM).
+SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance.
+This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation
+using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image
+segmentation tasks.
+from collections import OrderedDict
+from ultralytics.utils import DEFAULT_CFG, ops
+from ultralytics.utils.torch_utils import select_device, smart_inference_mode
+from .amg import (
+ batch_iterator,
+ batched_mask_to_box,
+ build_all_layer_point_grids,
+ calculate_stability_score,
+ generate_crop_boxes,
+ is_box_near_crop_edge,
+ remove_small_regions,
+ uncrop_boxes_xyxy,
+ uncrop_masks,
+class Predictor(BasePredictor):
+ Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
+ This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image
+ segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for
+ fine-grained control over segmentation results.
+ args (SimpleNamespace): Configuration arguments for the predictor.
+ device (torch.device): The device (CPU or GPU) on which the model is loaded.
+ im (torch.Tensor): The preprocessed input image.
+ features (torch.Tensor): Extracted image features.
+ prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
+ segment_all (bool): Flag to indicate if full image segmentation should be performed.
+ mean (torch.Tensor): Mean values for image normalization.
+ std (torch.Tensor): Standard deviation values for image normalization.
+ preprocess: Prepares input images for model inference.
+ pre_transform: Performs initial transformations on the input image.
+ inference: Performs segmentation inference based on input prompts.
+ prompt_inference: Internal function for prompt-based segmentation inference.
+ generate: Generates segmentation masks for an entire image.
+ setup_model: Initializes the SAM model for inference.
+ get_model: Builds and returns a SAM model.
+ postprocess: Post-processes model outputs to generate final results.
+ setup_source: Sets up the data source for inference.
+ set_image: Sets and preprocesses a single image for inference.
+ get_im_features: Extracts image features using the SAM image encoder.
+ set_prompts: Sets prompts for subsequent inference.
+ reset_image: Resets the current image and its features.
+ remove_small_regions: Removes small disconnected regions and holes from masks.
+ >>> predictor = Predictor()
+ >>> predictor.setup_model(model_path="sam_model.pt")
+ >>> predictor.set_image("image.jpg")
+ >>> bboxes = [[100, 100, 200, 200]]
+ >>> results = predictor(bboxes=bboxes)
+ Initialize the Predictor with configuration, overrides, and callbacks.
+ Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or
+ callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True
+ for optimal results.
+ cfg (Dict): Configuration dictionary containing default settings.
+ overrides (Dict | None): Dictionary of values to override default configuration.
+ _callbacks (Dict | None): Dictionary of callback functions to customize behavior.
+ >>> predictor_example = Predictor(cfg=DEFAULT_CFG)
+ >>> predictor_example_with_imgsz = Predictor(overrides={"imgsz": 640})
+ >>> predictor_example_with_callback = Predictor(_callbacks={"on_predict_start": custom_callback})
+ if overrides is None:
+ overrides = {}
+ overrides.update(dict(task="segment", mode="predict", batch=1))
+ self.args.retina_masks = True
+ self.im = None
+ self.features = None
+ self.segment_all = False
+ def preprocess(self, im):
+ Preprocess the input image for model inference.
+ This method prepares the input image by applying transformations and normalization. It supports both
+ torch.Tensor and list of np.ndarray as input formats.
+ im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
+ im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
+ >>> image = torch.rand(1, 3, 640, 640)
+ >>> preprocessed_image = predictor.preprocess(image)
+ if self.im is not None:
+ return self.im
+ not_tensor = not isinstance(im, torch.Tensor)
+ if not_tensor:
+ im = np.stack(self.pre_transform(im))
+ im = im[..., ::-1].transpose((0, 3, 1, 2))
+ im = np.ascontiguousarray(im)
+ im = torch.from_numpy(im)
+ im = im.to(self.device)
+ im = im.half() if self.model.fp16 else im.float()
+ im = (im - self.mean) / self.std
+ return im
+ Perform initial transformations on the input image for preprocessing.
+ This method applies transformations such as resizing to prepare the image for further preprocessing.
+ Currently, batched inference is not supported; hence the list length should be 1.
+ im (List[np.ndarray]): List containing a single image in HWC numpy array format.
+ (List[np.ndarray]): List containing the transformed image.
+ AssertionError: If the input list contains more than one image.
+ >>> image = np.random.rand(480, 640, 3) # Single HWC image
+ >>> transformed = predictor.pre_transform([image])
+ >>> print(len(transformed))
+ 1
+ assert len(im) == 1, "SAM model does not currently support batched inference"
+ letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
+ def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
+ Perform image segmentation inference based on the given input cues, using the currently loaded image.
+ This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
+ encoder, and mask decoder for real-time and promptable segmentation tasks.
+ im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
+ bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format.
+ points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
+ labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.
+ masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256.
+ multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts.
+ *args (Any): Additional positional arguments.
+ (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
+ (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
+ (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
+ >>> results = predictor(bboxes=[[0, 0, 100, 100]])
+ # Override prompts if any stored in self.prompts
+ bboxes = self.prompts.pop("bboxes", bboxes)
+ points = self.prompts.pop("points", points)
+ masks = self.prompts.pop("masks", masks)
+ labels = self.prompts.pop("labels", labels)
+ if all(i is None for i in [bboxes, points, masks]):
+ return self.generate(im, *args, **kwargs)
+ return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
+ def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
+ Performs image segmentation inference based on input cues using SAM's specialized architecture.
+ This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.
+ It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
+ im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
+ bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
+ points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
+ labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
+ masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
+ multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
+ AssertionError: If the number of points don't match the number of labels, in case labels were passed.
+ (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
+ (np.ndarray): Quality scores predicted by the model for each mask, with length C.
+ >>> im = torch.rand(1, 3, 1024, 1024)
+ >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes)
+ features = self.get_im_features(im) if self.features is None else self.features
+ bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
+ points = (points, labels) if points is not None else None
+ # Embed prompts
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
+ # Predict masks
+ pred_masks, pred_scores = self.model.mask_decoder(
+ image_embeddings=features,
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
+ # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
+ # `d` could be 1 or 3 depends on `multimask_output`.
+ return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
+ def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
+ Prepares and transforms the input prompts for processing based on the destination shape.
+ dst_shape (tuple): The target shape (height, width) for the prompts.
+ masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
+ (tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
+ src_shape = self.batch[1][0].shape[:2]
+ r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
+ # Transform input prompts
+ points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
+ # Assuming labels are all positive if users don't pass labels.
+ labels = np.ones(points.shape[:-1])
+ assert points.shape[-2] == labels.shape[-1], (
+ f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}."
+ points *= r
+ if points.ndim == 2:
+ # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
+ points, labels = points[:, None, :], labels[:, None]
+ bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
+ bboxes *= r
+ masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
+ return bboxes, points, labels, masks
+ def generate(
+ im,
+ crop_n_layers=0,
+ crop_overlap_ratio=512 / 1500,
+ crop_downscale_factor=1,
+ point_grids=None,
+ points_stride=32,
+ points_batch_size=64,
+ conf_thres=0.88,
+ stability_score_thresh=0.95,
+ stability_score_offset=0.95,
+ crop_nms_thresh=0.7,
+ Perform image segmentation using the Segment Anything Model (SAM).
+ This method segments an entire image into constituent parts by leveraging SAM's advanced architecture
+ and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
+ im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W).
+ crop_n_layers (int): Number of layers for additional mask predictions on image crops.
+ crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers.
+ crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer.
+ point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].
+ points_stride (int): Number of points to sample along each side of the image.
+ points_batch_size (int): Batch size for the number of points processed simultaneously.
+ conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction.
+ stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability.
+ stability_score_offset (float): Offset value for calculating stability score.
+ crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
+ pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
+ pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
+ pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
+ >>> im = torch.rand(1, 3, 1024, 1024) # Example input image
+ >>> masks, scores, boxes = predictor.generate(im)
+ import torchvision # scope for faster 'import ultralytics'
+ self.segment_all = True
+ ih, iw = im.shape[2:]
+ crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
+ if point_grids is None:
+ point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor)
+ pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
+ for crop_region, layer_idx in zip(crop_regions, layer_idxs):
+ x1, y1, x2, y2 = crop_region
+ w, h = x2 - x1, y2 - y1
+ area = torch.tensor(w * h, device=im.device)
+ points_scale = np.array([[w, h]]) # w, h
+ # Crop image and interpolate to input size
+ crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False)
+ # (num_points, 2)
+ points_for_image = point_grids[layer_idx] * points_scale
+ crop_masks, crop_scores, crop_bboxes = [], [], []
+ for (points,) in batch_iterator(points_batch_size, points_for_image):
+ pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
+ # Interpolate predicted masks to input size
+ pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0]
+ idx = pred_score > conf_thres
+ pred_mask, pred_score = pred_mask[idx], pred_score[idx]
+ stability_score = calculate_stability_score(
+ pred_mask, self.model.mask_threshold, stability_score_offset
+ idx = stability_score > stability_score_thresh
+ # Bool type is much more memory-efficient.
+ pred_mask = pred_mask > self.model.mask_threshold
+ # (N, 4)
+ pred_bbox = batched_mask_to_box(pred_mask).float()
+ keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])
+ if not torch.all(keep_mask):
+ pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask]
+ crop_masks.append(pred_mask)
+ crop_bboxes.append(pred_bbox)
+ crop_scores.append(pred_score)
+ # Do nms within this crop
+ crop_masks = torch.cat(crop_masks)
+ crop_bboxes = torch.cat(crop_bboxes)
+ crop_scores = torch.cat(crop_scores)
+ keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS
+ crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)
+ crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)
+ crop_scores = crop_scores[keep]
+ pred_masks.append(crop_masks)
+ pred_bboxes.append(crop_bboxes)
+ pred_scores.append(crop_scores)
+ region_areas.append(area.expand(len(crop_masks)))
+ pred_masks = torch.cat(pred_masks)
+ pred_bboxes = torch.cat(pred_bboxes)
+ pred_scores = torch.cat(pred_scores)
+ region_areas = torch.cat(region_areas)
+ # Remove duplicate masks between crops
+ if len(crop_regions) > 1:
+ scores = 1 / region_areas
+ keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)
+ pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep]
+ return pred_masks, pred_scores, pred_bboxes
+ def setup_model(self, model=None, verbose=True):
+ Initializes the Segment Anything Model (SAM) for inference.
+ This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
+ parameters for image normalization and other Ultralytics compatibility settings.
+ model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config.
+ verbose (bool): If True, prints selected device information.
+ >>> predictor.setup_model(model=sam_model, verbose=True)
+ device = select_device(self.args.device, verbose=verbose)
+ if model is None:
+ model = self.get_model()
+ model.eval()
+ self.model = model.to(device)
+ self.device = device
+ self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
+ self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
+ # Ultralytics compatibility settings
+ self.model.pt = False
+ self.model.triton = False
+ self.model.stride = 32
+ self.model.fp16 = False
+ self.done_warmup = True
+ def get_model(self):
+ """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks."""
+ return build_sam(self.args.model)
+ Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
+ This method scales masks and boxes to the original image size and applies a threshold to the mask
+ predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
+ preds (Tuple[torch.Tensor]): The output from SAM model inference, containing:
+ - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).
+ - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).
+ - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.
+ img (torch.Tensor): The processed input image tensor with shape (C, H, W).
+ orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
+ results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
+ metadata for each processed image.
+ >>> preds = predictor.inference(img)
+ >>> results = predictor.postprocess(preds, img, orig_imgs)
+ # (N, 1, H, W), (N, 1)
+ pred_masks, pred_scores = preds[:2]
+ pred_bboxes = preds[2] if self.segment_all else None
+ names = dict(enumerate(str(i) for i in range(len(pred_masks))))
+ for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]):
+ if len(masks) == 0:
+ masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device)
+ masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
+ masks = masks > self.model.mask_threshold # to bool
+ if pred_bboxes is not None:
+ pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
+ pred_bboxes = batched_mask_to_box(masks)
+ # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
+ cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
+ pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
+ results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
+ # Reset segment-all mode.
+ def setup_source(self, source):
+ Sets up the data source for inference.
+ This method configures the data source from which images will be fetched for inference. It supports
+ various input types such as image files, directories, video files, and other compatible data sources.
+ source (str | Path | None): The path or identifier for the image data source. Can be a file path,
+ directory path, URL, or other supported source types.
+ >>> predictor.setup_source("path/to/images")
+ >>> predictor.setup_source("video.mp4")
+ >>> predictor.setup_source(None) # Uses default source if available
+ - If source is None, the method may use a default source if configured.
+ - The method adapts to different source types and prepares them for subsequent inference steps.
+ - Supported source types may include local files, directories, URLs, and video streams.
+ if source is not None:
+ super().setup_source(source)
+ def set_image(self, image):
+ Preprocesses and sets a single image for inference.
+ This method prepares the model for inference on a single image by setting up the model if not already
+ initialized, configuring the data source, and preprocessing the image for feature extraction. It
+ ensures that only one image is set at a time and extracts image features for subsequent use.
+ image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
+ an image read by cv2.
+ AssertionError: If more than one image is attempted to be set.
+ >>> predictor.set_image("path/to/image.jpg")
+ >>> predictor.set_image(cv2.imread("path/to/image.jpg"))
+ - This method should be called before performing inference on a new image.
+ - The extracted features are stored in the `self.features` attribute for later use.
+ if self.model is None:
+ self.setup_model(model=None)
+ self.setup_source(image)
+ assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
+ for batch in self.dataset:
+ im = self.preprocess(batch[1])
+ self.features = self.get_im_features(im)
+ def get_im_features(self, im):
+ """Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
+ assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
+ f"SAM models only support square image size, but got {self.imgsz}."
+ self.model.set_imgsz(self.imgsz)
+ return self.model.image_encoder(im)
+ """Sets prompts for subsequent inference operations."""
+ def reset_image(self):
+ """Resets the current image and its features, clearing them for subsequent inference."""
+ def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
+ Remove small disconnected regions and holes from segmentation masks.
+ This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM).
+ It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
+ Suppression (NMS) to eliminate any newly created duplicate boxes.
+ masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of
+ masks, H is height, and W is width.
+ min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than
+ this will be removed.
+ nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
+ new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
+ keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
+ >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
+ >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7)
+ >>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}")
+ >>> print(f"Indices of kept masks: {keep}")
+ # Filter small disconnected regions and holes
+ new_masks = []
+ scores = []
+ for mask in masks:
+ mask = mask.cpu().numpy().astype(np.uint8)
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
+ unchanged = not changed
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
+ unchanged = unchanged and not changed
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+ # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing
+ scores.append(float(unchanged))
+ # Recalculate boxes and remove any new duplicates
+ new_masks = torch.cat(new_masks, dim=0)
+ boxes = batched_mask_to_box(new_masks)
+ keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
+ return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
+class SAM2Predictor(Predictor):
+ SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
+ This class extends the base Predictor class to implement SAM2-specific functionality for image
+ segmentation tasks. It provides methods for model initialization, feature extraction, and
+ prompt-based inference.
+ _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels.
+ model (torch.nn.Module): The loaded SAM2 model.
+ features (Dict[str, torch.Tensor]): Cached image features for efficient inference.
+ segment_all (bool): Flag to indicate if all segments should be predicted.
+ prompts (Dict): Dictionary to store various types of prompts for inference.
+ get_model: Retrieves and initializes the SAM2 model.
+ prompt_inference: Performs image segmentation inference based on various prompts.
+ set_image: Preprocesses and sets a single image for inference.
+ get_im_features: Extracts and processes image features using SAM2's image encoder.
+ >>> predictor = SAM2Predictor(cfg)
+ >>> result = predictor(bboxes=bboxes)[0]
+ >>> print(f"Predicted {len(result.masks)} masks with average score {result.boxes.conf.mean():.2f}")
+ _bb_feat_sizes = [
+ (256, 256),
+ (128, 128),
+ (64, 64),
+ """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
+ def prompt_inference(
+ bboxes=None,
+ points=None,
+ labels=None,
+ masks=None,
+ img_idx=-1,
+ Performs image segmentation inference based on various prompts using SAM2 architecture.
+ This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
+ based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
+ multi-object prediction scenarios.
+ bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
+ points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
+ labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
+ masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
+ img_idx (int): Index of the image in the batch to process.
+ (np.ndarray): Quality scores for each mask, with length C.
+ >>> result = predictor(image, bboxes=bboxes)[0]
+ >>> print(f"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}")
+ - The method supports batched inference for multiple objects when points or bboxes are provided.
+ - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
+ - When both bboxes and points are provided, they are merged into a single 'points' input for the model.
+ - SAM2 Paper: [Add link to SAM2 paper when available]
+ points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
+ points=points,
+ masks=masks,
+ batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
+ high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
+ pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
+ image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
+ repeat_image=batched_mode,
+ labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
+ (tuple): A tuple containing transformed points, labels, and masks.
+ bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
+ bboxes = bboxes.view(-1, 2, 2)
+ bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
+ # NOTE: merge "boxes" and "points" into a single "points" input
+ # (where boxes are added at the beginning) to model.sam_prompt_encoder
+ points = torch.cat([bboxes, points], dim=1)
+ labels = torch.cat([bbox_labels, labels], dim=1)
+ points, labels = bboxes, bbox_labels
+ return points, labels, masks
+ Preprocesses and sets a single image for inference using the SAM2 model.
+ This method initializes the model if not already done, configures the data source to the specified image,
+ and preprocesses the image for feature extraction. It supports setting only one image at a time.
+ image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
+ >>> predictor = SAM2Predictor()
+ >>> predictor.set_image(np.array([...])) # Using a numpy array
+ - This method must be called before performing any inference on a new image.
+ - The method caches the extracted features for efficient subsequent inferences on the same image.
+ - Only one image can be set at a time. To process multiple images, call this method for each new image.
+ """Extracts image features from the SAM image encoder for subsequent processing."""
+ f"SAM 2 models only support square image size, but got {self.imgsz}."
+ self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]]
+ backbone_out = self.model.forward_image(im)
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
+ if self.model.directly_add_no_mem_embed:
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+ feats = [
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+ ][::-1]
+ return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
+class SAM2VideoPredictor(SAM2Predictor):
+ SAM2VideoPredictor to handle user interactions with videos and manage inference states.
+ This class extends the functionality of SAM2Predictor to support video processing and maintains
+ the state of inference operations. It includes configurations for managing non-overlapping masks,
+ clearing memory for non-conditional inputs, and setting up callbacks for prediction events.
+ inference_state (Dict): A dictionary to store the current state of inference operations.
+ non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.
+ clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.
+ clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.
+ callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events.
+ cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG.
+ overrides (Dict, Optional): Additional configuration overrides. Defaults to None.
+ _callbacks (List, Optional): Custom callbacks to be added. Defaults to None.
+ The `fill_hole_area` attribute is defined but not used in the current implementation.
+ # fill_hole_area = 8 # not used
+ Initialize the predictor with configuration and optional overrides.
+ This constructor initializes the SAM2VideoPredictor with a given configuration, applies any
+ specified overrides, and sets up the inference state along with certain flags
+ that control the behavior of the predictor.
+ >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
+ >>> predictor_example_with_imgsz = SAM2VideoPredictor(overrides={"imgsz": 640})
+ >>> predictor_example_with_callback = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback})
+ self.inference_state = {}
+ self.non_overlap_masks = True
+ self.clear_non_cond_mem_around_input = False
+ self.clear_non_cond_mem_for_multi_obj = False
+ self.callbacks["on_predict_start"].append(self.init_state)
+ Retrieves and configures the model with binarization enabled.
+ This method overrides the base class implementation to set the binarize flag to True.
+ model = super().get_model()
+ model.set_binarize(True)
+ def inference(self, im, bboxes=None, points=None, labels=None, masks=None):
+ Perform image segmentation inference based on the given input cues, using the currently loaded image. This
+ method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
+ mask decoder for real-time and promptable segmentation tasks.
+ masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
+ (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
+ frame = self.dataset.frame
+ self.inference_state["im"] = im
+ output_dict = self.inference_state["output_dict"]
+ if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts
+ for i in range(len(points)):
+ self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame)
+ for i in range(len(masks)):
+ self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame)
+ self.propagate_in_video_preflight()
+ consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
+ batch_size = len(self.inference_state["obj_idx_to_id"])
+ if len(output_dict["cond_frame_outputs"]) == 0:
+ raise RuntimeError("No points are provided; please add points first")
+ if frame in consolidated_frame_inds["cond_frame_outputs"]:
+ storage_key = "cond_frame_outputs"
+ current_out = output_dict[storage_key][frame]
+ if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
+ # clear non-conditioning memory of the surrounding frames
+ self._clear_non_cond_mem_around_input(frame)
+ elif frame in consolidated_frame_inds["non_cond_frame_outputs"]:
+ storage_key = "non_cond_frame_outputs"
+ current_out = self._run_single_frame_inference(
+ frame_idx=frame,
+ batch_size=batch_size,
+ is_init_cond_frame=False,
+ reverse=False,
+ output_dict[storage_key][frame] = current_out
+ # Create slices of per-object outputs for subsequent interaction with each
+ # individual object after tracking.
+ self._add_output_per_object(frame, current_out, storage_key)
+ self.inference_state["frames_already_tracked"].append(frame)
+ pred_masks = current_out["pred_masks"].flatten(0, 1)
+ pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks
+ return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device)
+ Post-processes the predictions to apply non-overlapping constraints if required.
+ This method extends the post-processing functionality by applying non-overlapping constraints
+ to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that
+ the masks do not overlap, which can be useful for certain applications.
+ preds (Tuple[torch.Tensor]): The predictions from the model.
+ img (torch.Tensor): The processed image tensor.
+ orig_imgs (List[np.ndarray]): The original images before processing.
+ results (list): The post-processed predictions.
+ If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.
+ if self.non_overlap_masks:
+ if result.masks is None or len(result.masks) == 0:
+ result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0]
+ @smart_inference_mode()
+ def add_new_prompts(
+ obj_id,
+ frame_idx=0,
+ Adds new points or masks to a specific frame for a given object ID.
+ This method updates the inference state with new prompts (points or masks) for a specified
+ object and frame index. It ensures that the prompts are either points or masks, but not both,
+ and updates the internal state accordingly. It also handles the generation of new segmentations
+ based on the provided prompts and the existing state.
+ obj_id (int): The ID of the object to which the prompts are associated.
+ points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None.
+ labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None.
+ masks (torch.Tensor, optional): Binary masks for the object. Defaults to None.
+ frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0.
+ (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects.
+ AssertionError: If both `masks` and `points` are provided, or neither is provided.
+ - Only one type of prompt (either points or masks) can be added per call.
+ - If the frame is being tracked for the first time, it is treated as an initial conditioning frame.
+ - The method handles the consolidation of outputs and resizing of masks to the original video resolution.
+ assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other."
+ obj_idx = self._obj_id_to_idx(obj_id)
+ point_inputs = None
+ pop_key = "point_inputs_per_obj"
+ point_inputs = {"point_coords": points, "point_labels": labels}
+ self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs
+ pop_key = "mask_inputs_per_obj"
+ self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks
+ self.inference_state[pop_key][obj_idx].pop(frame_idx, None)
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
+ # frame, meaning that the inputs points are to generate segments on this frame without
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+ # the input points will be used to correct the already tracked masks.
+ is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"]
+ obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
+ # Add a frame to conditioning output if it's an initial conditioning frame or
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
+ is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+ # Get any previously predicted mask logits on this object and feed it along with
+ # the new clicks into the SAM mask decoder.
+ prev_sam_mask_logits = None
+ # lookup temporary output dict first, which contains the most recent output
+ # (if not found, then lookup conditioning and non-conditioning frame output)
+ prev_out = (
+ obj_temp_output_dict[storage_key].get(frame_idx)
+ or obj_output_dict["cond_frame_outputs"].get(frame_idx)
+ or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
+ if prev_out is not None and prev_out.get("pred_masks") is not None:
+ prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True)
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
+ prev_sam_mask_logits.clamp_(-32.0, 32.0)
+ output_dict=obj_output_dict, # run on the slice of a single object
+ batch_size=1, # run on the slice of a single object
+ mask_inputs=masks,
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+ # allows us to enforce non-overlapping constraints on all objects before encoding
+ # them into memory.
+ run_mem_encoder=False,
+ prev_sam_mask_logits=prev_sam_mask_logits,
+ # Add the output to the output dict (to be used as future memory)
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
+ # Resize the output mask to the original video resolution
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ is_cond=is_cond,
+ pred_masks = consolidated_out["pred_masks"].flatten(0, 1)
+ return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device)
+ def propagate_in_video_preflight(self):
+ Prepare inference_state and consolidate temporary outputs before tracking.
+ This method marks the start of tracking, disallowing the addition of new objects until the session is reset.
+ It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`.
+ Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent
+ with the provided inputs.
+ # Tracking has started and we don't allow adding new objects until session is reset.
+ self.inference_state["tracking_has_started"] = True
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
+ # add them into "output_dict".
+ temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"]
+ # "consolidated_frame_inds" contains indices of those frames where consolidated
+ # temporary outputs have been added (either in this call or any previous calls
+ # to `propagate_in_video_preflight`).
+ for is_cond in {False, True}:
+ # Separately consolidate conditioning and non-conditioning temp outputs
+ # Find all the frames that contain temporary outputs for any objects
+ # (these should be the frames that have just received clicks for mask inputs
+ # via `add_new_points` or `add_new_mask`)
+ temp_frame_inds = set()
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
+ temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
+ consolidated_frame_inds[storage_key].update(temp_frame_inds)
+ # consolidate the temporary output across all objects on this frame
+ for frame_idx in temp_frame_inds:
+ frame_idx, is_cond=is_cond, run_mem_encoder=True
+ # merge them into "output_dict" and also create per-object slices
+ output_dict[storage_key][frame_idx] = consolidated_out
+ self._add_output_per_object(frame_idx, consolidated_out, storage_key)
+ self._clear_non_cond_mem_around_input(frame_idx)
+ # clear temporary outputs in `temp_output_dict_per_obj`
+ obj_temp_output_dict[storage_key].clear()
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
+ # output on the same frame in "non_cond_frame_outputs"
+ for frame_idx in output_dict["cond_frame_outputs"]:
+ output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+ for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+ for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
+ assert frame_idx in output_dict["cond_frame_outputs"]
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
+ # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
+ # with either points or mask inputs (which should be true under a correct workflow).
+ all_consolidated_frame_inds = (
+ consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"]
+ input_frames_inds = set()
+ for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values():
+ input_frames_inds.update(point_inputs_per_frame.keys())
+ for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values():
+ input_frames_inds.update(mask_inputs_per_frame.keys())
+ assert all_consolidated_frame_inds == input_frames_inds
+ def init_state(predictor):
+ Initialize an inference state for the predictor.
+ This function sets up the initial state required for performing inference on video data.
+ It includes initializing various dictionaries and ordered dictionaries that will store
+ inputs, outputs, and other metadata relevant to the tracking process.
+ predictor (SAM2VideoPredictor): The predictor object for which to initialize the state.
+ if len(predictor.inference_state) > 0: # means initialized
+ return
+ assert predictor.dataset is not None
+ assert predictor.dataset.mode == "video"
+ inference_state = {
+ "num_frames": predictor.dataset.frames,
+ "point_inputs_per_obj": {}, # inputs points on each frame
+ "mask_inputs_per_obj": {}, # inputs mask on each frame
+ "constants": {}, # values that don't change across frames (so we only need to hold one copy of them)
+ # mapping between client-side object id and model-side object index
+ "obj_id_to_idx": OrderedDict(),
+ "obj_idx_to_id": OrderedDict(),
+ "obj_ids": [],
+ # A storage to hold the model's tracking results and states on each frame
+ "output_dict": {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
+ },
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
+ "output_dict_per_obj": {},
+ # A temporary storage to hold new outputs when user interact with a frame
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
+ "temp_output_dict_per_obj": {},
+ # Frames that already holds consolidated outputs from click or mask inputs
+ # (we directly use their consolidated outputs during tracking)
+ "consolidated_frame_inds": {
+ "cond_frame_outputs": set(), # set containing frame indices
+ "non_cond_frame_outputs": set(), # set containing frame indices
+ # metadata for each tracking frame (e.g. which direction it's tracked)
+ "tracking_has_started": False,
+ "frames_already_tracked": [],
+ predictor.inference_state = inference_state
+ def get_im_features(self, im, batch=1):
+ Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.
+ im (torch.Tensor): The input image tensor.
+ batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1.
+ vis_feats (torch.Tensor): The visual features extracted from the image.
+ vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.
+ feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features.
+ - If `batch` is greater than 1, the features are expanded to fit the batch size.
+ - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features.
+ if batch > 1: # expand features if there's more than one prompt
+ for i, feat in enumerate(backbone_out["backbone_fpn"]):
+ backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1)
+ for i, pos in enumerate(backbone_out["vision_pos_enc"]):
+ pos = pos.expand(batch, -1, -1, -1)
+ backbone_out["vision_pos_enc"][i] = pos
+ _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out)
+ return vis_feats, vis_pos_embed, feat_sizes
+ def _obj_id_to_idx(self, obj_id):
+ Map client-side object id to model-side object index.
+ obj_id (int): The unique identifier of the object provided by the client side.
+ obj_idx (int): The index of the object on the model side.
+ RuntimeError: If an attempt is made to add a new object after tracking has started.
+ - The method updates or retrieves mappings between object IDs and indices stored in
+ `inference_state`.
+ - It ensures that new objects can only be added before tracking commences.
+ - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`).
+ - Additional data structures are initialized for the new object to store inputs and outputs.
+ obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None)
+ if obj_idx is not None:
+ return obj_idx
+ # This is a new object id not sent to the server before. We only allow adding
+ # new objects *before* the tracking starts.
+ allow_new_object = not self.inference_state["tracking_has_started"]
+ if allow_new_object:
+ # get the next object slot
+ obj_idx = len(self.inference_state["obj_id_to_idx"])
+ self.inference_state["obj_id_to_idx"][obj_id] = obj_idx
+ self.inference_state["obj_idx_to_id"][obj_idx] = obj_id
+ self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"])
+ # set up input and output structures for this object
+ self.inference_state["point_inputs_per_obj"][obj_idx] = {}
+ self.inference_state["mask_inputs_per_obj"][obj_idx] = {}
+ self.inference_state["output_dict_per_obj"][obj_idx] = {
+ self.inference_state["temp_output_dict_per_obj"][obj_idx] = {
+ raise RuntimeError(
+ f"Cannot add new object id {obj_id} after tracking starts. "
+ f"All existing object ids: {self.inference_state['obj_ids']}. "
+ f"Please call 'reset_state' to restart from scratch."
+ def _run_single_frame_inference(
+ batch_size,
+ reverse,
+ Run tracking on a single frame based on current inputs and previous memory.
+ output_dict (Dict): The dictionary containing the output states of the tracking process.
+ frame_idx (int): The index of the current frame.
+ batch_size (int): The batch size for processing the frame.
+ is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.
+ point_inputs (Dict, Optional): Input points and their labels. Defaults to None.
+ mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None.
+ reverse (bool): Indicates if the tracking should be performed in reverse order.
+ run_mem_encoder (bool): Indicates if the memory encoder should be executed.
+ prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None.
+ current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions.
+ AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.
+ - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive.
+ - The method retrieves image features using the `get_im_features` method.
+ - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored.
+ - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements.
+ # Retrieve correct image features
+ current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(
+ self.inference_state["im"], batch_size
+ # point and mask should not appear as input simultaneously on the same frame
+ assert point_inputs is None or mask_inputs is None
+ current_out = self.model.track_step(
+ current_vision_pos_embeds=current_vision_pos_embeds,
+ num_frames=self.inference_state["num_frames"],
+ track_in_reverse=reverse,
+ run_mem_encoder=run_mem_encoder,
+ maskmem_features = current_out["maskmem_features"]
+ if maskmem_features is not None:
+ current_out["maskmem_features"] = maskmem_features.to(
+ dtype=torch.float16, device=self.device, non_blocking=True
+ # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions
+ # potentially fill holes in the predicted masks
+ # if self.fill_hole_area > 0:
+ # pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True)
+ # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area)
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"])
+ def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
+ Caches and manages the positional encoding for mask memory across frames and objects.
+ This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for
+ mask memory, which is constant across frames and objects, thus reducing the amount of
+ redundant information stored during an inference session. It checks if the positional
+ encoding has already been cached; if not, it caches a slice of the provided encoding.
+ If the batch size is greater than one, it expands the cached positional encoding to match
+ the current batch size.
+ out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory.
+ Should be a list of tensors or None.
+ out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
+ - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.
+ - Only a single object's slice is cached since the encoding is the same across objects.
+ - The method checks if the positional encoding has already been cached in the session's constants.
+ - If the batch size is greater than one, the cached encoding is expanded to fit the batch size.
+ model_constants = self.inference_state["constants"]
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
+ if out_maskmem_pos_enc is not None:
+ if "maskmem_pos_enc" not in model_constants:
+ assert isinstance(out_maskmem_pos_enc, list)
+ # only take the slice for one object, since it's same across objects
+ maskmem_pos_enc = [x[:1].clone() for x in out_maskmem_pos_enc]
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
+ # expand the cached maskmem_pos_enc to the actual batch size
+ batch_size = out_maskmem_pos_enc[0].size(0)
+ if batch_size > 1:
+ out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]
+ return out_maskmem_pos_enc
+ def _consolidate_temp_output_across_obj(
+ is_cond=False,
+ Consolidates per-object temporary outputs into a single output for all objects.
+ This method combines the temporary outputs for each object on a given frame into a unified
+ output. It fills in any missing objects either from the main output dictionary or leaves
+ placeholders if they do not exist in the main output. Optionally, it can re-run the memory
+ encoder after applying non-overlapping constraints to the object scores.
+ frame_idx (int): The index of the frame for which to consolidate outputs.
+ is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame.
+ Defaults to False.
+ run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after
+ consolidating the outputs. Defaults to False.
+ consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects.
+ - The method initializes the consolidated output with placeholder values for missing objects.
+ - It searches for outputs in both the temporary and main output dictionaries.
+ - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder.
+ - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True.
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
+ # will be added when rerunning the memory encoder after applying non-overlapping
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
+ consolidated_out = {
+ "maskmem_features": None,
+ "maskmem_pos_enc": None,
+ "pred_masks": torch.full(
+ size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
+ fill_value=-1024.0,
+ dtype=torch.float32,
+ device=self.device,
+ "obj_ptr": torch.full(
+ size=(batch_size, self.model.hidden_dim),
+ "object_score_logits": torch.full(
+ size=(batch_size, 1),
+ # default to 10.0 for object_score_logits, i.e. assuming the object is
+ # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
+ fill_value=10.0,
+ for obj_idx in range(batch_size):
+ out = (
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
+ # we fall back and look up its previous output in "output_dict_per_obj".
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
+ # "output_dict_per_obj" to find a previous output for this object.
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
+ # placeholder above) and set its object pointer to be a dummy pointer.
+ # Fill in dummy object pointers for those objects without any inputs or
+ # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
+ # i.e. when we need to build the memory for tracking).
+ if run_mem_encoder:
+ # fill object pointer with a dummy pointer (based on an empty mask)
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx)
+ # Add the temporary object output mask to consolidated output mask
+ consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"]
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
+ # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder
+ high_res_masks = F.interpolate(
+ consolidated_out["pred_masks"],
+ size=self.imgsz,
+ if self.model.non_overlap_masks_for_mem_enc:
+ high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)
+ consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder(
+ high_res_masks=high_res_masks,
+ is_mask_from_pts=True, # these frames are what the user interacted with
+ object_score_logits=consolidated_out["object_score_logits"],
+ return consolidated_out
+ def _get_empty_mask_ptr(self, frame_idx):
+ Get a dummy object pointer based on an empty mask on the current frame.
+ frame_idx (int): The index of the current frame for which to generate the dummy object pointer.
+ (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask.
+ current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"])
+ # Feed the empty mask and image feature above to get a dummy object pointer
+ is_init_cond_frame=True,
+ # A dummy (empty) mask with a single object
+ mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device),
+ output_dict={},
+ track_in_reverse=False,
+ return current_out["obj_ptr"]
+ def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):
+ Run the memory encoder on masks.
+ This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their
+ memory also needs to be computed again with the memory encoder.
+ high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory.
+ object_score_logits (torch.Tensor): Logits representing the object scores.
+ is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.
+ (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding.
+ current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size)
+ maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
+ pred_masks_high_res=high_res_masks,
+ is_mask_from_pts=is_mask_from_pts,
+ maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc)
+ return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc
+ def _add_output_per_object(self, frame_idx, current_out, storage_key):
+ Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
+ The resulting slices share the same tensor storage.
+ current_out (Dict): The current output dictionary containing multi-object outputs.
+ storage_key (str): The key used to store the output in the per-object output dictionary.
+ assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
+ maskmem_pos_enc = current_out["maskmem_pos_enc"]
+ assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
+ for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items():
+ obj_slice = slice(obj_idx, obj_idx + 1)
+ obj_out = {
+ "pred_masks": current_out["pred_masks"][obj_slice],
+ "obj_ptr": current_out["obj_ptr"][obj_slice],
+ obj_out["maskmem_features"] = maskmem_features[obj_slice]
+ if maskmem_pos_enc is not None:
+ obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
+ obj_output_dict[storage_key][frame_idx] = obj_out
+ def _clear_non_cond_mem_around_input(self, frame_idx):
+ Remove the non-conditioning memory around the input frame.
+ When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated
+ object appearance information and could confuse the model. This method clears those non-conditioning memories
+ surrounding the interacted frame to avoid giving the model both old and new information about the object.
+ frame_idx (int): The index of the current frame where user interaction occurred.
+ r = self.model.memory_temporal_stride_for_eval
+ frame_idx_begin = frame_idx - r * self.model.num_maskmem
+ frame_idx_end = frame_idx + r * self.model.num_maskmem
+ for t in range(frame_idx_begin, frame_idx_end + 1):
+ self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
+ obj_output_dict["non_cond_frame_outputs"].pop(t, None)
@@ -0,0 +1,357 @@
+from ultralytics.utils.loss import FocalLoss, VarifocalLoss
+from ultralytics.utils.metrics import bbox_iou
+from .ops import HungarianMatcher
+class DETRLoss(nn.Module):
+ DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the
+ DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary
+ losses.
+ nc (int): The number of classes.
+ loss_gain (dict): Coefficients for different loss components.
+ aux_loss (bool): Whether to compute auxiliary losses.
+ use_fl (bool): Use FocalLoss or not.
+ use_vfl (bool): Use VarifocalLoss or not.
+ use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch.
+ uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True.
+ matcher (HungarianMatcher): Object to compute matching cost and indices.
+ fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None.
+ vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None.
+ device (torch.device): Device on which tensors are stored.
+ self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
+ Initialize DETR loss function with customizable components and gains.
+ Uses default loss_gain if not provided. Initializes HungarianMatcher with
+ preset cost gains. Supports auxiliary losses and various loss types.
+ nc (int): Number of classes.
+ aux_loss (bool): Use auxiliary losses from each decoder layer.
+ use_fl (bool): Use FocalLoss.
+ use_vfl (bool): Use VarifocalLoss.
+ use_uni_match (bool): Use fixed layer for auxiliary branch label assignment.
+ uni_match_ind (int): Index of fixed layer for uni_match.
+ if loss_gain is None:
+ loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
+ self.nc = nc
+ self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
+ self.loss_gain = loss_gain
+ self.aux_loss = aux_loss
+ self.fl = FocalLoss() if use_fl else None
+ self.vfl = VarifocalLoss() if use_vfl else None
+ self.use_uni_match = use_uni_match
+ self.uni_match_ind = uni_match_ind
+ self.device = None
+ def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
+ """Computes the classification loss based on predictions, target values, and ground truth scores."""
+ # Logits: [b, query, num_classes], gt_class: list[[n, 1]]
+ name_class = f"loss_class{postfix}"
+ bs, nq = pred_scores.shape[:2]
+ # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
+ one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
+ one_hot.scatter_(2, targets.unsqueeze(-1), 1)
+ one_hot = one_hot[..., :-1]
+ gt_scores = gt_scores.view(bs, nq, 1) * one_hot
+ if self.fl:
+ if num_gts and self.vfl:
+ loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
+ loss_cls = self.fl(pred_scores, one_hot.float())
+ loss_cls /= max(num_gts, 1) / nq
+ loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
+ return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
+ def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
+ """Computes bounding box and GIoU losses for predicted and ground truth bounding boxes."""
+ # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
+ name_bbox = f"loss_bbox{postfix}"
+ name_giou = f"loss_giou{postfix}"
+ loss = {}
+ if len(gt_bboxes) == 0:
+ loss[name_bbox] = torch.tensor(0.0, device=self.device)
+ loss[name_giou] = torch.tensor(0.0, device=self.device)
+ return loss
+ loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
+ loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
+ loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
+ loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
+ return {k: v.squeeze() for k, v in loss.items()}
+ # This function is for future RT-DETR Segment models
+ # def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
+ # # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
+ # name_mask = f'loss_mask{postfix}'
+ # name_dice = f'loss_dice{postfix}'
+ #
+ # loss = {}
+ # if sum(len(a) for a in gt_mask) == 0:
+ # loss[name_mask] = torch.tensor(0., device=self.device)
+ # loss[name_dice] = torch.tensor(0., device=self.device)
+ # return loss
+ # num_gts = len(gt_mask)
+ # src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
+ # src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
+ # # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
+ # loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
+ # torch.tensor([num_gts], dtype=torch.float32))
+ # loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
+ # @staticmethod
+ # def _dice_loss(inputs, targets, num_gts):
+ # inputs = F.sigmoid(inputs).flatten(1)
+ # targets = targets.flatten(1)
+ # numerator = 2 * (inputs * targets).sum(1)
+ # denominator = inputs.sum(-1) + targets.sum(-1)
+ # loss = 1 - (numerator + 1) / (denominator + 1)
+ # return loss.sum() / num_gts
+ def _get_loss_aux(
+ pred_bboxes,
+ pred_scores,
+ gt_bboxes,
+ gt_cls,
+ gt_groups,
+ match_indices=None,
+ postfix="",
+ gt_mask=None,
+ """Get auxiliary losses."""
+ # NOTE: loss class, bbox, giou, mask, dice
+ loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
+ if match_indices is None and self.use_uni_match:
+ match_indices = self.matcher(
+ pred_bboxes[self.uni_match_ind],
+ pred_scores[self.uni_match_ind],
+ masks=masks[self.uni_match_ind] if masks is not None else None,
+ gt_mask=gt_mask,
+ for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
+ aux_masks = masks[i] if masks is not None else None
+ loss_ = self._get_loss(
+ aux_bboxes,
+ aux_scores,
+ masks=aux_masks,
+ postfix=postfix,
+ match_indices=match_indices,
+ loss[0] += loss_[f"loss_class{postfix}"]
+ loss[1] += loss_[f"loss_bbox{postfix}"]
+ loss[2] += loss_[f"loss_giou{postfix}"]
+ # if masks is not None and gt_mask is not None:
+ # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
+ # loss[3] += loss_[f'loss_mask{postfix}']
+ # loss[4] += loss_[f'loss_dice{postfix}']
+ loss = {
+ f"loss_class_aux{postfix}": loss[0],
+ f"loss_bbox_aux{postfix}": loss[1],
+ f"loss_giou_aux{postfix}": loss[2],
+ # loss[f'loss_mask_aux{postfix}'] = loss[3]
+ # loss[f'loss_dice_aux{postfix}'] = loss[4]
+ def _get_index(match_indices):
+ """Returns batch indices, source indices, and destination indices from provided match indices."""
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
+ src_idx = torch.cat([src for (src, _) in match_indices])
+ dst_idx = torch.cat([dst for (_, dst) in match_indices])
+ return (batch_idx, src_idx), dst_idx
+ def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
+ """Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
+ pred_assigned = torch.cat(
+ t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
+ for t, (i, _) in zip(pred_bboxes, match_indices)
+ gt_assigned = torch.cat(
+ t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
+ for t, (_, j) in zip(gt_bboxes, match_indices)
+ return pred_assigned, gt_assigned
+ def _get_loss(
+ """Get losses."""
+ if match_indices is None:
+ pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
+ idx, gt_idx = self._get_index(match_indices)
+ pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
+ targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
+ targets[idx] = gt_cls[gt_idx]
+ gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
+ if len(gt_bboxes):
+ gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
+ **self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix),
+ **self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix),
+ # **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
+ def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
+ Calculate loss for predicted bounding boxes and scores.
+ pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
+ pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
+ batch (dict): Batch information containing:
+ cls (torch.Tensor): Ground truth classes, shape [num_gts].
+ bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
+ gt_groups (List[int]): Number of ground truths for each image in the batch.
+ postfix (str): Postfix for loss names.
+ **kwargs (Any): Additional arguments, may include 'match_indices'.
+ (dict): Computed losses, including main and auxiliary (if enabled).
+ Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
+ self.aux_loss is True.
+ self.device = pred_bboxes.device
+ match_indices = kwargs.get("match_indices", None)
+ gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
+ total_loss = self._get_loss(
+ pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
+ if self.aux_loss:
+ total_loss.update(
+ self._get_loss_aux(
+ pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
+ return total_loss
+class RTDETRDetectionLoss(DETRLoss):
+ Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
+ This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
+ an additional denoising training loss when provided with denoising metadata.
+ def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
+ Forward pass to compute the detection loss.
+ preds (tuple): Predicted bounding boxes and scores.
+ batch (dict): Batch data containing ground truth information.
+ dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None.
+ dn_scores (torch.Tensor, optional): Denoising scores. Default is None.
+ dn_meta (dict, optional): Metadata for denoising. Default is None.
+ (dict): Dictionary containing the total loss and, if applicable, the denoising loss.
+ pred_bboxes, pred_scores = preds
+ total_loss = super().forward(pred_bboxes, pred_scores, batch)
+ # Check for denoising metadata to compute denoising training loss
+ if dn_meta is not None:
+ dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
+ assert len(batch["gt_groups"]) == len(dn_pos_idx)
+ # Get the match indices for denoising
+ match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
+ # Compute the denoising training loss
+ dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
+ total_loss.update(dn_loss)
+ # If no denoising metadata is provided, set denoising loss to zero
+ total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
+ def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
+ Get the match indices for denoising.
+ dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.
+ dn_num_group (int): Number of denoising groups.
+ gt_groups (List[int]): List of integers representing the number of ground truths for each image.
+ (List[tuple]): List of tuples containing matched indices for denoising.
+ dn_match_indices = []
+ idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
+ for i, num_gt in enumerate(gt_groups):
+ if num_gt > 0:
+ gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
+ gt_idx = gt_idx.repeat(dn_num_group)
+ assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
+ f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
+ dn_match_indices.append((dn_pos_idx[i], gt_idx))
+ dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
+ return dn_match_indices
@@ -0,0 +1,259 @@
+from scipy.optimize import linear_sum_assignment
+from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
+class HungarianMatcher(nn.Module):
+ A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an
+ end-to-end fashion.
+ HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost
+ function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
+ cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
+ use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
+ with_mask (bool): Indicates whether the model makes mask predictions.
+ num_sample_points (int): The number of sample points used in mask cost calculation.
+ alpha (float): The alpha factor in Focal Loss calculation.
+ gamma (float): The gamma factor in Focal Loss calculation.
+ forward(pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): Computes the
+ assignment between predictions and ground truths for a batch.
+ _cost_mask(bs, num_gts, masks=None, gt_mask=None): Computes the mask cost and dice cost if masks are predicted.
+ def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
+ """Initializes a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes."""
+ if cost_gain is None:
+ cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
+ self.cost_gain = cost_gain
+ self.use_fl = use_fl
+ self.with_mask = with_mask
+ self.num_sample_points = num_sample_points
+ self.alpha = alpha
+ self.gamma = gamma
+ def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
+ Forward pass for HungarianMatcher. This function computes costs based on prediction and ground truth
+ (classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching between
+ predictions and ground truth based on these costs.
+ pred_bboxes (Tensor): Predicted bounding boxes with shape [batch_size, num_queries, 4].
+ pred_scores (Tensor): Predicted scores with shape [batch_size, num_queries, num_classes].
+ gt_cls (torch.Tensor): Ground truth classes with shape [num_gts, ].
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape [num_gts, 4].
+ gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
+ each image.
+ masks (Tensor, optional): Predicted masks with shape [batch_size, num_queries, height, width].
+ Defaults to None.
+ gt_mask (List[Tensor], optional): List of ground truth masks, each with shape [num_masks, Height, Width].
+ (List[Tuple[Tensor, Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
+ - index_i is the tensor of indices of the selected predictions (in order)
+ - index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
+ For each batch element, it holds:
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+ bs, nq, nc = pred_scores.shape
+ if sum(gt_groups) == 0:
+ return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
+ # We flatten to compute the cost matrices in a batch
+ # [batch_size * num_queries, num_classes]
+ pred_scores = pred_scores.detach().view(-1, nc)
+ pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
+ # [batch_size * num_queries, 4]
+ pred_bboxes = pred_bboxes.detach().view(-1, 4)
+ # Compute the classification cost
+ pred_scores = pred_scores[:, gt_cls]
+ if self.use_fl:
+ neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
+ pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
+ cost_class = pos_cost_class - neg_cost_class
+ cost_class = -pred_scores
+ # Compute the L1 cost between boxes
+ cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
+ # Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
+ cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
+ # Final cost matrix
+ C = (
+ self.cost_gain["class"] * cost_class
+ + self.cost_gain["bbox"] * cost_bbox
+ + self.cost_gain["giou"] * cost_giou
+ # Compute the mask cost and dice cost
+ if self.with_mask:
+ C += self._cost_mask(bs, gt_groups, masks, gt_mask)
+ # Set invalid values (NaNs and infinities) to 0 (fixes ValueError: matrix contains invalid numeric entries)
+ C[C.isnan() | C.isinf()] = 0.0
+ C = C.view(bs, nq, -1).cpu()
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
+ gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
+ return [
+ (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
+ for k, (i, j) in enumerate(indices)
+ # def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
+ # assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
+ # # all masks share the same set of points for efficient matching
+ # sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
+ # sample_points = 2.0 * sample_points - 1.0
+ # out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
+ # out_mask = out_mask.flatten(0, 1)
+ # tgt_mask = torch.cat(gt_mask).unsqueeze(1)
+ # sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
+ # tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
+ # with torch.amp.autocast("cuda", enabled=False):
+ # # binary cross entropy cost
+ # pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
+ # neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
+ # cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
+ # cost_mask /= self.num_sample_points
+ # # dice cost
+ # out_mask = F.sigmoid(out_mask)
+ # numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
+ # denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
+ # cost_dice = 1 - (numerator + 1) / (denominator + 1)
+ # C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
+ # return C
+def get_cdn_group(
+ batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
+ Get contrastive denoising training group. This function creates a contrastive denoising training group with positive
+ and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
+ and returns the modified labels, bounding boxes, attention mask and meta information.
+ batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes'
+ (torch.Tensor with shape [num_gts, 4]), 'gt_groups' (List(int)) which is a list of batch size length
+ indicating the number of gts of each image.
+ num_classes (int): Number of classes.
+ num_queries (int): Number of queries.
+ class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
+ num_dn (int, optional): Number of denoising. Defaults to 100.
+ cls_noise_ratio (float, optional): Noise ratio for class labels. Defaults to 0.5.
+ box_noise_scale (float, optional): Noise scale for bounding box coordinates. Defaults to 1.0.
+ training (bool, optional): If it's in training mode. Defaults to False.
+ (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Dict]]): The modified class embeddings,
+ bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
+ is less than or equal to 0, the function returns None for all elements in the tuple.
+ if (not training) or num_dn <= 0:
+ return None, None, None, None
+ gt_groups = batch["gt_groups"]
+ total_num = sum(gt_groups)
+ max_nums = max(gt_groups)
+ if max_nums == 0:
+ num_group = num_dn // max_nums
+ num_group = 1 if num_group == 0 else num_group
+ # Pad gt to max_num of a batch
+ bs = len(gt_groups)
+ gt_cls = batch["cls"] # (bs*num, )
+ gt_bbox = batch["bboxes"] # bs*num, 4
+ b_idx = batch["batch_idx"]
+ # Each group has positive and negative queries.
+ dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
+ dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
+ dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
+ # Positive and negative mask
+ # (bs*num*num_group, ), the second total_num*num_group part as negative samples
+ neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
+ if cls_noise_ratio > 0:
+ # Half of bbox prob
+ mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
+ idx = torch.nonzero(mask).squeeze(-1)
+ # Randomly put a new one here
+ new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
+ dn_cls[idx] = new_label
+ if box_noise_scale > 0:
+ known_bbox = xywh2xyxy(dn_bbox)
+ diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
+ rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
+ rand_part = torch.rand_like(dn_bbox)
+ rand_part[neg_idx] += 1.0
+ rand_part *= rand_sign
+ known_bbox += rand_part * diff
+ known_bbox.clip_(min=0.0, max=1.0)
+ dn_bbox = xyxy2xywh(known_bbox)
+ dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
+ num_dn = int(max_nums * 2 * num_group) # total denoising queries
+ # class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
+ dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
+ padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
+ padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
+ map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
+ pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
+ map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
+ padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
+ padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
+ tgt_size = num_dn + num_queries
+ attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
+ # Match query cannot see the reconstruct
+ attn_mask[num_dn:, :num_dn] = True
+ # Reconstruct cannot see each other
+ for i in range(num_group):
+ if i == 0:
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
+ if i == num_group - 1:
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
+ attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
+ dn_meta = {
+ "dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
+ "dn_num_group": num_group,
+ "dn_num_split": [num_dn, num_queries],
+ padding_cls.to(class_embed.device),
+ padding_bbox.to(class_embed.device),
+ attn_mask.to(class_embed.device),
+ dn_meta,
+from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
+from .model import YOLO, YOLOWorld
+__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
+from ultralytics.models.yolo.classify.predict import ClassificationPredictor
+from ultralytics.models.yolo.classify.train import ClassificationTrainer
+from ultralytics.models.yolo.classify.val import ClassificationValidator
+__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
@@ -0,0 +1,60 @@
+import cv2
+class ClassificationPredictor(BasePredictor):
+ A class extending the BasePredictor class for prediction based on a classification model.
+ - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
+ from ultralytics.models.yolo.classify import ClassificationPredictor
+ args = dict(model="yolov8n-cls.pt", source=ASSETS)
+ predictor = ClassificationPredictor(overrides=args)
+ """Initializes ClassificationPredictor setting the task to 'classify'."""
+ self.args.task = "classify"
+ self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
+ def preprocess(self, img):
+ """Converts input image to model-compatible data type."""
+ if not isinstance(img, torch.Tensor):
+ is_legacy_transform = any(
+ self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
+ if is_legacy_transform: # to handle legacy transforms
+ img = torch.stack([self.transforms(im) for im in img], dim=0)
+ img = torch.stack(
+ [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
+ img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
+ return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
+ """Post-processes predictions to return Results objects."""
+ preds = preds[0] if isinstance(preds, (list, tuple)) else preds
+ Results(orig_img, path=img_path, names=self.model.names, probs=pred)
+ for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
@@ -0,0 +1,153 @@
+from ultralytics.data import ClassificationDataset, build_dataloader
+from ultralytics.engine.trainer import BaseTrainer
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import ClassificationModel
+from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
+from ultralytics.utils.plotting import plot_images, plot_results
+from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
+class ClassificationTrainer(BaseTrainer):
+ A class extending the BaseTrainer class for training based on a classification model.
+ from ultralytics.models.yolo.classify import ClassificationTrainer
+ args = dict(model="yolov8n-cls.pt", data="imagenet10", epochs=3)
+ trainer = ClassificationTrainer(overrides=args)
+ """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
+ overrides["task"] = "classify"
+ if overrides.get("imgsz") is None:
+ overrides["imgsz"] = 224
+ def set_model_attributes(self):
+ """Set the YOLO model's class names from the loaded dataset."""
+ self.model.names = self.data["names"]
+ """Returns a modified PyTorch model configured for training YOLO."""
+ model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
+ for m in model.modules():
+ if not self.args.pretrained and hasattr(m, "reset_parameters"):
+ m.reset_parameters()
+ if isinstance(m, torch.nn.Dropout) and self.args.dropout:
+ m.p = self.args.dropout # set dropout
+ for p in model.parameters():
+ p.requires_grad = True # for training
+ def setup_model(self):
+ """Load, create or download model for any task."""
+ if str(self.model) in torchvision.models.__dict__:
+ self.model = torchvision.models.__dict__[self.model](
+ weights="IMAGENET1K_V1" if self.args.pretrained else None
+ ckpt = None
+ ckpt = super().setup_model()
+ ClassificationModel.reshape_outputs(self.model, self.data["nc"])
+ return ckpt
+ def build_dataset(self, img_path, mode="train", batch=None):
+ """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
+ return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
+ def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
+ """Returns PyTorch DataLoader with transforms to preprocess images for inference."""
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
+ dataset = self.build_dataset(dataset_path, mode)
+ loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
+ # Attach inference transforms
+ if mode != "train":
+ if is_parallel(self.model):
+ self.model.module.transforms = loader.dataset.torch_transforms
+ self.model.transforms = loader.dataset.torch_transforms
+ return loader
+ """Preprocesses a batch of images and classes."""
+ batch["img"] = batch["img"].to(self.device)
+ batch["cls"] = batch["cls"].to(self.device)
+ def progress_string(self):
+ """Returns a formatted string showing training progress."""
+ return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
+ "Epoch",
+ "GPU_mem",
+ *self.loss_names,
+ "Instances",
+ "Size",
+ """Returns an instance of ClassificationValidator for validation."""
+ self.loss_names = ["loss"]
+ return yolo.classify.ClassificationValidator(
+ self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+ def label_loss_items(self, loss_items=None, prefix="train"):
+ Returns a loss dict with labelled training loss items tensor.
+ Not needed for classification but necessary for segmentation & detection
+ keys = [f"{prefix}/{x}" for x in self.loss_names]
+ if loss_items is None:
+ return keys
+ loss_items = [round(float(loss_items), 5)]
+ return dict(zip(keys, loss_items))
+ def plot_metrics(self):
+ """Plots metrics from a CSV file."""
+ plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
+ def final_eval(self):
+ """Evaluate trained model and save validation results."""
+ for f in self.last, self.best:
+ if f.exists():
+ strip_optimizer(f) # strip optimizers
+ if f is self.best:
+ LOGGER.info(f"\nValidating {f}...")
+ self.validator.args.data = self.args.data
+ self.validator.args.plots = self.args.plots
+ self.metrics = self.validator(model=f)
+ self.metrics.pop("fitness", None)
+ self.run_callbacks("on_fit_epoch_end")
+ def plot_training_samples(self, batch, ni):
+ """Plots training samples with their annotations."""
+ plot_images(
+ images=batch["img"],
+ batch_idx=torch.arange(len(batch["img"])),
+ cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
+ fname=self.save_dir / f"train_batch{ni}.jpg",
+ on_plot=self.on_plot,
@@ -0,0 +1,117 @@
+from ultralytics.engine.validator import BaseValidator
+from ultralytics.utils import LOGGER
+from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
+from ultralytics.utils.plotting import plot_images
+class ClassificationValidator(BaseValidator):
+ A class extending the BaseValidator class for validation based on a classification model.
+ from ultralytics.models.yolo.classify import ClassificationValidator
+ args = dict(model="yolov8n-cls.pt", data="imagenet10")
+ validator = ClassificationValidator(args=args)
+ """Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
+ self.targets = None
+ self.pred = None
+ self.metrics = ClassifyMetrics()
+ def get_desc(self):
+ """Returns a formatted string summarizing classification metrics."""
+ return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
+ def init_metrics(self, model):
+ """Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
+ self.names = model.names
+ self.nc = len(model.names)
+ self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify")
+ self.pred = []
+ self.targets = []
+ def preprocess(self, batch):
+ """Preprocesses input batch and returns it."""
+ batch["img"] = batch["img"].to(self.device, non_blocking=True)
+ batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
+ def update_metrics(self, preds, batch):
+ """Updates running metrics with model predictions and batch targets."""
+ n5 = min(len(self.names), 5)
+ self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
+ self.targets.append(batch["cls"].type(torch.int32).cpu())
+ def finalize_metrics(self, *args, **kwargs):
+ """Finalizes metrics of the model such as confusion_matrix and speed."""
+ self.confusion_matrix.process_cls_preds(self.pred, self.targets)
+ if self.args.plots:
+ for normalize in True, False:
+ self.confusion_matrix.plot(
+ save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
+ self.metrics.speed = self.speed
+ self.metrics.confusion_matrix = self.confusion_matrix
+ self.metrics.save_dir = self.save_dir
+ """Preprocesses the classification predictions."""
+ return preds[0] if isinstance(preds, (list, tuple)) else preds
+ def get_stats(self):
+ """Returns a dictionary of metrics obtained by processing targets and predictions."""
+ self.metrics.process(self.targets, self.pred)
+ return self.metrics.results_dict
+ def build_dataset(self, img_path):
+ """Creates and returns a ClassificationDataset instance using given image path and preprocessing parameters."""
+ return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
+ def get_dataloader(self, dataset_path, batch_size):
+ """Builds and returns a data loader for classification tasks with given parameters."""
+ dataset = self.build_dataset(dataset_path)
+ return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
+ def print_results(self):
+ """Prints evaluation metrics for YOLO object detection model."""
+ pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
+ LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
+ def plot_val_samples(self, batch, ni):
+ """Plot validation image samples."""
+ fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+ names=self.names,
+ def plot_predictions(self, batch, preds, ni):
+ """Plots predicted bounding boxes on input images and saves the result."""
+ batch["img"],
+ cls=torch.argmax(preds, dim=1),
+ fname=self.save_dir / f"val_batch{ni}_pred.jpg",
+ ) # pred
+from .predict import DetectionPredictor
+from .train import DetectionTrainer
+from .val import DetectionValidator
+__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"
@@ -0,0 +1,41 @@
+class DetectionPredictor(BasePredictor):
+ A class extending the BasePredictor class for prediction based on a detection model.
+ from ultralytics.models.yolo.detect import DetectionPredictor
+ args = dict(model="yolo11n.pt", source=ASSETS)
+ predictor = DetectionPredictor(overrides=args)
+ """Post-processes predictions and returns a list of Results objects."""
+import random
+from ultralytics.data import build_dataloader, build_yolo_dataset
+from ultralytics.nn.tasks import DetectionModel
+from ultralytics.utils import LOGGER, RANK
+from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
+from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
+class DetectionTrainer(BaseTrainer):
+ A class extending the BaseTrainer class for training based on a detection model.
+ from ultralytics.models.yolo.detect import DetectionTrainer
+ args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
+ trainer = DetectionTrainer(overrides=args)
+ Build YOLO Dataset.
+ gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
+ return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
+ """Construct and return dataloader."""
+ assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
+ dataset = self.build_dataset(dataset_path, mode, batch_size)
+ shuffle = mode == "train"
+ if getattr(dataset, "rect", False) and shuffle:
+ LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
+ shuffle = False
+ workers = self.args.workers if mode == "train" else self.args.workers * 2
+ return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
+ """Preprocesses a batch of images by scaling and converting to float."""
+ batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
+ if self.args.multi_scale:
+ imgs = batch["img"]
+ sz = (
+ random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
+ // self.stride
+ * self.stride
+ ) # size
+ sf = sz / max(imgs.shape[2:]) # scale factor
+ if sf != 1:
+ ns = [
+ math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
+ ] # new shape (stretched to gs-multiple)
+ imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
+ batch["img"] = imgs
+ """Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
+ # self.args.box *= 3 / nl # scale to layers
+ # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
+ # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
+ self.model.nc = self.data["nc"] # attach number of classes to model
+ self.model.names = self.data["names"] # attach class names to model
+ self.model.args = self.args # attach hyperparameters to model
+ # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
+ """Return a YOLO detection model."""
+ model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
+ """Returns a DetectionValidator for YOLO model validation."""
+ self.loss_names = "box_loss", "cls_loss", "dfl_loss"
+ return yolo.detect.DetectionValidator(
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+ if loss_items is not None:
+ loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
+ """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
+ batch_idx=batch["batch_idx"],
+ cls=batch["cls"].squeeze(-1),
+ bboxes=batch["bboxes"],
+ paths=batch["im_file"],
+ plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
+ def plot_training_labels(self):
+ """Create a labeled training plot of the YOLO model."""
+ boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
+ cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
+ plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
+ def auto_batch(self):
+ """Get batch size by calculating memory occupation of model."""
+ train_dataset = self.build_dataset(self.trainset, mode="train", batch=16)
+ # 4 for mosaic augmentation
+ max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4
+ return super().auto_batch(max_num_obj)
@@ -0,0 +1,337 @@
+import os
+from ultralytics.data import build_dataloader, build_yolo_dataset, converter
+from ultralytics.utils import LOGGER, ops
+from ultralytics.utils.checks import check_requirements
+from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
+from ultralytics.utils.plotting import output_to_target, plot_images
+class DetectionValidator(BaseValidator):
+ A class extending the BaseValidator class for validation based on a detection model.
+ from ultralytics.models.yolo.detect import DetectionValidator
+ args = dict(model="yolo11n.pt", data="coco8.yaml")
+ validator = DetectionValidator(args=args)
+ """Initialize detection model with necessary variables and settings."""
+ self.nt_per_class = None
+ self.nt_per_image = None
+ self.is_coco = False
+ self.is_lvis = False
+ self.class_map = None
+ self.args.task = "detect"
+ self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
+ self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
+ self.niou = self.iouv.numel()
+ self.lb = [] # for autolabelling
+ if self.args.save_hybrid:
+ LOGGER.warning(
+ "WARNING ⚠️ 'save_hybrid=True' will append ground truth to predictions for autolabelling.\n"
+ "WARNING ⚠️ 'save_hybrid=True' will cause incorrect mAP.\n"
+ """Preprocesses batch of images for YOLO training."""
+ batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
+ for k in ["batch_idx", "cls", "bboxes"]:
+ batch[k] = batch[k].to(self.device)
+ height, width = batch["img"].shape[2:]
+ nb = len(batch["img"])
+ bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
+ self.lb = [
+ torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
+ for i in range(nb)
+ """Initialize evaluation metrics for YOLO."""
+ val = self.data.get(self.args.split, "") # validation path
+ self.is_coco = (
+ isinstance(val, str)
+ and "coco" in val
+ and (val.endswith(f"{os.sep}val2017.txt") or val.endswith(f"{os.sep}test-dev2017.txt"))
+ ) # is COCO
+ self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
+ self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1))
+ self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val
+ self.metrics.names = self.names
+ self.metrics.plot = self.args.plots
+ self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
+ self.seen = 0
+ self.jdict = []
+ self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
+ """Return a formatted string summarizing class metrics of YOLO model."""
+ return ("%22s" + "%11s" * 7) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP75","mAP50-95)")
+ multi_label=True,
+ """Prepares a batch of images and annotations for validation."""
+ bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
+ ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
+ ops.scale_boxes(
+ pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
+ ) # native-space pred
+ return predn
+ """Metrics."""
+ for si, pred in enumerate(preds):
+ self.seen += 1
+ npr = len(pred)
+ stat = dict(
+ conf=torch.zeros(0, device=self.device),
+ pred_cls=torch.zeros(0, device=self.device),
+ tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+ pbatch = self._prepare_batch(si, batch)
+ cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
+ nl = len(cls)
+ stat["target_cls"] = cls
+ stat["target_img"] = cls.unique()
+ if npr == 0:
+ if nl:
+ for k in self.stats.keys():
+ self.stats[k].append(stat[k])
+ self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
+ # Predictions
+ if self.args.single_cls:
+ pred[:, 5] = 0
+ predn = self._prepare_pred(pred, pbatch)
+ stat["conf"] = predn[:, 4]
+ stat["pred_cls"] = predn[:, 5]
+ # Evaluate
+ stat["tp"] = self._process_batch(predn, bbox, cls)
+ self.confusion_matrix.process_batch(predn, bbox, cls)
+ # Save
+ if self.args.save_json:
+ self.pred_to_json(predn, batch["im_file"][si])
+ if self.args.save_txt:
+ self.save_one_txt(
+ predn,
+ self.args.save_conf,
+ pbatch["ori_shape"],
+ self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
+ """Set final values for metrics speed and confusion matrix."""
+ """Returns metrics statistics and results dictionary."""
+ stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
+ self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
+ self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
+ stats.pop("target_img", None)
+ if len(stats) and stats["tp"].any():
+ self.metrics.process(**stats)
+ """Prints training/validation set metrics per class."""
+ pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
+ LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
+ if self.nt_per_class.sum() == 0:
+ LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels")
+ # Print results per class
+ if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
+ for i, c in enumerate(self.metrics.ap_class_index):
+ LOGGER.info(
+ pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))
+ def _process_batch(self, detections, gt_bboxes, gt_cls):
+ Return correct prediction matrix.
+ detections (torch.Tensor): Tensor of shape (N, 6) representing detections where each detection is
+ (x1, y1, x2, y2, conf, class).
+ gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground-truth bounding box coordinates. Each
+ bounding box is of the format: (x1, y1, x2, y2).
+ gt_cls (torch.Tensor): Tensor of shape (M,) representing target class indices.
+ (torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.
+ The function does not return any value directly usable for metrics calculation. Instead, it provides an
+ intermediate representation used for evaluating predictions against ground truth.
+ iou = box_iou(gt_bboxes, detections[:, :4])
+ return self.match_predictions(detections[:, 5], gt_cls, iou)
+ return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
+ dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
+ return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
+ batch["batch_idx"],
+ batch["cls"].squeeze(-1),
+ batch["bboxes"],
+ *output_to_target(preds, max_det=self.args.max_det),
+ def save_one_txt(self, predn, save_conf, shape, file):
+ """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
+ from ultralytics.engine.results import Results
+ Results(
+ np.zeros((shape[0], shape[1]), dtype=np.uint8),
+ path=None,
+ boxes=predn[:, :6],
+ ).save_txt(file, save_conf=save_conf)
+ def pred_to_json(self, predn, filename):
+ """Serialize YOLO predictions to COCO json format."""
+ stem = Path(filename).stem
+ image_id = int(stem) if stem.isnumeric() else stem
+ box = ops.xyxy2xywh(predn[:, :4]) # xywh
+ box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
+ for p, b in zip(predn.tolist(), box.tolist()):
+ self.jdict.append(
+ {
+ "image_id": image_id,
+ "category_id": self.class_map[int(p[5])],
+ "bbox": [round(x, 3) for x in b],
+ "score": round(p[4], 5),
+ def eval_json(self, stats):
+ """Evaluates YOLO output in JSON format and returns performance statistics."""
+ if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
+ pred_json = self.save_dir / "predictions.json" # predictions
+ anno_json = (
+ self.data["path"]
+ / "annotations"
+ / ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
+ ) # annotations
+ pkg = "pycocotools" if self.is_coco else "lvis"
+ LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
+ try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
+ for x in pred_json, anno_json:
+ assert x.is_file(), f"{x} file not found"
+ check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
+ if self.is_coco:
+ from pycocotools.coco import COCO # noqa
+ from pycocotools.cocoeval import COCOeval # noqa
+ anno = COCO(str(anno_json)) # init annotations api
+ pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
+ val = COCOeval(anno, pred, "bbox")
+ from lvis import LVIS, LVISEval
+ anno = LVIS(str(anno_json)) # init annotations api
+ pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
+ val = LVISEval(anno, pred, "bbox")
+ val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
+ val.evaluate()
+ val.accumulate()
+ val.summarize()
+ if self.is_lvis:
+ val.print_results() # explicitly call print_results
+ # update mAP50-95 and mAP50
+ stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
+ val.stats[:2] if self.is_coco else [val.results["AP50"], val.results["AP"]]
+ except Exception as e:
+ LOGGER.warning(f"{pkg} unable to run: {e}")
+ return stats
@@ -0,0 +1,111 @@
+from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
+from ultralytics.utils import ROOT, yaml_load
+class YOLO(Model):
+ """YOLO (You Only Look Once) object detection model."""
+ def __init__(self, model="yolo11n.pt", task=None, verbose=False):
+ """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
+ path = Path(model)
+ if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
+ new_instance = YOLOWorld(path, verbose=verbose)
+ self.__class__ = type(new_instance)
+ self.__dict__ = new_instance.__dict__
+ # Continue with default YOLO initialization
+ super().__init__(model=model, task=task, verbose=verbose)
+ """Map head to model, trainer, validator, and predictor classes."""
+ "classify": {
+ "model": ClassificationModel,
+ "trainer": yolo.classify.ClassificationTrainer,
+ "validator": yolo.classify.ClassificationValidator,
+ "predictor": yolo.classify.ClassificationPredictor,
+ "model": DetectionModel,
+ "trainer": yolo.detect.DetectionTrainer,
+ "validator": yolo.detect.DetectionValidator,
+ "predictor": yolo.detect.DetectionPredictor,
+ "segment": {
+ "model": SegmentationModel,
+ "trainer": yolo.segment.SegmentationTrainer,
+ "validator": yolo.segment.SegmentationValidator,
+ "predictor": yolo.segment.SegmentationPredictor,
+ "pose": {
+ "model": PoseModel,
+ "trainer": yolo.pose.PoseTrainer,
+ "validator": yolo.pose.PoseValidator,
+ "predictor": yolo.pose.PosePredictor,
+ "obb": {
+ "model": OBBModel,
+ "trainer": yolo.obb.OBBTrainer,
+ "validator": yolo.obb.OBBValidator,
+ "predictor": yolo.obb.OBBPredictor,
+class YOLOWorld(Model):
+ """YOLO-World object detection model."""
+ def __init__(self, model="yolov8s-world.pt", verbose=False) -> None:
+ Initialize YOLOv8-World model with a pre-trained model file.
+ Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
+ COCO class names.
+ model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
+ verbose (bool): If True, prints additional information during initialization.
+ super().__init__(model=model, task="detect", verbose=verbose)
+ # Assign default COCO class names when there are no custom names
+ if not hasattr(self.model, "names"):
+ self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")
+ """Map head to model, validator, and predictor classes."""
+ "model": WorldModel,
+ "trainer": yolo.world.WorldTrainer,
+ def set_classes(self, classes):
+ Set classes.
+ classes (List(str)): A list of categories i.e. ["person"].
+ self.model.set_classes(classes)
+ # Remove background if it's given
+ background = " "
+ if background in classes:
+ classes.remove(background)
+ self.model.names = classes
+ # Reset method class names
+ # self.predictor = None # reset predictor otherwise old names remain
+ if self.predictor:
+ self.predictor.model.names = classes
+from .predict import OBBPredictor
+from .train import OBBTrainer
+from .val import OBBValidator
+__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"
@@ -0,0 +1,53 @@
+from ultralytics.models.yolo.detect.predict import DetectionPredictor
+class OBBPredictor(DetectionPredictor):
+ A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
+ from ultralytics.models.yolo.obb import OBBPredictor
+ args = dict(model="yolov8n-obb.pt", source=ASSETS)
+ predictor = OBBPredictor(overrides=args)
+ """Initializes OBBPredictor with optional model and data configuration overrides."""
+ self.args.task = "obb"
+ nc=len(self.model.names),
+ rotated=True,
+ rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
+ rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
+ # xywh, r, conf, cls
+ obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
+ results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
@@ -0,0 +1,44 @@
+from ultralytics.nn.tasks import OBBModel
+from ultralytics.utils import DEFAULT_CFG, RANK
+class OBBTrainer(yolo.detect.DetectionTrainer):
+ A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
+ from ultralytics.models.yolo.obb import OBBTrainer
+ args = dict(model="yolov8n-obb.pt", data="dota8.yaml", epochs=3)
+ trainer = OBBTrainer(overrides=args)
+ """Initialize a OBBTrainer object with given arguments."""
+ overrides["task"] = "obb"
+ """Return OBBModel initialized with specified config and weights."""
+ model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
+ """Return an instance of OBBValidator for validation of YOLO model."""
+ return yolo.obb.OBBValidator(
@@ -0,0 +1,203 @@
+from ultralytics.utils.metrics import OBBMetrics, batch_probiou
+from ultralytics.utils.plotting import output_to_rotated_target, plot_images
+class OBBValidator(DetectionValidator):
+ A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
+ from ultralytics.models.yolo.obb import OBBValidator
+ args = dict(model="yolov8n-obb.pt", data="dota8.yaml")
+ validator = OBBValidator(args=args)
+ validator(model=args["model"])
+ """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
+ self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
+ super().init_metrics(model)
+ self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
+ nc=self.nc,
+ Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
+ detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated
+ data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle).
+ gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is
+ represented as (x1, y1, x2, y2, angle).
+ gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes.
+ (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
+ Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.
+ detections = torch.rand(100, 7) # 100 sample detections
+ gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
+ gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
+ correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
+ This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
+ iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
+ """Prepares and returns a batch for OBB validation."""
+ bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
+ ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
+ """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
+ pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
+ *output_to_rotated_target(preds, max_det=self.args.max_det),
+ rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
+ poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
+ for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
+ "category_id": self.class_map[int(predn[i, 5].item())],
+ "score": round(predn[i, 4].item(), 5),
+ "rbox": [round(x, 3) for x in r],
+ "poly": [round(x, 3) for x in b],
+ import numpy as np
+ rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
+ obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
+ obb=obb,
+ if self.args.save_json and self.is_dota and len(self.jdict):
+ import json
+ import re
+ from collections import defaultdict
+ pred_txt = self.save_dir / "predictions_txt" # predictions
+ pred_txt.mkdir(parents=True, exist_ok=True)
+ data = json.load(open(pred_json))
+ # Save split results
+ LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
+ for d in data:
+ image_id = d["image_id"]
+ score = d["score"]
+ classname = self.names[d["category_id"] - 1].replace(" ", "-")
+ p = d["poly"]
+ with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a") as f:
+ f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
+ # Save merged results, this could result slightly lower map than using official merging script,
+ # because of the probiou calculation.
+ pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
+ pred_merged_txt.mkdir(parents=True, exist_ok=True)
+ merged_results = defaultdict(list)
+ LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
+ image_id = d["image_id"].split("__")[0]
+ pattern = re.compile(r"\d+___\d+")
+ x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
+ bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
+ bbox[0] += x
+ bbox[1] += y
+ bbox.extend([score, cls])
+ merged_results[image_id].append(bbox)
+ for image_id, bbox in merged_results.items():
+ bbox = torch.tensor(bbox)
+ max_wh = torch.max(bbox[:, :2]).item() * 2
+ c = bbox[:, 6:7] * max_wh # classes
+ scores = bbox[:, 5] # scores
+ b = bbox[:, :5].clone()
+ b[:, :2] += c
+ # 0.3 could get results close to the ones from official merging script, even slightly better.
+ i = ops.nms_rotated(b, scores, 0.3)
+ bbox = bbox[i]
+ b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
+ for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
+ classname = self.names[int(x[-1])].replace(" ", "-")
+ p = [round(i, 3) for i in x[:-2]] # poly
+ score = round(x[-2], 3)
+ with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a") as f:
+from .predict import PosePredictor
+from .train import PoseTrainer
+from .val import PoseValidator
+__all__ = "PoseTrainer", "PoseValidator", "PosePredictor"
@@ -0,0 +1,56 @@
+from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
+class PosePredictor(DetectionPredictor):
+ A class extending the DetectionPredictor class for prediction based on a pose model.
+ from ultralytics.models.yolo.pose import PosePredictor
+ args = dict(model="yolov8n-pose.pt", source=ASSETS)
+ predictor = PosePredictor(overrides=args)
+ """Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
+ self.args.task = "pose"
+ if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
+ "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
+ "See https://github.com/ultralytics/ultralytics/issues/4031."
+ """Return detection results for a given input image or list of images."""
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
+ pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
+ pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
+ results.append(
+ Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
@@ -0,0 +1,79 @@
+from ultralytics.nn.tasks import PoseModel
+from ultralytics.utils import DEFAULT_CFG, LOGGER
+class PoseTrainer(yolo.detect.DetectionTrainer):
+ A class extending the DetectionTrainer class for training based on a pose model.
+ from ultralytics.models.yolo.pose import PoseTrainer
+ args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml", epochs=3)
+ trainer = PoseTrainer(overrides=args)
+ """Initialize a PoseTrainer object with specified configurations and overrides."""
+ overrides["task"] = "pose"
+ """Get pose estimation model with specified configuration and weights."""
+ model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
+ """Sets keypoints shape attribute of PoseModel."""
+ super().set_model_attributes()
+ self.model.kpt_shape = self.data["kpt_shape"]
+ """Returns an instance of the PoseValidator class for validation."""
+ self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
+ return yolo.pose.PoseValidator(
+ """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
+ images = batch["img"]
+ kpts = batch["keypoints"]
+ cls = batch["cls"].squeeze(-1)
+ bboxes = batch["bboxes"]
+ paths = batch["im_file"]
+ images,
+ batch_idx,
+ cls,
+ bboxes,
+ kpts=kpts,
+ paths=paths,
+ """Plots training/val metrics."""
+ plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
@@ -0,0 +1,282 @@
+from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
+class PoseValidator(DetectionValidator):
+ A class extending the DetectionValidator class for validation based on a pose model.
+ from ultralytics.models.yolo.pose import PoseValidator
+ args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml")
+ validator = PoseValidator(args=args)
+ """Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
+ self.sigma = None
+ self.kpt_shape = None
+ self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
+ """Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
+ batch = super().preprocess(batch)
+ batch["keypoints"] = batch["keypoints"].to(self.device).float()
+ """Returns description of evaluation metrics in string format."""
+ return ("%22s" + "%11s" * 10) % (
+ "Class",
+ "Images",
+ "Box(P",
+ "R",
+ "mAP50",
+ "mAP50-95)",
+ "Pose(P",
+ """Apply non-maximum suppression and return detections with high confidence scores."""
+ """Initiate pose estimation metrics for YOLO model."""
+ self.kpt_shape = self.data["kpt_shape"]
+ is_pose = self.kpt_shape == [17, 3]
+ nkpt = self.kpt_shape[0]
+ self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
+ self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
+ """Prepares a batch for processing by converting keypoints to float and moving to device."""
+ pbatch = super()._prepare_batch(si, batch)
+ kpts = batch["keypoints"][batch["batch_idx"] == si]
+ h, w = pbatch["imgsz"]
+ kpts = kpts.clone()
+ kpts[..., 0] *= w
+ kpts[..., 1] *= h
+ kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
+ pbatch["kpts"] = kpts
+ return pbatch
+ """Prepares and scales keypoints in a batch for pose processing."""
+ predn = super()._prepare_pred(pred, pbatch)
+ nk = pbatch["kpts"].shape[1]
+ pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
+ ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
+ return predn, pred_kpts
+ tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+ predn, pred_kpts = self._prepare_pred(pred, pbatch)
+ stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
+ pred_kpts,
+ def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
+ Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
+ detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each
+ detection is of the format (x1, y1, x2, y2, conf, class).
+ gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each
+ box is of the format (x1, y1, x2, y2).
+ gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
+ pred_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing predicted keypoints, where
+ 51 corresponds to 17 keypoints each having 3 values.
+ gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints.
+ torch.Tensor: A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
+ where N is the number of detections.
+ detections = torch.rand(100, 6) # 100 predictions: (x1, y1, x2, y2, conf, class)
+ gt_bboxes = torch.rand(50, 4) # 50 ground truth boxes: (x1, y1, x2, y2)
+ gt_cls = torch.randint(0, 2, (50,)) # 50 ground truth class indices
+ pred_kpts = torch.rand(100, 51) # 100 predicted keypoints
+ gt_kpts = torch.rand(50, 51) # 50 ground truth keypoints
+ correct_preds = _process_batch(detections, gt_bboxes, gt_cls, pred_kpts, gt_kpts)
+ `0.53` scale factor used in area computation is referenced from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
+ if pred_kpts is not None and gt_kpts is not None:
+ # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
+ area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53
+ iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
+ else: # boxes
+ """Plots and saves validation set samples with predicted bounding boxes and keypoints."""
+ kpts=batch["keypoints"],
+ """Plots predictions for YOLO model."""
+ pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
+ kpts=pred_kpts,
+ def save_one_txt(self, predn, pred_kpts, save_conf, shape, file):
+ keypoints=pred_kpts,
+ """Converts YOLO predictions to COCO JSON format."""
+ "keypoints": p[6:],
+ """Evaluates object detection model using COCO JSON format."""
+ if self.args.save_json and self.is_coco and len(self.jdict):
+ anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
+ LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
+ check_requirements("pycocotools>=2.0.6")
+ for x in anno_json, pred_json:
+ for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
+ eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
+ eval.evaluate()
+ eval.accumulate()
+ eval.summarize()
+ idx = i * 4 + 2
+ stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
+ :2
+ ] # update mAP50-95 and mAP50
+ LOGGER.warning(f"pycocotools unable to run: {e}")
+from .predict import SegmentationPredictor
+from .train import SegmentationTrainer
+from .val import SegmentationValidator
+__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
+class SegmentationPredictor(DetectionPredictor):
+ A class extending the DetectionPredictor class for prediction based on a segmentation model.
+ from ultralytics.models.yolo.segment import SegmentationPredictor
+ args = dict(model="yolov8n-seg.pt", source=ASSETS)
+ predictor = SegmentationPredictor(overrides=args)
+ """Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks."""
+ """Applies non-max suppression and processes detections for each image in an input batch."""
+ p = ops.non_max_suppression(
+ preds[0],
+ proto = preds[1][-1] if isinstance(preds[1], tuple) else preds[1] # tuple if PyTorch model or array if exported
+ for i, (pred, orig_img, img_path) in enumerate(zip(p, orig_imgs, self.batch[0])):
+ if not len(pred): # save empty boxes
+ masks = None
+ elif self.args.retina_masks:
+ masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
+ masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
+ results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
@@ -0,0 +1,62 @@
+from ultralytics.nn.tasks import SegmentationModel
+class SegmentationTrainer(yolo.detect.DetectionTrainer):
+ A class extending the DetectionTrainer class for training based on a segmentation model.
+ from ultralytics.models.yolo.segment import SegmentationTrainer
+ args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml", epochs=3)
+ trainer = SegmentationTrainer(overrides=args)
+ """Initialize a SegmentationTrainer object with given arguments."""
+ overrides["task"] = "segment"
+ """Return SegmentationModel initialized with specified config and weights."""
+ model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
+ """Return an instance of SegmentationValidator for validation of YOLO model."""
+ self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
+ return yolo.segment.SegmentationValidator(
+ """Creates a plot of training sample images with labels and box coordinates."""
+ masks=batch["masks"],
+ plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
@@ -0,0 +1,318 @@
+from multiprocessing.pool import ThreadPool
+from ultralytics.utils import LOGGER, NUM_THREADS, ops
+from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
+class SegmentationValidator(DetectionValidator):
+ A class extending the DetectionValidator class for validation based on a segmentation model.
+ from ultralytics.models.yolo.segment import SegmentationValidator
+ args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml")
+ validator = SegmentationValidator(args=args)
+ """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
+ self.plot_masks = None
+ self.process = None
+ """Preprocesses batch by converting masks to float and sending to device."""
+ batch["masks"] = batch["masks"].to(self.device).float()
+ """Initialize metrics and select mask processing function based on save_json flag."""
+ self.plot_masks = []
+ # more accurate vs faster
+ self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
+ self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
+ """Return a formatted description of evaluation metrics."""
+ "Mask(P",
+ """Post-processes YOLO predictions and returns output detections with proto."""
+ proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
+ return p, proto
+ """Prepares a batch for training or inference by processing images and targets."""
+ prepared_batch = super()._prepare_batch(si, batch)
+ midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
+ prepared_batch["masks"] = batch["masks"][midx]
+ return prepared_batch
+ def _prepare_pred(self, pred, pbatch, proto):
+ pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
+ return predn, pred_masks
+ for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
+ tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+ # Masks
+ gt_masks = pbatch.pop("masks")
+ predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
+ stat["tp_m"] = self._process_batch(
+ predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
+ pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
+ if self.args.plots and self.batch_i < 3:
+ self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
+ self.pred_to_json(
+ batch["im_file"][si],
+ ops.scale_image(
+ pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
+ ratio_pad=batch["ratio_pad"][si],
+ pred_masks,
+ """Sets speed and confusion matrix for evaluation metrics."""
+ def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False):
+ Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
+ detections (torch.Tensor): Tensor of shape (N, 6) representing detected bounding boxes and
+ associated confidence scores and class indices. Each row is of the format [x1, y1, x2, y2, conf, class].
+ gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
+ Each row is of the format [x1, y1, x2, y2].
+ gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
+ pred_masks (torch.Tensor | None): Tensor representing predicted masks, if available. The shape should
+ match the ground truth masks.
+ gt_masks (torch.Tensor | None): Tensor of shape (M, H, W) representing ground truth masks, if available.
+ overlap (bool): Flag indicating if overlapping masks should be considered.
+ masks (bool): Flag indicating if the batch contains mask data.
+ (torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels.
+ - If `masks` is True, the function computes IoU between predicted and ground truth masks.
+ - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
+ detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]])
+ gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]])
+ gt_cls = torch.tensor([1, 0])
+ correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
+ if masks:
+ if overlap:
+ nl = len(gt_cls)
+ index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
+ gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
+ gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
+ if gt_masks.shape[1:] != pred_masks.shape[1:]:
+ gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
+ gt_masks = gt_masks.gt_(0.5)
+ iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
+ """Plots validation samples with bounding box labels."""
+ """Plots batch predictions with masks and bounding boxes."""
+ *output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
+ torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
+ self.plot_masks.clear()
+ def save_one_txt(self, predn, pred_masks, save_conf, shape, file):
+ masks=pred_masks,
+ def pred_to_json(self, predn, filename, pred_masks):
+ Save one JSON result.
+ >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
+ from pycocotools.mask import encode # noqa
+ def single_encode(x):
+ """Encode predicted masks as RLE and append results to jdict."""
+ rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
+ rle["counts"] = rle["counts"].decode("utf-8")
+ return rle
+ pred_masks = np.transpose(pred_masks, (2, 0, 1))
+ with ThreadPool(NUM_THREADS) as pool:
+ rles = pool.map(single_encode, pred_masks)
+ for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
+ "segmentation": rles[i],
+ """Return COCO-style object detection evaluation metrics."""
+ anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
+ for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]):
@@ -0,0 +1,5 @@
+from .train import WorldTrainer
+__all__ = ["WorldTrainer"]
@@ -0,0 +1,92 @@
+from ultralytics.data import build_yolo_dataset
+from ultralytics.nn.tasks import WorldModel
+from ultralytics.utils import DEFAULT_CFG, RANK, checks
+from ultralytics.utils.torch_utils import de_parallel
+def on_pretrain_routine_end(trainer):
+ """Callback."""
+ if RANK in {-1, 0}:
+ # NOTE: for evaluation
+ names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
+ de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
+ device = next(trainer.model.parameters()).device
+ trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device)
+ for p in trainer.text_model.parameters():
+ p.requires_grad_(False)
+class WorldTrainer(yolo.detect.DetectionTrainer):
+ A class to fine-tune a world model on a close-set dataset.
+ from ultralytics.models.yolo.world import WorldModel
+ args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
+ trainer = WorldTrainer(overrides=args)
+ """Initialize a WorldTrainer object with given arguments."""
+ # Import and assign clip
+ self.clip = clip
+ """Return WorldModel initialized with specified config and weights."""
+ # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
+ # NOTE: Following the official config, nc hard-coded to 80 for now.
+ model = WorldModel(
+ cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
+ ch=3,
+ nc=min(self.data["nc"], 80),
+ verbose=verbose and RANK == -1,
+ self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
+ return build_yolo_dataset(
+ self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
+ """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
+ # NOTE: add text features
+ texts = list(itertools.chain(*batch["texts"]))
+ text_token = self.clip.tokenize(texts).to(batch["img"].device)
+ txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
+ txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
+ batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
@@ -0,0 +1,109 @@
+from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset
+from ultralytics.data.utils import check_det_dataset
+from ultralytics.models.yolo.world import WorldTrainer
+from ultralytics.utils import DEFAULT_CFG
+class WorldTrainerFromScratch(WorldTrainer):
+ A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
+ from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
+ from ultralytics import YOLOWorld
+ data = dict(
+ train=dict(
+ yolo_data=["Objects365.yaml"],
+ grounding_data=[
+ dict(
+ img_path="../datasets/flickr30k/images",
+ json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
+ img_path="../datasets/GQA/images",
+ json_file="../datasets/GQA/final_mixed_train_no_coco.json",
+ val=dict(yolo_data=["lvis.yaml"]),
+ model = YOLOWorld("yolov8s-worldv2.yaml")
+ model.train(data=data, trainer=WorldTrainerFromScratch)
+ img_path (List[str] | str): Path to the folder containing images.
+ dataset = [
+ build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
+ if isinstance(im_path, str)
+ else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
+ for im_path in img_path
+ return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
+ def get_dataset(self):
+ Get train, val path from data dict if it exists.
+ Returns None if data format is not recognized.
+ final_data = {}
+ data_yaml = self.args.data
+ assert data_yaml.get("train", False), "train dataset not found" # object365.yaml
+ assert data_yaml.get("val", False), "validation dataset not found" # lvis.yaml
+ data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
+ assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
+ val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
+ for d in data["val"]:
+ if d.get("minival") is None: # for lvis dataset
+ d["minival"] = str(d["path"] / d["minival"])
+ for s in ["train", "val"]:
+ final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
+ # save grounding data if there's one
+ grounding_data = data_yaml[s].get("grounding_data")
+ if grounding_data is None:
+ grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]
+ for g in grounding_data:
+ assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
+ final_data[s] += grounding_data
+ # NOTE: to make training work properly, set `nc` and `names`
+ final_data["nc"] = data["val"][0]["nc"]
+ final_data["names"] = data["val"][0]["names"]
+ self.data = final_data
+ return final_data["train"], final_data["val"][0]
+ """DO NOT plot labels."""
+ pass
+ """Performs final evaluation and validation for object detection YOLO-World model."""
+ val = self.args.data["val"]["yolo_data"][0]
+ self.validator.args.data = val
+ self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
+ return super().final_eval()