فهرست منبع

Add files via upload

Mengqi Lei 2 ماه پیش
والد
کامیت
63f2ab2675
57فایلهای تغییر یافته به همراه11572 افزوده شده و 0 حذف شده
  1. 9 0
      ultralytics/models/__init__.py
  2. 7 0
      ultralytics/models/fastsam/__init__.py
  3. 55 0
      ultralytics/models/fastsam/model.py
  4. 150 0
      ultralytics/models/fastsam/predict.py
  5. 24 0
      ultralytics/models/fastsam/utils.py
  6. 40 0
      ultralytics/models/fastsam/val.py
  7. 7 0
      ultralytics/models/nas/__init__.py
  8. 94 0
      ultralytics/models/nas/model.py
  9. 57 0
      ultralytics/models/nas/predict.py
  10. 50 0
      ultralytics/models/nas/val.py
  11. 7 0
      ultralytics/models/rtdetr/__init__.py
  12. 54 0
      ultralytics/models/rtdetr/model.py
  13. 84 0
      ultralytics/models/rtdetr/predict.py
  14. 105 0
      ultralytics/models/rtdetr/train.py
  15. 135 0
      ultralytics/models/rtdetr/val.py
  16. 6 0
      ultralytics/models/sam/__init__.py
  17. 193 0
      ultralytics/models/sam/amg.py
  18. 358 0
      ultralytics/models/sam/build.py
  19. 175 0
      ultralytics/models/sam/model.py
  20. 1 0
      ultralytics/models/sam/modules/__init__.py
  21. 1129 0
      ultralytics/models/sam/modules/blocks.py
  22. 518 0
      ultralytics/models/sam/modules/decoders.py
  23. 794 0
      ultralytics/models/sam/modules/encoders.py
  24. 237 0
      ultralytics/models/sam/modules/memory_attention.py
  25. 1013 0
      ultralytics/models/sam/modules/sam.py
  26. 1013 0
      ultralytics/models/sam/modules/tiny_encoder.py
  27. 373 0
      ultralytics/models/sam/modules/transformer.py
  28. 293 0
      ultralytics/models/sam/modules/utils.py
  29. 1605 0
      ultralytics/models/sam/predict.py
  30. 1 0
      ultralytics/models/utils/__init__.py
  31. 357 0
      ultralytics/models/utils/loss.py
  32. 259 0
      ultralytics/models/utils/ops.py
  33. 7 0
      ultralytics/models/yolo/__init__.py
  34. 7 0
      ultralytics/models/yolo/classify/__init__.py
  35. 60 0
      ultralytics/models/yolo/classify/predict.py
  36. 153 0
      ultralytics/models/yolo/classify/train.py
  37. 117 0
      ultralytics/models/yolo/classify/val.py
  38. 7 0
      ultralytics/models/yolo/detect/__init__.py
  39. 41 0
      ultralytics/models/yolo/detect/predict.py
  40. 150 0
      ultralytics/models/yolo/detect/train.py
  41. 337 0
      ultralytics/models/yolo/detect/val.py
  42. 111 0
      ultralytics/models/yolo/model.py
  43. 7 0
      ultralytics/models/yolo/obb/__init__.py
  44. 53 0
      ultralytics/models/yolo/obb/predict.py
  45. 44 0
      ultralytics/models/yolo/obb/train.py
  46. 203 0
      ultralytics/models/yolo/obb/val.py
  47. 7 0
      ultralytics/models/yolo/pose/__init__.py
  48. 56 0
      ultralytics/models/yolo/pose/predict.py
  49. 79 0
      ultralytics/models/yolo/pose/train.py
  50. 282 0
      ultralytics/models/yolo/pose/val.py
  51. 7 0
      ultralytics/models/yolo/segment/__init__.py
  52. 55 0
      ultralytics/models/yolo/segment/predict.py
  53. 62 0
      ultralytics/models/yolo/segment/train.py
  54. 318 0
      ultralytics/models/yolo/segment/val.py
  55. 5 0
      ultralytics/models/yolo/world/__init__.py
  56. 92 0
      ultralytics/models/yolo/world/train.py
  57. 109 0
      ultralytics/models/yolo/world/train_world.py

+ 9 - 0
ultralytics/models/__init__.py

@@ -0,0 +1,9 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from .fastsam import FastSAM
+from .nas import NAS
+from .rtdetr import RTDETR
+from .sam import SAM
+from .yolo import YOLO, YOLOWorld
+
+__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld"  # allow simpler import

+ 7 - 0
ultralytics/models/fastsam/__init__.py

@@ -0,0 +1,7 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from .model import FastSAM
+from .predict import FastSAMPredictor
+from .val import FastSAMValidator
+
+__all__ = "FastSAMPredictor", "FastSAM", "FastSAMValidator"

+ 55 - 0
ultralytics/models/fastsam/model.py

@@ -0,0 +1,55 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from pathlib import Path
+
+from ultralytics.engine.model import Model
+
+from .predict import FastSAMPredictor
+from .val import FastSAMValidator
+
+
+class FastSAM(Model):
+    """
+    FastSAM model interface.
+
+    Example:
+        ```python
+        from ultralytics import FastSAM
+
+        model = FastSAM("last.pt")
+        results = model.predict("ultralytics/assets/bus.jpg")
+        ```
+    """
+
+    def __init__(self, model="FastSAM-x.pt"):
+        """Call the __init__ method of the parent class (YOLO) with the updated default model."""
+        if str(model) == "FastSAM.pt":
+            model = "FastSAM-x.pt"
+        assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
+        super().__init__(model=model, task="segment")
+
+    def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs):
+        """
+        Perform segmentation prediction on image or video source.
+
+        Supports prompted segmentation with bounding boxes, points, labels, and texts.
+
+        Args:
+            source (str | PIL.Image | numpy.ndarray): Input source.
+            stream (bool): Enable real-time streaming.
+            bboxes (list): Bounding box coordinates for prompted segmentation.
+            points (list): Points for prompted segmentation.
+            labels (list): Labels for prompted segmentation.
+            texts (list): Texts for prompted segmentation.
+            **kwargs (Any): Additional keyword arguments.
+
+        Returns:
+            (list): Model predictions.
+        """
+        prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
+        return super().predict(source, stream, prompts=prompts, **kwargs)
+
+    @property
+    def task_map(self):
+        """Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
+        return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}

+ 150 - 0
ultralytics/models/fastsam/predict.py

@@ -0,0 +1,150 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import torch
+from PIL import Image
+
+from ultralytics.models.yolo.segment import SegmentationPredictor
+from ultralytics.utils import DEFAULT_CFG, checks
+from ultralytics.utils.metrics import box_iou
+from ultralytics.utils.ops import scale_masks
+
+from .utils import adjust_bboxes_to_image_border
+
+
+class FastSAMPredictor(SegmentationPredictor):
+    """
+    FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
+    YOLO framework.
+
+    This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
+    adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single-
+    class segmentation.
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """Initializes a FastSAMPredictor for fast SAM segmentation tasks in Ultralytics YOLO framework."""
+        super().__init__(cfg, overrides, _callbacks)
+        self.prompts = {}
+
+    def postprocess(self, preds, img, orig_imgs):
+        """Applies box postprocess for FastSAM predictions."""
+        bboxes = self.prompts.pop("bboxes", None)
+        points = self.prompts.pop("points", None)
+        labels = self.prompts.pop("labels", None)
+        texts = self.prompts.pop("texts", None)
+        results = super().postprocess(preds, img, orig_imgs)
+        for result in results:
+            full_box = torch.tensor(
+                [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
+            )
+            boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
+            idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
+            if idx.numel() != 0:
+                result.boxes.xyxy[idx] = full_box
+
+        return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
+
+    def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
+        """
+        Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
+        Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
+
+        Args:
+            results (Results | List[Results]): The original inference results from FastSAM models without any prompts.
+            bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
+            points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
+            labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
+            texts (str | List[str], optional): Textual prompts, a list contains string objects.
+
+        Returns:
+            (List[Results]): The output results determined by prompts.
+        """
+        if bboxes is None and points is None and texts is None:
+            return results
+        prompt_results = []
+        if not isinstance(results, list):
+            results = [results]
+        for result in results:
+            if len(result) == 0:
+                prompt_results.append(result)
+                continue
+            masks = result.masks.data
+            if masks.shape[1:] != result.orig_shape:
+                masks = scale_masks(masks[None], result.orig_shape)[0]
+            # bboxes prompt
+            idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
+            if bboxes is not None:
+                bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
+                bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
+                bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
+                mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
+                full_mask_areas = torch.sum(masks, dim=(1, 2))
+
+                union = bbox_areas[:, None] + full_mask_areas - mask_areas
+                idx[torch.argmax(mask_areas / union, dim=1)] = True
+            if points is not None:
+                points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
+                points = points[None] if points.ndim == 1 else points
+                if labels is None:
+                    labels = torch.ones(points.shape[0])
+                labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
+                assert len(labels) == len(points), (
+                    f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
+                )
+                point_idx = (
+                    torch.ones(len(result), dtype=torch.bool, device=self.device)
+                    if labels.sum() == 0  # all negative points
+                    else torch.zeros(len(result), dtype=torch.bool, device=self.device)
+                )
+                for point, label in zip(points, labels):
+                    point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
+                idx |= point_idx
+            if texts is not None:
+                if isinstance(texts, str):
+                    texts = [texts]
+                crop_ims, filter_idx = [], []
+                for i, b in enumerate(result.boxes.xyxy.tolist()):
+                    x1, y1, x2, y2 = (int(x) for x in b)
+                    if masks[i].sum() <= 100:
+                        filter_idx.append(i)
+                        continue
+                    crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
+                similarity = self._clip_inference(crop_ims, texts)
+                text_idx = torch.argmax(similarity, dim=-1)  # (M, )
+                if len(filter_idx):
+                    text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
+                idx[text_idx] = True
+
+            prompt_results.append(result[idx])
+
+        return prompt_results
+
+    def _clip_inference(self, images, texts):
+        """
+        CLIP Inference process.
+
+        Args:
+            images (List[PIL.Image]): A list of source images and each of them should be PIL.Image type with RGB channel order.
+            texts (List[str]): A list of prompt texts and each of them should be string object.
+
+        Returns:
+            (torch.Tensor): The similarity between given images and texts.
+        """
+        try:
+            import clip
+        except ImportError:
+            checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
+            import clip
+        if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
+            self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
+        images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
+        tokenized_text = clip.tokenize(texts).to(self.device)
+        image_features = self.clip_model.encode_image(images)
+        text_features = self.clip_model.encode_text(tokenized_text)
+        image_features /= image_features.norm(dim=-1, keepdim=True)  # (N, 512)
+        text_features /= text_features.norm(dim=-1, keepdim=True)  # (M, 512)
+        return (image_features * text_features[:, None]).sum(-1)  # (M, N)
+
+    def set_prompts(self, prompts):
+        """Set prompts in advance."""
+        self.prompts = prompts

+ 24 - 0
ultralytics/models/fastsam/utils.py

@@ -0,0 +1,24 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+
+def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
+    """
+    Adjust bounding boxes to stick to image border if they are within a certain threshold.
+
+    Args:
+        boxes (torch.Tensor): (n, 4)
+        image_shape (tuple): (height, width)
+        threshold (int): pixel threshold
+
+    Returns:
+        adjusted_boxes (torch.Tensor): adjusted bounding boxes
+    """
+    # Image dimensions
+    h, w = image_shape
+
+    # Adjust boxes
+    boxes[boxes[:, 0] < threshold, 0] = 0  # x1
+    boxes[boxes[:, 1] < threshold, 1] = 0  # y1
+    boxes[boxes[:, 2] > w - threshold, 2] = w  # x2
+    boxes[boxes[:, 3] > h - threshold, 3] = h  # y2
+    return boxes

+ 40 - 0
ultralytics/models/fastsam/val.py

@@ -0,0 +1,40 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.models.yolo.segment import SegmentationValidator
+from ultralytics.utils.metrics import SegmentMetrics
+
+
+class FastSAMValidator(SegmentationValidator):
+    """
+    Custom validation class for fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
+
+    Extends the SegmentationValidator class, customizing the validation process specifically for fast SAM. This class
+    sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
+    to avoid errors during validation.
+
+    Attributes:
+        dataloader: The data loader object used for validation.
+        save_dir (str): The directory where validation results will be saved.
+        pbar: A progress bar object.
+        args: Additional arguments for customization.
+        _callbacks: List of callback functions to be invoked during validation.
+    """
+
+    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
+        """
+        Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
+
+        Args:
+            dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
+            save_dir (Path, optional): Directory to save results.
+            pbar (tqdm.tqdm): Progress bar for displaying progress.
+            args (SimpleNamespace): Configuration for the validator.
+            _callbacks (dict): Dictionary to store various callback functions.
+
+        Notes:
+            Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
+        """
+        super().__init__(dataloader, save_dir, pbar, args, _callbacks)
+        self.args.task = "segment"
+        self.args.plots = False  # disable ConfusionMatrix and other plots to avoid errors
+        self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)

+ 7 - 0
ultralytics/models/nas/__init__.py

@@ -0,0 +1,7 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from .model import NAS
+from .predict import NASPredictor
+from .val import NASValidator
+
+__all__ = "NASPredictor", "NASValidator", "NAS"

+ 94 - 0
ultralytics/models/nas/model.py

@@ -0,0 +1,94 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""
+YOLO-NAS model interface.
+
+Example:
+    ```python
+    from ultralytics import NAS
+
+    model = NAS("yolo_nas_s")
+    results = model.predict("ultralytics/assets/bus.jpg")
+    ```
+"""
+
+from pathlib import Path
+
+import torch
+
+from ultralytics.engine.model import Model
+from ultralytics.utils.downloads import attempt_download_asset
+from ultralytics.utils.torch_utils import model_info
+
+from .predict import NASPredictor
+from .val import NASValidator
+
+
+class NAS(Model):
+    """
+    YOLO NAS model for object detection.
+
+    This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
+    It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
+
+    Example:
+        ```python
+        from ultralytics import NAS
+
+        model = NAS("yolo_nas_s")
+        results = model.predict("ultralytics/assets/bus.jpg")
+        ```
+
+    Attributes:
+        model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'.
+
+    Note:
+        YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
+    """
+
+    def __init__(self, model="yolo_nas_s.pt") -> None:
+        """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
+        assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
+        super().__init__(model, task="detect")
+
+    def _load(self, weights: str, task=None) -> None:
+        """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
+        import super_gradients
+
+        suffix = Path(weights).suffix
+        if suffix == ".pt":
+            self.model = torch.load(attempt_download_asset(weights))
+
+        elif suffix == "":
+            self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
+
+        # Override the forward method to ignore additional arguments
+        def new_forward(x, *args, **kwargs):
+            """Ignore additional __call__ arguments."""
+            return self.model._original_forward(x)
+
+        self.model._original_forward = self.model.forward
+        self.model.forward = new_forward
+
+        # Standardize model
+        self.model.fuse = lambda verbose=True: self.model
+        self.model.stride = torch.tensor([32])
+        self.model.names = dict(enumerate(self.model._class_names))
+        self.model.is_fused = lambda: False  # for info()
+        self.model.yaml = {}  # for info()
+        self.model.pt_path = weights  # for export()
+        self.model.task = "detect"  # for export()
+
+    def info(self, detailed=False, verbose=True):
+        """
+        Logs model info.
+
+        Args:
+            detailed (bool): Show detailed information about model.
+            verbose (bool): Controls verbosity.
+        """
+        return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
+
+    @property
+    def task_map(self):
+        """Returns a dictionary mapping tasks to respective predictor and validator classes."""
+        return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}

+ 57 - 0
ultralytics/models/nas/predict.py

@@ -0,0 +1,57 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import torch
+
+from ultralytics.engine.predictor import BasePredictor
+from ultralytics.engine.results import Results
+from ultralytics.utils import ops
+
+
+class NASPredictor(BasePredictor):
+    """
+    Ultralytics YOLO NAS Predictor for object detection.
+
+    This class extends the `BasePredictor` from Ultralytics engine and is responsible for post-processing the
+    raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
+    scaling the bounding boxes to fit the original image dimensions.
+
+    Attributes:
+        args (Namespace): Namespace containing various configurations for post-processing.
+
+    Example:
+        ```python
+        from ultralytics import NAS
+
+        model = NAS("yolo_nas_s")
+        predictor = model.predictor
+        # Assumes that raw_preds, img, orig_imgs are available
+        results = predictor.postprocess(raw_preds, img, orig_imgs)
+        ```
+
+    Note:
+        Typically, this class is not instantiated directly. It is used internally within the `NAS` class.
+    """
+
+    def postprocess(self, preds_in, img, orig_imgs):
+        """Postprocess predictions and returns a list of Results objects."""
+        # Cat boxes and class scores
+        boxes = ops.xyxy2xywh(preds_in[0][0])
+        preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
+
+        preds = ops.non_max_suppression(
+            preds,
+            self.args.conf,
+            self.args.iou,
+            agnostic=self.args.agnostic_nms,
+            max_det=self.args.max_det,
+            classes=self.args.classes,
+        )
+
+        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
+            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+
+        results = []
+        for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
+            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
+            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
+        return results

+ 50 - 0
ultralytics/models/nas/val.py

@@ -0,0 +1,50 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import torch
+
+from ultralytics.models.yolo.detect import DetectionValidator
+from ultralytics.utils import ops
+
+__all__ = ["NASValidator"]
+
+
+class NASValidator(DetectionValidator):
+    """
+    Ultralytics YOLO NAS Validator for object detection.
+
+    Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions
+    generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
+    ultimately producing the final detections.
+
+    Attributes:
+        args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU.
+        lb (torch.Tensor): Optional tensor for multilabel NMS.
+
+    Example:
+        ```python
+        from ultralytics import NAS
+
+        model = NAS("yolo_nas_s")
+        validator = model.validator
+        # Assumes that raw_preds are available
+        final_preds = validator.postprocess(raw_preds)
+        ```
+
+    Note:
+        This class is generally not instantiated directly but is used internally within the `NAS` class.
+    """
+
+    def postprocess(self, preds_in):
+        """Apply Non-maximum suppression to prediction outputs."""
+        boxes = ops.xyxy2xywh(preds_in[0][0])
+        preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
+        return ops.non_max_suppression(
+            preds,
+            self.args.conf,
+            self.args.iou,
+            labels=self.lb,
+            multi_label=False,
+            agnostic=self.args.single_cls or self.args.agnostic_nms,
+            max_det=self.args.max_det,
+            max_time_img=0.5,
+        )

+ 7 - 0
ultralytics/models/rtdetr/__init__.py

@@ -0,0 +1,7 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from .model import RTDETR
+from .predict import RTDETRPredictor
+from .val import RTDETRValidator
+
+__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR"

+ 54 - 0
ultralytics/models/rtdetr/model.py

@@ -0,0 +1,54 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""
+Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time
+performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient
+hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
+
+For more information on RT-DETR, visit: https://arxiv.org/pdf/2304.08069.pdf
+"""
+
+from ultralytics.engine.model import Model
+from ultralytics.nn.tasks import RTDETRDetectionModel
+
+from .predict import RTDETRPredictor
+from .train import RTDETRTrainer
+from .val import RTDETRValidator
+
+
+class RTDETR(Model):
+    """
+    Interface for Baidu's RT-DETR model. This Vision Transformer-based object detector provides real-time performance
+    with high accuracy. It supports efficient hybrid encoding, IoU-aware query selection, and adaptable inference speed.
+
+    Attributes:
+        model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
+    """
+
+    def __init__(self, model="rtdetr-l.pt") -> None:
+        """
+        Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.
+
+        Args:
+            model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
+
+        Raises:
+            NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
+        """
+        super().__init__(model=model, task="detect")
+
+    @property
+    def task_map(self) -> dict:
+        """
+        Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
+
+        Returns:
+            dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
+        """
+        return {
+            "detect": {
+                "predictor": RTDETRPredictor,
+                "validator": RTDETRValidator,
+                "trainer": RTDETRTrainer,
+                "model": RTDETRDetectionModel,
+            }
+        }

+ 84 - 0
ultralytics/models/rtdetr/predict.py

@@ -0,0 +1,84 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import torch
+
+from ultralytics.data.augment import LetterBox
+from ultralytics.engine.predictor import BasePredictor
+from ultralytics.engine.results import Results
+from ultralytics.utils import ops
+
+
+class RTDETRPredictor(BasePredictor):
+    """
+    RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions using
+    Baidu's RT-DETR model.
+
+    This class leverages the power of Vision Transformers to provide real-time object detection while maintaining
+    high accuracy. It supports key features like efficient hybrid encoding and IoU-aware query selection.
+
+    Example:
+        ```python
+        from ultralytics.utils import ASSETS
+        from ultralytics.models.rtdetr import RTDETRPredictor
+
+        args = dict(model="rtdetr-l.pt", source=ASSETS)
+        predictor = RTDETRPredictor(overrides=args)
+        predictor.predict_cli()
+        ```
+
+    Attributes:
+        imgsz (int): Image size for inference (must be square and scale-filled).
+        args (dict): Argument overrides for the predictor.
+    """
+
+    def postprocess(self, preds, img, orig_imgs):
+        """
+        Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
+
+        The method filters detections based on confidence and class if specified in `self.args`.
+
+        Args:
+            preds (list): List of [predictions, extra] from the model.
+            img (torch.Tensor): Processed input images.
+            orig_imgs (list or torch.Tensor): Original, unprocessed images.
+
+        Returns:
+            (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
+                and class labels.
+        """
+        if not isinstance(preds, (list, tuple)):  # list for PyTorch inference but list[0] Tensor for export inference
+            preds = [preds, None]
+
+        nd = preds[0].shape[-1]
+        bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
+
+        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
+            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+
+        results = []
+        for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]):  # (300, 4)
+            bbox = ops.xywh2xyxy(bbox)
+            max_score, cls = score.max(-1, keepdim=True)  # (300, 1)
+            idx = max_score.squeeze(-1) > self.args.conf  # (300, )
+            if self.args.classes is not None:
+                idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
+            pred = torch.cat([bbox, max_score, cls], dim=-1)[idx]  # filter
+            oh, ow = orig_img.shape[:2]
+            pred[..., [0, 2]] *= ow
+            pred[..., [1, 3]] *= oh
+            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
+        return results
+
+    def pre_transform(self, im):
+        """
+        Pre-transforms the input images before feeding them into the model for inference. The input images are
+        letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scaleFilled.
+
+        Args:
+            im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list.
+
+        Returns:
+            (list): List of pre-transformed images ready for model inference.
+        """
+        letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True)
+        return [letterbox(image=x) for x in im]

+ 105 - 0
ultralytics/models/rtdetr/train.py

@@ -0,0 +1,105 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from copy import copy
+
+import torch
+
+from ultralytics.models.yolo.detect import DetectionTrainer
+from ultralytics.nn.tasks import RTDETRDetectionModel
+from ultralytics.utils import RANK, colorstr
+
+from .val import RTDETRDataset, RTDETRValidator
+
+
+class RTDETRTrainer(DetectionTrainer):
+    """
+    Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
+    class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
+    Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.
+
+    Notes:
+        - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
+        - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
+
+    Example:
+        ```python
+        from ultralytics.models.rtdetr.train import RTDETRTrainer
+
+        args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
+        trainer = RTDETRTrainer(overrides=args)
+        trainer.train()
+        ```
+    """
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """
+        Initialize and return an RT-DETR model for object detection tasks.
+
+        Args:
+            cfg (dict, optional): Model configuration. Defaults to None.
+            weights (str, optional): Path to pre-trained model weights. Defaults to None.
+            verbose (bool): Verbose logging if True. Defaults to True.
+
+        Returns:
+            (RTDETRDetectionModel): Initialized model.
+        """
+        model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
+        if weights:
+            model.load(weights)
+        return model
+
+    def build_dataset(self, img_path, mode="val", batch=None):
+        """
+        Build and return an RT-DETR dataset for training or validation.
+
+        Args:
+            img_path (str): Path to the folder containing images.
+            mode (str): Dataset mode, either 'train' or 'val'.
+            batch (int, optional): Batch size for rectangle training. Defaults to None.
+
+        Returns:
+            (RTDETRDataset): Dataset object for the specific mode.
+        """
+        return RTDETRDataset(
+            img_path=img_path,
+            imgsz=self.args.imgsz,
+            batch_size=batch,
+            augment=mode == "train",
+            hyp=self.args,
+            rect=False,
+            cache=self.args.cache or None,
+            single_cls=self.args.single_cls or False,
+            prefix=colorstr(f"{mode}: "),
+            classes=self.args.classes,
+            data=self.data,
+            fraction=self.args.fraction if mode == "train" else 1.0,
+        )
+
+    def get_validator(self):
+        """
+        Returns a DetectionValidator suitable for RT-DETR model validation.
+
+        Returns:
+            (RTDETRValidator): Validator object for model validation.
+        """
+        self.loss_names = "giou_loss", "cls_loss", "l1_loss"
+        return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
+
+    def preprocess_batch(self, batch):
+        """
+        Preprocess a batch of images. Scales and converts the images to float format.
+
+        Args:
+            batch (dict): Dictionary containing a batch of images, bboxes, and labels.
+
+        Returns:
+            (dict): Preprocessed batch.
+        """
+        batch = super().preprocess_batch(batch)
+        bs = len(batch["img"])
+        batch_idx = batch["batch_idx"]
+        gt_bbox, gt_class = [], []
+        for i in range(bs):
+            gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
+            gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
+        return batch

+ 135 - 0
ultralytics/models/rtdetr/val.py

@@ -0,0 +1,135 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import torch
+
+from ultralytics.data import YOLODataset
+from ultralytics.data.augment import Compose, Format, v8_transforms
+from ultralytics.models.yolo.detect import DetectionValidator
+from ultralytics.utils import colorstr, ops
+
+__all__ = ("RTDETRValidator",)  # tuple or list
+
+
+class RTDETRDataset(YOLODataset):
+    """
+    Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
+
+    This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
+    real-time detection and tracking tasks.
+    """
+
+    def __init__(self, *args, data=None, **kwargs):
+        """Initialize the RTDETRDataset class by inheriting from the YOLODataset class."""
+        super().__init__(*args, data=data, **kwargs)
+
+    # NOTE: add stretch version load_image for RTDETR mosaic
+    def load_image(self, i, rect_mode=False):
+        """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
+        return super().load_image(i=i, rect_mode=rect_mode)
+
+    def build_transforms(self, hyp=None):
+        """Temporary, only for evaluation."""
+        if self.augment:
+            hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
+            hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
+            transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
+        else:
+            # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
+            transforms = Compose([])
+        transforms.append(
+            Format(
+                bbox_format="xywh",
+                normalize=True,
+                return_mask=self.use_segments,
+                return_keypoint=self.use_keypoints,
+                batch_idx=True,
+                mask_ratio=hyp.mask_ratio,
+                mask_overlap=hyp.overlap_mask,
+            )
+        )
+        return transforms
+
+
+class RTDETRValidator(DetectionValidator):
+    """
+    RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
+    the RT-DETR (Real-Time DETR) object detection model.
+
+    The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
+    post-processing, and updates evaluation metrics accordingly.
+
+    Example:
+        ```python
+        from ultralytics.models.rtdetr import RTDETRValidator
+
+        args = dict(model="rtdetr-l.pt", data="coco8.yaml")
+        validator = RTDETRValidator(args=args)
+        validator()
+        ```
+
+    Note:
+        For further details on the attributes and methods, refer to the parent DetectionValidator class.
+    """
+
+    def build_dataset(self, img_path, mode="val", batch=None):
+        """
+        Build an RTDETR Dataset.
+
+        Args:
+            img_path (str): Path to the folder containing images.
+            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
+            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
+        """
+        return RTDETRDataset(
+            img_path=img_path,
+            imgsz=self.args.imgsz,
+            batch_size=batch,
+            augment=False,  # no augmentation
+            hyp=self.args,
+            rect=False,  # no rect
+            cache=self.args.cache or None,
+            prefix=colorstr(f"{mode}: "),
+            data=self.data,
+        )
+
+    def postprocess(self, preds):
+        """Apply Non-maximum suppression to prediction outputs."""
+        if not isinstance(preds, (list, tuple)):  # list for PyTorch inference but list[0] Tensor for export inference
+            preds = [preds, None]
+
+        bs, _, nd = preds[0].shape
+        bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
+        bboxes *= self.args.imgsz
+        outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
+        for i, bbox in enumerate(bboxes):  # (300, 4)
+            bbox = ops.xywh2xyxy(bbox)
+            score, cls = scores[i].max(-1)  # (300, )
+            # Do not need threshold for evaluation as only got 300 boxes here
+            # idx = score > self.args.conf
+            pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1)  # filter
+            # Sort by confidence to correctly get internal metrics
+            pred = pred[score.argsort(descending=True)]
+            outputs[i] = pred  # [idx]
+
+        return outputs
+
+    def _prepare_batch(self, si, batch):
+        """Prepares a batch for training or inference by applying transformations."""
+        idx = batch["batch_idx"] == si
+        cls = batch["cls"][idx].squeeze(-1)
+        bbox = batch["bboxes"][idx]
+        ori_shape = batch["ori_shape"][si]
+        imgsz = batch["img"].shape[2:]
+        ratio_pad = batch["ratio_pad"][si]
+        if len(cls):
+            bbox = ops.xywh2xyxy(bbox)  # target boxes
+            bbox[..., [0, 2]] *= ori_shape[1]  # native-space pred
+            bbox[..., [1, 3]] *= ori_shape[0]  # native-space pred
+        return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
+
+    def _prepare_pred(self, pred, pbatch):
+        """Prepares and returns a batch with transformed bounding boxes and class labels."""
+        predn = pred.clone()
+        predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz  # native-space pred
+        predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz  # native-space pred
+        return predn.float()

+ 6 - 0
ultralytics/models/sam/__init__.py

@@ -0,0 +1,6 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from .model import SAM
+from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
+
+__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor"  # tuple or list

+ 193 - 0
ultralytics/models/sam/amg.py

@@ -0,0 +1,193 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import math
+from itertools import product
+from typing import Any, Generator, List, Tuple
+
+import numpy as np
+import torch
+
+
+def is_box_near_crop_edge(
+    boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+    """Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance."""
+    crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+    orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+    boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+    near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+    near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+    near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+    return torch.any(near_crop_edge, dim=1)
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+    """Yields batches of data from input arguments with specified batch size for efficient processing."""
+    assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
+    n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+    for b in range(n_batches):
+        yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+
+def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
+    """
+    Computes the stability score for a batch of masks.
+
+    The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
+    high and low values.
+
+    Args:
+        masks (torch.Tensor): Batch of predicted mask logits.
+        mask_threshold (float): Threshold value for creating binary masks.
+        threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
+
+    Returns:
+        (torch.Tensor): Stability scores for each mask in the batch.
+
+    Notes:
+        - One mask is always contained inside the other.
+        - Memory is saved by preventing unnecessary cast to torch.int64.
+
+    Examples:
+        >>> masks = torch.rand(10, 256, 256)  # Batch of 10 masks
+        >>> mask_threshold = 0.5
+        >>> threshold_offset = 0.1
+        >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
+    """
+    intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
+    unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
+    return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+    """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
+    offset = 1 / (2 * n_per_side)
+    points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+    points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+    points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+    return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+
+
+def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
+    """Generates point grids for multiple crop layers with varying scales and densities."""
+    return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
+
+
+def generate_crop_boxes(
+    im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+    """Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions."""
+    crop_boxes, layer_idxs = [], []
+    im_h, im_w = im_size
+    short_side = min(im_h, im_w)
+
+    # Original image
+    crop_boxes.append([0, 0, im_w, im_h])
+    layer_idxs.append(0)
+
+    def crop_len(orig_len, n_crops, overlap):
+        """Crops bounding boxes to the size of the input image."""
+        return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+    for i_layer in range(n_layers):
+        n_crops_per_side = 2 ** (i_layer + 1)
+        overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+        crop_w = crop_len(im_w, n_crops_per_side, overlap)
+        crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+        crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+        crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+
+        # Crops in XYWH format
+        for x0, y0 in product(crop_box_x0, crop_box_y0):
+            box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+            crop_boxes.append(box)
+            layer_idxs.append(i_layer + 1)
+
+    return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+    """Uncrop bounding boxes by adding the crop box offset to their coordinates."""
+    x0, y0, _, _ = crop_box
+    offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+    # Check if boxes has a channel dimension
+    if len(boxes.shape) == 3:
+        offset = offset.unsqueeze(1)
+    return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+    """Uncrop points by adding the crop box offset to their coordinates."""
+    x0, y0, _, _ = crop_box
+    offset = torch.tensor([[x0, y0]], device=points.device)
+    # Check if points has a channel dimension
+    if len(points.shape) == 3:
+        offset = offset.unsqueeze(1)
+    return points + offset
+
+
+def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
+    """Uncrop masks by padding them to the original image size, handling coordinate transformations."""
+    x0, y0, x1, y1 = crop_box
+    if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+        return masks
+    # Coordinate transform masks
+    pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+    pad = (x0, pad_x - x0, y0, pad_y - y0)
+    return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
+    """Removes small disconnected regions or holes in a mask based on area threshold and mode."""
+    import cv2  # type: ignore
+
+    assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
+    correct_holes = mode == "holes"
+    working_mask = (correct_holes ^ mask).astype(np.uint8)
+    n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+    sizes = stats[:, -1][1:]  # Row 0 is background label
+    small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+    if not small_regions:
+        return mask, False
+    fill_labels = [0] + small_regions
+    if not correct_holes:
+        # If every region is below threshold, keep largest
+        fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
+    mask = np.isin(regions, fill_labels)
+    return mask, True
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+    """Calculates bounding boxes in XYXY format around binary masks, handling empty masks and various input shapes."""
+    # torch.max below raises an error on empty inputs, just skip in this case
+    if torch.numel(masks) == 0:
+        return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+    # Normalize shape to CxHxW
+    shape = masks.shape
+    h, w = shape[-2:]
+    masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)
+    # Get top and bottom edges
+    in_height, _ = torch.max(masks, dim=-1)
+    in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+    bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+    in_height_coords = in_height_coords + h * (~in_height)
+    top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+    # Get left and right edges
+    in_width, _ = torch.max(masks, dim=-2)
+    in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+    right_edges, _ = torch.max(in_width_coords, dim=-1)
+    in_width_coords = in_width_coords + w * (~in_width)
+    left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+    # If the mask is empty the right edge will be to the left of the left edge.
+    # Replace these boxes with [0, 0, 0, 0]
+    empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+    out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+    out = out * (~empty_filter).unsqueeze(-1)
+
+    # Return to original shape
+    return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]

+ 358 - 0
ultralytics/models/sam/build.py

@@ -0,0 +1,358 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from functools import partial
+
+import torch
+
+from ultralytics.utils.downloads import attempt_download_asset
+
+from .modules.decoders import MaskDecoder
+from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
+from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
+from .modules.sam import SAM2Model, SAMModel
+from .modules.tiny_encoder import TinyViT
+from .modules.transformer import TwoWayTransformer
+
+
+def build_sam_vit_h(checkpoint=None):
+    """Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
+    return _build_sam(
+        encoder_embed_dim=1280,
+        encoder_depth=32,
+        encoder_num_heads=16,
+        encoder_global_attn_indexes=[7, 15, 23, 31],
+        checkpoint=checkpoint,
+    )
+
+
+def build_sam_vit_l(checkpoint=None):
+    """Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
+    return _build_sam(
+        encoder_embed_dim=1024,
+        encoder_depth=24,
+        encoder_num_heads=16,
+        encoder_global_attn_indexes=[5, 11, 17, 23],
+        checkpoint=checkpoint,
+    )
+
+
+def build_sam_vit_b(checkpoint=None):
+    """Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint."""
+    return _build_sam(
+        encoder_embed_dim=768,
+        encoder_depth=12,
+        encoder_num_heads=12,
+        encoder_global_attn_indexes=[2, 5, 8, 11],
+        checkpoint=checkpoint,
+    )
+
+
+def build_mobile_sam(checkpoint=None):
+    """Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
+    return _build_sam(
+        encoder_embed_dim=[64, 128, 160, 320],
+        encoder_depth=[2, 2, 6, 2],
+        encoder_num_heads=[2, 4, 5, 10],
+        encoder_global_attn_indexes=None,
+        mobile_sam=True,
+        checkpoint=checkpoint,
+    )
+
+
+def build_sam2_t(checkpoint=None):
+    """Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
+    return _build_sam2(
+        encoder_embed_dim=96,
+        encoder_stages=[1, 2, 7, 2],
+        encoder_num_heads=1,
+        encoder_global_att_blocks=[5, 7, 9],
+        encoder_window_spec=[8, 4, 14, 7],
+        encoder_backbone_channel_list=[768, 384, 192, 96],
+        checkpoint=checkpoint,
+    )
+
+
+def build_sam2_s(checkpoint=None):
+    """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
+    return _build_sam2(
+        encoder_embed_dim=96,
+        encoder_stages=[1, 2, 11, 2],
+        encoder_num_heads=1,
+        encoder_global_att_blocks=[7, 10, 13],
+        encoder_window_spec=[8, 4, 14, 7],
+        encoder_backbone_channel_list=[768, 384, 192, 96],
+        checkpoint=checkpoint,
+    )
+
+
+def build_sam2_b(checkpoint=None):
+    """Builds and returns a SAM2 base-size model with specified architecture parameters."""
+    return _build_sam2(
+        encoder_embed_dim=112,
+        encoder_stages=[2, 3, 16, 3],
+        encoder_num_heads=2,
+        encoder_global_att_blocks=[12, 16, 20],
+        encoder_window_spec=[8, 4, 14, 7],
+        encoder_window_spatial_size=[14, 14],
+        encoder_backbone_channel_list=[896, 448, 224, 112],
+        checkpoint=checkpoint,
+    )
+
+
+def build_sam2_l(checkpoint=None):
+    """Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters."""
+    return _build_sam2(
+        encoder_embed_dim=144,
+        encoder_stages=[2, 6, 36, 4],
+        encoder_num_heads=2,
+        encoder_global_att_blocks=[23, 33, 43],
+        encoder_window_spec=[8, 4, 16, 8],
+        encoder_backbone_channel_list=[1152, 576, 288, 144],
+        checkpoint=checkpoint,
+    )
+
+
+def _build_sam(
+    encoder_embed_dim,
+    encoder_depth,
+    encoder_num_heads,
+    encoder_global_attn_indexes,
+    checkpoint=None,
+    mobile_sam=False,
+):
+    """
+    Builds a Segment Anything Model (SAM) with specified encoder parameters.
+
+    Args:
+        encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
+        encoder_depth (int | List[int]): Depth of the encoder.
+        encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
+        encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
+        checkpoint (str | None): Path to the model checkpoint file.
+        mobile_sam (bool): Whether to build a Mobile-SAM model.
+
+    Returns:
+        (SAMModel): A Segment Anything Model instance with the specified architecture.
+
+    Examples:
+        >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
+        >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
+    """
+    prompt_embed_dim = 256
+    image_size = 1024
+    vit_patch_size = 16
+    image_embedding_size = image_size // vit_patch_size
+    image_encoder = (
+        TinyViT(
+            img_size=1024,
+            in_chans=3,
+            num_classes=1000,
+            embed_dims=encoder_embed_dim,
+            depths=encoder_depth,
+            num_heads=encoder_num_heads,
+            window_sizes=[7, 7, 14, 7],
+            mlp_ratio=4.0,
+            drop_rate=0.0,
+            drop_path_rate=0.0,
+            use_checkpoint=False,
+            mbconv_expand_ratio=4.0,
+            local_conv_size=3,
+            layer_lr_decay=0.8,
+        )
+        if mobile_sam
+        else ImageEncoderViT(
+            depth=encoder_depth,
+            embed_dim=encoder_embed_dim,
+            img_size=image_size,
+            mlp_ratio=4,
+            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
+            num_heads=encoder_num_heads,
+            patch_size=vit_patch_size,
+            qkv_bias=True,
+            use_rel_pos=True,
+            global_attn_indexes=encoder_global_attn_indexes,
+            window_size=14,
+            out_chans=prompt_embed_dim,
+        )
+    )
+    sam = SAMModel(
+        image_encoder=image_encoder,
+        prompt_encoder=PromptEncoder(
+            embed_dim=prompt_embed_dim,
+            image_embedding_size=(image_embedding_size, image_embedding_size),
+            input_image_size=(image_size, image_size),
+            mask_in_chans=16,
+        ),
+        mask_decoder=MaskDecoder(
+            num_multimask_outputs=3,
+            transformer=TwoWayTransformer(
+                depth=2,
+                embedding_dim=prompt_embed_dim,
+                mlp_dim=2048,
+                num_heads=8,
+            ),
+            transformer_dim=prompt_embed_dim,
+            iou_head_depth=3,
+            iou_head_hidden_dim=256,
+        ),
+        pixel_mean=[123.675, 116.28, 103.53],
+        pixel_std=[58.395, 57.12, 57.375],
+    )
+    if checkpoint is not None:
+        checkpoint = attempt_download_asset(checkpoint)
+        with open(checkpoint, "rb") as f:
+            state_dict = torch.load(f)
+        sam.load_state_dict(state_dict)
+    sam.eval()
+    return sam
+
+
+def _build_sam2(
+    encoder_embed_dim=1280,
+    encoder_stages=[2, 6, 36, 4],
+    encoder_num_heads=2,
+    encoder_global_att_blocks=[7, 15, 23, 31],
+    encoder_backbone_channel_list=[1152, 576, 288, 144],
+    encoder_window_spatial_size=[7, 7],
+    encoder_window_spec=[8, 4, 16, 8],
+    checkpoint=None,
+):
+    """
+    Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
+
+    Args:
+        encoder_embed_dim (int): Embedding dimension for the encoder.
+        encoder_stages (List[int]): Number of blocks in each stage of the encoder.
+        encoder_num_heads (int): Number of attention heads in the encoder.
+        encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
+        encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
+        encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
+        encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
+        checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
+
+    Returns:
+        (SAM2Model): A configured and initialized SAM2 model.
+
+    Examples:
+        >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
+        >>> sam2_model.eval()
+    """
+    image_encoder = ImageEncoder(
+        trunk=Hiera(
+            embed_dim=encoder_embed_dim,
+            num_heads=encoder_num_heads,
+            stages=encoder_stages,
+            global_att_blocks=encoder_global_att_blocks,
+            window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
+            window_spec=encoder_window_spec,
+        ),
+        neck=FpnNeck(
+            d_model=256,
+            backbone_channel_list=encoder_backbone_channel_list,
+            fpn_top_down_levels=[2, 3],
+            fpn_interp_model="nearest",
+        ),
+        scalp=1,
+    )
+    memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
+    memory_encoder = MemoryEncoder(out_dim=64)
+
+    is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint
+    sam2 = SAM2Model(
+        image_encoder=image_encoder,
+        memory_attention=memory_attention,
+        memory_encoder=memory_encoder,
+        num_maskmem=7,
+        image_size=1024,
+        sigmoid_scale_for_mem_enc=20.0,
+        sigmoid_bias_for_mem_enc=-10.0,
+        use_mask_input_as_output_without_sam=True,
+        directly_add_no_mem_embed=True,
+        use_high_res_features_in_sam=True,
+        multimask_output_in_sam=True,
+        iou_prediction_use_sigmoid=True,
+        use_obj_ptrs_in_encoder=True,
+        add_tpos_enc_to_obj_ptrs=True,
+        only_obj_ptrs_in_the_past_for_eval=True,
+        pred_obj_scores=True,
+        pred_obj_scores_mlp=True,
+        fixed_no_obj_ptr=True,
+        multimask_output_for_tracking=True,
+        use_multimask_token_for_obj_ptr=True,
+        multimask_min_pt_num=0,
+        multimask_max_pt_num=1,
+        use_mlp_for_obj_ptr_proj=True,
+        compile_image_encoder=False,
+        no_obj_embed_spatial=is_sam2_1,
+        proj_tpos_enc_in_obj_ptrs=is_sam2_1,
+        use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,
+        sam_mask_decoder_extra_args=dict(
+            dynamic_multimask_via_stability=True,
+            dynamic_multimask_stability_delta=0.05,
+            dynamic_multimask_stability_thresh=0.98,
+        ),
+    )
+
+    if checkpoint is not None:
+        checkpoint = attempt_download_asset(checkpoint)
+        with open(checkpoint, "rb") as f:
+            state_dict = torch.load(f)["model"]
+        sam2.load_state_dict(state_dict)
+    sam2.eval()
+    return sam2
+
+
+sam_model_map = {
+    "sam_h.pt": build_sam_vit_h,
+    "sam_l.pt": build_sam_vit_l,
+    "sam_b.pt": build_sam_vit_b,
+    "mobile_sam.pt": build_mobile_sam,
+    "sam2_t.pt": build_sam2_t,
+    "sam2_s.pt": build_sam2_s,
+    "sam2_b.pt": build_sam2_b,
+    "sam2_l.pt": build_sam2_l,
+    "sam2.1_t.pt": build_sam2_t,
+    "sam2.1_s.pt": build_sam2_s,
+    "sam2.1_b.pt": build_sam2_b,
+    "sam2.1_l.pt": build_sam2_l,
+}
+
+
+def build_sam(ckpt="sam_b.pt"):
+    """
+    Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
+
+    Args:
+        ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
+
+    Returns:
+        (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
+
+    Raises:
+        FileNotFoundError: If the provided checkpoint is not a supported SAM model.
+
+    Examples:
+        >>> sam_model = build_sam("sam_b.pt")
+        >>> sam_model = build_sam("path/to/custom_checkpoint.pt")
+
+    Notes:
+        Supported pre-defined models include:
+        - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
+        - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
+    """
+    model_builder = None
+    ckpt = str(ckpt)  # to allow Path ckpt types
+    for k in sam_model_map.keys():
+        if ckpt.endswith(k):
+            model_builder = sam_model_map.get(k)
+
+    if not model_builder:
+        raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
+
+    return model_builder(ckpt)

+ 175 - 0
ultralytics/models/sam/model.py

@@ -0,0 +1,175 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""
+SAM model interface.
+
+This module provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for real-time image
+segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis,
+and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
+image distributions and tasks without prior knowledge.
+
+Key Features:
+    - Promptable segmentation
+    - Real-time performance
+    - Zero-shot transfer capabilities
+    - Trained on SA-1B dataset
+"""
+
+from pathlib import Path
+
+from ultralytics.engine.model import Model
+from ultralytics.utils.torch_utils import model_info
+
+from .build import build_sam
+from .predict import Predictor, SAM2Predictor
+
+
+class SAM(Model):
+    """
+    SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
+
+    This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
+    promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
+    boxes, points, or labels, and features zero-shot performance capabilities.
+
+    Attributes:
+        model (torch.nn.Module): The loaded SAM model.
+        is_sam2 (bool): Indicates whether the model is SAM2 variant.
+        task (str): The task type, set to "segment" for SAM models.
+
+    Methods:
+        predict: Performs segmentation prediction on the given image or video source.
+        info: Logs information about the SAM model.
+
+    Examples:
+        >>> sam = SAM("sam_b.pt")
+        >>> results = sam.predict("image.jpg", points=[[500, 375]])
+        >>> for r in results:
+        >>>     print(f"Detected {len(r.masks)} masks")
+    """
+
+    def __init__(self, model="sam_b.pt") -> None:
+        """
+        Initializes the SAM (Segment Anything Model) instance.
+
+        Args:
+            model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
+
+        Raises:
+            NotImplementedError: If the model file extension is not .pt or .pth.
+
+        Examples:
+            >>> sam = SAM("sam_b.pt")
+            >>> print(sam.is_sam2)
+        """
+        if model and Path(model).suffix not in {".pt", ".pth"}:
+            raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
+        self.is_sam2 = "sam2" in Path(model).stem
+        super().__init__(model=model, task="segment")
+
+    def _load(self, weights: str, task=None):
+        """
+        Loads the specified weights into the SAM model.
+
+        This method initializes the SAM model with the provided weights file, setting up the model architecture
+        and loading the pre-trained parameters.
+
+        Args:
+            weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
+            task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
+
+        Examples:
+            >>> sam = SAM("sam_b.pt")
+            >>> sam._load("path/to/custom_weights.pt")
+        """
+        self.model = build_sam(weights)
+
+    def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
+        """
+        Performs segmentation prediction on the given image or video source.
+
+        Args:
+            source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
+                a numpy.ndarray object.
+            stream (bool): If True, enables real-time streaming.
+            bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
+            points (List[List[float]] | None): List of points for prompted segmentation.
+            labels (List[int] | None): List of labels for prompted segmentation.
+            **kwargs (Any): Additional keyword arguments for prediction.
+
+        Returns:
+            (List): The model predictions.
+
+        Examples:
+            >>> sam = SAM("sam_b.pt")
+            >>> results = sam.predict("image.jpg", points=[[500, 375]])
+            >>> for r in results:
+            ...     print(f"Detected {len(r.masks)} masks")
+        """
+        overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
+        kwargs = {**overrides, **kwargs}
+        prompts = dict(bboxes=bboxes, points=points, labels=labels)
+        return super().predict(source, stream, prompts=prompts, **kwargs)
+
+    def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
+        """
+        Performs segmentation prediction on the given image or video source.
+
+        This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
+        for segmentation tasks.
+
+        Args:
+            source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image
+                object, or a numpy.ndarray object.
+            stream (bool): If True, enables real-time streaming.
+            bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
+            points (List[List[float]] | None): List of points for prompted segmentation.
+            labels (List[int] | None): List of labels for prompted segmentation.
+            **kwargs (Any): Additional keyword arguments to be passed to the predict method.
+
+        Returns:
+            (List): The model predictions, typically containing segmentation masks and other relevant information.
+
+        Examples:
+            >>> sam = SAM("sam_b.pt")
+            >>> results = sam("image.jpg", points=[[500, 375]])
+            >>> print(f"Detected {len(results[0].masks)} masks")
+        """
+        return self.predict(source, stream, bboxes, points, labels, **kwargs)
+
+    def info(self, detailed=False, verbose=True):
+        """
+        Logs information about the SAM model.
+
+        This method provides details about the Segment Anything Model (SAM), including its architecture,
+        parameters, and computational requirements.
+
+        Args:
+            detailed (bool): If True, displays detailed information about the model layers and operations.
+            verbose (bool): If True, prints the information to the console.
+
+        Returns:
+            (tuple): A tuple containing the model's information (string representations of the model).
+
+        Examples:
+            >>> sam = SAM("sam_b.pt")
+            >>> info = sam.info()
+            >>> print(info[0])  # Print summary information
+        """
+        return model_info(self.model, detailed=detailed, verbose=verbose)
+
+    @property
+    def task_map(self):
+        """
+        Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
+
+        Returns:
+            (Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor
+                class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
+
+        Examples:
+            >>> sam = SAM("sam_b.pt")
+            >>> task_map = sam.task_map
+            >>> print(task_map)
+            {'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
+        """
+        return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}

+ 1 - 0
ultralytics/models/sam/modules/__init__.py

@@ -0,0 +1 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

+ 1129 - 0
ultralytics/models/sam/modules/blocks.py

@@ -0,0 +1,1129 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import copy
+import math
+from functools import partial
+from typing import Any, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from ultralytics.nn.modules import MLP, LayerNorm2d, MLPBlock
+
+from .transformer import Attention, TwoWayAttentionBlock, TwoWayTransformer
+from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition
+
+
+class DropPath(nn.Module):
+    """
+    Implements stochastic depth regularization for neural networks during training.
+
+    Attributes:
+        drop_prob (float): Probability of dropping a path during training.
+        scale_by_keep (bool): Whether to scale the output by the keep probability.
+
+    Methods:
+        forward: Applies stochastic depth to input tensor during training, with optional scaling.
+
+    Examples:
+        >>> drop_path = DropPath(drop_prob=0.2, scale_by_keep=True)
+        >>> x = torch.randn(32, 64, 224, 224)
+        >>> output = drop_path(x)
+    """
+
+    def __init__(self, drop_prob=0.0, scale_by_keep=True):
+        """Initialize DropPath module for stochastic depth regularization during training."""
+        super().__init__()
+        self.drop_prob = drop_prob
+        self.scale_by_keep = scale_by_keep
+
+    def forward(self, x):
+        """Applies stochastic depth to input tensor during training, with optional scaling."""
+        if self.drop_prob == 0.0 or not self.training:
+            return x
+        keep_prob = 1 - self.drop_prob
+        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+        if keep_prob > 0.0 and self.scale_by_keep:
+            random_tensor.div_(keep_prob)
+        return x * random_tensor
+
+
+class MaskDownSampler(nn.Module):
+    """
+    A mask downsampling and embedding module for efficient processing of input masks.
+
+    This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks
+    while expanding their channel dimensions using convolutional layers, layer normalization, and activation
+    functions.
+
+    Attributes:
+        encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and
+            activation functions for downsampling and embedding masks.
+
+    Methods:
+        forward: Downsamples and encodes input mask to embed_dim channels.
+
+    Examples:
+        >>> mask_downsampler = MaskDownSampler(embed_dim=256, kernel_size=4, stride=4, padding=0, total_stride=16)
+        >>> input_mask = torch.randn(1, 1, 256, 256)
+        >>> output = mask_downsampler(input_mask)
+        >>> print(output.shape)
+        torch.Size([1, 256, 16, 16])
+    """
+
+    def __init__(
+        self,
+        embed_dim=256,
+        kernel_size=4,
+        stride=4,
+        padding=0,
+        total_stride=16,
+        activation=nn.GELU,
+    ):
+        """Initializes a mask downsampler module for progressive downsampling and channel expansion."""
+        super().__init__()
+        num_layers = int(math.log2(total_stride) // math.log2(stride))
+        assert stride**num_layers == total_stride
+        self.encoder = nn.Sequential()
+        mask_in_chans, mask_out_chans = 1, 1
+        for _ in range(num_layers):
+            mask_out_chans = mask_in_chans * (stride**2)
+            self.encoder.append(
+                nn.Conv2d(
+                    mask_in_chans,
+                    mask_out_chans,
+                    kernel_size=kernel_size,
+                    stride=stride,
+                    padding=padding,
+                )
+            )
+            self.encoder.append(LayerNorm2d(mask_out_chans))
+            self.encoder.append(activation())
+            mask_in_chans = mask_out_chans
+
+        self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
+
+    def forward(self, x):
+        """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
+        return self.encoder(x)
+
+
+class CXBlock(nn.Module):
+    """
+    ConvNeXt Block for efficient feature extraction in convolutional neural networks.
+
+    This block implements a modified version of the ConvNeXt architecture, offering improved performance and
+    flexibility in feature extraction.
+
+    Attributes:
+        dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
+        norm (LayerNorm2d): Layer normalization applied to channels.
+        pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
+        act (nn.GELU): GELU activation function.
+        pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
+        gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.
+        drop_path (nn.Module): DropPath layer for stochastic depth regularization.
+
+    Methods:
+        forward: Processes the input tensor through the ConvNeXt block.
+
+    Examples:
+        >>> import torch
+        >>> x = torch.randn(1, 64, 56, 56)
+        >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
+        >>> output = block(x)
+        >>> print(output.shape)
+        torch.Size([1, 64, 56, 56])
+    """
+
+    def __init__(
+        self,
+        dim,
+        kernel_size=7,
+        padding=3,
+        drop_path=0.0,
+        layer_scale_init_value=1e-6,
+        use_dwconv=True,
+    ):
+        """
+        Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
+
+        This block implements a modified version of the ConvNeXt architecture, offering improved performance and
+        flexibility in feature extraction.
+
+        Args:
+            dim (int): Number of input channels.
+            kernel_size (int): Size of the convolutional kernel.
+            padding (int): Padding size for the convolution.
+            drop_path (float): Stochastic depth rate.
+            layer_scale_init_value (float): Initial value for Layer Scale.
+            use_dwconv (bool): Whether to use depthwise convolution.
+
+        Examples:
+            >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
+            >>> x = torch.randn(1, 64, 32, 32)
+            >>> output = block(x)
+            >>> print(output.shape)
+            torch.Size([1, 64, 32, 32])
+        """
+        super().__init__()
+        self.dwconv = nn.Conv2d(
+            dim,
+            dim,
+            kernel_size=kernel_size,
+            padding=padding,
+            groups=dim if use_dwconv else 1,
+        )  # depthwise conv
+        self.norm = LayerNorm2d(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.pwconv2 = nn.Linear(4 * dim, dim)
+        self.gamma = (
+            nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+            if layer_scale_init_value > 0
+            else None
+        )
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+    def forward(self, x):
+        """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection."""
+        input = x
+        x = self.dwconv(x)
+        x = self.norm(x)
+        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+        if self.gamma is not None:
+            x = self.gamma * x
+        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
+
+        x = input + self.drop_path(x)
+        return x
+
+
+class Fuser(nn.Module):
+    """
+    A module for fusing features through multiple layers of a neural network.
+
+    This class applies a series of identical layers to an input tensor, optionally projecting the input first.
+
+    Attributes:
+        proj (nn.Module): An optional input projection layer. Identity if no projection is needed.
+        layers (nn.ModuleList): A list of identical layers to be applied sequentially.
+
+    Methods:
+        forward: Applies the fuser to an input tensor.
+
+    Examples:
+        >>> layer = CXBlock(dim=256)
+        >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)
+        >>> x = torch.randn(1, 256, 32, 32)
+        >>> output = fuser(x)
+        >>> print(output.shape)
+        torch.Size([1, 256, 32, 32])
+    """
+
+    def __init__(self, layer, num_layers, dim=None, input_projection=False):
+        """
+        Initializes the Fuser module for feature fusion through multiple layers.
+
+        This module creates a sequence of identical layers and optionally applies an input projection.
+
+        Args:
+            layer (nn.Module): The layer to be replicated in the fuser.
+            num_layers (int): The number of times to replicate the layer.
+            dim (int | None): The dimension for input projection, if used.
+            input_projection (bool): Whether to use input projection.
+
+        Examples:
+            >>> layer = nn.Linear(64, 64)
+            >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)
+            >>> input_tensor = torch.randn(1, 64)
+            >>> output = fuser(input_tensor)
+        """
+        super().__init__()
+        self.proj = nn.Identity()
+        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
+
+        if input_projection:
+            assert dim is not None
+            self.proj = nn.Conv2d(dim, dim, kernel_size=1)
+
+    def forward(self, x):
+        """Applies a series of layers to the input tensor, optionally projecting it first."""
+        x = self.proj(x)
+        for layer in self.layers:
+            x = layer(x)
+        return x
+
+
+class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
+    """
+    A two-way attention block for performing self-attention and cross-attention in both directions.
+
+    This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on
+    sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and
+    cross-attention from dense to sparse inputs.
+
+    Attributes:
+        self_attn (Attention): Self-attention layer for queries.
+        norm1 (nn.LayerNorm): Layer normalization after the first attention block.
+        cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
+        norm2 (nn.LayerNorm): Layer normalization after the second attention block.
+        mlp (MLP): MLP block for transforming query embeddings.
+        norm3 (nn.LayerNorm): Layer normalization after the MLP block.
+        norm4 (nn.LayerNorm): Layer normalization after the third attention block.
+        cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
+        skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.
+
+    Methods:
+        forward: Processes input through the attention blocks and MLP.
+
+    Examples:
+        >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8)
+        >>> sparse_input = torch.randn(1, 100, 256)
+        >>> dense_input = torch.randn(1, 256, 16, 16)
+        >>> sparse_output, dense_output = block(sparse_input, dense_input)
+    """
+
+    def __init__(
+        self,
+        embedding_dim: int,
+        num_heads: int,
+        mlp_dim: int = 2048,
+        activation: Type[nn.Module] = nn.ReLU,
+        attention_downsample_rate: int = 2,
+        skip_first_layer_pe: bool = False,
+    ) -> None:
+        """
+        Initializes a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
+
+        This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse
+        inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention
+        from dense to sparse inputs.
+
+        Args:
+            embedding_dim (int): The channel dimension of the embeddings.
+            num_heads (int): The number of heads in the attention layers.
+            mlp_dim (int): The hidden dimension of the MLP block.
+            activation (Type[nn.Module]): The activation function of the MLP block.
+            attention_downsample_rate (int): The downsample rate for attention computations.
+            skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
+
+        Examples:
+            >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)
+            >>> sparse_inputs = torch.randn(1, 100, 256)
+            >>> dense_inputs = torch.randn(1, 256, 32, 32)
+            >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)
+        """
+        super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)
+        self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)
+
+
+class SAM2TwoWayTransformer(TwoWayTransformer):
+    """
+    A Two-Way Transformer module for simultaneous attention to image and query points.
+
+    This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an
+    input image using queries with supplied positional embeddings. It is particularly useful for tasks like
+    object detection, image segmentation, and point cloud processing.
+
+    Attributes:
+        depth (int): Number of layers in the transformer.
+        embedding_dim (int): Channel dimension for input embeddings.
+        num_heads (int): Number of heads for multihead attention.
+        mlp_dim (int): Internal channel dimension for the MLP block.
+        layers (nn.ModuleList): List of SAM2TwoWayAttentionBlock layers comprising the transformer.
+        final_attn_token_to_image (Attention): Final attention layer from queries to image.
+        norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
+
+    Methods:
+        forward: Processes input image embeddings and query embeddings through the transformer.
+
+    Examples:
+        >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
+        >>> image_embedding = torch.randn(1, 256, 64, 64)
+        >>> query_embedding = torch.randn(1, 100, 256)
+        >>> output = transformer(image_embedding, query_embedding)
+        >>> print(output[0].shape, output[1].shape)
+        torch.Size([1, 100, 256]) torch.Size([1, 256, 64, 64])
+    """
+
+    def __init__(
+        self,
+        depth: int,
+        embedding_dim: int,
+        num_heads: int,
+        mlp_dim: int,
+        activation: Type[nn.Module] = nn.ReLU,
+        attention_downsample_rate: int = 2,
+    ) -> None:
+        """
+        Initializes a SAM2TwoWayTransformer instance.
+
+        This transformer decoder attends to an input image using queries with supplied positional embeddings.
+        It is designed for tasks like object detection, image segmentation, and point cloud processing.
+
+        Args:
+            depth (int): Number of layers in the transformer.
+            embedding_dim (int): Channel dimension for the input embeddings.
+            num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
+            mlp_dim (int): Channel dimension internal to the MLP block.
+            activation (Type[nn.Module]): Activation function to use in the MLP block.
+            attention_downsample_rate (int): Downsampling rate for attention computations.
+
+        Examples:
+            >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
+            >>> transformer
+            SAM2TwoWayTransformer(
+              (layers): ModuleList(
+                (0-4): 5 x SAM2TwoWayAttentionBlock(...)
+              )
+              (final_attn_token_to_image): Attention(...)
+              (norm_final_attn): LayerNorm(...)
+            )
+        """
+        super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)
+        self.layers = nn.ModuleList()
+        for i in range(depth):
+            self.layers.append(
+                SAM2TwoWayAttentionBlock(
+                    embedding_dim=embedding_dim,
+                    num_heads=num_heads,
+                    mlp_dim=mlp_dim,
+                    activation=activation,
+                    attention_downsample_rate=attention_downsample_rate,
+                    skip_first_layer_pe=(i == 0),
+                )
+            )
+
+
+class RoPEAttention(Attention):
+    """
+    Implements rotary position encoding for attention mechanisms in transformer architectures.
+
+    This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance
+    the positional awareness of the attention mechanism.
+
+    Attributes:
+        compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
+        freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding.
+        rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.
+
+    Methods:
+        forward: Applies rotary position encoding and computes attention between query, key, and value tensors.
+
+    Examples:
+        >>> rope_attn = RoPEAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))
+        >>> q = torch.randn(1, 1024, 256)
+        >>> k = torch.randn(1, 1024, 256)
+        >>> v = torch.randn(1, 1024, 256)
+        >>> output = rope_attn(q, k, v)
+        >>> print(output.shape)
+        torch.Size([1, 1024, 256])
+    """
+
+    def __init__(
+        self,
+        *args,
+        rope_theta=10000.0,
+        rope_k_repeat=False,
+        feat_sizes=(32, 32),  # [w, h] for stride 16 feats at 512 resolution
+        **kwargs,
+    ):
+        """Initializes RoPEAttention with rotary position encoding for enhanced positional awareness."""
+        super().__init__(*args, **kwargs)
+
+        self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
+        freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
+        self.freqs_cis = freqs_cis
+        self.rope_k_repeat = rope_k_repeat  # repeat q rope to match k length, needed for cross-attention to memories
+
+    def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
+        """Applies rotary position encoding and computes attention between query, key, and value tensors."""
+        q = self.q_proj(q)
+        k = self.k_proj(k)
+        v = self.v_proj(v)
+
+        # Separate into heads
+        q = self._separate_heads(q, self.num_heads)
+        k = self._separate_heads(k, self.num_heads)
+        v = self._separate_heads(v, self.num_heads)
+
+        # Apply rotary position encoding
+        w = h = math.sqrt(q.shape[-2])
+        self.freqs_cis = self.freqs_cis.to(q.device)
+        if self.freqs_cis.shape[0] != q.shape[-2]:
+            self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
+        if q.shape[-2] != k.shape[-2]:
+            assert self.rope_k_repeat
+
+        num_k_rope = k.size(-2) - num_k_exclude_rope
+        q, k[:, :, :num_k_rope] = apply_rotary_enc(
+            q,
+            k[:, :, :num_k_rope],
+            freqs_cis=self.freqs_cis,
+            repeat_freqs_k=self.rope_k_repeat,
+        )
+
+        # Attention
+        _, _, _, c_per_head = q.shape
+        attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
+        attn = attn / math.sqrt(c_per_head)
+        attn = torch.softmax(attn, dim=-1)
+
+        # Get output
+        out = attn @ v
+
+        out = self._recombine_heads(out)
+        out = self.out_proj(out)
+
+        return out
+
+
+def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
+    """Applies pooling and optional normalization to a tensor, handling spatial dimension permutations."""
+    if pool is None:
+        return x
+    # (B, H, W, C) -> (B, C, H, W)
+    x = x.permute(0, 3, 1, 2)
+    x = pool(x)
+    # (B, C, H', W') -> (B, H', W', C)
+    x = x.permute(0, 2, 3, 1)
+    if norm:
+        x = norm(x)
+
+    return x
+
+
+class MultiScaleAttention(nn.Module):
+    """
+    Implements multiscale self-attention with optional query pooling for efficient feature extraction.
+
+    This class provides a flexible implementation of multiscale attention, allowing for optional
+    downsampling of query features through pooling. It's designed to enhance the model's ability to
+    capture multiscale information in visual tasks.
+
+    Attributes:
+        dim (int): Input dimension of the feature map.
+        dim_out (int): Output dimension of the attention module.
+        num_heads (int): Number of attention heads.
+        scale (float): Scaling factor for dot-product attention.
+        q_pool (nn.Module | None): Optional pooling module for query features.
+        qkv (nn.Linear): Linear projection for query, key, and value.
+        proj (nn.Linear): Output projection.
+
+    Methods:
+        forward: Applies multiscale attention to the input tensor.
+
+    Examples:
+        >>> import torch
+        >>> from torch import nn
+        >>> x = torch.randn(1, 64, 64, 256)
+        >>> msa = MultiScaleAttention(dim=256, dim_out=256, num_heads=8)
+        >>> output = msa(x)
+        >>> print(output.shape)
+        torch.Size([1, 64, 64, 256])
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        dim_out: int,
+        num_heads: int,
+        q_pool: nn.Module = None,
+    ):
+        """Initializes multiscale attention with optional query pooling for efficient feature extraction."""
+        super().__init__()
+
+        self.dim = dim
+        self.dim_out = dim_out
+
+        self.num_heads = num_heads
+        head_dim = dim_out // num_heads
+        self.scale = head_dim**-0.5
+
+        self.q_pool = q_pool
+        self.qkv = nn.Linear(dim, dim_out * 3)
+        self.proj = nn.Linear(dim_out, dim_out)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Applies multiscale attention with optional query pooling to extract multiscale features."""
+        B, H, W, _ = x.shape
+        # qkv with shape (B, H * W, 3, nHead, C)
+        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
+        # q, k, v with shape (B, H * W, nheads, C)
+        q, k, v = torch.unbind(qkv, 2)
+
+        # Q pooling (for downsample at stage changes)
+        if self.q_pool:
+            q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
+            H, W = q.shape[1:3]  # downsampled shape
+            q = q.reshape(B, H * W, self.num_heads, -1)
+
+        # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
+        x = F.scaled_dot_product_attention(
+            q.transpose(1, 2),
+            k.transpose(1, 2),
+            v.transpose(1, 2),
+        )
+        # Transpose back
+        x = x.transpose(1, 2)
+        x = x.reshape(B, H, W, -1)
+
+        x = self.proj(x)
+
+        return x
+
+
+class MultiScaleBlock(nn.Module):
+    """
+    A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
+
+    This class implements a multiscale attention mechanism with optional window partitioning and downsampling,
+    designed for use in vision transformer architectures.
+
+    Attributes:
+        dim (int): Input dimension of the block.
+        dim_out (int): Output dimension of the block.
+        norm1 (nn.Module): First normalization layer.
+        window_size (int): Size of the window for partitioning.
+        pool (nn.Module | None): Pooling layer for query downsampling.
+        q_stride (Tuple[int, int] | None): Stride for query pooling.
+        attn (MultiScaleAttention): Multi-scale attention module.
+        drop_path (nn.Module): Drop path layer for regularization.
+        norm2 (nn.Module): Second normalization layer.
+        mlp (MLP): Multi-layer perceptron module.
+        proj (nn.Linear | None): Projection layer for dimension mismatch.
+
+    Methods:
+        forward: Processes input tensor through the multiscale block.
+
+    Examples:
+        >>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7)
+        >>> x = torch.randn(1, 56, 56, 256)
+        >>> output = block(x)
+        >>> print(output.shape)
+        torch.Size([1, 28, 28, 512])
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        dim_out: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        drop_path: float = 0.0,
+        norm_layer: Union[nn.Module, str] = "LayerNorm",
+        q_stride: Tuple[int, int] = None,
+        act_layer: nn.Module = nn.GELU,
+        window_size: int = 0,
+    ):
+        """Initializes a multiscale attention block with window partitioning and optional query pooling."""
+        super().__init__()
+
+        if isinstance(norm_layer, str):
+            norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
+
+        self.dim = dim
+        self.dim_out = dim_out
+        self.norm1 = norm_layer(dim)
+
+        self.window_size = window_size
+
+        self.pool, self.q_stride = None, q_stride
+        if self.q_stride:
+            self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)
+
+        self.attn = MultiScaleAttention(
+            dim,
+            dim_out,
+            num_heads=num_heads,
+            q_pool=self.pool,
+        )
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+        self.norm2 = norm_layer(dim_out)
+        self.mlp = MLP(
+            dim_out,
+            int(dim_out * mlp_ratio),
+            dim_out,
+            num_layers=2,
+            act=act_layer,
+        )
+
+        if dim != dim_out:
+            self.proj = nn.Linear(dim, dim_out)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Processes input through multiscale attention and MLP, with optional windowing and downsampling."""
+        shortcut = x  # B, H, W, C
+        x = self.norm1(x)
+
+        # Skip connection
+        if self.dim != self.dim_out:
+            shortcut = do_pool(self.proj(x), self.pool)
+
+        # Window partition
+        window_size = self.window_size
+        if window_size > 0:
+            H, W = x.shape[1], x.shape[2]
+            x, pad_hw = window_partition(x, window_size)
+
+        # Window Attention + Q Pooling (if stage change)
+        x = self.attn(x)
+        if self.q_stride:
+            # Shapes have changed due to Q pooling
+            window_size = self.window_size // self.q_stride[0]
+            H, W = shortcut.shape[1:3]
+
+            pad_h = (window_size - H % window_size) % window_size
+            pad_w = (window_size - W % window_size) % window_size
+            pad_hw = (H + pad_h, W + pad_w)
+
+        # Reverse window partition
+        if self.window_size > 0:
+            x = window_unpartition(x, window_size, pad_hw, (H, W))
+
+        x = shortcut + self.drop_path(x)
+        # MLP
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+
+class PositionEmbeddingSine(nn.Module):
+    """
+    A module for generating sinusoidal positional embeddings for 2D inputs like images.
+
+    This class implements sinusoidal position encoding for 2D spatial positions, which can be used in
+    transformer-based models for computer vision tasks.
+
+    Attributes:
+        num_pos_feats (int): Number of positional features (half of the embedding dimension).
+        temperature (int): Temperature parameter for the sinusoidal functions.
+        normalize (bool): Whether to normalize the positional embeddings.
+        scale (float): Scaling factor for the embeddings when normalize is True.
+        cache (Dict): Cache for storing precomputed embeddings.
+
+    Methods:
+        _encode_xy: Encodes 2D positions using sine and cosine functions.
+        encode_boxes: Encodes box coordinates and dimensions into positional embeddings.
+        encode_points: Encodes 2D point coordinates with sinusoidal positional embeddings.
+        forward: Generates sinusoidal position embeddings for 2D inputs.
+
+    Examples:
+        >>> pos_emb = PositionEmbeddingSine(num_pos_feats=128)
+        >>> x = torch.randn(1, 3, 224, 224)
+        >>> embeddings = pos_emb(x)
+        >>> print(embeddings.shape)
+        torch.Size([1, 256, 224, 224])
+    """
+
+    def __init__(
+        self,
+        num_pos_feats,
+        temperature: int = 10000,
+        normalize: bool = True,
+        scale: Optional[float] = None,
+    ):
+        """Initializes sinusoidal position embeddings for 2D image inputs."""
+        super().__init__()
+        assert num_pos_feats % 2 == 0, "Expecting even model width"
+        self.num_pos_feats = num_pos_feats // 2
+        self.temperature = temperature
+        self.normalize = normalize
+        if scale is not None and not normalize:
+            raise ValueError("normalize should be True if scale is passed")
+        if scale is None:
+            scale = 2 * math.pi
+        self.scale = scale
+
+        self.cache = {}
+
+    def _encode_xy(self, x, y):
+        """Encodes 2D positions using sine/cosine functions for transformer positional embeddings."""
+        assert len(x) == len(y) and x.ndim == y.ndim == 1
+        x_embed = x * self.scale
+        y_embed = y * self.scale
+
+        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+        pos_x = x_embed[:, None] / dim_t
+        pos_y = y_embed[:, None] / dim_t
+        pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
+        pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
+        return pos_x, pos_y
+
+    @torch.no_grad()
+    def encode_boxes(self, x, y, w, h):
+        """Encodes box coordinates and dimensions into positional embeddings for detection."""
+        pos_x, pos_y = self._encode_xy(x, y)
+        return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
+
+    encode = encode_boxes  # Backwards compatibility
+
+    @torch.no_grad()
+    def encode_points(self, x, y, labels):
+        """Encodes 2D points with sinusoidal embeddings and appends labels."""
+        (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
+        assert bx == by and nx == ny and bx == bl and nx == nl
+        pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
+        pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
+        return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
+
+    @torch.no_grad()
+    def forward(self, x: torch.Tensor):
+        """Generates sinusoidal position embeddings for 2D inputs like images."""
+        cache_key = (x.shape[-2], x.shape[-1])
+        if cache_key in self.cache:
+            return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
+        y_embed = (
+            torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
+            .view(1, -1, 1)
+            .repeat(x.shape[0], 1, x.shape[-1])
+        )
+        x_embed = (
+            torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
+            .view(1, 1, -1)
+            .repeat(x.shape[0], x.shape[-2], 1)
+        )
+
+        if self.normalize:
+            eps = 1e-6
+            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+        pos_x = x_embed[:, :, :, None] / dim_t
+        pos_y = y_embed[:, :, :, None] / dim_t
+        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+        self.cache[cache_key] = pos[0]
+        return pos
+
+
+class PositionEmbeddingRandom(nn.Module):
+    """
+    Positional encoding using random spatial frequencies.
+
+    This class generates positional embeddings for input coordinates using random spatial frequencies. It is
+    particularly useful for transformer-based models that require position information.
+
+    Attributes:
+        positional_encoding_gaussian_matrix (torch.Tensor): A buffer containing random values for encoding.
+
+    Methods:
+        _pe_encoding: Positionally encodes points that are normalized to [0,1].
+        forward: Generates positional encoding for a grid of the specified size.
+        forward_with_coords: Positionally encodes points that are not normalized to [0,1].
+
+    Examples:
+        >>> pe = PositionEmbeddingRandom(num_pos_feats=64)
+        >>> size = (32, 32)
+        >>> encoding = pe(size)
+        >>> print(encoding.shape)
+        torch.Size([128, 32, 32])
+    """
+
+    def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+        """Initializes random spatial frequency position embedding for transformers."""
+        super().__init__()
+        if scale is None or scale <= 0.0:
+            scale = 1.0
+        self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
+
+        # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
+        torch.use_deterministic_algorithms(False)
+        torch.backends.cudnn.deterministic = False
+
+    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+        """Encodes normalized [0,1] coordinates using random spatial frequencies."""
+        # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+        coords = 2 * coords - 1
+        coords = coords @ self.positional_encoding_gaussian_matrix
+        coords = 2 * np.pi * coords
+        # Outputs d_1 x ... x d_n x C shape
+        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+        """Generates positional encoding for a grid using random spatial frequencies."""
+        h, w = size
+        device: Any = self.positional_encoding_gaussian_matrix.device
+        grid = torch.ones((h, w), device=device, dtype=torch.float32)
+        y_embed = grid.cumsum(dim=0) - 0.5
+        x_embed = grid.cumsum(dim=1) - 0.5
+        y_embed = y_embed / h
+        x_embed = x_embed / w
+
+        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+        return pe.permute(2, 0, 1)  # C x H x W
+
+    def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
+        """Positionally encodes input coordinates, normalizing them to [0,1] based on the given image size."""
+        coords = coords_input.clone()
+        coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+        coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+        return self._pe_encoding(coords.to(torch.float))  # B x N x C
+
+
+class Block(nn.Module):
+    """
+    Transformer block with support for window attention and residual propagation.
+
+    This class implements a transformer block that can use either global or windowed self-attention,
+    followed by a feed-forward network. It supports relative positional embeddings and is designed
+    for use in vision transformer architectures.
+
+    Attributes:
+        norm1 (nn.Module): First normalization layer.
+        attn (REAttention): Self-attention layer with optional relative positional encoding.
+        norm2 (nn.Module): Second normalization layer.
+        mlp (MLPBlock): Multi-layer perceptron block.
+        window_size (int): Size of attention window. If 0, global attention is used.
+
+    Methods:
+        forward: Processes input through the transformer block.
+
+    Examples:
+        >>> import torch
+        >>> block = Block(dim=256, num_heads=8, window_size=7)
+        >>> x = torch.randn(1, 56, 56, 256)
+        >>> output = block(x)
+        >>> print(output.shape)
+        torch.Size([1, 56, 56, 256])
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        qkv_bias: bool = True,
+        norm_layer: Type[nn.Module] = nn.LayerNorm,
+        act_layer: Type[nn.Module] = nn.GELU,
+        use_rel_pos: bool = False,
+        rel_pos_zero_init: bool = True,
+        window_size: int = 0,
+        input_size: Optional[Tuple[int, int]] = None,
+    ) -> None:
+        """
+        Initializes a transformer block with optional window attention and relative positional embeddings.
+
+        This constructor sets up a transformer block that can use either global or windowed self-attention,
+        followed by a feed-forward network. It supports relative positional embeddings and is designed
+        for use in vision transformer architectures.
+
+        Args:
+            dim (int): Number of input channels.
+            num_heads (int): Number of attention heads in the self-attention layer.
+            mlp_ratio (float): Ratio of mlp hidden dimension to embedding dimension.
+            qkv_bias (bool): If True, adds a learnable bias to query, key, value projections.
+            norm_layer (Type[nn.Module]): Type of normalization layer to use.
+            act_layer (Type[nn.Module]): Type of activation function to use in the MLP block.
+            use_rel_pos (bool): If True, uses relative positional embeddings in attention.
+            rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
+            window_size (int): Size of attention window. If 0, uses global attention.
+            input_size (Optional[Tuple[int, int]]): Input resolution for calculating relative positional parameter size.
+
+        Examples:
+            >>> block = Block(dim=256, num_heads=8, window_size=7)
+            >>> x = torch.randn(1, 56, 56, 256)
+            >>> output = block(x)
+            >>> print(output.shape)
+            torch.Size([1, 56, 56, 256])
+        """
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = REAttention(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            use_rel_pos=use_rel_pos,
+            rel_pos_zero_init=rel_pos_zero_init,
+            input_size=input_size if window_size == 0 else (window_size, window_size),
+        )
+
+        self.norm2 = norm_layer(dim)
+        self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+
+        self.window_size = window_size
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Processes input through transformer block with optional windowed self-attention and residual connection."""
+        shortcut = x
+        x = self.norm1(x)
+        # Window partition
+        if self.window_size > 0:
+            H, W = x.shape[1], x.shape[2]
+            x, pad_hw = window_partition(x, self.window_size)
+
+        x = self.attn(x)
+        # Reverse window partition
+        if self.window_size > 0:
+            x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+        x = shortcut + x
+        return x + self.mlp(self.norm2(x))
+
+
+class REAttention(nn.Module):
+    """
+    Rotary Embedding Attention module for efficient self-attention in transformer architectures.
+
+    This class implements a multi-head attention mechanism with rotary positional embeddings, designed
+    for use in vision transformer models. It supports optional query pooling and window partitioning
+    for efficient processing of large inputs.
+
+    Attributes:
+        compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
+        freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding.
+        rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.
+        q_proj (nn.Linear): Linear projection for query.
+        k_proj (nn.Linear): Linear projection for key.
+        v_proj (nn.Linear): Linear projection for value.
+        out_proj (nn.Linear): Output projection.
+        num_heads (int): Number of attention heads.
+        internal_dim (int): Internal dimension for attention computation.
+
+    Methods:
+        forward: Applies rotary position encoding and computes attention between query, key, and value tensors.
+
+    Examples:
+        >>> rope_attn = REAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))
+        >>> q = torch.randn(1, 1024, 256)
+        >>> k = torch.randn(1, 1024, 256)
+        >>> v = torch.randn(1, 1024, 256)
+        >>> output = rope_attn(q, k, v)
+        >>> print(output.shape)
+        torch.Size([1, 1024, 256])
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int = 8,
+        qkv_bias: bool = True,
+        use_rel_pos: bool = False,
+        rel_pos_zero_init: bool = True,
+        input_size: Optional[Tuple[int, int]] = None,
+    ) -> None:
+        """
+        Initializes a Relative Position Attention module for transformer-based architectures.
+
+        This module implements multi-head attention with optional relative positional encodings, designed
+        specifically for vision tasks in transformer models.
+
+        Args:
+            dim (int): Number of input channels.
+            num_heads (int): Number of attention heads. Default is 8.
+            qkv_bias (bool): If True, adds a learnable bias to query, key, value projections. Default is True.
+            use_rel_pos (bool): If True, uses relative positional encodings. Default is False.
+            rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. Default is True.
+            input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
+                Required if use_rel_pos is True. Default is None.
+
+        Examples:
+            >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))
+            >>> x = torch.randn(1, 32, 32, 256)
+            >>> output = attention(x)
+            >>> print(output.shape)
+            torch.Size([1, 32, 32, 256])
+        """
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = head_dim**-0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.proj = nn.Linear(dim, dim)
+
+        self.use_rel_pos = use_rel_pos
+        if self.use_rel_pos:
+            assert input_size is not None, "Input size must be provided if using relative positional encoding."
+            # Initialize relative positional embeddings
+            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Applies multi-head attention with optional relative positional encoding to input tensor."""
+        B, H, W, _ = x.shape
+        # qkv with shape (3, B, nHead, H * W, C)
+        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        # q, k, v with shape (B * nHead, H * W, C)
+        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+
+        attn = (q * self.scale) @ k.transpose(-2, -1)
+
+        if self.use_rel_pos:
+            attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+
+        attn = attn.softmax(dim=-1)
+        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+        return self.proj(x)
+
+
+class PatchEmbed(nn.Module):
+    """
+    Image to Patch Embedding module for vision transformer architectures.
+
+    This module converts an input image into a sequence of patch embeddings using a convolutional layer.
+    It is commonly used as the first layer in vision transformer architectures to transform image data
+    into a suitable format for subsequent transformer blocks.
+
+    Attributes:
+        proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings.
+
+    Methods:
+        forward: Applies patch embedding to the input tensor.
+
+    Examples:
+        >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)
+        >>> x = torch.randn(1, 3, 224, 224)
+        >>> output = patch_embed(x)
+        >>> print(output.shape)
+        torch.Size([1, 768, 14, 14])
+    """
+
+    def __init__(
+        self,
+        kernel_size: Tuple[int, int] = (16, 16),
+        stride: Tuple[int, int] = (16, 16),
+        padding: Tuple[int, int] = (0, 0),
+        in_chans: int = 3,
+        embed_dim: int = 768,
+    ) -> None:
+        """
+        Initializes the PatchEmbed module for converting image patches to embeddings.
+
+        This module is typically used as the first layer in vision transformer architectures to transform
+        image data into a suitable format for subsequent transformer blocks.
+
+        Args:
+            kernel_size (Tuple[int, int]): Size of the convolutional kernel for patch extraction.
+            stride (Tuple[int, int]): Stride of the convolutional operation.
+            padding (Tuple[int, int]): Padding applied to the input before convolution.
+            in_chans (int): Number of input image channels.
+            embed_dim (int): Dimensionality of the output patch embeddings.
+
+        Examples:
+            >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)
+            >>> x = torch.randn(1, 3, 224, 224)
+            >>> output = patch_embed(x)
+            >>> print(output.shape)
+            torch.Size([1, 768, 14, 14])
+        """
+        super().__init__()
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Computes patch embedding by applying convolution and transposing resulting tensor."""
+        return self.proj(x).permute(0, 2, 3, 1)  # B C H W -> B H W C

+ 518 - 0
ultralytics/models/sam/modules/decoders.py

@@ -0,0 +1,518 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+
+from ultralytics.nn.modules import MLP, LayerNorm2d
+
+
+class MaskDecoder(nn.Module):
+    """
+    Decoder module for generating masks and their associated quality scores using a transformer architecture.
+
+    This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
+    generate mask predictions along with their quality scores.
+
+    Attributes:
+        transformer_dim (int): Channel dimension for the transformer module.
+        transformer (nn.Module): Transformer module used for mask prediction.
+        num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
+        iou_token (nn.Embedding): Embedding for the IoU token.
+        num_mask_tokens (int): Number of mask tokens.
+        mask_tokens (nn.Embedding): Embedding for the mask tokens.
+        output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.
+        output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.
+        iou_prediction_head (nn.Module): MLP for predicting mask quality.
+
+    Methods:
+        forward: Predicts masks given image and prompt embeddings.
+        predict_masks: Internal method for mask prediction.
+
+    Examples:
+        >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
+        >>> masks, iou_pred = decoder(
+        ...     image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True
+        ... )
+        >>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
+    """
+
+    def __init__(
+        self,
+        transformer_dim: int,
+        transformer: nn.Module,
+        num_multimask_outputs: int = 3,
+        activation: Type[nn.Module] = nn.GELU,
+        iou_head_depth: int = 3,
+        iou_head_hidden_dim: int = 256,
+    ) -> None:
+        """
+        Initializes the MaskDecoder module for generating masks and their quality scores.
+
+        Args:
+            transformer_dim (int): Channel dimension for the transformer module.
+            transformer (nn.Module): Transformer module used for mask prediction.
+            num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
+            activation (Type[nn.Module]): Type of activation to use when upscaling masks.
+            iou_head_depth (int): Depth of the MLP used to predict mask quality.
+            iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
+
+        Examples:
+            >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
+            >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)
+            >>> print(decoder)
+        """
+        super().__init__()
+        self.transformer_dim = transformer_dim
+        self.transformer = transformer
+
+        self.num_multimask_outputs = num_multimask_outputs
+
+        self.iou_token = nn.Embedding(1, transformer_dim)
+        self.num_mask_tokens = num_multimask_outputs + 1
+        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+        self.output_upscaling = nn.Sequential(
+            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+            LayerNorm2d(transformer_dim // 4),
+            activation(),
+            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+            activation(),
+        )
+        self.output_hypernetworks_mlps = nn.ModuleList(
+            [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
+        )
+
+        self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
+
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        multimask_output: bool,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Predicts masks given image and prompt embeddings.
+
+        Args:
+            image_embeddings (torch.Tensor): Embeddings from the image encoder.
+            image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
+            sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
+            dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
+            multimask_output (bool): Whether to return multiple masks or a single mask.
+
+        Returns:
+            (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
+                - masks (torch.Tensor): Batched predicted masks.
+                - iou_pred (torch.Tensor): Batched predictions of mask quality.
+
+        Examples:
+            >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
+            >>> image_emb = torch.rand(1, 256, 64, 64)
+            >>> image_pe = torch.rand(1, 256, 64, 64)
+            >>> sparse_emb = torch.rand(1, 2, 256)
+            >>> dense_emb = torch.rand(1, 256, 64, 64)
+            >>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True)
+            >>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
+        """
+        masks, iou_pred = self.predict_masks(
+            image_embeddings=image_embeddings,
+            image_pe=image_pe,
+            sparse_prompt_embeddings=sparse_prompt_embeddings,
+            dense_prompt_embeddings=dense_prompt_embeddings,
+        )
+
+        # Select the correct mask or masks for output
+        mask_slice = slice(1, None) if multimask_output else slice(0, 1)
+        masks = masks[:, mask_slice, :, :]
+        iou_pred = iou_pred[:, mask_slice]
+
+        # Prepare output
+        return masks, iou_pred
+
+    def predict_masks(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Predicts masks and quality scores using image and prompt embeddings via transformer architecture."""
+        # Concatenate output tokens
+        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
+        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+        # Expand per-image data in batch direction to be per-mask
+        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+        src = src + dense_prompt_embeddings
+        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+        b, c, h, w = src.shape
+
+        # Run the transformer
+        hs, src = self.transformer(src, pos_src, tokens)
+        iou_token_out = hs[:, 0, :]
+        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+
+        # Upscale mask embeddings and predict masks using the mask tokens
+        src = src.transpose(1, 2).view(b, c, h, w)
+        upscaled_embedding = self.output_upscaling(src)
+        hyper_in_list: List[torch.Tensor] = [
+            self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
+        ]
+        hyper_in = torch.stack(hyper_in_list, dim=1)
+        b, c, h, w = upscaled_embedding.shape
+        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+        # Generate mask quality predictions
+        iou_pred = self.iou_prediction_head(iou_token_out)
+
+        return masks, iou_pred
+
+
+class SAM2MaskDecoder(nn.Module):
+    """
+    Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
+
+    This class extends the functionality of the MaskDecoder, incorporating additional features such as
+    high-resolution feature processing, dynamic multimask output, and object score prediction.
+
+    Attributes:
+        transformer_dim (int): Channel dimension of the transformer.
+        transformer (nn.Module): Transformer used to predict masks.
+        num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
+        iou_token (nn.Embedding): Embedding for IOU token.
+        num_mask_tokens (int): Total number of mask tokens.
+        mask_tokens (nn.Embedding): Embedding for mask tokens.
+        pred_obj_scores (bool): Whether to predict object scores.
+        obj_score_token (nn.Embedding): Embedding for object score token.
+        use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
+        output_upscaling (nn.Sequential): Upscaling layers for output.
+        use_high_res_features (bool): Whether to use high-resolution features.
+        conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).
+        conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).
+        output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.
+        iou_prediction_head (MLP): MLP for IOU prediction.
+        pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.
+        dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
+        dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
+        dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
+
+    Methods:
+        forward: Predicts masks given image and prompt embeddings.
+        predict_masks: Predicts instance segmentation masks from image and prompt embeddings.
+        _get_stability_scores: Computes mask stability scores based on IoU between thresholds.
+        _dynamic_multimask_via_stability: Dynamically selects the most stable mask output.
+
+    Examples:
+        >>> image_embeddings = torch.rand(1, 256, 64, 64)
+        >>> image_pe = torch.rand(1, 256, 64, 64)
+        >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
+        >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
+        >>> decoder = SAM2MaskDecoder(256, transformer)
+        >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
+        ...     image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
+        ... )
+    """
+
+    def __init__(
+        self,
+        transformer_dim: int,
+        transformer: nn.Module,
+        num_multimask_outputs: int = 3,
+        activation: Type[nn.Module] = nn.GELU,
+        iou_head_depth: int = 3,
+        iou_head_hidden_dim: int = 256,
+        use_high_res_features: bool = False,
+        iou_prediction_use_sigmoid=False,
+        dynamic_multimask_via_stability=False,
+        dynamic_multimask_stability_delta=0.05,
+        dynamic_multimask_stability_thresh=0.98,
+        pred_obj_scores: bool = False,
+        pred_obj_scores_mlp: bool = False,
+        use_multimask_token_for_obj_ptr: bool = False,
+    ) -> None:
+        """
+        Initializes the SAM2MaskDecoder module for predicting instance segmentation masks.
+
+        This decoder extends the functionality of MaskDecoder, incorporating additional features such as
+        high-resolution feature processing, dynamic multimask output, and object score prediction.
+
+        Args:
+            transformer_dim (int): Channel dimension of the transformer.
+            transformer (nn.Module): Transformer used to predict masks.
+            num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
+            activation (Type[nn.Module]): Type of activation to use when upscaling masks.
+            iou_head_depth (int): Depth of the MLP used to predict mask quality.
+            iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
+            use_high_res_features (bool): Whether to use high-resolution features.
+            iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
+            dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
+            dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
+            dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
+            pred_obj_scores (bool): Whether to predict object scores.
+            pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
+            use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
+
+        Examples:
+            >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
+            >>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)
+            >>> print(decoder)
+        """
+        super().__init__()
+        self.transformer_dim = transformer_dim
+        self.transformer = transformer
+
+        self.num_multimask_outputs = num_multimask_outputs
+
+        self.iou_token = nn.Embedding(1, transformer_dim)
+        self.num_mask_tokens = num_multimask_outputs + 1
+        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+        self.pred_obj_scores = pred_obj_scores
+        if self.pred_obj_scores:
+            self.obj_score_token = nn.Embedding(1, transformer_dim)
+        self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+
+        self.output_upscaling = nn.Sequential(
+            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+            LayerNorm2d(transformer_dim // 4),
+            activation(),
+            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+            activation(),
+        )
+        self.use_high_res_features = use_high_res_features
+        if use_high_res_features:
+            self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
+            self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
+
+        self.output_hypernetworks_mlps = nn.ModuleList(
+            [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
+        )
+
+        self.iou_prediction_head = MLP(
+            transformer_dim,
+            iou_head_hidden_dim,
+            self.num_mask_tokens,
+            iou_head_depth,
+            sigmoid=iou_prediction_use_sigmoid,
+        )
+        if self.pred_obj_scores:
+            self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
+            if pred_obj_scores_mlp:
+                self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
+
+        # When outputting a single mask, optionally we can dynamically fall back to the best
+        # multimask output token if the single mask output token gives low stability scores.
+        self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
+        self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
+        self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
+
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        multimask_output: bool,
+        repeat_image: bool,
+        high_res_features: Optional[List[torch.Tensor]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Predicts masks given image and prompt embeddings.
+
+        Args:
+            image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
+            image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W).
+            sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C).
+            dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).
+            multimask_output (bool): Whether to return multiple masks or a single mask.
+            repeat_image (bool): Flag to repeat the image embeddings.
+            high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
+
+        Returns:
+            (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
+                - masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
+                - iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).
+                - sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).
+                - object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).
+
+        Examples:
+            >>> image_embeddings = torch.rand(1, 256, 64, 64)
+            >>> image_pe = torch.rand(1, 256, 64, 64)
+            >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
+            >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
+            >>> decoder = SAM2MaskDecoder(256, transformer)
+            >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
+            ...     image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
+            ... )
+        """
+        masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
+            image_embeddings=image_embeddings,
+            image_pe=image_pe,
+            sparse_prompt_embeddings=sparse_prompt_embeddings,
+            dense_prompt_embeddings=dense_prompt_embeddings,
+            repeat_image=repeat_image,
+            high_res_features=high_res_features,
+        )
+
+        # Select the correct mask or masks for output
+        if multimask_output:
+            masks = masks[:, 1:, :, :]
+            iou_pred = iou_pred[:, 1:]
+        elif self.dynamic_multimask_via_stability and not self.training:
+            masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+        else:
+            masks = masks[:, 0:1, :, :]
+            iou_pred = iou_pred[:, 0:1]
+
+        if multimask_output and self.use_multimask_token_for_obj_ptr:
+            sam_tokens_out = mask_tokens_out[:, 1:]  # [b, 3, c] shape
+        else:
+            # Take the mask output token. Here we *always* use the token for single mask output.
+            # At test time, even if we track after 1-click (and using multimask_output=True),
+            # we still take the single mask token here. The rationale is that we always track
+            # after multiple clicks during training, so the past tokens seen during training
+            # are always the single mask token (and we'll let it be the object-memory token).
+            sam_tokens_out = mask_tokens_out[:, 0:1]  # [b, 1, c] shape
+
+        # Prepare output
+        return masks, iou_pred, sam_tokens_out, object_score_logits
+
+    def predict_masks(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        repeat_image: bool,
+        high_res_features: Optional[List[torch.Tensor]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Predicts instance segmentation masks from image and prompt embeddings using a transformer."""
+        # Concatenate output tokens
+        s = 0
+        if self.pred_obj_scores:
+            output_tokens = torch.cat(
+                [
+                    self.obj_score_token.weight,
+                    self.iou_token.weight,
+                    self.mask_tokens.weight,
+                ],
+                dim=0,
+            )
+            s = 1
+        else:
+            output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+        # Expand per-image data in batch direction to be per-mask
+        if repeat_image:
+            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+        else:
+            assert image_embeddings.shape[0] == tokens.shape[0]
+            src = image_embeddings
+        src = src + dense_prompt_embeddings
+        assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
+        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+        b, c, h, w = src.shape
+
+        # Run the transformer
+        hs, src = self.transformer(src, pos_src, tokens)
+        iou_token_out = hs[:, s, :]
+        mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
+
+        # Upscale mask embeddings and predict masks using the mask tokens
+        src = src.transpose(1, 2).view(b, c, h, w)
+        if not self.use_high_res_features:
+            upscaled_embedding = self.output_upscaling(src)
+        else:
+            dc1, ln1, act1, dc2, act2 = self.output_upscaling
+            feat_s0, feat_s1 = high_res_features
+            upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
+            upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
+
+        hyper_in_list: List[torch.Tensor] = [
+            self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
+        ]
+        hyper_in = torch.stack(hyper_in_list, dim=1)
+        b, c, h, w = upscaled_embedding.shape
+        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+        # Generate mask quality predictions
+        iou_pred = self.iou_prediction_head(iou_token_out)
+        if self.pred_obj_scores:
+            assert s == 1
+            object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
+        else:
+            # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
+            object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
+
+        return masks, iou_pred, mask_tokens_out, object_score_logits
+
+    def _get_stability_scores(self, mask_logits):
+        """Computes mask stability scores based on IoU between upper and lower thresholds."""
+        mask_logits = mask_logits.flatten(-2)
+        stability_delta = self.dynamic_multimask_stability_delta
+        area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+        area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+        return torch.where(area_u > 0, area_i / area_u, 1.0)
+
+    def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+        """
+        Dynamically selects the most stable mask output based on stability scores and IoU predictions.
+
+        This method is used when outputting a single mask. If the stability score from the current single-mask
+        output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs
+        (based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask
+        for both clicking and tracking scenarios.
+
+        Args:
+            all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
+                batch size, N is number of masks (typically 4), and H, W are mask dimensions.
+            all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
+
+        Returns:
+            (Tuple[torch.Tensor, torch.Tensor]):
+                - mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).
+                - iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).
+
+        Examples:
+            >>> decoder = SAM2MaskDecoder(...)
+            >>> all_mask_logits = torch.rand(2, 4, 256, 256)  # 2 images, 4 masks each
+            >>> all_iou_scores = torch.rand(2, 4)
+            >>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores)
+            >>> print(mask_logits.shape, iou_scores.shape)
+            torch.Size([2, 1, 256, 256]) torch.Size([2, 1])
+        """
+        # The best mask from multimask output tokens (1~3)
+        multimask_logits = all_mask_logits[:, 1:, :, :]
+        multimask_iou_scores = all_iou_scores[:, 1:]
+        best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
+        batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
+        best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
+        best_multimask_logits = best_multimask_logits.unsqueeze(1)
+        best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
+        best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
+
+        # The mask from singlemask output token 0 and its stability score
+        singlemask_logits = all_mask_logits[:, 0:1, :, :]
+        singlemask_iou_scores = all_iou_scores[:, 0:1]
+        stability_scores = self._get_stability_scores(singlemask_logits)
+        is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+
+        # Dynamically fall back to best multimask output upon low stability scores.
+        mask_logits_out = torch.where(
+            is_stable[..., None, None].expand_as(singlemask_logits),
+            singlemask_logits,
+            best_multimask_logits,
+        )
+        iou_scores_out = torch.where(
+            is_stable.expand_as(singlemask_iou_scores),
+            singlemask_iou_scores,
+            best_multimask_iou_scores,
+        )
+        return mask_logits_out, iou_scores_out

+ 794 - 0
ultralytics/models/sam/modules/encoders.py

@@ -0,0 +1,794 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ultralytics.nn.modules import LayerNorm2d
+
+from .blocks import (
+    Block,
+    CXBlock,
+    Fuser,
+    MaskDownSampler,
+    MultiScaleBlock,
+    PatchEmbed,
+    PositionEmbeddingRandom,
+    PositionEmbeddingSine,
+)
+
+
+class ImageEncoderViT(nn.Module):
+    """
+    An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
+
+    This class processes images by splitting them into patches, applying transformer blocks, and generating a final
+    encoded representation through a neck module.
+
+    Attributes:
+        img_size (int): Dimension of input images, assumed to be square.
+        patch_embed (PatchEmbed): Module for patch embedding.
+        pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
+        blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
+        neck (nn.Sequential): Neck module to further process the output.
+
+    Methods:
+        forward: Processes input through patch embedding, positional embedding, blocks, and neck.
+
+    Examples:
+        >>> import torch
+        >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
+        >>> input_image = torch.randn(1, 3, 224, 224)
+        >>> output = encoder(input_image)
+        >>> print(output.shape)
+    """
+
+    def __init__(
+        self,
+        img_size: int = 1024,
+        patch_size: int = 16,
+        in_chans: int = 3,
+        embed_dim: int = 768,
+        depth: int = 12,
+        num_heads: int = 12,
+        mlp_ratio: float = 4.0,
+        out_chans: int = 256,
+        qkv_bias: bool = True,
+        norm_layer: Type[nn.Module] = nn.LayerNorm,
+        act_layer: Type[nn.Module] = nn.GELU,
+        use_abs_pos: bool = True,
+        use_rel_pos: bool = False,
+        rel_pos_zero_init: bool = True,
+        window_size: int = 0,
+        global_attn_indexes: Tuple[int, ...] = (),
+    ) -> None:
+        """
+        Initializes an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
+
+        Args:
+            img_size (int): Input image size, assumed to be square.
+            patch_size (int): Size of image patches.
+            in_chans (int): Number of input image channels.
+            embed_dim (int): Dimension of patch embeddings.
+            depth (int): Number of transformer blocks.
+            num_heads (int): Number of attention heads in each block.
+            mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
+            out_chans (int): Number of output channels from the neck module.
+            qkv_bias (bool): If True, adds learnable bias to query, key, value projections.
+            norm_layer (Type[nn.Module]): Type of normalization layer to use.
+            act_layer (Type[nn.Module]): Type of activation layer to use.
+            use_abs_pos (bool): If True, uses absolute positional embeddings.
+            use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.
+            rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
+            window_size (int): Size of attention window for windowed attention blocks.
+            global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.
+
+        Attributes:
+            img_size (int): Dimension of input images.
+            patch_embed (PatchEmbed): Module for patch embedding.
+            pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
+            blocks (nn.ModuleList): List of transformer blocks.
+            neck (nn.Sequential): Neck module for final processing.
+
+        Examples:
+            >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
+            >>> input_image = torch.randn(1, 3, 224, 224)
+            >>> output = encoder(input_image)
+            >>> print(output.shape)
+        """
+        super().__init__()
+        self.img_size = img_size
+
+        self.patch_embed = PatchEmbed(
+            kernel_size=(patch_size, patch_size),
+            stride=(patch_size, patch_size),
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+        )
+
+        self.pos_embed: Optional[nn.Parameter] = None
+        if use_abs_pos:
+            # Initialize absolute positional embedding with pretrain image size.
+            self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
+
+        self.blocks = nn.ModuleList()
+        for i in range(depth):
+            block = Block(
+                dim=embed_dim,
+                num_heads=num_heads,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                norm_layer=norm_layer,
+                act_layer=act_layer,
+                use_rel_pos=use_rel_pos,
+                rel_pos_zero_init=rel_pos_zero_init,
+                window_size=window_size if i not in global_attn_indexes else 0,
+                input_size=(img_size // patch_size, img_size // patch_size),
+            )
+            self.blocks.append(block)
+
+        self.neck = nn.Sequential(
+            nn.Conv2d(
+                embed_dim,
+                out_chans,
+                kernel_size=1,
+                bias=False,
+            ),
+            LayerNorm2d(out_chans),
+            nn.Conv2d(
+                out_chans,
+                out_chans,
+                kernel_size=3,
+                padding=1,
+                bias=False,
+            ),
+            LayerNorm2d(out_chans),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Processes input through patch embedding, positional embedding, transformer blocks, and neck module."""
+        x = self.patch_embed(x)
+        if self.pos_embed is not None:
+            pos_embed = (
+                F.interpolate(self.pos_embed.permute(0, 3, 1, 2), scale_factor=self.img_size / 1024).permute(0, 2, 3, 1)
+                if self.img_size != 1024
+                else self.pos_embed
+            )
+            x = x + pos_embed
+        for blk in self.blocks:
+            x = blk(x)
+        return self.neck(x.permute(0, 3, 1, 2))
+
+
+class PromptEncoder(nn.Module):
+    """
+    Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
+
+    Attributes:
+        embed_dim (int): Dimension of the embeddings.
+        input_image_size (Tuple[int, int]): Size of the input image as (H, W).
+        image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
+        pe_layer (PositionEmbeddingRandom): Module for random position embedding.
+        num_point_embeddings (int): Number of point embeddings for different types of points.
+        point_embeddings (nn.ModuleList): List of point embeddings.
+        not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
+        mask_input_size (Tuple[int, int]): Size of the input mask.
+        mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
+        no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
+
+    Methods:
+        get_dense_pe: Returns the positional encoding used to encode point prompts.
+        forward: Embeds different types of prompts, returning both sparse and dense embeddings.
+
+    Examples:
+        >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
+        >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
+        >>> boxes = torch.rand(1, 2, 2)
+        >>> masks = torch.rand(1, 1, 256, 256)
+        >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
+        >>> print(sparse_embeddings.shape, dense_embeddings.shape)
+        torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        image_embedding_size: Tuple[int, int],
+        input_image_size: Tuple[int, int],
+        mask_in_chans: int,
+        activation: Type[nn.Module] = nn.GELU,
+    ) -> None:
+        """
+        Initializes the PromptEncoder module for encoding various types of prompts.
+
+        This module encodes different types of prompts (points, boxes, masks) for input to SAM's mask decoder,
+        producing both sparse and dense embeddings.
+
+        Args:
+            embed_dim (int): The dimension of the embeddings.
+            image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W).
+            input_image_size (Tuple[int, int]): The padded size of the input image as (H, W).
+            mask_in_chans (int): The number of hidden channels used for encoding input masks.
+            activation (Type[nn.Module]): The activation function to use when encoding input masks.
+
+        Attributes:
+            embed_dim (int): Dimension of the embeddings.
+            input_image_size (Tuple[int, int]): Size of the input image as (H, W).
+            image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
+            pe_layer (PositionEmbeddingRandom): Module for random position embedding.
+            num_point_embeddings (int): Number of point embeddings for different types of points.
+            point_embeddings (nn.ModuleList): List of point embeddings.
+            not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
+            mask_input_size (Tuple[int, int]): Size of the input mask.
+            mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
+
+        Examples:
+            >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
+            >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
+            >>> boxes = torch.rand(1, 2, 2)
+            >>> masks = torch.rand(1, 1, 256, 256)
+            >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
+            >>> print(sparse_embeddings.shape, dense_embeddings.shape)
+            torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
+        """
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.input_image_size = input_image_size
+        self.image_embedding_size = image_embedding_size
+        self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+
+        self.num_point_embeddings: int = 4  # pos/neg point + 2 box corners
+        point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]
+        self.point_embeddings = nn.ModuleList(point_embeddings)
+        self.not_a_point_embed = nn.Embedding(1, embed_dim)
+
+        self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
+        self.mask_downscaling = nn.Sequential(
+            nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+            LayerNorm2d(mask_in_chans // 4),
+            activation(),
+            nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+            LayerNorm2d(mask_in_chans),
+            activation(),
+            nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+        )
+        self.no_mask_embed = nn.Embedding(1, embed_dim)
+
+    def get_dense_pe(self) -> torch.Tensor:
+        """
+        Returns the dense positional encoding used for encoding point prompts.
+
+        This method generates a positional encoding for a dense set of points matching the shape of the image
+        encoding. The encoding is used to provide spatial information to the model when processing point prompts.
+
+        Returns:
+            (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
+                height and width of the image embedding size, respectively.
+
+        Examples:
+            >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
+            >>> dense_pe = prompt_encoder.get_dense_pe()
+            >>> print(dense_pe.shape)
+            torch.Size([1, 256, 64, 64])
+        """
+        return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+    def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
+        """Embeds point prompts by applying positional encoding and label-specific embeddings."""
+        points = points + 0.5  # Shift to center of pixel
+        if pad:
+            padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
+            padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
+            points = torch.cat([points, padding_point], dim=1)
+            labels = torch.cat([labels, padding_label], dim=1)
+        point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
+        point_embedding[labels == -1] = 0.0
+        point_embedding[labels == -1] += self.not_a_point_embed.weight
+        point_embedding[labels == 0] += self.point_embeddings[0].weight
+        point_embedding[labels == 1] += self.point_embeddings[1].weight
+        point_embedding[labels == 2] += self.point_embeddings[2].weight
+        point_embedding[labels == 3] += self.point_embeddings[3].weight
+        return point_embedding
+
+    def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+        """Embeds box prompts by applying positional encoding and adding corner embeddings."""
+        boxes = boxes + 0.5  # Shift to center of pixel
+        coords = boxes.reshape(-1, 2, 2)
+        corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
+        corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+        corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+        return corner_embedding
+
+    def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+        """Embeds mask inputs by downscaling and processing through convolutional layers."""
+        return self.mask_downscaling(masks)
+
+    @staticmethod
+    def _get_batch_size(
+        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+        boxes: Optional[torch.Tensor],
+        masks: Optional[torch.Tensor],
+    ) -> int:
+        """Gets the batch size of the output given the batch size of the input prompts."""
+        if points is not None:
+            return points[0].shape[0]
+        elif boxes is not None:
+            return boxes.shape[0]
+        elif masks is not None:
+            return masks.shape[0]
+        else:
+            return 1
+
+    def _get_device(self) -> torch.device:
+        """Returns the device of the first point embedding's weight tensor."""
+        return self.point_embeddings[0].weight.device
+
+    def forward(
+        self,
+        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+        boxes: Optional[torch.Tensor],
+        masks: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Embeds different types of prompts, returning both sparse and dense embeddings.
+
+        Args:
+            points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
+                tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with
+                shape (B, N).
+            boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
+            masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
+
+        Returns:
+            (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
+                - sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
+                - dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
+
+        Examples:
+            >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
+            >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
+            >>> boxes = torch.rand(1, 2, 2, 2)
+            >>> masks = torch.rand(1, 1, 256, 256)
+            >>> sparse_emb, dense_emb = encoder(points, boxes, masks)
+            >>> print(sparse_emb.shape, dense_emb.shape)
+            torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
+        """
+        bs = self._get_batch_size(points, boxes, masks)
+        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
+        if points is not None:
+            coords, labels = points
+            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
+            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+        if boxes is not None:
+            box_embeddings = self._embed_boxes(boxes)
+            sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
+
+        if masks is not None:
+            dense_embeddings = self._embed_masks(masks)
+        else:
+            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+            )
+
+        return sparse_embeddings, dense_embeddings
+
+
+class MemoryEncoder(nn.Module):
+    """
+    Encodes pixel features and masks into a memory representation for efficient image segmentation.
+
+    This class processes pixel-level features and masks, fusing them to generate encoded memory representations
+    suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
+
+    Attributes:
+        mask_downsampler (MaskDownSampler): Module for downsampling input masks.
+        pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.
+        fuser (Fuser): Module for fusing pixel features and masks.
+        position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.
+        out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.
+
+    Methods:
+        forward: Processes input pixel features and masks to generate encoded memory representations.
+
+    Examples:
+        >>> import torch
+        >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
+        >>> pix_feat = torch.randn(1, 256, 64, 64)
+        >>> masks = torch.randn(1, 1, 64, 64)
+        >>> encoded_feat, pos = encoder(pix_feat, masks)
+        >>> print(encoded_feat.shape, pos.shape)
+        torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
+    """
+
+    def __init__(
+        self,
+        out_dim,
+        in_dim=256,  # in_dim of pix_feats
+    ):
+        """Initializes the MemoryEncoder for encoding pixel features and masks into memory representations."""
+        super().__init__()
+
+        self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
+
+        self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
+        self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
+        self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
+        self.out_proj = nn.Identity()
+        if out_dim != in_dim:
+            self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
+
+    def forward(
+        self,
+        pix_feat: torch.Tensor,
+        masks: torch.Tensor,
+        skip_mask_sigmoid: bool = False,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Processes pixel features and masks to generate encoded memory representations for segmentation."""
+        if not skip_mask_sigmoid:
+            masks = F.sigmoid(masks)
+        masks = self.mask_downsampler(masks)
+
+        # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
+        pix_feat = pix_feat.to(masks.device)
+
+        x = self.pix_feat_proj(pix_feat)
+        x = x + masks
+        x = self.fuser(x)
+        x = self.out_proj(x)
+
+        pos = self.position_encoding(x).to(x.dtype)
+
+        return {"vision_features": x, "vision_pos_enc": [pos]}
+
+
+class ImageEncoder(nn.Module):
+    """
+    Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings.
+
+    This class combines a trunk network for feature extraction with a neck network for feature refinement
+    and positional encoding generation. It can optionally discard the lowest resolution features.
+
+    Attributes:
+        trunk (nn.Module): The trunk network for initial feature extraction.
+        neck (nn.Module): The neck network for feature refinement and positional encoding generation.
+        scalp (int): Number of lowest resolution feature levels to discard.
+
+    Methods:
+        forward: Processes the input image through the trunk and neck networks.
+
+    Examples:
+        >>> trunk = SomeTrunkNetwork()
+        >>> neck = SomeNeckNetwork()
+        >>> encoder = ImageEncoder(trunk, neck, scalp=1)
+        >>> image = torch.randn(1, 3, 224, 224)
+        >>> output = encoder(image)
+        >>> print(output.keys())
+        dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
+    """
+
+    def __init__(
+        self,
+        trunk: nn.Module,
+        neck: nn.Module,
+        scalp: int = 0,
+    ):
+        """Initializes the ImageEncoder with trunk and neck networks for feature extraction and refinement."""
+        super().__init__()
+        self.trunk = trunk
+        self.neck = neck
+        self.scalp = scalp
+        assert self.trunk.channel_list == self.neck.backbone_channel_list, (
+            f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
+        )
+
+    def forward(self, sample: torch.Tensor):
+        """Encodes input through patch embedding, positional embedding, transformer blocks, and neck module."""
+        features, pos = self.neck(self.trunk(sample))
+        if self.scalp > 0:
+            # Discard the lowest resolution features
+            features, pos = features[: -self.scalp], pos[: -self.scalp]
+
+        src = features[-1]
+        return {
+            "vision_features": src,
+            "vision_pos_enc": pos,
+            "backbone_fpn": features,
+        }
+
+
+class FpnNeck(nn.Module):
+    """
+    A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
+
+    This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
+    similar to ViT positional embedding interpolation.
+
+    Attributes:
+        position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.
+        convs (nn.ModuleList): List of convolutional layers for each backbone level.
+        backbone_channel_list (List[int]): List of channel dimensions from the backbone.
+        fpn_interp_model (str): Interpolation mode for FPN feature resizing.
+        fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
+        fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.
+
+    Methods:
+        forward: Performs forward pass through the FPN neck.
+
+    Examples:
+        >>> backbone_channels = [64, 128, 256, 512]
+        >>> fpn_neck = FpnNeck(256, backbone_channels)
+        >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]
+        >>> outputs, positions = fpn_neck(inputs)
+        >>> print(len(outputs), len(positions))
+        4 4
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        backbone_channel_list: List[int],
+        kernel_size: int = 1,
+        stride: int = 1,
+        padding: int = 0,
+        fpn_interp_model: str = "bilinear",
+        fuse_type: str = "sum",
+        fpn_top_down_levels: Optional[List[int]] = None,
+    ):
+        """
+        Initializes a modified Feature Pyramid Network (FPN) neck.
+
+        This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
+        similar to ViT positional embedding interpolation.
+
+        Args:
+            d_model (int): Dimension of the model.
+            backbone_channel_list (List[int]): List of channel dimensions from the backbone.
+            kernel_size (int): Kernel size for the convolutional layers.
+            stride (int): Stride for the convolutional layers.
+            padding (int): Padding for the convolutional layers.
+            fpn_interp_model (str): Interpolation mode for FPN feature resizing.
+            fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
+            fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
+
+        Examples:
+            >>> backbone_channels = [64, 128, 256, 512]
+            >>> fpn_neck = FpnNeck(256, backbone_channels)
+            >>> print(fpn_neck)
+        """
+        super().__init__()
+        self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
+        self.convs = nn.ModuleList()
+        self.backbone_channel_list = backbone_channel_list
+        for dim in backbone_channel_list:
+            current = nn.Sequential()
+            current.add_module(
+                "conv",
+                nn.Conv2d(
+                    in_channels=dim,
+                    out_channels=d_model,
+                    kernel_size=kernel_size,
+                    stride=stride,
+                    padding=padding,
+                ),
+            )
+
+            self.convs.append(current)
+        self.fpn_interp_model = fpn_interp_model
+        assert fuse_type in {"sum", "avg"}
+        self.fuse_type = fuse_type
+
+        # levels to have top-down features in its outputs
+        # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
+        # have top-down propagation, while outputs of level 0 and level 1 have only
+        # lateral features from the same backbone level.
+        if fpn_top_down_levels is None:
+            # default is to have top-down features on all levels
+            fpn_top_down_levels = range(len(self.convs))
+        self.fpn_top_down_levels = list(fpn_top_down_levels)
+
+    def forward(self, xs: List[torch.Tensor]):
+        """
+        Performs forward pass through the Feature Pyramid Network (FPN) neck.
+
+        This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
+        and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
+
+        Args:
+            xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
+
+        Returns:
+            (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing:
+                - out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape
+                  (B, d_model, H, W).
+                - pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.
+
+        Examples:
+            >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
+            >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
+            >>> outputs, positions = fpn_neck(inputs)
+            >>> print(len(outputs), len(positions))
+            4 4
+        """
+        out = [None] * len(self.convs)
+        pos = [None] * len(self.convs)
+        assert len(xs) == len(self.convs)
+        # fpn forward pass
+        # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
+        prev_features = None
+        # forward in top-down order (from low to high resolution)
+        n = len(self.convs) - 1
+        for i in range(n, -1, -1):
+            x = xs[i]
+            lateral_features = self.convs[n - i](x)
+            if i in self.fpn_top_down_levels and prev_features is not None:
+                top_down_features = F.interpolate(
+                    prev_features.to(dtype=torch.float32),
+                    scale_factor=2.0,
+                    mode=self.fpn_interp_model,
+                    align_corners=(None if self.fpn_interp_model == "nearest" else False),
+                    antialias=False,
+                )
+                prev_features = lateral_features + top_down_features
+                if self.fuse_type == "avg":
+                    prev_features /= 2
+            else:
+                prev_features = lateral_features
+            x_out = prev_features
+            out[i] = x_out
+            pos[i] = self.position_encoding(x_out).to(x_out.dtype)
+
+        return out, pos
+
+
+class Hiera(nn.Module):
+    """
+    Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
+
+    This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for
+    efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,
+    with optional pooling and global attention mechanisms.
+
+    Attributes:
+        window_spec (Tuple[int, ...]): Window sizes for each stage.
+        q_stride (Tuple[int, int]): Downsampling stride between stages.
+        stage_ends (List[int]): Indices of the last block in each stage.
+        q_pool_blocks (List[int]): Indices of blocks where pooling is applied.
+        return_interm_layers (bool): Whether to return intermediate layer outputs.
+        patch_embed (PatchEmbed): Module for patch embedding.
+        global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.
+        window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.
+        pos_embed (nn.Parameter): Positional embedding for the background.
+        pos_embed_window (nn.Parameter): Positional embedding for the window.
+        blocks (nn.ModuleList): List of MultiScaleBlock modules.
+        channel_list (List[int]): List of output channel dimensions for each stage.
+
+    Methods:
+        _get_pos_embed: Generates positional embeddings by interpolating and combining window and background embeddings.
+        forward: Performs the forward pass through the Hiera model.
+
+    Examples:
+        >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
+        >>> input_tensor = torch.randn(1, 3, 224, 224)
+        >>> output_features = model(input_tensor)
+        >>> for feat in output_features:
+        ...     print(feat.shape)
+    """
+
+    def __init__(
+        self,
+        embed_dim: int = 96,  # initial embed dim
+        num_heads: int = 1,  # initial number of heads
+        drop_path_rate: float = 0.0,  # stochastic depth
+        q_pool: int = 3,  # number of q_pool stages
+        q_stride: Tuple[int, int] = (2, 2),  # downsample stride bet. stages
+        stages: Tuple[int, ...] = (2, 3, 16, 3),  # blocks per stage
+        dim_mul: float = 2.0,  # dim_mul factor at stage shift
+        head_mul: float = 2.0,  # head_mul factor at stage shift
+        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
+        # window size per stage, when not using global att.
+        window_spec: Tuple[int, ...] = (
+            8,
+            4,
+            14,
+            7,
+        ),
+        # global attn in these blocks
+        global_att_blocks: Tuple[int, ...] = (
+            12,
+            16,
+            20,
+        ),
+        return_interm_layers=True,  # return feats from every stage
+    ):
+        """Initializes the Hiera model, configuring its hierarchical vision transformer architecture."""
+        super().__init__()
+
+        assert len(stages) == len(window_spec)
+        self.window_spec = window_spec
+
+        depth = sum(stages)
+        self.q_stride = q_stride
+        self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
+        assert 0 <= q_pool <= len(self.stage_ends[:-1])
+        self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
+        self.return_interm_layers = return_interm_layers
+
+        self.patch_embed = PatchEmbed(
+            embed_dim=embed_dim,
+            kernel_size=(7, 7),
+            stride=(4, 4),
+            padding=(3, 3),
+        )
+        # Which blocks have global att?
+        self.global_att_blocks = global_att_blocks
+
+        # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
+        self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
+        self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
+        self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
+
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+
+        cur_stage = 1
+        self.blocks = nn.ModuleList()
+
+        for i in range(depth):
+            dim_out = embed_dim
+            # lags by a block, so first block of
+            # next stage uses an initial window size
+            # of previous stage and final window size of current stage
+            window_size = self.window_spec[cur_stage - 1]
+
+            if self.global_att_blocks is not None:
+                window_size = 0 if i in self.global_att_blocks else window_size
+
+            if i - 1 in self.stage_ends:
+                dim_out = int(embed_dim * dim_mul)
+                num_heads = int(num_heads * head_mul)
+                cur_stage += 1
+
+            block = MultiScaleBlock(
+                dim=embed_dim,
+                dim_out=dim_out,
+                num_heads=num_heads,
+                drop_path=dpr[i],
+                q_stride=self.q_stride if i in self.q_pool_blocks else None,
+                window_size=window_size,
+            )
+
+            embed_dim = dim_out
+            self.blocks.append(block)
+
+        self.channel_list = (
+            [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
+            if return_interm_layers
+            else [self.blocks[-1].dim_out]
+        )
+
+    def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
+        """Generates positional embeddings by interpolating and combining window and background embeddings."""
+        h, w = hw
+        window_embed = self.pos_embed_window
+        pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
+        pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
+        pos_embed = pos_embed.permute(0, 2, 3, 1)
+        return pos_embed
+
+    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+        """Performs forward pass through Hiera model, extracting multiscale features from input images."""
+        x = self.patch_embed(x)
+        # x: (B, H, W, C)
+
+        # Add pos embed
+        x = x + self._get_pos_embed(x.shape[1:3])
+
+        outputs = []
+        for i, blk in enumerate(self.blocks):
+            x = blk(x)
+            if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
+                feats = x.permute(0, 3, 1, 2)
+                outputs.append(feats)
+
+        return outputs

+ 237 - 0
ultralytics/models/sam/modules/memory_attention.py

@@ -0,0 +1,237 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import copy
+from typing import Optional
+
+import torch
+from torch import Tensor, nn
+
+from .blocks import RoPEAttention
+
+
+class MemoryAttentionLayer(nn.Module):
+    """
+    Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
+
+    This class combines self-attention, cross-attention, and feedforward components to process input tensors and
+    generate memory-based attention outputs.
+
+    Attributes:
+        d_model (int): Dimensionality of the model.
+        dim_feedforward (int): Dimensionality of the feedforward network.
+        dropout_value (float): Dropout rate for regularization.
+        self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).
+        cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.
+        linear1 (nn.Linear): First linear layer of the feedforward network.
+        linear2 (nn.Linear): Second linear layer of the feedforward network.
+        norm1 (nn.LayerNorm): Layer normalization for self-attention output.
+        norm2 (nn.LayerNorm): Layer normalization for cross-attention output.
+        norm3 (nn.LayerNorm): Layer normalization for feedforward network output.
+        dropout1 (nn.Dropout): Dropout layer after self-attention.
+        dropout2 (nn.Dropout): Dropout layer after cross-attention.
+        dropout3 (nn.Dropout): Dropout layer after feedforward network.
+        activation (nn.ReLU): Activation function for the feedforward network.
+        pos_enc_at_attn (bool): Flag to add positional encoding at attention.
+        pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.
+        pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.
+
+    Methods:
+        forward: Performs the full memory attention operation on input tensors.
+        _forward_sa: Performs self-attention on input tensor.
+        _forward_ca: Performs cross-attention between target and memory tensors.
+
+    Examples:
+        >>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)
+        >>> tgt = torch.randn(1, 100, 256)
+        >>> memory = torch.randn(1, 100, 64)
+        >>> pos = torch.randn(1, 100, 256)
+        >>> query_pos = torch.randn(1, 100, 256)
+        >>> output = layer(tgt, memory, pos, query_pos)
+        >>> print(output.shape)
+        torch.Size([1, 100, 256])
+    """
+
+    def __init__(
+        self,
+        d_model: int = 256,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        pos_enc_at_attn: bool = False,
+        pos_enc_at_cross_attn_keys: bool = True,
+        pos_enc_at_cross_attn_queries: bool = False,
+    ):
+        """Initializes a memory attention layer with self-attention, cross-attention, and feedforward components."""
+        super().__init__()
+        self.d_model = d_model
+        self.dim_feedforward = dim_feedforward
+        self.dropout_value = dropout
+        self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
+        self.cross_attn_image = RoPEAttention(
+            rope_k_repeat=True,
+            embedding_dim=256,
+            num_heads=1,
+            downsample_rate=1,
+            kv_in_dim=64,
+        )
+
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.norm3 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.dropout3 = nn.Dropout(dropout)
+
+        self.activation = nn.ReLU()
+
+        # Where to add pos enc
+        self.pos_enc_at_attn = pos_enc_at_attn
+        self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
+        self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
+
+    def _forward_sa(self, tgt, query_pos):
+        """Performs self-attention on input tensor using positional encoding and RoPE attention mechanism."""
+        tgt2 = self.norm1(tgt)
+        q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
+        tgt2 = self.self_attn(q, k, v=tgt2)
+        tgt = tgt + self.dropout1(tgt2)
+        return tgt
+
+    def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
+        """Performs cross-attention between target and memory tensors using RoPEAttention mechanism."""
+        kwds = {}
+        if num_k_exclude_rope > 0:
+            assert isinstance(self.cross_attn_image, RoPEAttention)
+            kwds = {"num_k_exclude_rope": num_k_exclude_rope}
+
+        # Cross-Attention
+        tgt2 = self.norm2(tgt)
+        tgt2 = self.cross_attn_image(
+            q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
+            k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
+            v=memory,
+            **kwds,
+        )
+        tgt = tgt + self.dropout2(tgt2)
+        return tgt
+
+    def forward(
+        self,
+        tgt,
+        memory,
+        pos: Optional[Tensor] = None,
+        query_pos: Optional[Tensor] = None,
+        num_k_exclude_rope: int = 0,
+    ) -> torch.Tensor:
+        """Processes input tensors using self-attention, cross-attention, and MLP for memory-based attention."""
+        tgt = self._forward_sa(tgt, query_pos)
+        tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
+        # MLP
+        tgt2 = self.norm3(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+        tgt = tgt + self.dropout3(tgt2)
+        return tgt
+
+
+class MemoryAttention(nn.Module):
+    """
+    Memory attention module for processing sequential data with self and cross-attention mechanisms.
+
+    This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
+    for processing sequential data, particularly useful in transformer-like architectures.
+
+    Attributes:
+        d_model (int): The dimension of the model's hidden state.
+        layers (nn.ModuleList): A list of MemoryAttentionLayer modules.
+        num_layers (int): The number of attention layers.
+        norm (nn.LayerNorm): Layer normalization applied to the output.
+        pos_enc_at_input (bool): Whether to apply positional encoding at the input.
+        batch_first (bool): Whether the input tensors are in batch-first format.
+
+    Methods:
+        forward: Processes input tensors through the attention layers.
+
+    Examples:
+        >>> d_model = 256
+        >>> layer = MemoryAttentionLayer(d_model)
+        >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
+        >>> curr = torch.randn(10, 32, d_model)  # (seq_len, batch_size, d_model)
+        >>> memory = torch.randn(20, 32, d_model)  # (mem_len, batch_size, d_model)
+        >>> curr_pos = torch.randn(10, 32, d_model)
+        >>> memory_pos = torch.randn(20, 32, d_model)
+        >>> output = attention(curr, memory, curr_pos, memory_pos)
+        >>> print(output.shape)
+        torch.Size([10, 32, 256])
+    """
+
+    def __init__(
+        self,
+        d_model: int,
+        pos_enc_at_input: bool,
+        layer: nn.Module,
+        num_layers: int,
+        batch_first: bool = True,  # Do layers expect batch first input?
+    ):
+        """Initializes MemoryAttention module with layers and normalization for attention processing."""
+        super().__init__()
+        self.d_model = d_model
+        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
+        self.num_layers = num_layers
+        self.norm = nn.LayerNorm(d_model)
+        self.pos_enc_at_input = pos_enc_at_input
+        self.batch_first = batch_first
+
+    def forward(
+        self,
+        curr: torch.Tensor,  # self-attention inputs
+        memory: torch.Tensor,  # cross-attention inputs
+        curr_pos: Optional[Tensor] = None,  # pos_enc for self-attention inputs
+        memory_pos: Optional[Tensor] = None,  # pos_enc for cross-attention inputs
+        num_obj_ptr_tokens: int = 0,  # number of object pointer *tokens*
+    ):
+        """Processes input tensors through multiple attention layers, applying self and cross-attention mechanisms."""
+        if isinstance(curr, list):
+            assert isinstance(curr_pos, list)
+            assert len(curr) == len(curr_pos) == 1
+            curr, curr_pos = (
+                curr[0],
+                curr_pos[0],
+            )
+
+        assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"
+
+        output = curr
+        if self.pos_enc_at_input and curr_pos is not None:
+            output = output + 0.1 * curr_pos
+
+        if self.batch_first:
+            # Convert to batch first
+            output = output.transpose(0, 1)
+            curr_pos = curr_pos.transpose(0, 1)
+            memory = memory.transpose(0, 1)
+            memory_pos = memory_pos.transpose(0, 1)
+
+        for layer in self.layers:
+            kwds = {}
+            if isinstance(layer.cross_attn_image, RoPEAttention):
+                kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
+
+            output = layer(
+                tgt=output,
+                memory=memory,
+                pos=memory_pos,
+                query_pos=curr_pos,
+                **kwds,
+            )
+        normed_output = self.norm(output)
+
+        if self.batch_first:
+            # Convert back to seq first
+            normed_output = normed_output.transpose(0, 1)
+            curr_pos = curr_pos.transpose(0, 1)
+
+        return normed_output

+ 1013 - 0
ultralytics/models/sam/modules/sam.py

@@ -0,0 +1,1013 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn.init import trunc_normal_
+
+from ultralytics.nn.modules import MLP
+
+from .blocks import SAM2TwoWayTransformer
+from .decoders import MaskDecoder, SAM2MaskDecoder
+from .encoders import ImageEncoderViT, PromptEncoder
+from .utils import get_1d_sine_pe, select_closest_cond_frames
+
+# a large negative value as a placeholder score for missing objects
+NO_OBJ_SCORE = -1024.0
+
+
+class SAMModel(nn.Module):
+    """
+    Segment Anything Model (SAM) for object segmentation tasks.
+
+    This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images
+    and input prompts.
+
+    Attributes:
+        mask_threshold (float): Threshold value for mask prediction.
+        image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
+        prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
+        mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
+
+    Methods:
+        __init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
+
+    Examples:
+        >>> image_encoder = ImageEncoderViT(...)
+        >>> prompt_encoder = PromptEncoder(...)
+        >>> mask_decoder = MaskDecoder(...)
+        >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
+        >>> # Further usage depends on SAMPredictor class
+
+    Notes:
+        All forward() operations are implemented in the SAMPredictor class.
+    """
+
+    mask_threshold: float = 0.0
+
+    def __init__(
+        self,
+        image_encoder: ImageEncoderViT,
+        prompt_encoder: PromptEncoder,
+        mask_decoder: MaskDecoder,
+        pixel_mean: List[float] = (123.675, 116.28, 103.53),
+        pixel_std: List[float] = (58.395, 57.12, 57.375),
+    ) -> None:
+        """
+        Initialize the SAMModel class to predict object masks from an image and input prompts.
+
+        Args:
+            image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
+            prompt_encoder (PromptEncoder): Encodes various types of input prompts.
+            mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
+            pixel_mean (List[float]): Mean values for normalizing pixels in the input image.
+            pixel_std (List[float]): Std values for normalizing pixels in the input image.
+
+        Examples:
+            >>> image_encoder = ImageEncoderViT(...)
+            >>> prompt_encoder = PromptEncoder(...)
+            >>> mask_decoder = MaskDecoder(...)
+            >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
+            >>> # Further usage depends on SAMPredictor class
+
+        Notes:
+            All forward() operations moved to SAMPredictor.
+        """
+        super().__init__()
+        self.image_encoder = image_encoder
+        self.prompt_encoder = prompt_encoder
+        self.mask_decoder = mask_decoder
+        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+    def set_imgsz(self, imgsz):
+        """
+        Set image size to make model compatible with different image sizes.
+
+        Args:
+            imgsz (Tuple[int, int]): The size of the input image.
+        """
+        if hasattr(self.image_encoder, "set_imgsz"):
+            self.image_encoder.set_imgsz(imgsz)
+        self.prompt_encoder.input_image_size = imgsz
+        self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz]  # 16 is fixed as patch size of ViT model
+        self.image_encoder.img_size = imgsz[0]
+
+
+class SAM2Model(torch.nn.Module):
+    """
+    SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
+
+    This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms
+    for temporal consistency and efficient tracking of objects across frames.
+
+    Attributes:
+        mask_threshold (float): Threshold value for mask prediction.
+        image_encoder (ImageEncoderViT): Visual encoder for extracting image features.
+        memory_attention (nn.Module): Module for attending to memory features.
+        memory_encoder (nn.Module): Encoder for generating memory representations.
+        num_maskmem (int): Number of accessible memory frames.
+        image_size (int): Size of input images.
+        backbone_stride (int): Stride of the backbone network output.
+        sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings.
+        sam_image_embedding_size (int): Size of SAM image embeddings.
+        sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts.
+        sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.
+        obj_ptr_proj (nn.Module): Projection layer for object pointers.
+        obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.
+
+    Methods:
+        forward_image: Processes image batch through encoder to extract multi-level features.
+        track_step: Performs a single tracking step, updating object masks and memory features.
+
+    Examples:
+        >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
+        >>> image_batch = torch.rand(1, 3, 512, 512)
+        >>> features = model.forward_image(image_batch)
+        >>> track_results = model.track_step(0, True, features, None, None, None, {})
+    """
+
+    mask_threshold: float = 0.0
+
+    def __init__(
+        self,
+        image_encoder,
+        memory_attention,
+        memory_encoder,
+        num_maskmem=7,
+        image_size=512,
+        backbone_stride=16,
+        sigmoid_scale_for_mem_enc=1.0,
+        sigmoid_bias_for_mem_enc=0.0,
+        binarize_mask_from_pts_for_mem_enc=False,
+        use_mask_input_as_output_without_sam=False,
+        max_cond_frames_in_attn=-1,
+        directly_add_no_mem_embed=False,
+        use_high_res_features_in_sam=False,
+        multimask_output_in_sam=False,
+        multimask_min_pt_num=1,
+        multimask_max_pt_num=1,
+        multimask_output_for_tracking=False,
+        use_multimask_token_for_obj_ptr: bool = False,
+        iou_prediction_use_sigmoid=False,
+        memory_temporal_stride_for_eval=1,
+        non_overlap_masks_for_mem_enc=False,
+        use_obj_ptrs_in_encoder=False,
+        max_obj_ptrs_in_encoder=16,
+        add_tpos_enc_to_obj_ptrs=True,
+        proj_tpos_enc_in_obj_ptrs=False,
+        use_signed_tpos_enc_to_obj_ptrs=False,
+        only_obj_ptrs_in_the_past_for_eval=False,
+        pred_obj_scores: bool = False,
+        pred_obj_scores_mlp: bool = False,
+        fixed_no_obj_ptr: bool = False,
+        soft_no_obj_ptr: bool = False,
+        use_mlp_for_obj_ptr_proj: bool = False,
+        no_obj_embed_spatial: bool = False,
+        sam_mask_decoder_extra_args=None,
+        compile_image_encoder: bool = False,
+    ):
+        """
+        Initializes the SAM2Model for video object segmentation with memory-based tracking.
+
+        Args:
+            image_encoder (nn.Module): Visual encoder for extracting image features.
+            memory_attention (nn.Module): Module for attending to memory features.
+            memory_encoder (nn.Module): Encoder for generating memory representations.
+            num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames).
+            image_size (int): Size of input images.
+            backbone_stride (int): Stride of the image backbone output.
+            sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
+            sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
+            binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
+                with clicks during evaluation.
+            use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
+                prompt encoder and mask decoder on frames with mask input.
+            max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
+                -1 means no limit.
+            directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
+                first frame.
+            use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
+            multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial
+                conditioning frames.
+            multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
+            multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
+            multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
+            use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
+            iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
+            memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
+            non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
+                memory encoder during evaluation.
+            use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
+            max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
+                cross-attention.
+            add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
+                the encoder.
+            proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
+                encoding in object pointers.
+            use_signed_tpos_enc_to_obj_ptrs (bool): whether to use signed distance (instead of unsigned absolute distance)
+                in the temporal positional encoding in the object pointers, only relevant when both `use_obj_ptrs_in_encoder=True`
+                and `add_tpos_enc_to_obj_ptrs=True`.
+            only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
+                during evaluation.
+            pred_obj_scores (bool): Whether to predict if there is an object in the frame.
+            pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
+            fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
+            soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
+            use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
+            no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.
+            sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder.
+            compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
+
+        Examples:
+            >>> image_encoder = ImageEncoderViT(...)
+            >>> memory_attention = SAM2TwoWayTransformer(...)
+            >>> memory_encoder = nn.Sequential(...)
+            >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
+            >>> image_batch = torch.rand(1, 3, 512, 512)
+            >>> features = model.forward_image(image_batch)
+            >>> track_results = model.track_step(0, True, features, None, None, None, {})
+        """
+        super().__init__()
+
+        # Part 1: the image backbone
+        self.image_encoder = image_encoder
+        # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
+        self.use_high_res_features_in_sam = use_high_res_features_in_sam
+        self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
+        self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
+        self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
+        if use_obj_ptrs_in_encoder:
+            # A conv layer to downsample the mask prompt to stride 4 (the same stride as
+            # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
+            # so that it can be fed into the SAM mask decoder to generate a pointer.
+            self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
+        self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
+        if proj_tpos_enc_in_obj_ptrs:
+            assert add_tpos_enc_to_obj_ptrs  # these options need to be used together
+        self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
+        self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
+        self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
+
+        # Part 2: memory attention to condition current frame's visual features
+        # with memories (and obj ptrs) from past frames
+        self.memory_attention = memory_attention
+        self.hidden_dim = memory_attention.d_model
+
+        # Part 3: memory encoder for the previous frame's outputs
+        self.memory_encoder = memory_encoder
+        self.mem_dim = self.hidden_dim
+        if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
+            # if there is compression of memories along channel dim
+            self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
+        self.num_maskmem = num_maskmem  # Number of memories accessible
+        # Temporal encoding of the memories
+        self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
+        trunc_normal_(self.maskmem_tpos_enc, std=0.02)
+        # a single token to indicate no memory embedding from previous frames
+        self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+        self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+        trunc_normal_(self.no_mem_embed, std=0.02)
+        trunc_normal_(self.no_mem_pos_enc, std=0.02)
+        self.directly_add_no_mem_embed = directly_add_no_mem_embed
+        # Apply sigmoid to the output raw mask logits (to turn them from
+        # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
+        self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
+        self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
+        self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
+        self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
+        self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
+        # On frames with mask input, whether to directly output the input mask without
+        # using a SAM prompt encoder + mask decoder
+        self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
+        self.multimask_output_in_sam = multimask_output_in_sam
+        self.multimask_min_pt_num = multimask_min_pt_num
+        self.multimask_max_pt_num = multimask_max_pt_num
+        self.multimask_output_for_tracking = multimask_output_for_tracking
+        self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+        self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
+
+        # Part 4: SAM-style prompt encoder (for both mask and point inputs)
+        # and SAM-style mask decoder for the final mask output
+        self.image_size = image_size
+        self.backbone_stride = backbone_stride
+        self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
+        self.pred_obj_scores = pred_obj_scores
+        self.pred_obj_scores_mlp = pred_obj_scores_mlp
+        self.fixed_no_obj_ptr = fixed_no_obj_ptr
+        self.soft_no_obj_ptr = soft_no_obj_ptr
+        if self.fixed_no_obj_ptr:
+            assert self.pred_obj_scores
+            assert self.use_obj_ptrs_in_encoder
+        if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
+            self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
+            trunc_normal_(self.no_obj_ptr, std=0.02)
+        self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
+        self.no_obj_embed_spatial = None
+        if no_obj_embed_spatial:
+            self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
+            trunc_normal_(self.no_obj_embed_spatial, std=0.02)
+
+        self._build_sam_heads()
+        self.max_cond_frames_in_attn = max_cond_frames_in_attn
+
+        # Model compilation
+        if compile_image_encoder:
+            # Compile the forward function (not the full module) to allow loading checkpoints.
+            print("Image encoder compilation is enabled. First forward pass will be slow.")
+            self.image_encoder.forward = torch.compile(
+                self.image_encoder.forward,
+                mode="max-autotune",
+                fullgraph=True,
+                dynamic=False,
+            )
+
+    @property
+    def device(self):
+        """Returns the device on which the model's parameters are stored."""
+        return next(self.parameters()).device
+
+    def forward(self, *args, **kwargs):
+        """Processes image and prompt inputs to generate object masks and scores in video sequences."""
+        raise NotImplementedError(
+            "Please use the corresponding methods in SAM2VideoPredictor for inference."
+            "See notebooks/video_predictor_example.ipynb for an example."
+        )
+
+    def _build_sam_heads(self):
+        """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
+        self.sam_prompt_embed_dim = self.hidden_dim
+        self.sam_image_embedding_size = self.image_size // self.backbone_stride
+
+        # Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code)
+        self.sam_prompt_encoder = PromptEncoder(
+            embed_dim=self.sam_prompt_embed_dim,
+            image_embedding_size=(
+                self.sam_image_embedding_size,
+                self.sam_image_embedding_size,
+            ),
+            input_image_size=(self.image_size, self.image_size),
+            mask_in_chans=16,
+        )
+        self.sam_mask_decoder = SAM2MaskDecoder(
+            num_multimask_outputs=3,
+            transformer=SAM2TwoWayTransformer(
+                depth=2,
+                embedding_dim=self.sam_prompt_embed_dim,
+                mlp_dim=2048,
+                num_heads=8,
+            ),
+            transformer_dim=self.sam_prompt_embed_dim,
+            iou_head_depth=3,
+            iou_head_hidden_dim=256,
+            use_high_res_features=self.use_high_res_features_in_sam,
+            iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
+            pred_obj_scores=self.pred_obj_scores,
+            pred_obj_scores_mlp=self.pred_obj_scores_mlp,
+            use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
+            **(self.sam_mask_decoder_extra_args or {}),
+        )
+        if self.use_obj_ptrs_in_encoder:
+            # a linear projection on SAM output tokens to turn them into object pointers
+            self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
+            if self.use_mlp_for_obj_ptr_proj:
+                self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
+        else:
+            self.obj_ptr_proj = torch.nn.Identity()
+        if self.proj_tpos_enc_in_obj_ptrs:
+            # a linear projection on temporal positional encoding in object pointers to
+            # avoid potential interference with spatial positional encoding
+            self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
+        else:
+            self.obj_ptr_tpos_proj = torch.nn.Identity()
+
+    def _forward_sam_heads(
+        self,
+        backbone_features,
+        point_inputs=None,
+        mask_inputs=None,
+        high_res_features=None,
+        multimask_output=False,
+    ):
+        """
+        Forward pass through SAM prompt encoders and mask heads.
+
+        This method processes image features and optional point/mask inputs to generate object masks and scores.
+
+        Args:
+            backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
+            point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
+                'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
+                    pixel-unit coordinates in (x, y) format for P input points.
+                'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
+                    0 means negative clicks, and -1 means padding.
+            mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
+                same spatial size as the image.
+            high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
+                (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
+                for SAM decoder.
+            multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
+                output only 1 mask and its IoU estimate.
+
+        Returns:
+            (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
+                low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
+                high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
+                ious: Tensor of shape (B, M) with estimated IoU for each output mask.
+                low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
+                high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
+                obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
+                object_score_logits: Tensor of shape (B) with object score logits.
+
+            Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
+
+        Examples:
+            >>> backbone_features = torch.rand(1, 256, 32, 32)
+            >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
+            >>> mask_inputs = torch.rand(1, 1, 512, 512)
+            >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
+            >>> (
+            ...     low_res_multimasks,
+            ...     high_res_multimasks,
+            ...     ious,
+            ...     low_res_masks,
+            ...     high_res_masks,
+            ...     obj_ptr,
+            ...     object_score_logits,
+            ... ) = results
+        """
+        B = backbone_features.size(0)
+        device = backbone_features.device
+        assert backbone_features.size(1) == self.sam_prompt_embed_dim
+        assert backbone_features.size(2) == self.sam_image_embedding_size
+        assert backbone_features.size(3) == self.sam_image_embedding_size
+
+        # a) Handle point prompts
+        if point_inputs is not None:
+            sam_point_coords = point_inputs["point_coords"]
+            sam_point_labels = point_inputs["point_labels"]
+            assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
+        else:
+            # If no points are provide, pad with an empty point (with label -1)
+            sam_point_coords = torch.zeros(B, 1, 2, device=device)
+            sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
+
+        # b) Handle mask prompts
+        if mask_inputs is not None:
+            # If mask_inputs is provided, downsize it into low-res mask input if needed
+            # and feed it as a dense mask prompt into the SAM mask encoder
+            assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
+            if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
+                sam_mask_prompt = F.interpolate(
+                    mask_inputs.float(),
+                    size=self.sam_prompt_encoder.mask_input_size,
+                    align_corners=False,
+                    mode="bilinear",
+                    antialias=True,  # use antialias for downsampling
+                )
+            else:
+                sam_mask_prompt = mask_inputs
+        else:
+            # Otherwise, simply feed None (and SAM's prompt encoder will add
+            # a learned `no_mask_embed` to indicate no mask input in this case).
+            sam_mask_prompt = None
+
+        sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
+            points=(sam_point_coords, sam_point_labels),
+            boxes=None,
+            masks=sam_mask_prompt,
+        )
+        low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder(
+            image_embeddings=backbone_features,
+            image_pe=self.sam_prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embeddings,
+            dense_prompt_embeddings=dense_embeddings,
+            multimask_output=multimask_output,
+            repeat_image=False,  # the image is already batched
+            high_res_features=high_res_features,
+        )
+        if self.pred_obj_scores:
+            is_obj_appearing = object_score_logits > 0
+
+            # Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction
+            low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE)
+
+        # convert masks from possibly bfloat16 (or float16) to float32
+        # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+        low_res_multimasks = low_res_multimasks.float()
+        high_res_multimasks = F.interpolate(
+            low_res_multimasks,
+            size=(self.image_size, self.image_size),
+            mode="bilinear",
+            align_corners=False,
+        )
+
+        sam_output_token = sam_output_tokens[:, 0]
+        if multimask_output:
+            # take the best mask prediction (with the highest IoU estimation)
+            best_iou_inds = torch.argmax(ious, dim=-1)
+            batch_inds = torch.arange(B, device=device)
+            low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+            high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+            if sam_output_tokens.size(1) > 1:
+                sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
+        else:
+            low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
+
+        # Extract object pointer from the SAM output token (with occlusion handling)
+        obj_ptr = self.obj_ptr_proj(sam_output_token)
+        if self.pred_obj_scores:
+            # Allow *soft* no obj ptr, unlike for masks
+            if self.soft_no_obj_ptr:
+                lambda_is_obj_appearing = object_score_logits.sigmoid()
+            else:
+                lambda_is_obj_appearing = is_obj_appearing.float()
+
+            if self.fixed_no_obj_ptr:
+                obj_ptr = lambda_is_obj_appearing * obj_ptr
+            obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+        return (
+            low_res_multimasks,
+            high_res_multimasks,
+            ious,
+            low_res_masks,
+            high_res_masks,
+            obj_ptr,
+            object_score_logits,
+        )
+
+    def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
+        """Processes mask inputs directly as output, bypassing SAM encoder/decoder."""
+        # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
+        out_scale, out_bias = 20.0, -10.0  # sigmoid(-10.0)=4.5398e-05
+        mask_inputs_float = mask_inputs.float()
+        high_res_masks = mask_inputs_float * out_scale + out_bias
+        low_res_masks = F.interpolate(
+            high_res_masks,
+            size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
+            align_corners=False,
+            mode="bilinear",
+            antialias=True,  # use antialias for downsampling
+        )
+        # a dummy IoU prediction of all 1's under mask input
+        ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
+        if not self.use_obj_ptrs_in_encoder:
+            # all zeros as a dummy object pointer (of shape [B, C])
+            obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
+        else:
+            # produce an object pointer using the SAM decoder from the mask input
+            _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
+                backbone_features=backbone_features,
+                mask_inputs=self.mask_downsample(mask_inputs_float),
+                high_res_features=high_res_features,
+            )
+        # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
+        # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
+        # on the object_scores from the SAM decoder.
+        is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
+        is_obj_appearing = is_obj_appearing[..., None]
+        lambda_is_obj_appearing = is_obj_appearing.float()
+        object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
+        if self.pred_obj_scores:
+            if self.fixed_no_obj_ptr:
+                obj_ptr = lambda_is_obj_appearing * obj_ptr
+            obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+        return (
+            low_res_masks,
+            high_res_masks,
+            ious,
+            low_res_masks,
+            high_res_masks,
+            obj_ptr,
+            object_score_logits,
+        )
+
+    def forward_image(self, img_batch: torch.Tensor):
+        """Processes image batch through encoder to extract multi-level features for SAM model."""
+        backbone_out = self.image_encoder(img_batch)
+        if self.use_high_res_features_in_sam:
+            # precompute projected level 0 and level 1 features in SAM decoder
+            # to avoid running it again on every SAM click
+            backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
+            backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
+        return backbone_out
+
+    def _prepare_backbone_features(self, backbone_out):
+        """Prepares and flattens visual features from the image backbone output for further processing."""
+        assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
+        assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
+
+        feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
+        vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
+
+        feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
+        # flatten NxCxHxW to HWxNxC
+        vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
+        vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
+
+        return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
+
+    def _prepare_memory_conditioned_features(
+        self,
+        frame_idx,
+        is_init_cond_frame,
+        current_vision_feats,
+        current_vision_pos_embeds,
+        feat_sizes,
+        output_dict,
+        num_frames,
+        track_in_reverse=False,  # tracking in reverse time order (for demo usage)
+    ):
+        """Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
+        B = current_vision_feats[-1].size(1)  # batch size on this frame
+        C = self.hidden_dim
+        H, W = feat_sizes[-1]  # top-level (lowest-resolution) feature size
+        device = current_vision_feats[-1].device
+        # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
+        # In this case, we skip the fusion with any memory.
+        if self.num_maskmem == 0:  # Disable memory and skip fusion
+            return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+        num_obj_ptr_tokens = 0
+        tpos_sign_mul = -1 if track_in_reverse else 1
+        # Step 1: condition the visual features of the current frame on previous memories
+        if not is_init_cond_frame:
+            # Retrieve the memories encoded with the maskmem backbone
+            to_cat_memory, to_cat_memory_pos_embed = [], []
+            # Add conditioning frame's output first (all cond frames have t_pos=0 for
+            # when getting temporal positional embedding below)
+            assert len(output_dict["cond_frame_outputs"]) > 0
+            # Select a maximum number of temporally closest cond frames for cross attention
+            cond_outputs = output_dict["cond_frame_outputs"]
+            selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
+                frame_idx, cond_outputs, self.max_cond_frames_in_attn
+            )
+            t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
+            # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
+            # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
+            # We also allow taking the memory frame non-consecutively (with r>1), in which case
+            # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
+            r = 1 if self.training else self.memory_temporal_stride_for_eval
+            for t_pos in range(1, self.num_maskmem):
+                t_rel = self.num_maskmem - t_pos  # how many frames before current frame
+                if t_rel == 1:
+                    # for t_rel == 1, we take the last frame (regardless of r)
+                    prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel
+                elif not track_in_reverse:
+                    # first find the nearest frame among every r-th frames before this frame
+                    # for r=1, this would be (frame_idx - 2)
+                    prev_frame_idx = ((frame_idx - 2) // r) * r
+                    # then seek further among every r-th frames
+                    prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
+                else:
+                    # first find the nearest frame among every r-th frames after this frame
+                    # for r=1, this would be (frame_idx + 2)
+                    prev_frame_idx = -(-(frame_idx + 2) // r) * r
+                    # then seek further among every r-th frames
+                    prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
+                out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
+                if out is None:
+                    # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
+                    # frames, we still attend to it as if it's a non-conditioning frame.
+                    out = unselected_cond_outputs.get(prev_frame_idx, None)
+                t_pos_and_prevs.append((t_pos, out))
+
+            for t_pos, prev in t_pos_and_prevs:
+                if prev is None:
+                    continue  # skip padding frames
+                # "maskmem_features" might have been offloaded to CPU in demo use cases,
+                # so we load it back to inference device (it's a no-op if it's already on device).
+                feats = prev["maskmem_features"].to(device=device, non_blocking=True)
+                to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
+                # Spatial positional encoding (it might have been offloaded to CPU in eval)
+                maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device)
+                maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
+                # Temporal positional encoding
+                maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
+                to_cat_memory_pos_embed.append(maskmem_enc)
+
+            # Construct the list of past object pointers
+            if self.use_obj_ptrs_in_encoder:
+                max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
+                # First add those object pointers from selected conditioning frames
+                # (optionally, only include object pointers in the past during evaluation)
+                if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
+                    ptr_cond_outputs = {
+                        t: out
+                        for t, out in selected_cond_outputs.items()
+                        if (t >= frame_idx if track_in_reverse else t <= frame_idx)
+                    }
+                else:
+                    ptr_cond_outputs = selected_cond_outputs
+                pos_and_ptrs = [
+                    # Temporal pos encoding contains how far away each pointer is from current frame
+                    (
+                        (
+                            (frame_idx - t) * tpos_sign_mul
+                            if self.use_signed_tpos_enc_to_obj_ptrs
+                            else abs(frame_idx - t)
+                        ),
+                        out["obj_ptr"],
+                    )
+                    for t, out in ptr_cond_outputs.items()
+                ]
+                # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
+                for t_diff in range(1, max_obj_ptrs_in_encoder):
+                    t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
+                    if t < 0 or (num_frames is not None and t >= num_frames):
+                        break
+                    out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
+                    if out is not None:
+                        pos_and_ptrs.append((t_diff, out["obj_ptr"]))
+                # If we have at least one object pointer, add them to the across attention
+                if pos_and_ptrs:
+                    pos_list, ptrs_list = zip(*pos_and_ptrs)
+                    # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
+                    obj_ptrs = torch.stack(ptrs_list, dim=0)
+                    # a temporal positional embedding based on how far each object pointer is from
+                    # the current frame (sine embedding normalized by the max pointer num).
+                    if self.add_tpos_enc_to_obj_ptrs:
+                        t_diff_max = max_obj_ptrs_in_encoder - 1
+                        tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
+                        obj_pos = torch.tensor(pos_list, device=device)
+                        obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
+                        obj_pos = self.obj_ptr_tpos_proj(obj_pos)
+                        obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
+                    else:
+                        obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
+                    if self.mem_dim < C:
+                        # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
+                        obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
+                        obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
+                        obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
+                    to_cat_memory.append(obj_ptrs)
+                    to_cat_memory_pos_embed.append(obj_pos)
+                    num_obj_ptr_tokens = obj_ptrs.shape[0]
+                else:
+                    num_obj_ptr_tokens = 0
+        else:
+            # for initial conditioning frames, encode them without using any previous memory
+            if self.directly_add_no_mem_embed:
+                # directly add no-mem embedding (instead of using the transformer encoder)
+                pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
+                pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+                return pix_feat_with_mem
+
+            # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
+            to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
+            to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
+
+        # Step 2: Concatenate the memories and forward through the transformer encoder
+        memory = torch.cat(to_cat_memory, dim=0)
+        memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
+
+        pix_feat_with_mem = self.memory_attention(
+            curr=current_vision_feats,
+            curr_pos=current_vision_pos_embeds,
+            memory=memory,
+            memory_pos=memory_pos_embed,
+            num_obj_ptr_tokens=num_obj_ptr_tokens,
+        )
+        # reshape the output (HW)BC => BCHW
+        pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+        return pix_feat_with_mem
+
+    def _encode_new_memory(
+        self,
+        current_vision_feats,
+        feat_sizes,
+        pred_masks_high_res,
+        object_score_logits,
+        is_mask_from_pts,
+    ):
+        """Encodes frame features and masks into a new memory representation for video segmentation."""
+        B = current_vision_feats[-1].size(1)  # batch size on this frame
+        C = self.hidden_dim
+        H, W = feat_sizes[-1]  # top-level (lowest-resolution) feature size
+        # top-level feature, (HW)BC => BCHW
+        pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+        if self.non_overlap_masks_for_mem_enc and not self.training:
+            # optionally, apply non-overlapping constraints to the masks (it's applied
+            # in the batch dimension and should only be used during eval, where all
+            # the objects come from the same video under batch size 1).
+            pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
+        # scale the raw mask logits with a temperature before applying sigmoid
+        binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
+        if binarize and not self.training:
+            mask_for_mem = (pred_masks_high_res > 0).float()
+        else:
+            # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+            mask_for_mem = torch.sigmoid(pred_masks_high_res)
+        # apply scale and bias terms to the sigmoid probabilities
+        if self.sigmoid_scale_for_mem_enc != 1.0:
+            mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
+        if self.sigmoid_bias_for_mem_enc != 0.0:
+            mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
+        maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True)  # sigmoid already applied
+        maskmem_features = maskmem_out["vision_features"]
+        maskmem_pos_enc = maskmem_out["vision_pos_enc"]
+        # add a no-object embedding to the spatial memory to indicate that the frame
+        # is predicted to be occluded (i.e. no object is appearing in the frame)
+        if self.no_obj_embed_spatial is not None:
+            is_obj_appearing = (object_score_logits > 0).float()
+            maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[
+                ..., None, None
+            ].expand(*maskmem_features.shape)
+
+        return maskmem_features, maskmem_pos_enc
+
+    def _track_step(
+        self,
+        frame_idx,
+        is_init_cond_frame,
+        current_vision_feats,
+        current_vision_pos_embeds,
+        feat_sizes,
+        point_inputs,
+        mask_inputs,
+        output_dict,
+        num_frames,
+        track_in_reverse,
+        prev_sam_mask_logits,
+    ):
+        """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
+        current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
+        # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
+        if len(current_vision_feats) > 1:
+            high_res_features = [
+                x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
+                for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
+            ]
+        else:
+            high_res_features = None
+        if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
+            # When use_mask_input_as_output_without_sam=True, we directly output the mask input
+            # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
+            pix_feat = current_vision_feats[-1].permute(1, 2, 0)
+            pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
+            sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
+        else:
+            # fused the visual feature with previous memory features in the memory bank
+            pix_feat = self._prepare_memory_conditioned_features(
+                frame_idx=frame_idx,
+                is_init_cond_frame=is_init_cond_frame,
+                current_vision_feats=current_vision_feats[-1:],
+                current_vision_pos_embeds=current_vision_pos_embeds[-1:],
+                feat_sizes=feat_sizes[-1:],
+                output_dict=output_dict,
+                num_frames=num_frames,
+                track_in_reverse=track_in_reverse,
+            )
+            # apply SAM-style segmentation head
+            # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
+            # e.g. in demo where such logits come from earlier interaction instead of correction sampling
+            # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
+            if prev_sam_mask_logits is not None:
+                assert point_inputs is not None and mask_inputs is None
+                mask_inputs = prev_sam_mask_logits
+            multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
+            sam_outputs = self._forward_sam_heads(
+                backbone_features=pix_feat,
+                point_inputs=point_inputs,
+                mask_inputs=mask_inputs,
+                high_res_features=high_res_features,
+                multimask_output=multimask_output,
+            )
+        return current_out, sam_outputs, high_res_features, pix_feat
+
+    def _encode_memory_in_output(
+        self,
+        current_vision_feats,
+        feat_sizes,
+        point_inputs,
+        run_mem_encoder,
+        high_res_masks,
+        object_score_logits,
+        current_out,
+    ):
+        """Finally run the memory encoder on the predicted mask to encode, it into a new memory feature (that can be
+        used in future frames).
+        """
+        if run_mem_encoder and self.num_maskmem > 0:
+            high_res_masks_for_mem_enc = high_res_masks
+            maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+                current_vision_feats=current_vision_feats,
+                feat_sizes=feat_sizes,
+                pred_masks_high_res=high_res_masks_for_mem_enc,
+                object_score_logits=object_score_logits,
+                is_mask_from_pts=(point_inputs is not None),
+            )
+            current_out["maskmem_features"] = maskmem_features
+            current_out["maskmem_pos_enc"] = maskmem_pos_enc
+        else:
+            current_out["maskmem_features"] = None
+            current_out["maskmem_pos_enc"] = None
+
+    def track_step(
+        self,
+        frame_idx,
+        is_init_cond_frame,
+        current_vision_feats,
+        current_vision_pos_embeds,
+        feat_sizes,
+        point_inputs,
+        mask_inputs,
+        output_dict,
+        num_frames,
+        track_in_reverse=False,  # tracking in reverse time order (for demo usage)
+        # Whether to run the memory encoder on the predicted masks. Sometimes we might want
+        # to skip the memory encoder with `run_mem_encoder=False`. For example,
+        # in demo we might call `track_step` multiple times for each user click,
+        # and only encode the memory when the user finalizes their clicks. And in ablation
+        # settings like SAM training on static images, we don't need the memory encoder.
+        run_mem_encoder=True,
+        # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
+        prev_sam_mask_logits=None,
+    ):
+        """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
+        current_out, sam_outputs, _, _ = self._track_step(
+            frame_idx,
+            is_init_cond_frame,
+            current_vision_feats,
+            current_vision_pos_embeds,
+            feat_sizes,
+            point_inputs,
+            mask_inputs,
+            output_dict,
+            num_frames,
+            track_in_reverse,
+            prev_sam_mask_logits,
+        )
+        _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs
+
+        current_out["pred_masks"] = low_res_masks
+        current_out["pred_masks_high_res"] = high_res_masks
+        current_out["obj_ptr"] = obj_ptr
+        if not self.training:
+            # Only add this in inference (to avoid unused param in activation checkpointing;
+            # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
+            current_out["object_score_logits"] = object_score_logits
+
+        # Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)
+        self._encode_memory_in_output(
+            current_vision_feats,
+            feat_sizes,
+            point_inputs,
+            run_mem_encoder,
+            high_res_masks,
+            object_score_logits,
+            current_out,
+        )
+
+        return current_out
+
+    def _use_multimask(self, is_init_cond_frame, point_inputs):
+        """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
+        num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
+        return (
+            self.multimask_output_in_sam
+            and (is_init_cond_frame or self.multimask_output_for_tracking)
+            and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
+        )
+
+    @staticmethod
+    def _apply_non_overlapping_constraints(pred_masks):
+        """Applies non-overlapping constraints to masks, keeping the highest scoring object per location."""
+        batch_size = pred_masks.size(0)
+        if batch_size == 1:
+            return pred_masks
+
+        device = pred_masks.device
+        # "max_obj_inds": object index of the object with the highest score at each location
+        max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
+        # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
+        batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
+        keep = max_obj_inds == batch_obj_inds
+        # suppress overlapping regions' scores below -10.0 so that the foreground regions
+        # don't overlap (here sigmoid(-10.0)=4.5398e-05)
+        pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
+        return pred_masks
+
+    def set_binarize(self, binarize=False):
+        """Set binarize for VideoPredictor."""
+        self.binarize_mask_from_pts_for_mem_enc = binarize
+
+    def set_imgsz(self, imgsz):
+        """
+        Set image size to make model compatible with different image sizes.
+
+        Args:
+            imgsz (Tuple[int, int]): The size of the input image.
+        """
+        self.image_size = imgsz[0]
+        self.sam_prompt_encoder.input_image_size = imgsz
+        self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz]  # fixed ViT patch size of 16

+ 1013 - 0
ultralytics/models/sam/modules/tiny_encoder.py

@@ -0,0 +1,1013 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+# --------------------------------------------------------
+# TinyViT Model Architecture
+# Copyright (c) 2022 Microsoft
+# Adapted from LeViT and Swin Transformer
+#   LeViT: (https://github.com/facebookresearch/levit)
+#   Swin: (https://github.com/microsoft/swin-transformer)
+# Build the TinyViT Model
+# --------------------------------------------------------
+
+import itertools
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+
+from ultralytics.nn.modules import LayerNorm2d
+from ultralytics.utils.instance import to_2tuple
+
+
+class Conv2d_BN(torch.nn.Sequential):
+    """
+    A sequential container that performs 2D convolution followed by batch normalization.
+
+    Attributes:
+        c (torch.nn.Conv2d): 2D convolution layer.
+        1 (torch.nn.BatchNorm2d): Batch normalization layer.
+
+    Methods:
+        __init__: Initializes the Conv2d_BN with specified parameters.
+
+    Args:
+        a (int): Number of input channels.
+        b (int): Number of output channels.
+        ks (int): Kernel size for the convolution. Defaults to 1.
+        stride (int): Stride for the convolution. Defaults to 1.
+        pad (int): Padding for the convolution. Defaults to 0.
+        dilation (int): Dilation factor for the convolution. Defaults to 1.
+        groups (int): Number of groups for the convolution. Defaults to 1.
+        bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1.
+
+    Examples:
+        >>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)
+        >>> input_tensor = torch.randn(1, 3, 224, 224)
+        >>> output = conv_bn(input_tensor)
+        >>> print(output.shape)
+    """
+
+    def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
+        """Initializes a sequential container with 2D convolution followed by batch normalization."""
+        super().__init__()
+        self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
+        bn = torch.nn.BatchNorm2d(b)
+        torch.nn.init.constant_(bn.weight, bn_weight_init)
+        torch.nn.init.constant_(bn.bias, 0)
+        self.add_module("bn", bn)
+
+
+class PatchEmbed(nn.Module):
+    """
+    Embeds images into patches and projects them into a specified embedding dimension.
+
+    Attributes:
+        patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.
+        num_patches (int): Total number of patches.
+        in_chans (int): Number of input channels.
+        embed_dim (int): Dimension of the embedding.
+        seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.
+
+    Methods:
+        forward: Processes the input tensor through the patch embedding sequence.
+
+    Examples:
+        >>> import torch
+        >>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)
+        >>> x = torch.randn(1, 3, 224, 224)
+        >>> output = patch_embed(x)
+        >>> print(output.shape)
+    """
+
+    def __init__(self, in_chans, embed_dim, resolution, activation):
+        """Initializes patch embedding with convolutional layers for image-to-patch conversion and projection."""
+        super().__init__()
+        img_size: Tuple[int, int] = to_2tuple(resolution)
+        self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
+        self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+        n = embed_dim
+        self.seq = nn.Sequential(
+            Conv2d_BN(in_chans, n // 2, 3, 2, 1),
+            activation(),
+            Conv2d_BN(n // 2, n, 3, 2, 1),
+        )
+
+    def forward(self, x):
+        """Processes input tensor through patch embedding sequence, converting images to patch embeddings."""
+        return self.seq(x)
+
+
+class MBConv(nn.Module):
+    """
+    Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
+
+    Attributes:
+        in_chans (int): Number of input channels.
+        hidden_chans (int): Number of hidden channels.
+        out_chans (int): Number of output channels.
+        conv1 (Conv2d_BN): First convolutional layer.
+        act1 (nn.Module): First activation function.
+        conv2 (Conv2d_BN): Depthwise convolutional layer.
+        act2 (nn.Module): Second activation function.
+        conv3 (Conv2d_BN): Final convolutional layer.
+        act3 (nn.Module): Third activation function.
+        drop_path (nn.Module): Drop path layer (Identity for inference).
+
+    Methods:
+        forward: Performs the forward pass through the MBConv layer.
+
+    Examples:
+        >>> in_chans, out_chans = 32, 64
+        >>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)
+        >>> x = torch.randn(1, in_chans, 56, 56)
+        >>> output = mbconv(x)
+        >>> print(output.shape)
+        torch.Size([1, 64, 56, 56])
+    """
+
+    def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
+        """Initializes the MBConv layer with specified input/output channels, expansion ratio, and activation."""
+        super().__init__()
+        self.in_chans = in_chans
+        self.hidden_chans = int(in_chans * expand_ratio)
+        self.out_chans = out_chans
+
+        self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
+        self.act1 = activation()
+
+        self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)
+        self.act2 = activation()
+
+        self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
+        self.act3 = activation()
+
+        # NOTE: `DropPath` is needed only for training.
+        # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.drop_path = nn.Identity()
+
+    def forward(self, x):
+        """Implements the forward pass of MBConv, applying convolutions and skip connection."""
+        shortcut = x
+        x = self.conv1(x)
+        x = self.act1(x)
+        x = self.conv2(x)
+        x = self.act2(x)
+        x = self.conv3(x)
+        x = self.drop_path(x)
+        x += shortcut
+        return self.act3(x)
+
+
+class PatchMerging(nn.Module):
+    """
+    Merges neighboring patches in the feature map and projects to a new dimension.
+
+    This class implements a patch merging operation that combines spatial information and adjusts the feature
+    dimension. It uses a series of convolutional layers with batch normalization to achieve this.
+
+    Attributes:
+        input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.
+        dim (int): The input dimension of the feature map.
+        out_dim (int): The output dimension after merging and projection.
+        act (nn.Module): The activation function used between convolutions.
+        conv1 (Conv2d_BN): The first convolutional layer for dimension projection.
+        conv2 (Conv2d_BN): The second convolutional layer for spatial merging.
+        conv3 (Conv2d_BN): The third convolutional layer for final projection.
+
+    Methods:
+        forward: Applies the patch merging operation to the input tensor.
+
+    Examples:
+        >>> input_resolution = (56, 56)
+        >>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)
+        >>> x = torch.randn(4, 64, 56, 56)
+        >>> output = patch_merging(x)
+        >>> print(output.shape)
+    """
+
+    def __init__(self, input_resolution, dim, out_dim, activation):
+        """Initializes the PatchMerging module for merging and projecting neighboring patches in feature maps."""
+        super().__init__()
+
+        self.input_resolution = input_resolution
+        self.dim = dim
+        self.out_dim = out_dim
+        self.act = activation()
+        self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
+        stride_c = 1 if out_dim in {320, 448, 576} else 2
+        self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
+        self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
+
+    def forward(self, x):
+        """Applies patch merging and dimension projection to the input feature map."""
+        if x.ndim == 3:
+            H, W = self.input_resolution
+            B = len(x)
+            # (B, C, H, W)
+            x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
+
+        x = self.conv1(x)
+        x = self.act(x)
+
+        x = self.conv2(x)
+        x = self.act(x)
+        x = self.conv3(x)
+        return x.flatten(2).transpose(1, 2)
+
+
+class ConvLayer(nn.Module):
+    """
+    Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
+
+    This layer optionally applies downsample operations to the output and supports gradient checkpointing.
+
+    Attributes:
+        dim (int): Dimensionality of the input and output.
+        input_resolution (Tuple[int, int]): Resolution of the input image.
+        depth (int): Number of MBConv layers in the block.
+        use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
+        blocks (nn.ModuleList): List of MBConv layers.
+        downsample (Optional[Callable]): Function for downsampling the output.
+
+    Methods:
+        forward: Processes the input through the convolutional layers.
+
+    Examples:
+        >>> input_tensor = torch.randn(1, 64, 56, 56)
+        >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
+        >>> output = conv_layer(input_tensor)
+        >>> print(output.shape)
+    """
+
+    def __init__(
+        self,
+        dim,
+        input_resolution,
+        depth,
+        activation,
+        drop_path=0.0,
+        downsample=None,
+        use_checkpoint=False,
+        out_dim=None,
+        conv_expand_ratio=4.0,
+    ):
+        """
+        Initializes the ConvLayer with the given dimensions and settings.
+
+        This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
+        optionally applies downsampling to the output.
+
+        Args:
+            dim (int): The dimensionality of the input and output.
+            input_resolution (Tuple[int, int]): The resolution of the input image.
+            depth (int): The number of MBConv layers in the block.
+            activation (Callable): Activation function applied after each convolution.
+            drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv.
+            downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
+            use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
+            out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
+            conv_expand_ratio (float): Expansion ratio for the MBConv layers.
+
+        Examples:
+            >>> input_tensor = torch.randn(1, 64, 56, 56)
+            >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
+            >>> output = conv_layer(input_tensor)
+            >>> print(output.shape)
+        """
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # Build blocks
+        self.blocks = nn.ModuleList(
+            [
+                MBConv(
+                    dim,
+                    dim,
+                    conv_expand_ratio,
+                    activation,
+                    drop_path[i] if isinstance(drop_path, list) else drop_path,
+                )
+                for i in range(depth)
+            ]
+        )
+
+        # Patch merging layer
+        self.downsample = (
+            None
+            if downsample is None
+            else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
+        )
+
+    def forward(self, x):
+        """Processes input through convolutional layers, applying MBConv blocks and optional downsampling."""
+        for blk in self.blocks:
+            x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
+        return x if self.downsample is None else self.downsample(x)
+
+
+class Mlp(nn.Module):
+    """
+    Multi-layer Perceptron (MLP) module for transformer architectures.
+
+    This module applies layer normalization, two fully-connected layers with an activation function in between,
+    and dropout. It is commonly used in transformer-based architectures.
+
+    Attributes:
+        norm (nn.LayerNorm): Layer normalization applied to the input.
+        fc1 (nn.Linear): First fully-connected layer.
+        fc2 (nn.Linear): Second fully-connected layer.
+        act (nn.Module): Activation function applied after the first fully-connected layer.
+        drop (nn.Dropout): Dropout layer applied after the activation function.
+
+    Methods:
+        forward: Applies the MLP operations on the input tensor.
+
+    Examples:
+        >>> import torch
+        >>> from torch import nn
+        >>> mlp = Mlp(in_features=256, hidden_features=512, out_features=256, act_layer=nn.GELU, drop=0.1)
+        >>> x = torch.randn(32, 100, 256)
+        >>> output = mlp(x)
+        >>> print(output.shape)
+        torch.Size([32, 100, 256])
+    """
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
+        """Initializes a multi-layer perceptron with configurable input, hidden, and output dimensions."""
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.norm = nn.LayerNorm(in_features)
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.act = act_layer()
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        """Applies MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
+        x = self.norm(x)
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        return self.drop(x)
+
+
+class Attention(torch.nn.Module):
+    """
+    Multi-head attention module with spatial awareness and trainable attention biases.
+
+    This module implements a multi-head attention mechanism with support for spatial awareness, applying
+    attention biases based on spatial resolution. It includes trainable attention biases for each unique
+    offset between spatial positions in the resolution grid.
+
+    Attributes:
+        num_heads (int): Number of attention heads.
+        scale (float): Scaling factor for attention scores.
+        key_dim (int): Dimensionality of the keys and queries.
+        nh_kd (int): Product of num_heads and key_dim.
+        d (int): Dimensionality of the value vectors.
+        dh (int): Product of d and num_heads.
+        attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.
+        norm (nn.LayerNorm): Layer normalization applied to input.
+        qkv (nn.Linear): Linear layer for computing query, key, and value projections.
+        proj (nn.Linear): Linear layer for final projection.
+        attention_biases (nn.Parameter): Learnable attention biases.
+        attention_bias_idxs (Tensor): Indices for attention biases.
+        ab (Tensor): Cached attention biases for inference, deleted during training.
+
+    Methods:
+        train: Sets the module in training mode and handles the 'ab' attribute.
+        forward: Performs the forward pass of the attention mechanism.
+
+    Examples:
+        >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
+        >>> x = torch.randn(1, 196, 256)
+        >>> output = attn(x)
+        >>> print(output.shape)
+        torch.Size([1, 196, 256])
+    """
+
+    def __init__(
+        self,
+        dim,
+        key_dim,
+        num_heads=8,
+        attn_ratio=4,
+        resolution=(14, 14),
+    ):
+        """
+        Initializes the Attention module for multi-head attention with spatial awareness.
+
+        This module implements a multi-head attention mechanism with support for spatial awareness, applying
+        attention biases based on spatial resolution. It includes trainable attention biases for each unique
+        offset between spatial positions in the resolution grid.
+
+        Args:
+            dim (int): The dimensionality of the input and output.
+            key_dim (int): The dimensionality of the keys and queries.
+            num_heads (int): Number of attention heads. Default is 8.
+            attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
+            resolution (Tuple[int, int]): Spatial resolution of the input feature map. Default is (14, 14).
+
+        Raises:
+            AssertionError: If 'resolution' is not a tuple of length 2.
+
+        Examples:
+            >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
+            >>> x = torch.randn(1, 196, 256)
+            >>> output = attn(x)
+            >>> print(output.shape)
+            torch.Size([1, 196, 256])
+        """
+        super().__init__()
+
+        assert isinstance(resolution, tuple) and len(resolution) == 2, "'resolution' argument not tuple of length 2"
+        self.num_heads = num_heads
+        self.scale = key_dim**-0.5
+        self.key_dim = key_dim
+        self.nh_kd = nh_kd = key_dim * num_heads
+        self.d = int(attn_ratio * key_dim)
+        self.dh = int(attn_ratio * key_dim) * num_heads
+        self.attn_ratio = attn_ratio
+        h = self.dh + nh_kd * 2
+
+        self.norm = nn.LayerNorm(dim)
+        self.qkv = nn.Linear(dim, h)
+        self.proj = nn.Linear(self.dh, dim)
+
+        points = list(itertools.product(range(resolution[0]), range(resolution[1])))
+        N = len(points)
+        attention_offsets = {}
+        idxs = []
+        for p1 in points:
+            for p2 in points:
+                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
+                if offset not in attention_offsets:
+                    attention_offsets[offset] = len(attention_offsets)
+                idxs.append(attention_offsets[offset])
+        self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
+        self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
+
+    @torch.no_grad()
+    def train(self, mode=True):
+        """Performs multi-head attention with spatial awareness and trainable attention biases."""
+        super().train(mode)
+        if mode and hasattr(self, "ab"):
+            del self.ab
+        else:
+            self.ab = self.attention_biases[:, self.attention_bias_idxs]
+
+    def forward(self, x):  # x
+        """Applies multi-head attention with spatial awareness and trainable attention biases."""
+        B, N, _ = x.shape  # B, N, C
+
+        # Normalization
+        x = self.norm(x)
+
+        qkv = self.qkv(x)
+        # (B, N, num_heads, d)
+        q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
+        # (B, num_heads, N, d)
+        q = q.permute(0, 2, 1, 3)
+        k = k.permute(0, 2, 1, 3)
+        v = v.permute(0, 2, 1, 3)
+        self.ab = self.ab.to(self.attention_biases.device)
+
+        attn = (q @ k.transpose(-2, -1)) * self.scale + (
+            self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
+        )
+        attn = attn.softmax(dim=-1)
+        x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
+        return self.proj(x)
+
+
+class TinyViTBlock(nn.Module):
+    """
+    TinyViT Block that applies self-attention and a local convolution to the input.
+
+    This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
+    local convolutions to process input features efficiently.
+
+    Attributes:
+        dim (int): The dimensionality of the input and output.
+        input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
+        num_heads (int): Number of attention heads.
+        window_size (int): Size of the attention window.
+        mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
+        drop_path (nn.Module): Stochastic depth layer, identity function during inference.
+        attn (Attention): Self-attention module.
+        mlp (Mlp): Multi-layer perceptron module.
+        local_conv (Conv2d_BN): Depth-wise local convolution layer.
+
+    Methods:
+        forward: Processes the input through the TinyViT block.
+        extra_repr: Returns a string with extra information about the block's parameters.
+
+    Examples:
+        >>> input_tensor = torch.randn(1, 196, 192)
+        >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
+        >>> output = block(input_tensor)
+        >>> print(output.shape)
+        torch.Size([1, 196, 192])
+    """
+
+    def __init__(
+        self,
+        dim,
+        input_resolution,
+        num_heads,
+        window_size=7,
+        mlp_ratio=4.0,
+        drop=0.0,
+        drop_path=0.0,
+        local_conv_size=3,
+        activation=nn.GELU,
+    ):
+        """
+        Initializes a TinyViT block with self-attention and local convolution.
+
+        This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
+        local convolutions to process input features efficiently.
+
+        Args:
+            dim (int): Dimensionality of the input and output features.
+            input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
+            num_heads (int): Number of attention heads.
+            window_size (int): Size of the attention window. Must be greater than 0.
+            mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
+            drop (float): Dropout rate.
+            drop_path (float): Stochastic depth rate.
+            local_conv_size (int): Kernel size of the local convolution.
+            activation (torch.nn.Module): Activation function for MLP.
+
+        Raises:
+            AssertionError: If window_size is not greater than 0.
+            AssertionError: If dim is not divisible by num_heads.
+
+        Examples:
+            >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
+            >>> input_tensor = torch.randn(1, 196, 192)
+            >>> output = block(input_tensor)
+            >>> print(output.shape)
+            torch.Size([1, 196, 192])
+        """
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.num_heads = num_heads
+        assert window_size > 0, "window_size must be greater than 0"
+        self.window_size = window_size
+        self.mlp_ratio = mlp_ratio
+
+        # NOTE: `DropPath` is needed only for training.
+        # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.drop_path = nn.Identity()
+
+        assert dim % num_heads == 0, "dim must be divisible by num_heads"
+        head_dim = dim // num_heads
+
+        window_resolution = (window_size, window_size)
+        self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)
+
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        mlp_activation = activation
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop)
+
+        pad = local_conv_size // 2
+        self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
+
+    def forward(self, x):
+        """Applies self-attention, local convolution, and MLP operations to the input tensor."""
+        h, w = self.input_resolution
+        b, hw, c = x.shape  # batch, height*width, channels
+        assert hw == h * w, "input feature has wrong size"
+        res_x = x
+        if h == self.window_size and w == self.window_size:
+            x = self.attn(x)
+        else:
+            x = x.view(b, h, w, c)
+            pad_b = (self.window_size - h % self.window_size) % self.window_size
+            pad_r = (self.window_size - w % self.window_size) % self.window_size
+            padding = pad_b > 0 or pad_r > 0
+            if padding:
+                x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
+
+            pH, pW = h + pad_b, w + pad_r
+            nH = pH // self.window_size
+            nW = pW // self.window_size
+
+            # Window partition
+            x = (
+                x.view(b, nH, self.window_size, nW, self.window_size, c)
+                .transpose(2, 3)
+                .reshape(b * nH * nW, self.window_size * self.window_size, c)
+            )
+            x = self.attn(x)
+
+            # Window reverse
+            x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c)
+            if padding:
+                x = x[:, :h, :w].contiguous()
+
+            x = x.view(b, hw, c)
+
+        x = res_x + self.drop_path(x)
+        x = x.transpose(1, 2).reshape(b, c, h, w)
+        x = self.local_conv(x)
+        x = x.view(b, c, hw).transpose(1, 2)
+
+        return x + self.drop_path(self.mlp(x))
+
+    def extra_repr(self) -> str:
+        """
+        Returns a string representation of the TinyViTBlock's parameters.
+
+        This method provides a formatted string containing key information about the TinyViTBlock, including its
+        dimension, input resolution, number of attention heads, window size, and MLP ratio.
+
+        Returns:
+            (str): A formatted string containing the block's parameters.
+
+        Examples:
+            >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0)
+            >>> print(block.extra_repr())
+            dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0
+        """
+        return (
+            f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
+            f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
+        )
+
+
+class BasicLayer(nn.Module):
+    """
+    A basic TinyViT layer for one stage in a TinyViT architecture.
+
+    This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
+    and an optional downsampling operation.
+
+    Attributes:
+        dim (int): The dimensionality of the input and output features.
+        input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
+        depth (int): Number of TinyViT blocks in this layer.
+        use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
+        blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.
+        downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.
+
+    Methods:
+        forward: Processes the input through the layer's blocks and optional downsampling.
+        extra_repr: Returns a string with the layer's parameters for printing.
+
+    Examples:
+        >>> input_tensor = torch.randn(1, 3136, 192)
+        >>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
+        >>> output = layer(input_tensor)
+        >>> print(output.shape)
+        torch.Size([1, 784, 384])
+    """
+
+    def __init__(
+        self,
+        dim,
+        input_resolution,
+        depth,
+        num_heads,
+        window_size,
+        mlp_ratio=4.0,
+        drop=0.0,
+        drop_path=0.0,
+        downsample=None,
+        use_checkpoint=False,
+        local_conv_size=3,
+        activation=nn.GELU,
+        out_dim=None,
+    ):
+        """
+        Initializes a BasicLayer in the TinyViT architecture.
+
+        This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
+        process feature maps at a specific resolution and dimensionality within the TinyViT model.
+
+        Args:
+            dim (int): Dimensionality of the input and output features.
+            input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
+            depth (int): Number of TinyViT blocks in this layer.
+            num_heads (int): Number of attention heads in each TinyViT block.
+            window_size (int): Size of the local window for attention computation.
+            mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
+            drop (float): Dropout rate.
+            drop_path (float | List[float]): Stochastic depth rate. Can be a float or a list of floats for each block.
+            downsample (nn.Module | None): Downsampling layer at the end of the layer. None to skip downsampling.
+            use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
+            local_conv_size (int): Kernel size for the local convolution in each TinyViT block.
+            activation (nn.Module): Activation function used in the MLP.
+            out_dim (int | None): Output dimension after downsampling. None means it will be the same as `dim`.
+
+        Raises:
+            ValueError: If `drop_path` is a list and its length doesn't match `depth`.
+
+        Examples:
+            >>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
+            >>> x = torch.randn(1, 56 * 56, 96)
+            >>> output = layer(x)
+            >>> print(output.shape)
+        """
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # Build blocks
+        self.blocks = nn.ModuleList(
+            [
+                TinyViTBlock(
+                    dim=dim,
+                    input_resolution=input_resolution,
+                    num_heads=num_heads,
+                    window_size=window_size,
+                    mlp_ratio=mlp_ratio,
+                    drop=drop,
+                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                    local_conv_size=local_conv_size,
+                    activation=activation,
+                )
+                for i in range(depth)
+            ]
+        )
+
+        # Patch merging layer
+        self.downsample = (
+            None
+            if downsample is None
+            else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
+        )
+
+    def forward(self, x):
+        """Processes input through TinyViT blocks and optional downsampling."""
+        for blk in self.blocks:
+            x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
+        return x if self.downsample is None else self.downsample(x)
+
+    def extra_repr(self) -> str:
+        """Returns a string with the layer's parameters for printing."""
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+
+class TinyViT(nn.Module):
+    """
+    TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
+
+    This class implements the TinyViT model, which combines elements of vision transformers and convolutional
+    neural networks for improved efficiency and performance on vision tasks.
+
+    Attributes:
+        img_size (int): Input image size.
+        num_classes (int): Number of classification classes.
+        depths (List[int]): Number of blocks in each stage.
+        num_layers (int): Total number of layers in the network.
+        mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
+        patch_embed (PatchEmbed): Module for patch embedding.
+        patches_resolution (Tuple[int, int]): Resolution of embedded patches.
+        layers (nn.ModuleList): List of network layers.
+        norm_head (nn.LayerNorm): Layer normalization for the classifier head.
+        head (nn.Linear): Linear layer for final classification.
+        neck (nn.Sequential): Neck module for feature refinement.
+
+    Methods:
+        set_layer_lr_decay: Sets layer-wise learning rate decay.
+        _init_weights: Initializes weights for linear and normalization layers.
+        no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay.
+        forward_features: Processes input through the feature extraction layers.
+        forward: Performs a forward pass through the entire network.
+
+    Examples:
+        >>> model = TinyViT(img_size=224, num_classes=1000)
+        >>> x = torch.randn(1, 3, 224, 224)
+        >>> features = model.forward_features(x)
+        >>> print(features.shape)
+        torch.Size([1, 256, 64, 64])
+    """
+
+    def __init__(
+        self,
+        img_size=224,
+        in_chans=3,
+        num_classes=1000,
+        embed_dims=(96, 192, 384, 768),
+        depths=(2, 2, 6, 2),
+        num_heads=(3, 6, 12, 24),
+        window_sizes=(7, 7, 14, 7),
+        mlp_ratio=4.0,
+        drop_rate=0.0,
+        drop_path_rate=0.1,
+        use_checkpoint=False,
+        mbconv_expand_ratio=4.0,
+        local_conv_size=3,
+        layer_lr_decay=1.0,
+    ):
+        """
+        Initializes the TinyViT model.
+
+        This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
+        attention and convolution blocks, and a classification head.
+
+        Args:
+            img_size (int): Size of the input image. Default is 224.
+            in_chans (int): Number of input channels. Default is 3.
+            num_classes (int): Number of classes for classification. Default is 1000.
+            embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage.
+                Default is (96, 192, 384, 768).
+            depths (Tuple[int, int, int, int]): Number of blocks in each stage. Default is (2, 2, 6, 2).
+            num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage.
+                Default is (3, 6, 12, 24).
+            window_sizes (Tuple[int, int, int, int]): Window sizes for each stage. Default is (7, 7, 14, 7).
+            mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. Default is 4.0.
+            drop_rate (float): Dropout rate. Default is 0.0.
+            drop_path_rate (float): Stochastic depth rate. Default is 0.1.
+            use_checkpoint (bool): Whether to use checkpointing to save memory. Default is False.
+            mbconv_expand_ratio (float): Expansion ratio for MBConv layer. Default is 4.0.
+            local_conv_size (int): Kernel size for local convolutions. Default is 3.
+            layer_lr_decay (float): Layer-wise learning rate decay factor. Default is 1.0.
+
+        Examples:
+            >>> model = TinyViT(img_size=224, num_classes=1000)
+            >>> x = torch.randn(1, 3, 224, 224)
+            >>> output = model(x)
+            >>> print(output.shape)
+            torch.Size([1, 1000])
+        """
+        super().__init__()
+        self.img_size = img_size
+        self.num_classes = num_classes
+        self.depths = depths
+        self.num_layers = len(depths)
+        self.mlp_ratio = mlp_ratio
+
+        activation = nn.GELU
+
+        self.patch_embed = PatchEmbed(
+            in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation
+        )
+
+        patches_resolution = self.patch_embed.patches_resolution
+        self.patches_resolution = patches_resolution
+
+        # Stochastic depth
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
+
+        # Build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            kwargs = dict(
+                dim=embed_dims[i_layer],
+                input_resolution=(
+                    patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
+                    patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
+                ),
+                #   input_resolution=(patches_resolution[0] // (2 ** i_layer),
+                #                     patches_resolution[1] // (2 ** i_layer)),
+                depth=depths[i_layer],
+                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
+                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+                use_checkpoint=use_checkpoint,
+                out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
+                activation=activation,
+            )
+            if i_layer == 0:
+                layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
+            else:
+                layer = BasicLayer(
+                    num_heads=num_heads[i_layer],
+                    window_size=window_sizes[i_layer],
+                    mlp_ratio=self.mlp_ratio,
+                    drop=drop_rate,
+                    local_conv_size=local_conv_size,
+                    **kwargs,
+                )
+            self.layers.append(layer)
+
+        # Classifier head
+        self.norm_head = nn.LayerNorm(embed_dims[-1])
+        self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
+
+        # Init weights
+        self.apply(self._init_weights)
+        self.set_layer_lr_decay(layer_lr_decay)
+        self.neck = nn.Sequential(
+            nn.Conv2d(
+                embed_dims[-1],
+                256,
+                kernel_size=1,
+                bias=False,
+            ),
+            LayerNorm2d(256),
+            nn.Conv2d(
+                256,
+                256,
+                kernel_size=3,
+                padding=1,
+                bias=False,
+            ),
+            LayerNorm2d(256),
+        )
+
+    def set_layer_lr_decay(self, layer_lr_decay):
+        """Sets layer-wise learning rate decay for the TinyViT model based on depth."""
+        decay_rate = layer_lr_decay
+
+        # Layers -> blocks (depth)
+        depth = sum(self.depths)
+        lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
+
+        def _set_lr_scale(m, scale):
+            """Sets the learning rate scale for each layer in the model based on the layer's depth."""
+            for p in m.parameters():
+                p.lr_scale = scale
+
+        self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
+        i = 0
+        for layer in self.layers:
+            for block in layer.blocks:
+                block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
+                i += 1
+            if layer.downsample is not None:
+                layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
+        assert i == depth
+        for m in [self.norm_head, self.head]:
+            m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
+
+        for k, p in self.named_parameters():
+            p.param_name = k
+
+        def _check_lr_scale(m):
+            """Checks if the learning rate scale attribute is present in module's parameters."""
+            for p in m.parameters():
+                assert hasattr(p, "lr_scale"), p.param_name
+
+        self.apply(_check_lr_scale)
+
+    @staticmethod
+    def _init_weights(m):
+        """Initializes weights for linear and normalization layers in the TinyViT model."""
+        if isinstance(m, nn.Linear):
+            # NOTE: This initialization is needed only for training.
+            # trunc_normal_(m.weight, std=.02)
+            if m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    @torch.jit.ignore
+    def no_weight_decay_keywords(self):
+        """Returns a set of keywords for parameters that should not use weight decay."""
+        return {"attention_biases"}
+
+    def forward_features(self, x):
+        """Processes input through feature extraction layers, returning spatial features."""
+        x = self.patch_embed(x)  # x input is (N, C, H, W)
+
+        x = self.layers[0](x)
+        start_i = 1
+
+        for i in range(start_i, len(self.layers)):
+            layer = self.layers[i]
+            x = layer(x)
+        batch, _, channel = x.shape
+        x = x.view(batch, self.patches_resolution[0] // 4, self.patches_resolution[1] // 4, channel)
+        x = x.permute(0, 3, 1, 2)
+        return self.neck(x)
+
+    def forward(self, x):
+        """Performs the forward pass through the TinyViT model, extracting features from the input image."""
+        return self.forward_features(x)
+
+    def set_imgsz(self, imgsz=[1024, 1024]):
+        """
+        Set image size to make model compatible with different image sizes.
+
+        Args:
+            imgsz (Tuple[int, int]): The size of the input image.
+        """
+        imgsz = [s // 4 for s in imgsz]
+        self.patches_resolution = imgsz
+        for i, layer in enumerate(self.layers):
+            input_resolution = (
+                imgsz[0] // (2 ** (i - 1 if i == 3 else i)),
+                imgsz[1] // (2 ** (i - 1 if i == 3 else i)),
+            )
+            layer.input_resolution = input_resolution
+            if layer.downsample is not None:
+                layer.downsample.input_resolution = input_resolution
+            if isinstance(layer, BasicLayer):
+                for b in layer.blocks:
+                    b.input_resolution = input_resolution

+ 373 - 0
ultralytics/models/sam/modules/transformer.py

@@ -0,0 +1,373 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import math
+from typing import Tuple, Type
+
+import torch
+from torch import Tensor, nn
+
+from ultralytics.nn.modules import MLPBlock
+
+
+class TwoWayTransformer(nn.Module):
+    """
+    A Two-Way Transformer module for simultaneous attention to image and query points.
+
+    This class implements a specialized transformer decoder that attends to an input image using queries with
+    supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
+    cloud processing.
+
+    Attributes:
+        depth (int): Number of layers in the transformer.
+        embedding_dim (int): Channel dimension for input embeddings.
+        num_heads (int): Number of heads for multihead attention.
+        mlp_dim (int): Internal channel dimension for the MLP block.
+        layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
+        final_attn_token_to_image (Attention): Final attention layer from queries to image.
+        norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
+
+    Methods:
+        forward: Processes image and point embeddings through the transformer.
+
+    Examples:
+        >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
+        >>> image_embedding = torch.randn(1, 256, 32, 32)
+        >>> image_pe = torch.randn(1, 256, 32, 32)
+        >>> point_embedding = torch.randn(1, 100, 256)
+        >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
+        >>> print(output_queries.shape, output_image.shape)
+    """
+
+    def __init__(
+        self,
+        depth: int,
+        embedding_dim: int,
+        num_heads: int,
+        mlp_dim: int,
+        activation: Type[nn.Module] = nn.ReLU,
+        attention_downsample_rate: int = 2,
+    ) -> None:
+        """
+        Initialize a Two-Way Transformer for simultaneous attention to image and query points.
+
+        Args:
+            depth (int): Number of layers in the transformer.
+            embedding_dim (int): Channel dimension for input embeddings.
+            num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
+            mlp_dim (int): Internal channel dimension for the MLP block.
+            activation (Type[nn.Module]): Activation function to use in the MLP block.
+            attention_downsample_rate (int): Downsampling rate for attention mechanism.
+
+        Attributes:
+            depth (int): Number of layers in the transformer.
+            embedding_dim (int): Channel dimension for input embeddings.
+            num_heads (int): Number of heads for multihead attention.
+            mlp_dim (int): Internal channel dimension for the MLP block.
+            layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
+            final_attn_token_to_image (Attention): Final attention layer from queries to image.
+            norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
+
+        Examples:
+            >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
+            >>> image_embedding = torch.randn(1, 256, 32, 32)
+            >>> image_pe = torch.randn(1, 256, 32, 32)
+            >>> point_embedding = torch.randn(1, 100, 256)
+            >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
+            >>> print(output_queries.shape, output_image.shape)
+        """
+        super().__init__()
+        self.depth = depth
+        self.embedding_dim = embedding_dim
+        self.num_heads = num_heads
+        self.mlp_dim = mlp_dim
+        self.layers = nn.ModuleList()
+
+        for i in range(depth):
+            self.layers.append(
+                TwoWayAttentionBlock(
+                    embedding_dim=embedding_dim,
+                    num_heads=num_heads,
+                    mlp_dim=mlp_dim,
+                    activation=activation,
+                    attention_downsample_rate=attention_downsample_rate,
+                    skip_first_layer_pe=(i == 0),
+                )
+            )
+
+        self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
+        self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+    def forward(
+        self,
+        image_embedding: Tensor,
+        image_pe: Tensor,
+        point_embedding: Tensor,
+    ) -> Tuple[Tensor, Tensor]:
+        """
+        Processes image and point embeddings through the Two-Way Transformer.
+
+        Args:
+            image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
+            image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
+            point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
+
+        Returns:
+            (Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
+
+        Examples:
+            >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
+            >>> image_embedding = torch.randn(1, 256, 32, 32)
+            >>> image_pe = torch.randn(1, 256, 32, 32)
+            >>> point_embedding = torch.randn(1, 100, 256)
+            >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
+            >>> print(output_queries.shape, output_image.shape)
+        """
+        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+        image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+        # Prepare queries
+        queries = point_embedding
+        keys = image_embedding
+
+        # Apply transformer blocks and final layernorm
+        for layer in self.layers:
+            queries, keys = layer(
+                queries=queries,
+                keys=keys,
+                query_pe=point_embedding,
+                key_pe=image_pe,
+            )
+
+        # Apply the final attention layer from the points to the image
+        q = queries + point_embedding
+        k = keys + image_pe
+        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+        queries = queries + attn_out
+        queries = self.norm_final_attn(queries)
+
+        return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+    """
+    A two-way attention block for simultaneous attention to image and query points.
+
+    This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
+    cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
+    inputs to sparse inputs.
+
+    Attributes:
+        self_attn (Attention): Self-attention layer for queries.
+        norm1 (nn.LayerNorm): Layer normalization after self-attention.
+        cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
+        norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
+        mlp (MLPBlock): MLP block for transforming query embeddings.
+        norm3 (nn.LayerNorm): Layer normalization after MLP block.
+        norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
+        cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
+        skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
+
+    Methods:
+        forward: Applies self-attention and cross-attention to queries and keys.
+
+    Examples:
+        >>> embedding_dim, num_heads = 256, 8
+        >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
+        >>> queries = torch.randn(1, 100, embedding_dim)
+        >>> keys = torch.randn(1, 1000, embedding_dim)
+        >>> query_pe = torch.randn(1, 100, embedding_dim)
+        >>> key_pe = torch.randn(1, 1000, embedding_dim)
+        >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
+    """
+
+    def __init__(
+        self,
+        embedding_dim: int,
+        num_heads: int,
+        mlp_dim: int = 2048,
+        activation: Type[nn.Module] = nn.ReLU,
+        attention_downsample_rate: int = 2,
+        skip_first_layer_pe: bool = False,
+    ) -> None:
+        """
+        Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
+
+        This block implements a specialized transformer layer with four main components: self-attention on sparse
+        inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
+        of dense inputs to sparse inputs.
+
+        Args:
+            embedding_dim (int): Channel dimension of the embeddings.
+            num_heads (int): Number of attention heads in the attention layers.
+            mlp_dim (int): Hidden dimension of the MLP block.
+            activation (Type[nn.Module]): Activation function for the MLP block.
+            attention_downsample_rate (int): Downsampling rate for the attention mechanism.
+            skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
+
+        Examples:
+            >>> embedding_dim, num_heads = 256, 8
+            >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
+            >>> queries = torch.randn(1, 100, embedding_dim)
+            >>> keys = torch.randn(1, 1000, embedding_dim)
+            >>> query_pe = torch.randn(1, 100, embedding_dim)
+            >>> key_pe = torch.randn(1, 1000, embedding_dim)
+            >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
+        """
+        super().__init__()
+        self.self_attn = Attention(embedding_dim, num_heads)
+        self.norm1 = nn.LayerNorm(embedding_dim)
+
+        self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
+        self.norm2 = nn.LayerNorm(embedding_dim)
+
+        self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
+        self.norm3 = nn.LayerNorm(embedding_dim)
+
+        self.norm4 = nn.LayerNorm(embedding_dim)
+        self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
+
+        self.skip_first_layer_pe = skip_first_layer_pe
+
+    def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
+        """Applies two-way attention to process query and key embeddings in a transformer block."""
+        # Self attention block
+        if self.skip_first_layer_pe:
+            queries = self.self_attn(q=queries, k=queries, v=queries)
+        else:
+            q = queries + query_pe
+            attn_out = self.self_attn(q=q, k=q, v=queries)
+            queries = queries + attn_out
+        queries = self.norm1(queries)
+
+        # Cross attention block, tokens attending to image embedding
+        q = queries + query_pe
+        k = keys + key_pe
+        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+        queries = queries + attn_out
+        queries = self.norm2(queries)
+
+        # MLP block
+        mlp_out = self.mlp(queries)
+        queries = queries + mlp_out
+        queries = self.norm3(queries)
+
+        # Cross attention block, image embedding attending to tokens
+        q = queries + query_pe
+        k = keys + key_pe
+        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+        keys = keys + attn_out
+        keys = self.norm4(keys)
+
+        return queries, keys
+
+
+class Attention(nn.Module):
+    """
+    An attention layer with downscaling capability for embedding size after projection.
+
+    This class implements a multi-head attention mechanism with the option to downsample the internal
+    dimension of queries, keys, and values.
+
+    Attributes:
+        embedding_dim (int): Dimensionality of input embeddings.
+        kv_in_dim (int): Dimensionality of key and value inputs.
+        internal_dim (int): Internal dimension after downsampling.
+        num_heads (int): Number of attention heads.
+        q_proj (nn.Linear): Linear projection for queries.
+        k_proj (nn.Linear): Linear projection for keys.
+        v_proj (nn.Linear): Linear projection for values.
+        out_proj (nn.Linear): Linear projection for output.
+
+    Methods:
+        _separate_heads: Separates input tensor into attention heads.
+        _recombine_heads: Recombines separated attention heads.
+        forward: Computes attention output for given query, key, and value tensors.
+
+    Examples:
+        >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
+        >>> q = torch.randn(1, 100, 256)
+        >>> k = v = torch.randn(1, 50, 256)
+        >>> output = attn(q, k, v)
+        >>> print(output.shape)
+        torch.Size([1, 100, 256])
+    """
+
+    def __init__(
+        self,
+        embedding_dim: int,
+        num_heads: int,
+        downsample_rate: int = 1,
+        kv_in_dim: int = None,
+    ) -> None:
+        """
+        Initializes the Attention module with specified dimensions and settings.
+
+        This class implements a multi-head attention mechanism with optional downsampling of the internal
+        dimension for queries, keys, and values.
+
+        Args:
+            embedding_dim (int): Dimensionality of input embeddings.
+            num_heads (int): Number of attention heads.
+            downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
+            kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
+
+        Raises:
+            AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
+
+        Examples:
+            >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
+            >>> q = torch.randn(1, 100, 256)
+            >>> k = v = torch.randn(1, 50, 256)
+            >>> output = attn(q, k, v)
+            >>> print(output.shape)
+            torch.Size([1, 100, 256])
+        """
+        super().__init__()
+        self.embedding_dim = embedding_dim
+        self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
+        self.internal_dim = embedding_dim // downsample_rate
+        self.num_heads = num_heads
+        assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
+
+        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+        self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+    @staticmethod
+    def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
+        """Separates the input tensor into the specified number of attention heads."""
+        b, n, c = x.shape
+        x = x.reshape(b, n, num_heads, c // num_heads)
+        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head
+
+    @staticmethod
+    def _recombine_heads(x: Tensor) -> Tensor:
+        """Recombines separated attention heads into a single tensor."""
+        b, n_heads, n_tokens, c_per_head = x.shape
+        x = x.transpose(1, 2)
+        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C
+
+    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+        """Applies multi-head attention to query, key, and value tensors with optional downsampling."""
+        # Input projections
+        q = self.q_proj(q)
+        k = self.k_proj(k)
+        v = self.v_proj(v)
+
+        # Separate into heads
+        q = self._separate_heads(q, self.num_heads)
+        k = self._separate_heads(k, self.num_heads)
+        v = self._separate_heads(v, self.num_heads)
+
+        # Attention
+        _, _, _, c_per_head = q.shape
+        attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
+        attn = attn / math.sqrt(c_per_head)
+        attn = torch.softmax(attn, dim=-1)
+
+        # Get output
+        out = attn @ v
+        out = self._recombine_heads(out)
+        return self.out_proj(out)

+ 293 - 0
ultralytics/models/sam/modules/utils.py

@@ -0,0 +1,293 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from typing import Tuple
+
+import torch
+import torch.nn.functional as F
+
+
+def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
+    """
+    Selects the closest conditioning frames to a given frame index.
+
+    Args:
+        frame_idx (int): Current frame index.
+        cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.
+        max_cond_frame_num (int): Maximum number of conditioning frames to select.
+
+    Returns:
+        (Tuple[Dict[int, Any], Dict[int, Any]]): A tuple containing two dictionaries:
+            - selected_outputs: Selected items from cond_frame_outputs.
+            - unselected_outputs: Items not selected from cond_frame_outputs.
+
+    Examples:
+        >>> frame_idx = 5
+        >>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
+        >>> max_cond_frame_num = 2
+        >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
+        >>> print(selected)
+        {3: 'b', 7: 'c'}
+        >>> print(unselected)
+        {1: 'a', 9: 'd'}
+    """
+    if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
+        selected_outputs = cond_frame_outputs
+        unselected_outputs = {}
+    else:
+        assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
+        selected_outputs = {}
+
+        # the closest conditioning frame before `frame_idx` (if any)
+        idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
+        if idx_before is not None:
+            selected_outputs[idx_before] = cond_frame_outputs[idx_before]
+
+        # the closest conditioning frame after `frame_idx` (if any)
+        idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
+        if idx_after is not None:
+            selected_outputs[idx_after] = cond_frame_outputs[idx_after]
+
+        # add other temporally closest conditioning frames until reaching a total
+        # of `max_cond_frame_num` conditioning frames.
+        num_remain = max_cond_frame_num - len(selected_outputs)
+        inds_remain = sorted(
+            (t for t in cond_frame_outputs if t not in selected_outputs),
+            key=lambda x: abs(x - frame_idx),
+        )[:num_remain]
+        selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
+        unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}
+
+    return selected_outputs, unselected_outputs
+
+
+def get_1d_sine_pe(pos_inds, dim, temperature=10000):
+    """Generates 1D sinusoidal positional embeddings for given positions and dimensions."""
+    pe_dim = dim // 2
+    dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
+    dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
+
+    pos_embed = pos_inds.unsqueeze(-1) / dim_t
+    pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
+    return pos_embed
+
+
+def init_t_xy(end_x: int, end_y: int):
+    """Initializes 1D and 2D coordinate tensors for a grid of specified dimensions."""
+    t = torch.arange(end_x * end_y, dtype=torch.float32)
+    t_x = (t % end_x).float()
+    t_y = torch.div(t, end_x, rounding_mode="floor").float()
+    return t_x, t_y
+
+
+def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
+    """Computes axial complex exponential positional encodings for 2D spatial positions in a grid."""
+    freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+    freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+
+    t_x, t_y = init_t_xy(end_x, end_y)
+    freqs_x = torch.outer(t_x, freqs_x)
+    freqs_y = torch.outer(t_y, freqs_y)
+    freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
+    freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
+    return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+    """Reshapes frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
+    ndim = x.ndim
+    assert 0 <= 1 < ndim
+    assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
+    shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
+    return freqs_cis.view(*shape)
+
+
+def apply_rotary_enc(
+    xq: torch.Tensor,
+    xk: torch.Tensor,
+    freqs_cis: torch.Tensor,
+    repeat_freqs_k: bool = False,
+):
+    """Applies rotary positional encoding to query and key tensors using complex-valued frequency components."""
+    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
+    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+    if xk_ is None:
+        # no keys to rotate, due to dropout
+        return xq_out.type_as(xq).to(xq.device), xk
+    # repeat freqs along seq_len dim to match k seq_len
+    if repeat_freqs_k:
+        r = xk_.shape[-2] // xq_.shape[-2]
+        freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
+    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+    return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
+
+
+def window_partition(x, window_size):
+    """
+    Partitions input tensor into non-overlapping windows with padding if needed.
+
+    Args:
+        x (torch.Tensor): Input tensor with shape (B, H, W, C).
+        window_size (int): Size of each window.
+
+    Returns:
+        (Tuple[torch.Tensor, Tuple[int, int]]): A tuple containing:
+            - windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).
+            - (Hp, Wp) (Tuple[int, int]): Padded height and width before partition.
+
+    Examples:
+        >>> x = torch.randn(1, 16, 16, 3)
+        >>> windows, (Hp, Wp) = window_partition(x, window_size=4)
+        >>> print(windows.shape, Hp, Wp)
+        torch.Size([16, 4, 4, 3]) 16 16
+    """
+    B, H, W, C = x.shape
+
+    pad_h = (window_size - H % window_size) % window_size
+    pad_w = (window_size - W % window_size) % window_size
+    if pad_h > 0 or pad_w > 0:
+        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+    Hp, Wp = H + pad_h, W + pad_w
+
+    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows, (Hp, Wp)
+
+
+def window_unpartition(windows, window_size, pad_hw, hw):
+    """
+    Unpartitions windowed sequences into original sequences and removes padding.
+
+    This function reverses the windowing process, reconstructing the original input from windowed segments
+    and removing any padding that was added during the windowing process.
+
+    Args:
+        windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
+            window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
+            the size of each window, and C is the number of channels.
+        window_size (int): Size of each window.
+        pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
+        hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
+
+    Returns:
+        (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
+            are the original height and width, and C is the number of channels.
+
+    Examples:
+        >>> windows = torch.rand(32, 8, 8, 64)  # 32 windows of size 8x8 with 64 channels
+        >>> pad_hw = (16, 16)  # Padded height and width
+        >>> hw = (15, 14)  # Original height and width
+        >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)
+        >>> print(x.shape)
+        torch.Size([1, 15, 14, 64])
+    """
+    Hp, Wp = pad_hw
+    H, W = hw
+    B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+    x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+    if Hp > H or Wp > W:
+        x = x[:, :H, :W, :].contiguous()
+    return x
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+    """
+    Extracts relative positional embeddings based on query and key sizes.
+
+    Args:
+        q_size (int): Size of the query.
+        k_size (int): Size of the key.
+        rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
+            distance and C is the embedding dimension.
+
+    Returns:
+        (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
+            k_size, C).
+
+    Examples:
+        >>> q_size, k_size = 8, 16
+        >>> rel_pos = torch.randn(31, 64)  # 31 = 2 * max(8, 16) - 1
+        >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
+        >>> print(extracted_pos.shape)
+        torch.Size([8, 16, 64])
+    """
+    max_rel_dist = int(2 * max(q_size, k_size) - 1)
+    # Interpolate rel pos if needed.
+    if rel_pos.shape[0] != max_rel_dist:
+        # Interpolate rel pos.
+        rel_pos_resized = F.interpolate(
+            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+            size=max_rel_dist,
+            mode="linear",
+        )
+        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+    else:
+        rel_pos_resized = rel_pos
+
+    # Scale the coords with short length if shapes for q and k are different.
+    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+    return rel_pos_resized[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(
+    attn: torch.Tensor,
+    q: torch.Tensor,
+    rel_pos_h: torch.Tensor,
+    rel_pos_w: torch.Tensor,
+    q_size: Tuple[int, int],
+    k_size: Tuple[int, int],
+) -> torch.Tensor:
+    """
+    Adds decomposed Relative Positional Embeddings to the attention map.
+
+    This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
+    paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
+    positions.
+
+    Args:
+        attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
+        q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
+        rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
+        rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
+        q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
+        k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
+
+    Returns:
+        (torch.Tensor): Updated attention map with added relative positional embeddings, shape
+            (B, q_h * q_w, k_h * k_w).
+
+    Examples:
+        >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
+        >>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
+        >>> q = torch.rand(B, q_h * q_w, C)
+        >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
+        >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
+        >>> q_size, k_size = (q_h, q_w), (k_h, k_w)
+        >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
+        >>> print(updated_attn.shape)
+        torch.Size([1, 64, 64])
+
+    References:
+        https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
+    """
+    q_h, q_w = q_size
+    k_h, k_w = k_size
+    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+    Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+    B, _, dim = q.shape
+    r_q = q.reshape(B, q_h, q_w, dim)
+    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+
+    attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
+        B, q_h * q_w, k_h * k_w
+    )
+
+    return attn

+ 1605 - 0
ultralytics/models/sam/predict.py

@@ -0,0 +1,1605 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""
+Generate predictions using the Segment Anything Model (SAM).
+
+SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance.
+This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation
+using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image
+segmentation tasks.
+"""
+
+from collections import OrderedDict
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from ultralytics.data.augment import LetterBox
+from ultralytics.engine.predictor import BasePredictor
+from ultralytics.engine.results import Results
+from ultralytics.utils import DEFAULT_CFG, ops
+from ultralytics.utils.torch_utils import select_device, smart_inference_mode
+
+from .amg import (
+    batch_iterator,
+    batched_mask_to_box,
+    build_all_layer_point_grids,
+    calculate_stability_score,
+    generate_crop_boxes,
+    is_box_near_crop_edge,
+    remove_small_regions,
+    uncrop_boxes_xyxy,
+    uncrop_masks,
+)
+from .build import build_sam
+
+
+class Predictor(BasePredictor):
+    """
+    Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
+
+    This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image
+    segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for
+    fine-grained control over segmentation results.
+
+    Attributes:
+        args (SimpleNamespace): Configuration arguments for the predictor.
+        model (torch.nn.Module): The loaded SAM model.
+        device (torch.device): The device (CPU or GPU) on which the model is loaded.
+        im (torch.Tensor): The preprocessed input image.
+        features (torch.Tensor): Extracted image features.
+        prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
+        segment_all (bool): Flag to indicate if full image segmentation should be performed.
+        mean (torch.Tensor): Mean values for image normalization.
+        std (torch.Tensor): Standard deviation values for image normalization.
+
+    Methods:
+        preprocess: Prepares input images for model inference.
+        pre_transform: Performs initial transformations on the input image.
+        inference: Performs segmentation inference based on input prompts.
+        prompt_inference: Internal function for prompt-based segmentation inference.
+        generate: Generates segmentation masks for an entire image.
+        setup_model: Initializes the SAM model for inference.
+        get_model: Builds and returns a SAM model.
+        postprocess: Post-processes model outputs to generate final results.
+        setup_source: Sets up the data source for inference.
+        set_image: Sets and preprocesses a single image for inference.
+        get_im_features: Extracts image features using the SAM image encoder.
+        set_prompts: Sets prompts for subsequent inference.
+        reset_image: Resets the current image and its features.
+        remove_small_regions: Removes small disconnected regions and holes from masks.
+
+    Examples:
+        >>> predictor = Predictor()
+        >>> predictor.setup_model(model_path="sam_model.pt")
+        >>> predictor.set_image("image.jpg")
+        >>> bboxes = [[100, 100, 200, 200]]
+        >>> results = predictor(bboxes=bboxes)
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """
+        Initialize the Predictor with configuration, overrides, and callbacks.
+
+        Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or
+        callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True
+        for optimal results.
+
+        Args:
+            cfg (Dict): Configuration dictionary containing default settings.
+            overrides (Dict | None): Dictionary of values to override default configuration.
+            _callbacks (Dict | None): Dictionary of callback functions to customize behavior.
+
+        Examples:
+            >>> predictor_example = Predictor(cfg=DEFAULT_CFG)
+            >>> predictor_example_with_imgsz = Predictor(overrides={"imgsz": 640})
+            >>> predictor_example_with_callback = Predictor(_callbacks={"on_predict_start": custom_callback})
+        """
+        if overrides is None:
+            overrides = {}
+        overrides.update(dict(task="segment", mode="predict", batch=1))
+        super().__init__(cfg, overrides, _callbacks)
+        self.args.retina_masks = True
+        self.im = None
+        self.features = None
+        self.prompts = {}
+        self.segment_all = False
+
+    def preprocess(self, im):
+        """
+        Preprocess the input image for model inference.
+
+        This method prepares the input image by applying transformations and normalization. It supports both
+        torch.Tensor and list of np.ndarray as input formats.
+
+        Args:
+            im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
+
+        Returns:
+            im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
+
+        Examples:
+            >>> predictor = Predictor()
+            >>> image = torch.rand(1, 3, 640, 640)
+            >>> preprocessed_image = predictor.preprocess(image)
+        """
+        if self.im is not None:
+            return self.im
+        not_tensor = not isinstance(im, torch.Tensor)
+        if not_tensor:
+            im = np.stack(self.pre_transform(im))
+            im = im[..., ::-1].transpose((0, 3, 1, 2))
+            im = np.ascontiguousarray(im)
+            im = torch.from_numpy(im)
+
+        im = im.to(self.device)
+        im = im.half() if self.model.fp16 else im.float()
+        if not_tensor:
+            im = (im - self.mean) / self.std
+        return im
+
+    def pre_transform(self, im):
+        """
+        Perform initial transformations on the input image for preprocessing.
+
+        This method applies transformations such as resizing to prepare the image for further preprocessing.
+        Currently, batched inference is not supported; hence the list length should be 1.
+
+        Args:
+            im (List[np.ndarray]): List containing a single image in HWC numpy array format.
+
+        Returns:
+            (List[np.ndarray]): List containing the transformed image.
+
+        Raises:
+            AssertionError: If the input list contains more than one image.
+
+        Examples:
+            >>> predictor = Predictor()
+            >>> image = np.random.rand(480, 640, 3)  # Single HWC image
+            >>> transformed = predictor.pre_transform([image])
+            >>> print(len(transformed))
+            1
+        """
+        assert len(im) == 1, "SAM model does not currently support batched inference"
+        letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
+        return [letterbox(image=x) for x in im]
+
+    def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
+        """
+        Perform image segmentation inference based on the given input cues, using the currently loaded image.
+
+        This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
+        encoder, and mask decoder for real-time and promptable segmentation tasks.
+
+        Args:
+            im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
+            bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format.
+            points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
+            labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.
+            masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256.
+            multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts.
+            *args (Any): Additional positional arguments.
+            **kwargs (Any): Additional keyword arguments.
+
+        Returns:
+            (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
+            (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
+            (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
+
+        Examples:
+            >>> predictor = Predictor()
+            >>> predictor.setup_model(model_path="sam_model.pt")
+            >>> predictor.set_image("image.jpg")
+            >>> results = predictor(bboxes=[[0, 0, 100, 100]])
+        """
+        # Override prompts if any stored in self.prompts
+        bboxes = self.prompts.pop("bboxes", bboxes)
+        points = self.prompts.pop("points", points)
+        masks = self.prompts.pop("masks", masks)
+        labels = self.prompts.pop("labels", labels)
+
+        if all(i is None for i in [bboxes, points, masks]):
+            return self.generate(im, *args, **kwargs)
+
+        return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
+
+    def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
+        """
+        Performs image segmentation inference based on input cues using SAM's specialized architecture.
+
+        This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.
+        It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
+
+        Args:
+            im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
+            bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
+            points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
+            labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
+            masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
+            multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
+
+        Raises:
+            AssertionError: If the number of points don't match the number of labels, in case labels were passed.
+
+        Returns:
+            (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
+            (np.ndarray): Quality scores predicted by the model for each mask, with length C.
+
+        Examples:
+            >>> predictor = Predictor()
+            >>> im = torch.rand(1, 3, 1024, 1024)
+            >>> bboxes = [[100, 100, 200, 200]]
+            >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes)
+        """
+        features = self.get_im_features(im) if self.features is None else self.features
+
+        bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
+        points = (points, labels) if points is not None else None
+        # Embed prompts
+        sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
+
+        # Predict masks
+        pred_masks, pred_scores = self.model.mask_decoder(
+            image_embeddings=features,
+            image_pe=self.model.prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embeddings,
+            dense_prompt_embeddings=dense_embeddings,
+            multimask_output=multimask_output,
+        )
+
+        # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
+        # `d` could be 1 or 3 depends on `multimask_output`.
+        return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
+
+    def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
+        """
+        Prepares and transforms the input prompts for processing based on the destination shape.
+
+        Args:
+            dst_shape (tuple): The target shape (height, width) for the prompts.
+            bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
+            points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
+            labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
+            masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
+
+        Raises:
+            AssertionError: If the number of points don't match the number of labels, in case labels were passed.
+
+        Returns:
+            (tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
+        """
+        src_shape = self.batch[1][0].shape[:2]
+        r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
+        # Transform input prompts
+        if points is not None:
+            points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
+            points = points[None] if points.ndim == 1 else points
+            # Assuming labels are all positive if users don't pass labels.
+            if labels is None:
+                labels = np.ones(points.shape[:-1])
+            labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
+            assert points.shape[-2] == labels.shape[-1], (
+                f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}."
+            )
+            points *= r
+            if points.ndim == 2:
+                # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
+                points, labels = points[:, None, :], labels[:, None]
+        if bboxes is not None:
+            bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
+            bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
+            bboxes *= r
+        if masks is not None:
+            masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
+        return bboxes, points, labels, masks
+
+    def generate(
+        self,
+        im,
+        crop_n_layers=0,
+        crop_overlap_ratio=512 / 1500,
+        crop_downscale_factor=1,
+        point_grids=None,
+        points_stride=32,
+        points_batch_size=64,
+        conf_thres=0.88,
+        stability_score_thresh=0.95,
+        stability_score_offset=0.95,
+        crop_nms_thresh=0.7,
+    ):
+        """
+        Perform image segmentation using the Segment Anything Model (SAM).
+
+        This method segments an entire image into constituent parts by leveraging SAM's advanced architecture
+        and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
+
+        Args:
+            im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W).
+            crop_n_layers (int): Number of layers for additional mask predictions on image crops.
+            crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers.
+            crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer.
+            point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].
+            points_stride (int): Number of points to sample along each side of the image.
+            points_batch_size (int): Batch size for the number of points processed simultaneously.
+            conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction.
+            stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability.
+            stability_score_offset (float): Offset value for calculating stability score.
+            crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
+
+        Returns:
+            pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
+            pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
+            pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
+
+        Examples:
+            >>> predictor = Predictor()
+            >>> im = torch.rand(1, 3, 1024, 1024)  # Example input image
+            >>> masks, scores, boxes = predictor.generate(im)
+        """
+        import torchvision  # scope for faster 'import ultralytics'
+
+        self.segment_all = True
+        ih, iw = im.shape[2:]
+        crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
+        if point_grids is None:
+            point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor)
+        pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
+        for crop_region, layer_idx in zip(crop_regions, layer_idxs):
+            x1, y1, x2, y2 = crop_region
+            w, h = x2 - x1, y2 - y1
+            area = torch.tensor(w * h, device=im.device)
+            points_scale = np.array([[w, h]])  # w, h
+            # Crop image and interpolate to input size
+            crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False)
+            # (num_points, 2)
+            points_for_image = point_grids[layer_idx] * points_scale
+            crop_masks, crop_scores, crop_bboxes = [], [], []
+            for (points,) in batch_iterator(points_batch_size, points_for_image):
+                pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
+                # Interpolate predicted masks to input size
+                pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0]
+                idx = pred_score > conf_thres
+                pred_mask, pred_score = pred_mask[idx], pred_score[idx]
+
+                stability_score = calculate_stability_score(
+                    pred_mask, self.model.mask_threshold, stability_score_offset
+                )
+                idx = stability_score > stability_score_thresh
+                pred_mask, pred_score = pred_mask[idx], pred_score[idx]
+                # Bool type is much more memory-efficient.
+                pred_mask = pred_mask > self.model.mask_threshold
+                # (N, 4)
+                pred_bbox = batched_mask_to_box(pred_mask).float()
+                keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])
+                if not torch.all(keep_mask):
+                    pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask]
+
+                crop_masks.append(pred_mask)
+                crop_bboxes.append(pred_bbox)
+                crop_scores.append(pred_score)
+
+            # Do nms within this crop
+            crop_masks = torch.cat(crop_masks)
+            crop_bboxes = torch.cat(crop_bboxes)
+            crop_scores = torch.cat(crop_scores)
+            keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou)  # NMS
+            crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)
+            crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)
+            crop_scores = crop_scores[keep]
+
+            pred_masks.append(crop_masks)
+            pred_bboxes.append(crop_bboxes)
+            pred_scores.append(crop_scores)
+            region_areas.append(area.expand(len(crop_masks)))
+
+        pred_masks = torch.cat(pred_masks)
+        pred_bboxes = torch.cat(pred_bboxes)
+        pred_scores = torch.cat(pred_scores)
+        region_areas = torch.cat(region_areas)
+
+        # Remove duplicate masks between crops
+        if len(crop_regions) > 1:
+            scores = 1 / region_areas
+            keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)
+            pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep]
+
+        return pred_masks, pred_scores, pred_bboxes
+
+    def setup_model(self, model=None, verbose=True):
+        """
+        Initializes the Segment Anything Model (SAM) for inference.
+
+        This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
+        parameters for image normalization and other Ultralytics compatibility settings.
+
+        Args:
+            model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config.
+            verbose (bool): If True, prints selected device information.
+
+        Examples:
+            >>> predictor = Predictor()
+            >>> predictor.setup_model(model=sam_model, verbose=True)
+        """
+        device = select_device(self.args.device, verbose=verbose)
+        if model is None:
+            model = self.get_model()
+        model.eval()
+        self.model = model.to(device)
+        self.device = device
+        self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
+        self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
+
+        # Ultralytics compatibility settings
+        self.model.pt = False
+        self.model.triton = False
+        self.model.stride = 32
+        self.model.fp16 = False
+        self.done_warmup = True
+
+    def get_model(self):
+        """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks."""
+        return build_sam(self.args.model)
+
+    def postprocess(self, preds, img, orig_imgs):
+        """
+        Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
+
+        This method scales masks and boxes to the original image size and applies a threshold to the mask
+        predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
+
+        Args:
+            preds (Tuple[torch.Tensor]): The output from SAM model inference, containing:
+                - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).
+                - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).
+                - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.
+            img (torch.Tensor): The processed input image tensor with shape (C, H, W).
+            orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
+
+        Returns:
+            results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
+                metadata for each processed image.
+
+        Examples:
+            >>> predictor = Predictor()
+            >>> preds = predictor.inference(img)
+            >>> results = predictor.postprocess(preds, img, orig_imgs)
+        """
+        # (N, 1, H, W), (N, 1)
+        pred_masks, pred_scores = preds[:2]
+        pred_bboxes = preds[2] if self.segment_all else None
+        names = dict(enumerate(str(i) for i in range(len(pred_masks))))
+
+        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
+            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+
+        results = []
+        for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]):
+            if len(masks) == 0:
+                masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device)
+            else:
+                masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
+                masks = masks > self.model.mask_threshold  # to bool
+                if pred_bboxes is not None:
+                    pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
+                else:
+                    pred_bboxes = batched_mask_to_box(masks)
+                # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
+                cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
+                pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
+            results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
+        # Reset segment-all mode.
+        self.segment_all = False
+        return results
+
+    def setup_source(self, source):
+        """
+        Sets up the data source for inference.
+
+        This method configures the data source from which images will be fetched for inference. It supports
+        various input types such as image files, directories, video files, and other compatible data sources.
+
+        Args:
+            source (str | Path | None): The path or identifier for the image data source. Can be a file path,
+                directory path, URL, or other supported source types.
+
+        Examples:
+            >>> predictor = Predictor()
+            >>> predictor.setup_source("path/to/images")
+            >>> predictor.setup_source("video.mp4")
+            >>> predictor.setup_source(None)  # Uses default source if available
+
+        Notes:
+            - If source is None, the method may use a default source if configured.
+            - The method adapts to different source types and prepares them for subsequent inference steps.
+            - Supported source types may include local files, directories, URLs, and video streams.
+        """
+        if source is not None:
+            super().setup_source(source)
+
+    def set_image(self, image):
+        """
+        Preprocesses and sets a single image for inference.
+
+        This method prepares the model for inference on a single image by setting up the model if not already
+        initialized, configuring the data source, and preprocessing the image for feature extraction. It
+        ensures that only one image is set at a time and extracts image features for subsequent use.
+
+        Args:
+            image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
+                an image read by cv2.
+
+        Raises:
+            AssertionError: If more than one image is attempted to be set.
+
+        Examples:
+            >>> predictor = Predictor()
+            >>> predictor.set_image("path/to/image.jpg")
+            >>> predictor.set_image(cv2.imread("path/to/image.jpg"))
+
+        Notes:
+            - This method should be called before performing inference on a new image.
+            - The extracted features are stored in the `self.features` attribute for later use.
+        """
+        if self.model is None:
+            self.setup_model(model=None)
+        self.setup_source(image)
+        assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
+        for batch in self.dataset:
+            im = self.preprocess(batch[1])
+            self.features = self.get_im_features(im)
+            break
+
+    def get_im_features(self, im):
+        """Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
+        assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
+            f"SAM models only support square image size, but got {self.imgsz}."
+        )
+        self.model.set_imgsz(self.imgsz)
+        return self.model.image_encoder(im)
+
+    def set_prompts(self, prompts):
+        """Sets prompts for subsequent inference operations."""
+        self.prompts = prompts
+
+    def reset_image(self):
+        """Resets the current image and its features, clearing them for subsequent inference."""
+        self.im = None
+        self.features = None
+
+    @staticmethod
+    def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
+        """
+        Remove small disconnected regions and holes from segmentation masks.
+
+        This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM).
+        It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
+        Suppression (NMS) to eliminate any newly created duplicate boxes.
+
+        Args:
+            masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of
+                masks, H is height, and W is width.
+            min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than
+                this will be removed.
+            nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
+
+        Returns:
+            new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
+            keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
+
+        Examples:
+            >>> masks = torch.rand(5, 640, 640) > 0.5  # 5 random binary masks
+            >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7)
+            >>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}")
+            >>> print(f"Indices of kept masks: {keep}")
+        """
+        import torchvision  # scope for faster 'import ultralytics'
+
+        if len(masks) == 0:
+            return masks
+
+        # Filter small disconnected regions and holes
+        new_masks = []
+        scores = []
+        for mask in masks:
+            mask = mask.cpu().numpy().astype(np.uint8)
+            mask, changed = remove_small_regions(mask, min_area, mode="holes")
+            unchanged = not changed
+            mask, changed = remove_small_regions(mask, min_area, mode="islands")
+            unchanged = unchanged and not changed
+
+            new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+            # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing
+            scores.append(float(unchanged))
+
+        # Recalculate boxes and remove any new duplicates
+        new_masks = torch.cat(new_masks, dim=0)
+        boxes = batched_mask_to_box(new_masks)
+        keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
+
+        return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
+
+
+class SAM2Predictor(Predictor):
+    """
+    SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
+
+    This class extends the base Predictor class to implement SAM2-specific functionality for image
+    segmentation tasks. It provides methods for model initialization, feature extraction, and
+    prompt-based inference.
+
+    Attributes:
+        _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels.
+        model (torch.nn.Module): The loaded SAM2 model.
+        device (torch.device): The device (CPU or GPU) on which the model is loaded.
+        features (Dict[str, torch.Tensor]): Cached image features for efficient inference.
+        segment_all (bool): Flag to indicate if all segments should be predicted.
+        prompts (Dict): Dictionary to store various types of prompts for inference.
+
+    Methods:
+        get_model: Retrieves and initializes the SAM2 model.
+        prompt_inference: Performs image segmentation inference based on various prompts.
+        set_image: Preprocesses and sets a single image for inference.
+        get_im_features: Extracts and processes image features using SAM2's image encoder.
+
+    Examples:
+        >>> predictor = SAM2Predictor(cfg)
+        >>> predictor.set_image("path/to/image.jpg")
+        >>> bboxes = [[100, 100, 200, 200]]
+        >>> result = predictor(bboxes=bboxes)[0]
+        >>> print(f"Predicted {len(result.masks)} masks with average score {result.boxes.conf.mean():.2f}")
+    """
+
+    _bb_feat_sizes = [
+        (256, 256),
+        (128, 128),
+        (64, 64),
+    ]
+
+    def get_model(self):
+        """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
+        return build_sam(self.args.model)
+
+    def prompt_inference(
+        self,
+        im,
+        bboxes=None,
+        points=None,
+        labels=None,
+        masks=None,
+        multimask_output=False,
+        img_idx=-1,
+    ):
+        """
+        Performs image segmentation inference based on various prompts using SAM2 architecture.
+
+        This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
+        based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
+        multi-object prediction scenarios.
+
+        Args:
+            im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
+            bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
+            points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
+            labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
+            masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
+            multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
+            img_idx (int): Index of the image in the batch to process.
+
+        Returns:
+            (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
+            (np.ndarray): Quality scores for each mask, with length C.
+
+        Examples:
+            >>> predictor = SAM2Predictor(cfg)
+            >>> image = torch.rand(1, 3, 640, 640)
+            >>> bboxes = [[100, 100, 200, 200]]
+            >>> result = predictor(image, bboxes=bboxes)[0]
+            >>> print(f"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}")
+
+        Notes:
+            - The method supports batched inference for multiple objects when points or bboxes are provided.
+            - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
+            - When both bboxes and points are provided, they are merged into a single 'points' input for the model.
+
+        References:
+            - SAM2 Paper: [Add link to SAM2 paper when available]
+        """
+        features = self.get_im_features(im) if self.features is None else self.features
+
+        points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
+        points = (points, labels) if points is not None else None
+
+        sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
+            points=points,
+            boxes=None,
+            masks=masks,
+        )
+        # Predict masks
+        batched_mode = points is not None and points[0].shape[0] > 1  # multi object prediction
+        high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
+        pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
+            image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
+            image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embeddings,
+            dense_prompt_embeddings=dense_embeddings,
+            multimask_output=multimask_output,
+            repeat_image=batched_mode,
+            high_res_features=high_res_features,
+        )
+        # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
+        # `d` could be 1 or 3 depends on `multimask_output`.
+        return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
+
+    def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
+        """
+        Prepares and transforms the input prompts for processing based on the destination shape.
+
+        Args:
+            dst_shape (tuple): The target shape (height, width) for the prompts.
+            bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
+            points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
+            labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
+            masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
+
+        Raises:
+            AssertionError: If the number of points don't match the number of labels, in case labels were passed.
+
+        Returns:
+            (tuple): A tuple containing transformed points, labels, and masks.
+        """
+        bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
+        if bboxes is not None:
+            bboxes = bboxes.view(-1, 2, 2)
+            bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
+            # NOTE: merge "boxes" and "points" into a single "points" input
+            # (where boxes are added at the beginning) to model.sam_prompt_encoder
+            if points is not None:
+                points = torch.cat([bboxes, points], dim=1)
+                labels = torch.cat([bbox_labels, labels], dim=1)
+            else:
+                points, labels = bboxes, bbox_labels
+        return points, labels, masks
+
+    def set_image(self, image):
+        """
+        Preprocesses and sets a single image for inference using the SAM2 model.
+
+        This method initializes the model if not already done, configures the data source to the specified image,
+        and preprocesses the image for feature extraction. It supports setting only one image at a time.
+
+        Args:
+            image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
+
+        Raises:
+            AssertionError: If more than one image is attempted to be set.
+
+        Examples:
+            >>> predictor = SAM2Predictor()
+            >>> predictor.set_image("path/to/image.jpg")
+            >>> predictor.set_image(np.array([...]))  # Using a numpy array
+
+        Notes:
+            - This method must be called before performing any inference on a new image.
+            - The method caches the extracted features for efficient subsequent inferences on the same image.
+            - Only one image can be set at a time. To process multiple images, call this method for each new image.
+        """
+        if self.model is None:
+            self.setup_model(model=None)
+        self.setup_source(image)
+        assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
+        for batch in self.dataset:
+            im = self.preprocess(batch[1])
+            self.features = self.get_im_features(im)
+            break
+
+    def get_im_features(self, im):
+        """Extracts image features from the SAM image encoder for subsequent processing."""
+        assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
+            f"SAM 2 models only support square image size, but got {self.imgsz}."
+        )
+        self.model.set_imgsz(self.imgsz)
+        self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]]
+
+        backbone_out = self.model.forward_image(im)
+        _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
+        if self.model.directly_add_no_mem_embed:
+            vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+        feats = [
+            feat.permute(1, 2, 0).view(1, -1, *feat_size)
+            for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+        ][::-1]
+        return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
+
+
+class SAM2VideoPredictor(SAM2Predictor):
+    """
+    SAM2VideoPredictor to handle user interactions with videos and manage inference states.
+
+    This class extends the functionality of SAM2Predictor to support video processing and maintains
+    the state of inference operations. It includes configurations for managing non-overlapping masks,
+    clearing memory for non-conditional inputs, and setting up callbacks for prediction events.
+
+    Attributes:
+        inference_state (Dict): A dictionary to store the current state of inference operations.
+        non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.
+        clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.
+        clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.
+        callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events.
+
+    Args:
+        cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG.
+        overrides (Dict, Optional): Additional configuration overrides. Defaults to None.
+        _callbacks (List, Optional): Custom callbacks to be added. Defaults to None.
+
+    Note:
+        The `fill_hole_area` attribute is defined but not used in the current implementation.
+    """
+
+    # fill_hole_area = 8  # not used
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """
+        Initialize the predictor with configuration and optional overrides.
+
+        This constructor initializes the SAM2VideoPredictor with a given configuration, applies any
+        specified overrides, and sets up the inference state along with certain flags
+        that control the behavior of the predictor.
+
+        Args:
+            cfg (Dict): Configuration dictionary containing default settings.
+            overrides (Dict | None): Dictionary of values to override default configuration.
+            _callbacks (Dict | None): Dictionary of callback functions to customize behavior.
+
+        Examples:
+            >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
+            >>> predictor_example_with_imgsz = SAM2VideoPredictor(overrides={"imgsz": 640})
+            >>> predictor_example_with_callback = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback})
+        """
+        super().__init__(cfg, overrides, _callbacks)
+        self.inference_state = {}
+        self.non_overlap_masks = True
+        self.clear_non_cond_mem_around_input = False
+        self.clear_non_cond_mem_for_multi_obj = False
+        self.callbacks["on_predict_start"].append(self.init_state)
+
+    def get_model(self):
+        """
+        Retrieves and configures the model with binarization enabled.
+
+        Note:
+            This method overrides the base class implementation to set the binarize flag to True.
+        """
+        model = super().get_model()
+        model.set_binarize(True)
+        return model
+
+    def inference(self, im, bboxes=None, points=None, labels=None, masks=None):
+        """
+        Perform image segmentation inference based on the given input cues, using the currently loaded image. This
+        method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
+        mask decoder for real-time and promptable segmentation tasks.
+
+        Args:
+            im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
+            bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
+            points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
+            labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
+            masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
+
+        Returns:
+            (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
+            (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
+        """
+        # Override prompts if any stored in self.prompts
+        bboxes = self.prompts.pop("bboxes", bboxes)
+        points = self.prompts.pop("points", points)
+        masks = self.prompts.pop("masks", masks)
+
+        frame = self.dataset.frame
+        self.inference_state["im"] = im
+        output_dict = self.inference_state["output_dict"]
+        if len(output_dict["cond_frame_outputs"]) == 0:  # initialize prompts
+            points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
+            if points is not None:
+                for i in range(len(points)):
+                    self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame)
+            elif masks is not None:
+                for i in range(len(masks)):
+                    self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame)
+        self.propagate_in_video_preflight()
+
+        consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
+        batch_size = len(self.inference_state["obj_idx_to_id"])
+        if len(output_dict["cond_frame_outputs"]) == 0:
+            raise RuntimeError("No points are provided; please add points first")
+
+        if frame in consolidated_frame_inds["cond_frame_outputs"]:
+            storage_key = "cond_frame_outputs"
+            current_out = output_dict[storage_key][frame]
+            if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
+                # clear non-conditioning memory of the surrounding frames
+                self._clear_non_cond_mem_around_input(frame)
+        elif frame in consolidated_frame_inds["non_cond_frame_outputs"]:
+            storage_key = "non_cond_frame_outputs"
+            current_out = output_dict[storage_key][frame]
+        else:
+            storage_key = "non_cond_frame_outputs"
+            current_out = self._run_single_frame_inference(
+                output_dict=output_dict,
+                frame_idx=frame,
+                batch_size=batch_size,
+                is_init_cond_frame=False,
+                point_inputs=None,
+                mask_inputs=None,
+                reverse=False,
+                run_mem_encoder=True,
+            )
+            output_dict[storage_key][frame] = current_out
+        # Create slices of per-object outputs for subsequent interaction with each
+        # individual object after tracking.
+        self._add_output_per_object(frame, current_out, storage_key)
+        self.inference_state["frames_already_tracked"].append(frame)
+        pred_masks = current_out["pred_masks"].flatten(0, 1)
+        pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0]  # filter blank masks
+
+        return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device)
+
+    def postprocess(self, preds, img, orig_imgs):
+        """
+        Post-processes the predictions to apply non-overlapping constraints if required.
+
+        This method extends the post-processing functionality by applying non-overlapping constraints
+        to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that
+        the masks do not overlap, which can be useful for certain applications.
+
+        Args:
+            preds (Tuple[torch.Tensor]): The predictions from the model.
+            img (torch.Tensor): The processed image tensor.
+            orig_imgs (List[np.ndarray]): The original images before processing.
+
+        Returns:
+            results (list): The post-processed predictions.
+
+        Note:
+            If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.
+        """
+        results = super().postprocess(preds, img, orig_imgs)
+        if self.non_overlap_masks:
+            for result in results:
+                if result.masks is None or len(result.masks) == 0:
+                    continue
+                result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0]
+        return results
+
+    @smart_inference_mode()
+    def add_new_prompts(
+        self,
+        obj_id,
+        points=None,
+        labels=None,
+        masks=None,
+        frame_idx=0,
+    ):
+        """
+        Adds new points or masks to a specific frame for a given object ID.
+
+        This method updates the inference state with new prompts (points or masks) for a specified
+        object and frame index. It ensures that the prompts are either points or masks, but not both,
+        and updates the internal state accordingly. It also handles the generation of new segmentations
+        based on the provided prompts and the existing state.
+
+        Args:
+            obj_id (int): The ID of the object to which the prompts are associated.
+            points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None.
+            labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None.
+            masks (torch.Tensor, optional): Binary masks for the object. Defaults to None.
+            frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0.
+
+        Returns:
+            (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects.
+
+        Raises:
+            AssertionError: If both `masks` and `points` are provided, or neither is provided.
+
+        Note:
+            - Only one type of prompt (either points or masks) can be added per call.
+            - If the frame is being tracked for the first time, it is treated as an initial conditioning frame.
+            - The method handles the consolidation of outputs and resizing of masks to the original video resolution.
+        """
+        assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other."
+        obj_idx = self._obj_id_to_idx(obj_id)
+
+        point_inputs = None
+        pop_key = "point_inputs_per_obj"
+        if points is not None:
+            point_inputs = {"point_coords": points, "point_labels": labels}
+            self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs
+            pop_key = "mask_inputs_per_obj"
+        self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks
+        self.inference_state[pop_key][obj_idx].pop(frame_idx, None)
+        # If this frame hasn't been tracked before, we treat it as an initial conditioning
+        # frame, meaning that the inputs points are to generate segments on this frame without
+        # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+        # the input points will be used to correct the already tracked masks.
+        is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"]
+        obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
+        obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
+        # Add a frame to conditioning output if it's an initial conditioning frame or
+        # if the model sees all frames receiving clicks/mask as conditioning frames.
+        is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond
+        storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+        # Get any previously predicted mask logits on this object and feed it along with
+        # the new clicks into the SAM mask decoder.
+        prev_sam_mask_logits = None
+        # lookup temporary output dict first, which contains the most recent output
+        # (if not found, then lookup conditioning and non-conditioning frame output)
+        if point_inputs is not None:
+            prev_out = (
+                obj_temp_output_dict[storage_key].get(frame_idx)
+                or obj_output_dict["cond_frame_outputs"].get(frame_idx)
+                or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
+            )
+
+            if prev_out is not None and prev_out.get("pred_masks") is not None:
+                prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True)
+                # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
+                prev_sam_mask_logits.clamp_(-32.0, 32.0)
+        current_out = self._run_single_frame_inference(
+            output_dict=obj_output_dict,  # run on the slice of a single object
+            frame_idx=frame_idx,
+            batch_size=1,  # run on the slice of a single object
+            is_init_cond_frame=is_init_cond_frame,
+            point_inputs=point_inputs,
+            mask_inputs=masks,
+            reverse=False,
+            # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+            # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+            # allows us to enforce non-overlapping constraints on all objects before encoding
+            # them into memory.
+            run_mem_encoder=False,
+            prev_sam_mask_logits=prev_sam_mask_logits,
+        )
+        # Add the output to the output dict (to be used as future memory)
+        obj_temp_output_dict[storage_key][frame_idx] = current_out
+
+        # Resize the output mask to the original video resolution
+        consolidated_out = self._consolidate_temp_output_across_obj(
+            frame_idx,
+            is_cond=is_cond,
+            run_mem_encoder=False,
+        )
+        pred_masks = consolidated_out["pred_masks"].flatten(0, 1)
+        return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device)
+
+    @smart_inference_mode()
+    def propagate_in_video_preflight(self):
+        """
+        Prepare inference_state and consolidate temporary outputs before tracking.
+
+        This method marks the start of tracking, disallowing the addition of new objects until the session is reset.
+        It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`.
+        Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent
+        with the provided inputs.
+        """
+        # Tracking has started and we don't allow adding new objects until session is reset.
+        self.inference_state["tracking_has_started"] = True
+        batch_size = len(self.inference_state["obj_idx_to_id"])
+
+        # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
+        # add them into "output_dict".
+        temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"]
+        output_dict = self.inference_state["output_dict"]
+        # "consolidated_frame_inds" contains indices of those frames where consolidated
+        # temporary outputs have been added (either in this call or any previous calls
+        # to `propagate_in_video_preflight`).
+        consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
+        for is_cond in {False, True}:
+            # Separately consolidate conditioning and non-conditioning temp outputs
+            storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+            # Find all the frames that contain temporary outputs for any objects
+            # (these should be the frames that have just received clicks for mask inputs
+            # via `add_new_points` or `add_new_mask`)
+            temp_frame_inds = set()
+            for obj_temp_output_dict in temp_output_dict_per_obj.values():
+                temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
+            consolidated_frame_inds[storage_key].update(temp_frame_inds)
+            # consolidate the temporary output across all objects on this frame
+            for frame_idx in temp_frame_inds:
+                consolidated_out = self._consolidate_temp_output_across_obj(
+                    frame_idx, is_cond=is_cond, run_mem_encoder=True
+                )
+                # merge them into "output_dict" and also create per-object slices
+                output_dict[storage_key][frame_idx] = consolidated_out
+                self._add_output_per_object(frame_idx, consolidated_out, storage_key)
+                if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
+                    # clear non-conditioning memory of the surrounding frames
+                    self._clear_non_cond_mem_around_input(frame_idx)
+
+            # clear temporary outputs in `temp_output_dict_per_obj`
+            for obj_temp_output_dict in temp_output_dict_per_obj.values():
+                obj_temp_output_dict[storage_key].clear()
+
+        # edge case: if an output is added to "cond_frame_outputs", we remove any prior
+        # output on the same frame in "non_cond_frame_outputs"
+        for frame_idx in output_dict["cond_frame_outputs"]:
+            output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+        for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
+            for frame_idx in obj_output_dict["cond_frame_outputs"]:
+                obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+        for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
+            assert frame_idx in output_dict["cond_frame_outputs"]
+            consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
+
+        # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
+        # with either points or mask inputs (which should be true under a correct workflow).
+        all_consolidated_frame_inds = (
+            consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"]
+        )
+        input_frames_inds = set()
+        for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values():
+            input_frames_inds.update(point_inputs_per_frame.keys())
+        for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values():
+            input_frames_inds.update(mask_inputs_per_frame.keys())
+        assert all_consolidated_frame_inds == input_frames_inds
+
+    @staticmethod
+    def init_state(predictor):
+        """
+        Initialize an inference state for the predictor.
+
+        This function sets up the initial state required for performing inference on video data.
+        It includes initializing various dictionaries and ordered dictionaries that will store
+        inputs, outputs, and other metadata relevant to the tracking process.
+
+        Args:
+            predictor (SAM2VideoPredictor): The predictor object for which to initialize the state.
+        """
+        if len(predictor.inference_state) > 0:  # means initialized
+            return
+        assert predictor.dataset is not None
+        assert predictor.dataset.mode == "video"
+
+        inference_state = {
+            "num_frames": predictor.dataset.frames,
+            "point_inputs_per_obj": {},  # inputs points on each frame
+            "mask_inputs_per_obj": {},  # inputs mask on each frame
+            "constants": {},  # values that don't change across frames (so we only need to hold one copy of them)
+            # mapping between client-side object id and model-side object index
+            "obj_id_to_idx": OrderedDict(),
+            "obj_idx_to_id": OrderedDict(),
+            "obj_ids": [],
+            # A storage to hold the model's tracking results and states on each frame
+            "output_dict": {
+                "cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
+                "non_cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
+            },
+            # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
+            "output_dict_per_obj": {},
+            # A temporary storage to hold new outputs when user interact with a frame
+            # to add clicks or mask (it's merged into "output_dict" before propagation starts)
+            "temp_output_dict_per_obj": {},
+            # Frames that already holds consolidated outputs from click or mask inputs
+            # (we directly use their consolidated outputs during tracking)
+            "consolidated_frame_inds": {
+                "cond_frame_outputs": set(),  # set containing frame indices
+                "non_cond_frame_outputs": set(),  # set containing frame indices
+            },
+            # metadata for each tracking frame (e.g. which direction it's tracked)
+            "tracking_has_started": False,
+            "frames_already_tracked": [],
+        }
+        predictor.inference_state = inference_state
+
+    def get_im_features(self, im, batch=1):
+        """
+        Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.
+
+        Args:
+            im (torch.Tensor): The input image tensor.
+            batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1.
+
+        Returns:
+            vis_feats (torch.Tensor): The visual features extracted from the image.
+            vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.
+            feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features.
+
+        Note:
+            - If `batch` is greater than 1, the features are expanded to fit the batch size.
+            - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features.
+        """
+        backbone_out = self.model.forward_image(im)
+        if batch > 1:  # expand features if there's more than one prompt
+            for i, feat in enumerate(backbone_out["backbone_fpn"]):
+                backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1)
+            for i, pos in enumerate(backbone_out["vision_pos_enc"]):
+                pos = pos.expand(batch, -1, -1, -1)
+                backbone_out["vision_pos_enc"][i] = pos
+        _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out)
+        return vis_feats, vis_pos_embed, feat_sizes
+
+    def _obj_id_to_idx(self, obj_id):
+        """
+        Map client-side object id to model-side object index.
+
+        Args:
+            obj_id (int): The unique identifier of the object provided by the client side.
+
+        Returns:
+            obj_idx (int): The index of the object on the model side.
+
+        Raises:
+            RuntimeError: If an attempt is made to add a new object after tracking has started.
+
+        Note:
+            - The method updates or retrieves mappings between object IDs and indices stored in
+              `inference_state`.
+            - It ensures that new objects can only be added before tracking commences.
+            - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`).
+            - Additional data structures are initialized for the new object to store inputs and outputs.
+        """
+        obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None)
+        if obj_idx is not None:
+            return obj_idx
+
+        # This is a new object id not sent to the server before. We only allow adding
+        # new objects *before* the tracking starts.
+        allow_new_object = not self.inference_state["tracking_has_started"]
+        if allow_new_object:
+            # get the next object slot
+            obj_idx = len(self.inference_state["obj_id_to_idx"])
+            self.inference_state["obj_id_to_idx"][obj_id] = obj_idx
+            self.inference_state["obj_idx_to_id"][obj_idx] = obj_id
+            self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"])
+            # set up input and output structures for this object
+            self.inference_state["point_inputs_per_obj"][obj_idx] = {}
+            self.inference_state["mask_inputs_per_obj"][obj_idx] = {}
+            self.inference_state["output_dict_per_obj"][obj_idx] = {
+                "cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
+                "non_cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
+            }
+            self.inference_state["temp_output_dict_per_obj"][obj_idx] = {
+                "cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
+                "non_cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
+            }
+            return obj_idx
+        else:
+            raise RuntimeError(
+                f"Cannot add new object id {obj_id} after tracking starts. "
+                f"All existing object ids: {self.inference_state['obj_ids']}. "
+                f"Please call 'reset_state' to restart from scratch."
+            )
+
+    def _run_single_frame_inference(
+        self,
+        output_dict,
+        frame_idx,
+        batch_size,
+        is_init_cond_frame,
+        point_inputs,
+        mask_inputs,
+        reverse,
+        run_mem_encoder,
+        prev_sam_mask_logits=None,
+    ):
+        """
+        Run tracking on a single frame based on current inputs and previous memory.
+
+        Args:
+            output_dict (Dict): The dictionary containing the output states of the tracking process.
+            frame_idx (int): The index of the current frame.
+            batch_size (int): The batch size for processing the frame.
+            is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.
+            point_inputs (Dict, Optional): Input points and their labels. Defaults to None.
+            mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None.
+            reverse (bool): Indicates if the tracking should be performed in reverse order.
+            run_mem_encoder (bool): Indicates if the memory encoder should be executed.
+            prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None.
+
+        Returns:
+            current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions.
+
+        Raises:
+            AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.
+
+        Note:
+            - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive.
+            - The method retrieves image features using the `get_im_features` method.
+            - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored.
+            - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements.
+        """
+        # Retrieve correct image features
+        current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(
+            self.inference_state["im"], batch_size
+        )
+
+        # point and mask should not appear as input simultaneously on the same frame
+        assert point_inputs is None or mask_inputs is None
+        current_out = self.model.track_step(
+            frame_idx=frame_idx,
+            is_init_cond_frame=is_init_cond_frame,
+            current_vision_feats=current_vision_feats,
+            current_vision_pos_embeds=current_vision_pos_embeds,
+            feat_sizes=feat_sizes,
+            point_inputs=point_inputs,
+            mask_inputs=mask_inputs,
+            output_dict=output_dict,
+            num_frames=self.inference_state["num_frames"],
+            track_in_reverse=reverse,
+            run_mem_encoder=run_mem_encoder,
+            prev_sam_mask_logits=prev_sam_mask_logits,
+        )
+
+        maskmem_features = current_out["maskmem_features"]
+        if maskmem_features is not None:
+            current_out["maskmem_features"] = maskmem_features.to(
+                dtype=torch.float16, device=self.device, non_blocking=True
+            )
+        # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions
+        # potentially fill holes in the predicted masks
+        # if self.fill_hole_area > 0:
+        #     pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True)
+        #     pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area)
+
+        # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+        current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"])
+        return current_out
+
+    def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
+        """
+        Caches and manages the positional encoding for mask memory across frames and objects.
+
+        This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for
+        mask memory, which is constant across frames and objects, thus reducing the amount of
+        redundant information stored during an inference session. It checks if the positional
+        encoding has already been cached; if not, it caches a slice of the provided encoding.
+        If the batch size is greater than one, it expands the cached positional encoding to match
+        the current batch size.
+
+        Args:
+            out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory.
+                Should be a list of tensors or None.
+
+        Returns:
+            out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
+
+        Note:
+            - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.
+            - Only a single object's slice is cached since the encoding is the same across objects.
+            - The method checks if the positional encoding has already been cached in the session's constants.
+            - If the batch size is greater than one, the cached encoding is expanded to fit the batch size.
+        """
+        model_constants = self.inference_state["constants"]
+        # "out_maskmem_pos_enc" should be either a list of tensors or None
+        if out_maskmem_pos_enc is not None:
+            if "maskmem_pos_enc" not in model_constants:
+                assert isinstance(out_maskmem_pos_enc, list)
+                # only take the slice for one object, since it's same across objects
+                maskmem_pos_enc = [x[:1].clone() for x in out_maskmem_pos_enc]
+                model_constants["maskmem_pos_enc"] = maskmem_pos_enc
+            else:
+                maskmem_pos_enc = model_constants["maskmem_pos_enc"]
+            # expand the cached maskmem_pos_enc to the actual batch size
+            batch_size = out_maskmem_pos_enc[0].size(0)
+            if batch_size > 1:
+                out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]
+        return out_maskmem_pos_enc
+
+    def _consolidate_temp_output_across_obj(
+        self,
+        frame_idx,
+        is_cond=False,
+        run_mem_encoder=False,
+    ):
+        """
+        Consolidates per-object temporary outputs into a single output for all objects.
+
+        This method combines the temporary outputs for each object on a given frame into a unified
+        output. It fills in any missing objects either from the main output dictionary or leaves
+        placeholders if they do not exist in the main output. Optionally, it can re-run the memory
+        encoder after applying non-overlapping constraints to the object scores.
+
+        Args:
+            frame_idx (int): The index of the frame for which to consolidate outputs.
+            is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame.
+                Defaults to False.
+            run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after
+                consolidating the outputs. Defaults to False.
+
+        Returns:
+            consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects.
+
+        Note:
+            - The method initializes the consolidated output with placeholder values for missing objects.
+            - It searches for outputs in both the temporary and main output dictionaries.
+            - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder.
+            - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True.
+        """
+        batch_size = len(self.inference_state["obj_idx_to_id"])
+        storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+        # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
+        # will be added when rerunning the memory encoder after applying non-overlapping
+        # constraints to object scores. Its "pred_masks" are prefilled with a large
+        # negative value (NO_OBJ_SCORE) to represent missing objects.
+        consolidated_out = {
+            "maskmem_features": None,
+            "maskmem_pos_enc": None,
+            "pred_masks": torch.full(
+                size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
+                fill_value=-1024.0,
+                dtype=torch.float32,
+                device=self.device,
+            ),
+            "obj_ptr": torch.full(
+                size=(batch_size, self.model.hidden_dim),
+                fill_value=-1024.0,
+                dtype=torch.float32,
+                device=self.device,
+            ),
+            "object_score_logits": torch.full(
+                size=(batch_size, 1),
+                # default to 10.0 for object_score_logits, i.e. assuming the object is
+                # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
+                fill_value=10.0,
+                dtype=torch.float32,
+                device=self.device,
+            ),
+        }
+        for obj_idx in range(batch_size):
+            obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
+            obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
+            out = (
+                obj_temp_output_dict[storage_key].get(frame_idx)
+                # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
+                # we fall back and look up its previous output in "output_dict_per_obj".
+                # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
+                # "output_dict_per_obj" to find a previous output for this object.
+                or obj_output_dict["cond_frame_outputs"].get(frame_idx)
+                or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
+            )
+            # If the object doesn't appear in "output_dict_per_obj" either, we skip it
+            # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
+            # placeholder above) and set its object pointer to be a dummy pointer.
+            if out is None:
+                # Fill in dummy object pointers for those objects without any inputs or
+                # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
+                # i.e. when we need to build the memory for tracking).
+                if run_mem_encoder:
+                    # fill object pointer with a dummy pointer (based on an empty mask)
+                    consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx)
+                continue
+            # Add the temporary object output mask to consolidated output mask
+            consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"]
+            consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
+
+        # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder
+        if run_mem_encoder:
+            high_res_masks = F.interpolate(
+                consolidated_out["pred_masks"],
+                size=self.imgsz,
+                mode="bilinear",
+                align_corners=False,
+            )
+            if self.model.non_overlap_masks_for_mem_enc:
+                high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)
+            consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder(
+                batch_size=batch_size,
+                high_res_masks=high_res_masks,
+                is_mask_from_pts=True,  # these frames are what the user interacted with
+                object_score_logits=consolidated_out["object_score_logits"],
+            )
+
+        return consolidated_out
+
+    def _get_empty_mask_ptr(self, frame_idx):
+        """
+        Get a dummy object pointer based on an empty mask on the current frame.
+
+        Args:
+            frame_idx (int): The index of the current frame for which to generate the dummy object pointer.
+
+        Returns:
+            (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask.
+        """
+        # Retrieve correct image features
+        current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"])
+
+        # Feed the empty mask and image feature above to get a dummy object pointer
+        current_out = self.model.track_step(
+            frame_idx=frame_idx,
+            is_init_cond_frame=True,
+            current_vision_feats=current_vision_feats,
+            current_vision_pos_embeds=current_vision_pos_embeds,
+            feat_sizes=feat_sizes,
+            point_inputs=None,
+            # A dummy (empty) mask with a single object
+            mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device),
+            output_dict={},
+            num_frames=self.inference_state["num_frames"],
+            track_in_reverse=False,
+            run_mem_encoder=False,
+            prev_sam_mask_logits=None,
+        )
+        return current_out["obj_ptr"]
+
+    def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):
+        """
+        Run the memory encoder on masks.
+
+        This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their
+        memory also needs to be computed again with the memory encoder.
+
+        Args:
+            batch_size (int): The batch size for processing the frame.
+            high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory.
+            object_score_logits (torch.Tensor): Logits representing the object scores.
+            is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.
+
+        Returns:
+            (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding.
+        """
+        # Retrieve correct image features
+        current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size)
+        maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
+            current_vision_feats=current_vision_feats,
+            feat_sizes=feat_sizes,
+            pred_masks_high_res=high_res_masks,
+            is_mask_from_pts=is_mask_from_pts,
+            object_score_logits=object_score_logits,
+        )
+
+        # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+        maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc)
+        return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc
+
+    def _add_output_per_object(self, frame_idx, current_out, storage_key):
+        """
+        Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
+
+        The resulting slices share the same tensor storage.
+
+        Args:
+            frame_idx (int): The index of the current frame.
+            current_out (Dict): The current output dictionary containing multi-object outputs.
+            storage_key (str): The key used to store the output in the per-object output dictionary.
+        """
+        maskmem_features = current_out["maskmem_features"]
+        assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
+
+        maskmem_pos_enc = current_out["maskmem_pos_enc"]
+        assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
+
+        for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items():
+            obj_slice = slice(obj_idx, obj_idx + 1)
+            obj_out = {
+                "maskmem_features": None,
+                "maskmem_pos_enc": None,
+                "pred_masks": current_out["pred_masks"][obj_slice],
+                "obj_ptr": current_out["obj_ptr"][obj_slice],
+            }
+            if maskmem_features is not None:
+                obj_out["maskmem_features"] = maskmem_features[obj_slice]
+            if maskmem_pos_enc is not None:
+                obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
+            obj_output_dict[storage_key][frame_idx] = obj_out
+
+    def _clear_non_cond_mem_around_input(self, frame_idx):
+        """
+        Remove the non-conditioning memory around the input frame.
+
+        When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated
+        object appearance information and could confuse the model. This method clears those non-conditioning memories
+        surrounding the interacted frame to avoid giving the model both old and new information about the object.
+
+        Args:
+            frame_idx (int): The index of the current frame where user interaction occurred.
+        """
+        r = self.model.memory_temporal_stride_for_eval
+        frame_idx_begin = frame_idx - r * self.model.num_maskmem
+        frame_idx_end = frame_idx + r * self.model.num_maskmem
+        for t in range(frame_idx_begin, frame_idx_end + 1):
+            self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
+            for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
+                obj_output_dict["non_cond_frame_outputs"].pop(t, None)

+ 1 - 0
ultralytics/models/utils/__init__.py

@@ -0,0 +1 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

+ 357 - 0
ultralytics/models/utils/loss.py

@@ -0,0 +1,357 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ultralytics.utils.loss import FocalLoss, VarifocalLoss
+from ultralytics.utils.metrics import bbox_iou
+
+from .ops import HungarianMatcher
+
+
+class DETRLoss(nn.Module):
+    """
+    DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the
+    DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary
+    losses.
+
+    Attributes:
+        nc (int): The number of classes.
+        loss_gain (dict): Coefficients for different loss components.
+        aux_loss (bool): Whether to compute auxiliary losses.
+        use_fl (bool): Use FocalLoss or not.
+        use_vfl (bool): Use VarifocalLoss or not.
+        use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch.
+        uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True.
+        matcher (HungarianMatcher): Object to compute matching cost and indices.
+        fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None.
+        vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None.
+        device (torch.device): Device on which tensors are stored.
+    """
+
+    def __init__(
+        self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
+    ):
+        """
+        Initialize DETR loss function with customizable components and gains.
+
+        Uses default loss_gain if not provided. Initializes HungarianMatcher with
+        preset cost gains. Supports auxiliary losses and various loss types.
+
+        Args:
+            nc (int): Number of classes.
+            loss_gain (dict): Coefficients for different loss components.
+            aux_loss (bool): Use auxiliary losses from each decoder layer.
+            use_fl (bool): Use FocalLoss.
+            use_vfl (bool): Use VarifocalLoss.
+            use_uni_match (bool): Use fixed layer for auxiliary branch label assignment.
+            uni_match_ind (int): Index of fixed layer for uni_match.
+        """
+        super().__init__()
+
+        if loss_gain is None:
+            loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
+        self.nc = nc
+        self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
+        self.loss_gain = loss_gain
+        self.aux_loss = aux_loss
+        self.fl = FocalLoss() if use_fl else None
+        self.vfl = VarifocalLoss() if use_vfl else None
+
+        self.use_uni_match = use_uni_match
+        self.uni_match_ind = uni_match_ind
+        self.device = None
+
+    def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
+        """Computes the classification loss based on predictions, target values, and ground truth scores."""
+        # Logits: [b, query, num_classes], gt_class: list[[n, 1]]
+        name_class = f"loss_class{postfix}"
+        bs, nq = pred_scores.shape[:2]
+        # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1]  # (bs, num_queries, num_classes)
+        one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
+        one_hot.scatter_(2, targets.unsqueeze(-1), 1)
+        one_hot = one_hot[..., :-1]
+        gt_scores = gt_scores.view(bs, nq, 1) * one_hot
+
+        if self.fl:
+            if num_gts and self.vfl:
+                loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
+            else:
+                loss_cls = self.fl(pred_scores, one_hot.float())
+            loss_cls /= max(num_gts, 1) / nq
+        else:
+            loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum()  # YOLO CLS loss
+
+        return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
+
+    def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
+        """Computes bounding box and GIoU losses for predicted and ground truth bounding boxes."""
+        # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
+        name_bbox = f"loss_bbox{postfix}"
+        name_giou = f"loss_giou{postfix}"
+
+        loss = {}
+        if len(gt_bboxes) == 0:
+            loss[name_bbox] = torch.tensor(0.0, device=self.device)
+            loss[name_giou] = torch.tensor(0.0, device=self.device)
+            return loss
+
+        loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
+        loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
+        loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
+        loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
+        return {k: v.squeeze() for k, v in loss.items()}
+
+    # This function is for future RT-DETR Segment models
+    # def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
+    #     # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
+    #     name_mask = f'loss_mask{postfix}'
+    #     name_dice = f'loss_dice{postfix}'
+    #
+    #     loss = {}
+    #     if sum(len(a) for a in gt_mask) == 0:
+    #         loss[name_mask] = torch.tensor(0., device=self.device)
+    #         loss[name_dice] = torch.tensor(0., device=self.device)
+    #         return loss
+    #
+    #     num_gts = len(gt_mask)
+    #     src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
+    #     src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
+    #     # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
+    #     loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
+    #                                                                     torch.tensor([num_gts], dtype=torch.float32))
+    #     loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
+    #     return loss
+
+    # This function is for future RT-DETR Segment models
+    # @staticmethod
+    # def _dice_loss(inputs, targets, num_gts):
+    #     inputs = F.sigmoid(inputs).flatten(1)
+    #     targets = targets.flatten(1)
+    #     numerator = 2 * (inputs * targets).sum(1)
+    #     denominator = inputs.sum(-1) + targets.sum(-1)
+    #     loss = 1 - (numerator + 1) / (denominator + 1)
+    #     return loss.sum() / num_gts
+
+    def _get_loss_aux(
+        self,
+        pred_bboxes,
+        pred_scores,
+        gt_bboxes,
+        gt_cls,
+        gt_groups,
+        match_indices=None,
+        postfix="",
+        masks=None,
+        gt_mask=None,
+    ):
+        """Get auxiliary losses."""
+        # NOTE: loss class, bbox, giou, mask, dice
+        loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
+        if match_indices is None and self.use_uni_match:
+            match_indices = self.matcher(
+                pred_bboxes[self.uni_match_ind],
+                pred_scores[self.uni_match_ind],
+                gt_bboxes,
+                gt_cls,
+                gt_groups,
+                masks=masks[self.uni_match_ind] if masks is not None else None,
+                gt_mask=gt_mask,
+            )
+        for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
+            aux_masks = masks[i] if masks is not None else None
+            loss_ = self._get_loss(
+                aux_bboxes,
+                aux_scores,
+                gt_bboxes,
+                gt_cls,
+                gt_groups,
+                masks=aux_masks,
+                gt_mask=gt_mask,
+                postfix=postfix,
+                match_indices=match_indices,
+            )
+            loss[0] += loss_[f"loss_class{postfix}"]
+            loss[1] += loss_[f"loss_bbox{postfix}"]
+            loss[2] += loss_[f"loss_giou{postfix}"]
+            # if masks is not None and gt_mask is not None:
+            #     loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
+            #     loss[3] += loss_[f'loss_mask{postfix}']
+            #     loss[4] += loss_[f'loss_dice{postfix}']
+
+        loss = {
+            f"loss_class_aux{postfix}": loss[0],
+            f"loss_bbox_aux{postfix}": loss[1],
+            f"loss_giou_aux{postfix}": loss[2],
+        }
+        # if masks is not None and gt_mask is not None:
+        #     loss[f'loss_mask_aux{postfix}'] = loss[3]
+        #     loss[f'loss_dice_aux{postfix}'] = loss[4]
+        return loss
+
+    @staticmethod
+    def _get_index(match_indices):
+        """Returns batch indices, source indices, and destination indices from provided match indices."""
+        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
+        src_idx = torch.cat([src for (src, _) in match_indices])
+        dst_idx = torch.cat([dst for (_, dst) in match_indices])
+        return (batch_idx, src_idx), dst_idx
+
+    def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
+        """Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
+        pred_assigned = torch.cat(
+            [
+                t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
+                for t, (i, _) in zip(pred_bboxes, match_indices)
+            ]
+        )
+        gt_assigned = torch.cat(
+            [
+                t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
+                for t, (_, j) in zip(gt_bboxes, match_indices)
+            ]
+        )
+        return pred_assigned, gt_assigned
+
+    def _get_loss(
+        self,
+        pred_bboxes,
+        pred_scores,
+        gt_bboxes,
+        gt_cls,
+        gt_groups,
+        masks=None,
+        gt_mask=None,
+        postfix="",
+        match_indices=None,
+    ):
+        """Get losses."""
+        if match_indices is None:
+            match_indices = self.matcher(
+                pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
+            )
+
+        idx, gt_idx = self._get_index(match_indices)
+        pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
+
+        bs, nq = pred_scores.shape[:2]
+        targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
+        targets[idx] = gt_cls[gt_idx]
+
+        gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
+        if len(gt_bboxes):
+            gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
+
+        return {
+            **self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix),
+            **self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix),
+            # **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
+        }
+
+    def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
+        """
+        Calculate loss for predicted bounding boxes and scores.
+
+        Args:
+            pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
+            pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
+            batch (dict): Batch information containing:
+                cls (torch.Tensor): Ground truth classes, shape [num_gts].
+                bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
+                gt_groups (List[int]): Number of ground truths for each image in the batch.
+            postfix (str): Postfix for loss names.
+            **kwargs (Any): Additional arguments, may include 'match_indices'.
+
+        Returns:
+            (dict): Computed losses, including main and auxiliary (if enabled).
+
+        Note:
+            Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
+            self.aux_loss is True.
+        """
+        self.device = pred_bboxes.device
+        match_indices = kwargs.get("match_indices", None)
+        gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
+
+        total_loss = self._get_loss(
+            pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
+        )
+
+        if self.aux_loss:
+            total_loss.update(
+                self._get_loss_aux(
+                    pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
+                )
+            )
+
+        return total_loss
+
+
+class RTDETRDetectionLoss(DETRLoss):
+    """
+    Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
+
+    This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
+    an additional denoising training loss when provided with denoising metadata.
+    """
+
+    def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
+        """
+        Forward pass to compute the detection loss.
+
+        Args:
+            preds (tuple): Predicted bounding boxes and scores.
+            batch (dict): Batch data containing ground truth information.
+            dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None.
+            dn_scores (torch.Tensor, optional): Denoising scores. Default is None.
+            dn_meta (dict, optional): Metadata for denoising. Default is None.
+
+        Returns:
+            (dict): Dictionary containing the total loss and, if applicable, the denoising loss.
+        """
+        pred_bboxes, pred_scores = preds
+        total_loss = super().forward(pred_bboxes, pred_scores, batch)
+
+        # Check for denoising metadata to compute denoising training loss
+        if dn_meta is not None:
+            dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
+            assert len(batch["gt_groups"]) == len(dn_pos_idx)
+
+            # Get the match indices for denoising
+            match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
+
+            # Compute the denoising training loss
+            dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
+            total_loss.update(dn_loss)
+        else:
+            # If no denoising metadata is provided, set denoising loss to zero
+            total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
+
+        return total_loss
+
+    @staticmethod
+    def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
+        """
+        Get the match indices for denoising.
+
+        Args:
+            dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.
+            dn_num_group (int): Number of denoising groups.
+            gt_groups (List[int]): List of integers representing the number of ground truths for each image.
+
+        Returns:
+            (List[tuple]): List of tuples containing matched indices for denoising.
+        """
+        dn_match_indices = []
+        idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
+        for i, num_gt in enumerate(gt_groups):
+            if num_gt > 0:
+                gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
+                gt_idx = gt_idx.repeat(dn_num_group)
+                assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
+                f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
+                dn_match_indices.append((dn_pos_idx[i], gt_idx))
+            else:
+                dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
+        return dn_match_indices

+ 259 - 0
ultralytics/models/utils/ops.py

@@ -0,0 +1,259 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+
+from ultralytics.utils.metrics import bbox_iou
+from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
+
+
+class HungarianMatcher(nn.Module):
+    """
+    A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an
+    end-to-end fashion.
+
+    HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost
+    function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
+
+    Attributes:
+        cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
+        use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
+        with_mask (bool): Indicates whether the model makes mask predictions.
+        num_sample_points (int): The number of sample points used in mask cost calculation.
+        alpha (float): The alpha factor in Focal Loss calculation.
+        gamma (float): The gamma factor in Focal Loss calculation.
+
+    Methods:
+        forward(pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): Computes the
+            assignment between predictions and ground truths for a batch.
+        _cost_mask(bs, num_gts, masks=None, gt_mask=None): Computes the mask cost and dice cost if masks are predicted.
+    """
+
+    def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
+        """Initializes a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes."""
+        super().__init__()
+        if cost_gain is None:
+            cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
+        self.cost_gain = cost_gain
+        self.use_fl = use_fl
+        self.with_mask = with_mask
+        self.num_sample_points = num_sample_points
+        self.alpha = alpha
+        self.gamma = gamma
+
+    def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
+        """
+        Forward pass for HungarianMatcher. This function computes costs based on prediction and ground truth
+        (classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching between
+        predictions and ground truth based on these costs.
+
+        Args:
+            pred_bboxes (Tensor): Predicted bounding boxes with shape [batch_size, num_queries, 4].
+            pred_scores (Tensor): Predicted scores with shape [batch_size, num_queries, num_classes].
+            gt_cls (torch.Tensor): Ground truth classes with shape [num_gts, ].
+            gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape [num_gts, 4].
+            gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
+                each image.
+            masks (Tensor, optional): Predicted masks with shape [batch_size, num_queries, height, width].
+                Defaults to None.
+            gt_mask (List[Tensor], optional): List of ground truth masks, each with shape [num_masks, Height, Width].
+                Defaults to None.
+
+        Returns:
+            (List[Tuple[Tensor, Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
+                - index_i is the tensor of indices of the selected predictions (in order)
+                - index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
+                For each batch element, it holds:
+                    len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+        """
+        bs, nq, nc = pred_scores.shape
+
+        if sum(gt_groups) == 0:
+            return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
+
+        # We flatten to compute the cost matrices in a batch
+        # [batch_size * num_queries, num_classes]
+        pred_scores = pred_scores.detach().view(-1, nc)
+        pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
+        # [batch_size * num_queries, 4]
+        pred_bboxes = pred_bboxes.detach().view(-1, 4)
+
+        # Compute the classification cost
+        pred_scores = pred_scores[:, gt_cls]
+        if self.use_fl:
+            neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
+            pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
+            cost_class = pos_cost_class - neg_cost_class
+        else:
+            cost_class = -pred_scores
+
+        # Compute the L1 cost between boxes
+        cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1)  # (bs*num_queries, num_gt)
+
+        # Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
+        cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
+
+        # Final cost matrix
+        C = (
+            self.cost_gain["class"] * cost_class
+            + self.cost_gain["bbox"] * cost_bbox
+            + self.cost_gain["giou"] * cost_giou
+        )
+        # Compute the mask cost and dice cost
+        if self.with_mask:
+            C += self._cost_mask(bs, gt_groups, masks, gt_mask)
+
+        # Set invalid values (NaNs and infinities) to 0 (fixes ValueError: matrix contains invalid numeric entries)
+        C[C.isnan() | C.isinf()] = 0.0
+
+        C = C.view(bs, nq, -1).cpu()
+        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
+        gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)  # (idx for queries, idx for gt)
+        return [
+            (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
+            for k, (i, j) in enumerate(indices)
+        ]
+
+    # This function is for future RT-DETR Segment models
+    # def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
+    #     assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
+    #     # all masks share the same set of points for efficient matching
+    #     sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
+    #     sample_points = 2.0 * sample_points - 1.0
+    #
+    #     out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
+    #     out_mask = out_mask.flatten(0, 1)
+    #
+    #     tgt_mask = torch.cat(gt_mask).unsqueeze(1)
+    #     sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
+    #     tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
+    #
+    #     with torch.amp.autocast("cuda", enabled=False):
+    #         # binary cross entropy cost
+    #         pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
+    #         neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
+    #         cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
+    #         cost_mask /= self.num_sample_points
+    #
+    #         # dice cost
+    #         out_mask = F.sigmoid(out_mask)
+    #         numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
+    #         denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
+    #         cost_dice = 1 - (numerator + 1) / (denominator + 1)
+    #
+    #         C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
+    #     return C
+
+
+def get_cdn_group(
+    batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
+):
+    """
+    Get contrastive denoising training group. This function creates a contrastive denoising training group with positive
+    and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
+    and returns the modified labels, bounding boxes, attention mask and meta information.
+
+    Args:
+        batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes'
+            (torch.Tensor with shape [num_gts, 4]), 'gt_groups' (List(int)) which is a list of batch size length
+            indicating the number of gts of each image.
+        num_classes (int): Number of classes.
+        num_queries (int): Number of queries.
+        class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
+        num_dn (int, optional): Number of denoising. Defaults to 100.
+        cls_noise_ratio (float, optional): Noise ratio for class labels. Defaults to 0.5.
+        box_noise_scale (float, optional): Noise scale for bounding box coordinates. Defaults to 1.0.
+        training (bool, optional): If it's in training mode. Defaults to False.
+
+    Returns:
+        (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Dict]]): The modified class embeddings,
+            bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
+            is less than or equal to 0, the function returns None for all elements in the tuple.
+    """
+    if (not training) or num_dn <= 0:
+        return None, None, None, None
+    gt_groups = batch["gt_groups"]
+    total_num = sum(gt_groups)
+    max_nums = max(gt_groups)
+    if max_nums == 0:
+        return None, None, None, None
+
+    num_group = num_dn // max_nums
+    num_group = 1 if num_group == 0 else num_group
+    # Pad gt to max_num of a batch
+    bs = len(gt_groups)
+    gt_cls = batch["cls"]  # (bs*num, )
+    gt_bbox = batch["bboxes"]  # bs*num, 4
+    b_idx = batch["batch_idx"]
+
+    # Each group has positive and negative queries.
+    dn_cls = gt_cls.repeat(2 * num_group)  # (2*num_group*bs*num, )
+    dn_bbox = gt_bbox.repeat(2 * num_group, 1)  # 2*num_group*bs*num, 4
+    dn_b_idx = b_idx.repeat(2 * num_group).view(-1)  # (2*num_group*bs*num, )
+
+    # Positive and negative mask
+    # (bs*num*num_group, ), the second total_num*num_group part as negative samples
+    neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
+
+    if cls_noise_ratio > 0:
+        # Half of bbox prob
+        mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
+        idx = torch.nonzero(mask).squeeze(-1)
+        # Randomly put a new one here
+        new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
+        dn_cls[idx] = new_label
+
+    if box_noise_scale > 0:
+        known_bbox = xywh2xyxy(dn_bbox)
+
+        diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale  # 2*num_group*bs*num, 4
+
+        rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
+        rand_part = torch.rand_like(dn_bbox)
+        rand_part[neg_idx] += 1.0
+        rand_part *= rand_sign
+        known_bbox += rand_part * diff
+        known_bbox.clip_(min=0.0, max=1.0)
+        dn_bbox = xyxy2xywh(known_bbox)
+        dn_bbox = torch.logit(dn_bbox, eps=1e-6)  # inverse sigmoid
+
+    num_dn = int(max_nums * 2 * num_group)  # total denoising queries
+    # class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
+    dn_cls_embed = class_embed[dn_cls]  # bs*num * 2 * num_group, 256
+    padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
+    padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
+
+    map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
+    pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
+
+    map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
+    padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
+    padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
+
+    tgt_size = num_dn + num_queries
+    attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
+    # Match query cannot see the reconstruct
+    attn_mask[num_dn:, :num_dn] = True
+    # Reconstruct cannot see each other
+    for i in range(num_group):
+        if i == 0:
+            attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
+        if i == num_group - 1:
+            attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
+        else:
+            attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
+            attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
+    dn_meta = {
+        "dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
+        "dn_num_group": num_group,
+        "dn_num_split": [num_dn, num_queries],
+    }
+
+    return (
+        padding_cls.to(class_embed.device),
+        padding_bbox.to(class_embed.device),
+        attn_mask.to(class_embed.device),
+        dn_meta,
+    )

+ 7 - 0
ultralytics/models/yolo/__init__.py

@@ -0,0 +1,7 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
+
+from .model import YOLO, YOLOWorld
+
+__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"

+ 7 - 0
ultralytics/models/yolo/classify/__init__.py

@@ -0,0 +1,7 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.models.yolo.classify.predict import ClassificationPredictor
+from ultralytics.models.yolo.classify.train import ClassificationTrainer
+from ultralytics.models.yolo.classify.val import ClassificationValidator
+
+__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"

+ 60 - 0
ultralytics/models/yolo/classify/predict.py

@@ -0,0 +1,60 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import cv2
+import torch
+from PIL import Image
+
+from ultralytics.engine.predictor import BasePredictor
+from ultralytics.engine.results import Results
+from ultralytics.utils import DEFAULT_CFG, ops
+
+
+class ClassificationPredictor(BasePredictor):
+    """
+    A class extending the BasePredictor class for prediction based on a classification model.
+
+    Notes:
+        - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
+
+    Example:
+        ```python
+        from ultralytics.utils import ASSETS
+        from ultralytics.models.yolo.classify import ClassificationPredictor
+
+        args = dict(model="yolov8n-cls.pt", source=ASSETS)
+        predictor = ClassificationPredictor(overrides=args)
+        predictor.predict_cli()
+        ```
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """Initializes ClassificationPredictor setting the task to 'classify'."""
+        super().__init__(cfg, overrides, _callbacks)
+        self.args.task = "classify"
+        self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
+
+    def preprocess(self, img):
+        """Converts input image to model-compatible data type."""
+        if not isinstance(img, torch.Tensor):
+            is_legacy_transform = any(
+                self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
+            )
+            if is_legacy_transform:  # to handle legacy transforms
+                img = torch.stack([self.transforms(im) for im in img], dim=0)
+            else:
+                img = torch.stack(
+                    [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
+                )
+        img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
+        return img.half() if self.model.fp16 else img.float()  # uint8 to fp16/32
+
+    def postprocess(self, preds, img, orig_imgs):
+        """Post-processes predictions to return Results objects."""
+        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
+            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+
+        preds = preds[0] if isinstance(preds, (list, tuple)) else preds
+        return [
+            Results(orig_img, path=img_path, names=self.model.names, probs=pred)
+            for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
+        ]

+ 153 - 0
ultralytics/models/yolo/classify/train.py

@@ -0,0 +1,153 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from copy import copy
+
+import torch
+
+from ultralytics.data import ClassificationDataset, build_dataloader
+from ultralytics.engine.trainer import BaseTrainer
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import ClassificationModel
+from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
+from ultralytics.utils.plotting import plot_images, plot_results
+from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
+
+
+class ClassificationTrainer(BaseTrainer):
+    """
+    A class extending the BaseTrainer class for training based on a classification model.
+
+    Notes:
+        - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.classify import ClassificationTrainer
+
+        args = dict(model="yolov8n-cls.pt", data="imagenet10", epochs=3)
+        trainer = ClassificationTrainer(overrides=args)
+        trainer.train()
+        ```
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
+        if overrides is None:
+            overrides = {}
+        overrides["task"] = "classify"
+        if overrides.get("imgsz") is None:
+            overrides["imgsz"] = 224
+        super().__init__(cfg, overrides, _callbacks)
+
+    def set_model_attributes(self):
+        """Set the YOLO model's class names from the loaded dataset."""
+        self.model.names = self.data["names"]
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Returns a modified PyTorch model configured for training YOLO."""
+        model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
+        if weights:
+            model.load(weights)
+
+        for m in model.modules():
+            if not self.args.pretrained and hasattr(m, "reset_parameters"):
+                m.reset_parameters()
+            if isinstance(m, torch.nn.Dropout) and self.args.dropout:
+                m.p = self.args.dropout  # set dropout
+        for p in model.parameters():
+            p.requires_grad = True  # for training
+        return model
+
+    def setup_model(self):
+        """Load, create or download model for any task."""
+        import torchvision  # scope for faster 'import ultralytics'
+
+        if str(self.model) in torchvision.models.__dict__:
+            self.model = torchvision.models.__dict__[self.model](
+                weights="IMAGENET1K_V1" if self.args.pretrained else None
+            )
+            ckpt = None
+        else:
+            ckpt = super().setup_model()
+        ClassificationModel.reshape_outputs(self.model, self.data["nc"])
+        return ckpt
+
+    def build_dataset(self, img_path, mode="train", batch=None):
+        """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
+        return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
+
+    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
+        """Returns PyTorch DataLoader with transforms to preprocess images for inference."""
+        with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
+            dataset = self.build_dataset(dataset_path, mode)
+
+        loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
+        # Attach inference transforms
+        if mode != "train":
+            if is_parallel(self.model):
+                self.model.module.transforms = loader.dataset.torch_transforms
+            else:
+                self.model.transforms = loader.dataset.torch_transforms
+        return loader
+
+    def preprocess_batch(self, batch):
+        """Preprocesses a batch of images and classes."""
+        batch["img"] = batch["img"].to(self.device)
+        batch["cls"] = batch["cls"].to(self.device)
+        return batch
+
+    def progress_string(self):
+        """Returns a formatted string showing training progress."""
+        return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
+            "Epoch",
+            "GPU_mem",
+            *self.loss_names,
+            "Instances",
+            "Size",
+        )
+
+    def get_validator(self):
+        """Returns an instance of ClassificationValidator for validation."""
+        self.loss_names = ["loss"]
+        return yolo.classify.ClassificationValidator(
+            self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+        )
+
+    def label_loss_items(self, loss_items=None, prefix="train"):
+        """
+        Returns a loss dict with labelled training loss items tensor.
+
+        Not needed for classification but necessary for segmentation & detection
+        """
+        keys = [f"{prefix}/{x}" for x in self.loss_names]
+        if loss_items is None:
+            return keys
+        loss_items = [round(float(loss_items), 5)]
+        return dict(zip(keys, loss_items))
+
+    def plot_metrics(self):
+        """Plots metrics from a CSV file."""
+        plot_results(file=self.csv, classify=True, on_plot=self.on_plot)  # save results.png
+
+    def final_eval(self):
+        """Evaluate trained model and save validation results."""
+        for f in self.last, self.best:
+            if f.exists():
+                strip_optimizer(f)  # strip optimizers
+                if f is self.best:
+                    LOGGER.info(f"\nValidating {f}...")
+                    self.validator.args.data = self.args.data
+                    self.validator.args.plots = self.args.plots
+                    self.metrics = self.validator(model=f)
+                    self.metrics.pop("fitness", None)
+                    self.run_callbacks("on_fit_epoch_end")
+
+    def plot_training_samples(self, batch, ni):
+        """Plots training samples with their annotations."""
+        plot_images(
+            images=batch["img"],
+            batch_idx=torch.arange(len(batch["img"])),
+            cls=batch["cls"].view(-1),  # warning: use .view(), not .squeeze() for Classify models
+            fname=self.save_dir / f"train_batch{ni}.jpg",
+            on_plot=self.on_plot,
+        )

+ 117 - 0
ultralytics/models/yolo/classify/val.py

@@ -0,0 +1,117 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import torch
+
+from ultralytics.data import ClassificationDataset, build_dataloader
+from ultralytics.engine.validator import BaseValidator
+from ultralytics.utils import LOGGER
+from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
+from ultralytics.utils.plotting import plot_images
+
+
+class ClassificationValidator(BaseValidator):
+    """
+    A class extending the BaseValidator class for validation based on a classification model.
+
+    Notes:
+        - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.classify import ClassificationValidator
+
+        args = dict(model="yolov8n-cls.pt", data="imagenet10")
+        validator = ClassificationValidator(args=args)
+        validator()
+        ```
+    """
+
+    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
+        """Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
+        super().__init__(dataloader, save_dir, pbar, args, _callbacks)
+        self.targets = None
+        self.pred = None
+        self.args.task = "classify"
+        self.metrics = ClassifyMetrics()
+
+    def get_desc(self):
+        """Returns a formatted string summarizing classification metrics."""
+        return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
+
+    def init_metrics(self, model):
+        """Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
+        self.names = model.names
+        self.nc = len(model.names)
+        self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify")
+        self.pred = []
+        self.targets = []
+
+    def preprocess(self, batch):
+        """Preprocesses input batch and returns it."""
+        batch["img"] = batch["img"].to(self.device, non_blocking=True)
+        batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
+        batch["cls"] = batch["cls"].to(self.device)
+        return batch
+
+    def update_metrics(self, preds, batch):
+        """Updates running metrics with model predictions and batch targets."""
+        n5 = min(len(self.names), 5)
+        self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
+        self.targets.append(batch["cls"].type(torch.int32).cpu())
+
+    def finalize_metrics(self, *args, **kwargs):
+        """Finalizes metrics of the model such as confusion_matrix and speed."""
+        self.confusion_matrix.process_cls_preds(self.pred, self.targets)
+        if self.args.plots:
+            for normalize in True, False:
+                self.confusion_matrix.plot(
+                    save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
+                )
+        self.metrics.speed = self.speed
+        self.metrics.confusion_matrix = self.confusion_matrix
+        self.metrics.save_dir = self.save_dir
+
+    def postprocess(self, preds):
+        """Preprocesses the classification predictions."""
+        return preds[0] if isinstance(preds, (list, tuple)) else preds
+
+    def get_stats(self):
+        """Returns a dictionary of metrics obtained by processing targets and predictions."""
+        self.metrics.process(self.targets, self.pred)
+        return self.metrics.results_dict
+
+    def build_dataset(self, img_path):
+        """Creates and returns a ClassificationDataset instance using given image path and preprocessing parameters."""
+        return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
+
+    def get_dataloader(self, dataset_path, batch_size):
+        """Builds and returns a data loader for classification tasks with given parameters."""
+        dataset = self.build_dataset(dataset_path)
+        return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
+
+    def print_results(self):
+        """Prints evaluation metrics for YOLO object detection model."""
+        pf = "%22s" + "%11.3g" * len(self.metrics.keys)  # print format
+        LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
+
+    def plot_val_samples(self, batch, ni):
+        """Plot validation image samples."""
+        plot_images(
+            images=batch["img"],
+            batch_idx=torch.arange(len(batch["img"])),
+            cls=batch["cls"].view(-1),  # warning: use .view(), not .squeeze() for Classify models
+            fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+            names=self.names,
+            on_plot=self.on_plot,
+        )
+
+    def plot_predictions(self, batch, preds, ni):
+        """Plots predicted bounding boxes on input images and saves the result."""
+        plot_images(
+            batch["img"],
+            batch_idx=torch.arange(len(batch["img"])),
+            cls=torch.argmax(preds, dim=1),
+            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
+            names=self.names,
+            on_plot=self.on_plot,
+        )  # pred

+ 7 - 0
ultralytics/models/yolo/detect/__init__.py

@@ -0,0 +1,7 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from .predict import DetectionPredictor
+from .train import DetectionTrainer
+from .val import DetectionValidator
+
+__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"

+ 41 - 0
ultralytics/models/yolo/detect/predict.py

@@ -0,0 +1,41 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.engine.predictor import BasePredictor
+from ultralytics.engine.results import Results
+from ultralytics.utils import ops
+
+
+class DetectionPredictor(BasePredictor):
+    """
+    A class extending the BasePredictor class for prediction based on a detection model.
+
+    Example:
+        ```python
+        from ultralytics.utils import ASSETS
+        from ultralytics.models.yolo.detect import DetectionPredictor
+
+        args = dict(model="yolo11n.pt", source=ASSETS)
+        predictor = DetectionPredictor(overrides=args)
+        predictor.predict_cli()
+        ```
+    """
+
+    def postprocess(self, preds, img, orig_imgs):
+        """Post-processes predictions and returns a list of Results objects."""
+        preds = ops.non_max_suppression(
+            preds,
+            self.args.conf,
+            self.args.iou,
+            agnostic=self.args.agnostic_nms,
+            max_det=self.args.max_det,
+            classes=self.args.classes,
+        )
+
+        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
+            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+
+        results = []
+        for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
+            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
+            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
+        return results

+ 150 - 0
ultralytics/models/yolo/detect/train.py

@@ -0,0 +1,150 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import math
+import random
+from copy import copy
+
+import numpy as np
+import torch.nn as nn
+
+from ultralytics.data import build_dataloader, build_yolo_dataset
+from ultralytics.engine.trainer import BaseTrainer
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import DetectionModel
+from ultralytics.utils import LOGGER, RANK
+from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
+from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
+
+
+class DetectionTrainer(BaseTrainer):
+    """
+    A class extending the BaseTrainer class for training based on a detection model.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.detect import DetectionTrainer
+
+        args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
+        trainer = DetectionTrainer(overrides=args)
+        trainer.train()
+        ```
+    """
+
+    def build_dataset(self, img_path, mode="train", batch=None):
+        """
+        Build YOLO Dataset.
+
+        Args:
+            img_path (str): Path to the folder containing images.
+            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
+            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
+        """
+        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
+        return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
+
+    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
+        """Construct and return dataloader."""
+        assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
+        with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
+            dataset = self.build_dataset(dataset_path, mode, batch_size)
+        shuffle = mode == "train"
+        if getattr(dataset, "rect", False) and shuffle:
+            LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
+            shuffle = False
+        workers = self.args.workers if mode == "train" else self.args.workers * 2
+        return build_dataloader(dataset, batch_size, workers, shuffle, rank)  # return dataloader
+
+    def preprocess_batch(self, batch):
+        """Preprocesses a batch of images by scaling and converting to float."""
+        batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
+        if self.args.multi_scale:
+            imgs = batch["img"]
+            sz = (
+                random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
+                // self.stride
+                * self.stride
+            )  # size
+            sf = sz / max(imgs.shape[2:])  # scale factor
+            if sf != 1:
+                ns = [
+                    math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
+                ]  # new shape (stretched to gs-multiple)
+                imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
+            batch["img"] = imgs
+        return batch
+
+    def set_model_attributes(self):
+        """Nl = de_parallel(self.model).model[-1].nl  # number of detection layers (to scale hyps)."""
+        # self.args.box *= 3 / nl  # scale to layers
+        # self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers
+        # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
+        self.model.nc = self.data["nc"]  # attach number of classes to model
+        self.model.names = self.data["names"]  # attach class names to model
+        self.model.args = self.args  # attach hyperparameters to model
+        # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Return a YOLO detection model."""
+        model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
+        if weights:
+            model.load(weights)
+        return model
+
+    def get_validator(self):
+        """Returns a DetectionValidator for YOLO model validation."""
+        self.loss_names = "box_loss", "cls_loss", "dfl_loss"
+        return yolo.detect.DetectionValidator(
+            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+        )
+
+    def label_loss_items(self, loss_items=None, prefix="train"):
+        """
+        Returns a loss dict with labelled training loss items tensor.
+
+        Not needed for classification but necessary for segmentation & detection
+        """
+        keys = [f"{prefix}/{x}" for x in self.loss_names]
+        if loss_items is not None:
+            loss_items = [round(float(x), 5) for x in loss_items]  # convert tensors to 5 decimal place floats
+            return dict(zip(keys, loss_items))
+        else:
+            return keys
+
+    def progress_string(self):
+        """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
+        return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
+            "Epoch",
+            "GPU_mem",
+            *self.loss_names,
+            "Instances",
+            "Size",
+        )
+
+    def plot_training_samples(self, batch, ni):
+        """Plots training samples with their annotations."""
+        plot_images(
+            images=batch["img"],
+            batch_idx=batch["batch_idx"],
+            cls=batch["cls"].squeeze(-1),
+            bboxes=batch["bboxes"],
+            paths=batch["im_file"],
+            fname=self.save_dir / f"train_batch{ni}.jpg",
+            on_plot=self.on_plot,
+        )
+
+    def plot_metrics(self):
+        """Plots metrics from a CSV file."""
+        plot_results(file=self.csv, on_plot=self.on_plot)  # save results.png
+
+    def plot_training_labels(self):
+        """Create a labeled training plot of the YOLO model."""
+        boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
+        cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
+        plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
+
+    def auto_batch(self):
+        """Get batch size by calculating memory occupation of model."""
+        train_dataset = self.build_dataset(self.trainset, mode="train", batch=16)
+        # 4 for mosaic augmentation
+        max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4
+        return super().auto_batch(max_num_obj)

+ 337 - 0
ultralytics/models/yolo/detect/val.py

@@ -0,0 +1,337 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import os
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from ultralytics.data import build_dataloader, build_yolo_dataset, converter
+from ultralytics.engine.validator import BaseValidator
+from ultralytics.utils import LOGGER, ops
+from ultralytics.utils.checks import check_requirements
+from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
+from ultralytics.utils.plotting import output_to_target, plot_images
+
+
+class DetectionValidator(BaseValidator):
+    """
+    A class extending the BaseValidator class for validation based on a detection model.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.detect import DetectionValidator
+
+        args = dict(model="yolo11n.pt", data="coco8.yaml")
+        validator = DetectionValidator(args=args)
+        validator()
+        ```
+    """
+
+    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
+        """Initialize detection model with necessary variables and settings."""
+        super().__init__(dataloader, save_dir, pbar, args, _callbacks)
+        self.nt_per_class = None
+        self.nt_per_image = None
+        self.is_coco = False
+        self.is_lvis = False
+        self.class_map = None
+        self.args.task = "detect"
+        self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
+        self.iouv = torch.linspace(0.5, 0.95, 10)  # IoU vector for mAP@0.5:0.95
+        self.niou = self.iouv.numel()
+        self.lb = []  # for autolabelling
+        if self.args.save_hybrid:
+            LOGGER.warning(
+                "WARNING ⚠️ 'save_hybrid=True' will append ground truth to predictions for autolabelling.\n"
+                "WARNING ⚠️ 'save_hybrid=True' will cause incorrect mAP.\n"
+            )
+
+    def preprocess(self, batch):
+        """Preprocesses batch of images for YOLO training."""
+        batch["img"] = batch["img"].to(self.device, non_blocking=True)
+        batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
+        for k in ["batch_idx", "cls", "bboxes"]:
+            batch[k] = batch[k].to(self.device)
+
+        if self.args.save_hybrid:
+            height, width = batch["img"].shape[2:]
+            nb = len(batch["img"])
+            bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
+            self.lb = [
+                torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
+                for i in range(nb)
+            ]
+
+        return batch
+
+    def init_metrics(self, model):
+        """Initialize evaluation metrics for YOLO."""
+        val = self.data.get(self.args.split, "")  # validation path
+        self.is_coco = (
+            isinstance(val, str)
+            and "coco" in val
+            and (val.endswith(f"{os.sep}val2017.txt") or val.endswith(f"{os.sep}test-dev2017.txt"))
+        )  # is COCO
+        self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco  # is LVIS
+        self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1))
+        self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training  # run final val
+        self.names = model.names
+        self.nc = len(model.names)
+        self.metrics.names = self.names
+        self.metrics.plot = self.args.plots
+        self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
+        self.seen = 0
+        self.jdict = []
+        self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
+
+    def get_desc(self):
+        """Return a formatted string summarizing class metrics of YOLO model."""
+        return ("%22s" + "%11s" * 7) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP75","mAP50-95)")
+
+    def postprocess(self, preds):
+        """Apply Non-maximum suppression to prediction outputs."""
+        return ops.non_max_suppression(
+            preds,
+            self.args.conf,
+            self.args.iou,
+            labels=self.lb,
+            multi_label=True,
+            agnostic=self.args.single_cls or self.args.agnostic_nms,
+            max_det=self.args.max_det,
+        )
+
+    def _prepare_batch(self, si, batch):
+        """Prepares a batch of images and annotations for validation."""
+        idx = batch["batch_idx"] == si
+        cls = batch["cls"][idx].squeeze(-1)
+        bbox = batch["bboxes"][idx]
+        ori_shape = batch["ori_shape"][si]
+        imgsz = batch["img"].shape[2:]
+        ratio_pad = batch["ratio_pad"][si]
+        if len(cls):
+            bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]  # target boxes
+            ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad)  # native-space labels
+        return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
+
+    def _prepare_pred(self, pred, pbatch):
+        """Prepares a batch of images and annotations for validation."""
+        predn = pred.clone()
+        ops.scale_boxes(
+            pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
+        )  # native-space pred
+        return predn
+
+    def update_metrics(self, preds, batch):
+        """Metrics."""
+        for si, pred in enumerate(preds):
+            self.seen += 1
+            npr = len(pred)
+            stat = dict(
+                conf=torch.zeros(0, device=self.device),
+                pred_cls=torch.zeros(0, device=self.device),
+                tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+            )
+            pbatch = self._prepare_batch(si, batch)
+            cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
+            nl = len(cls)
+            stat["target_cls"] = cls
+            stat["target_img"] = cls.unique()
+            if npr == 0:
+                if nl:
+                    for k in self.stats.keys():
+                        self.stats[k].append(stat[k])
+                    if self.args.plots:
+                        self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
+                continue
+
+            # Predictions
+            if self.args.single_cls:
+                pred[:, 5] = 0
+            predn = self._prepare_pred(pred, pbatch)
+            stat["conf"] = predn[:, 4]
+            stat["pred_cls"] = predn[:, 5]
+
+            # Evaluate
+            if nl:
+                stat["tp"] = self._process_batch(predn, bbox, cls)
+            if self.args.plots:
+                self.confusion_matrix.process_batch(predn, bbox, cls)
+            for k in self.stats.keys():
+                self.stats[k].append(stat[k])
+
+            # Save
+            if self.args.save_json:
+                self.pred_to_json(predn, batch["im_file"][si])
+            if self.args.save_txt:
+                self.save_one_txt(
+                    predn,
+                    self.args.save_conf,
+                    pbatch["ori_shape"],
+                    self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
+                )
+
+    def finalize_metrics(self, *args, **kwargs):
+        """Set final values for metrics speed and confusion matrix."""
+        self.metrics.speed = self.speed
+        self.metrics.confusion_matrix = self.confusion_matrix
+
+    def get_stats(self):
+        """Returns metrics statistics and results dictionary."""
+        stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()}  # to numpy
+        self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
+        self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
+        stats.pop("target_img", None)
+        if len(stats) and stats["tp"].any():
+            self.metrics.process(**stats)
+        return self.metrics.results_dict
+
+    def print_results(self):
+        """Prints training/validation set metrics per class."""
+        pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys)  # print format
+        LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
+        if self.nt_per_class.sum() == 0:
+            LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels")
+
+        # Print results per class
+        if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
+            for i, c in enumerate(self.metrics.ap_class_index):
+                LOGGER.info(
+                    pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))
+                )
+
+        if self.args.plots:
+            for normalize in True, False:
+                self.confusion_matrix.plot(
+                    save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
+                )
+
+    def _process_batch(self, detections, gt_bboxes, gt_cls):
+        """
+        Return correct prediction matrix.
+
+        Args:
+            detections (torch.Tensor): Tensor of shape (N, 6) representing detections where each detection is
+                (x1, y1, x2, y2, conf, class).
+            gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground-truth bounding box coordinates. Each
+                bounding box is of the format: (x1, y1, x2, y2).
+            gt_cls (torch.Tensor): Tensor of shape (M,) representing target class indices.
+
+        Returns:
+            (torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.
+
+        Note:
+            The function does not return any value directly usable for metrics calculation. Instead, it provides an
+            intermediate representation used for evaluating predictions against ground truth.
+        """
+        iou = box_iou(gt_bboxes, detections[:, :4])
+        return self.match_predictions(detections[:, 5], gt_cls, iou)
+
+    def build_dataset(self, img_path, mode="val", batch=None):
+        """
+        Build YOLO Dataset.
+
+        Args:
+            img_path (str): Path to the folder containing images.
+            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
+            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
+        """
+        return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
+
+    def get_dataloader(self, dataset_path, batch_size):
+        """Construct and return dataloader."""
+        dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
+        return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1)  # return dataloader
+
+    def plot_val_samples(self, batch, ni):
+        """Plot validation image samples."""
+        plot_images(
+            batch["img"],
+            batch["batch_idx"],
+            batch["cls"].squeeze(-1),
+            batch["bboxes"],
+            paths=batch["im_file"],
+            fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+            names=self.names,
+            on_plot=self.on_plot,
+        )
+
+    def plot_predictions(self, batch, preds, ni):
+        """Plots predicted bounding boxes on input images and saves the result."""
+        plot_images(
+            batch["img"],
+            *output_to_target(preds, max_det=self.args.max_det),
+            paths=batch["im_file"],
+            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
+            names=self.names,
+            on_plot=self.on_plot,
+        )  # pred
+
+    def save_one_txt(self, predn, save_conf, shape, file):
+        """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
+        from ultralytics.engine.results import Results
+
+        Results(
+            np.zeros((shape[0], shape[1]), dtype=np.uint8),
+            path=None,
+            names=self.names,
+            boxes=predn[:, :6],
+        ).save_txt(file, save_conf=save_conf)
+
+    def pred_to_json(self, predn, filename):
+        """Serialize YOLO predictions to COCO json format."""
+        stem = Path(filename).stem
+        image_id = int(stem) if stem.isnumeric() else stem
+        box = ops.xyxy2xywh(predn[:, :4])  # xywh
+        box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
+        for p, b in zip(predn.tolist(), box.tolist()):
+            self.jdict.append(
+                {
+                    "image_id": image_id,
+                    "category_id": self.class_map[int(p[5])],
+                    "bbox": [round(x, 3) for x in b],
+                    "score": round(p[4], 5),
+                }
+            )
+
+    def eval_json(self, stats):
+        """Evaluates YOLO output in JSON format and returns performance statistics."""
+        if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
+            pred_json = self.save_dir / "predictions.json"  # predictions
+            anno_json = (
+                self.data["path"]
+                / "annotations"
+                / ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
+            )  # annotations
+            pkg = "pycocotools" if self.is_coco else "lvis"
+            LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
+            try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
+                for x in pred_json, anno_json:
+                    assert x.is_file(), f"{x} file not found"
+                check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
+                if self.is_coco:
+                    from pycocotools.coco import COCO  # noqa
+                    from pycocotools.cocoeval import COCOeval  # noqa
+
+                    anno = COCO(str(anno_json))  # init annotations api
+                    pred = anno.loadRes(str(pred_json))  # init predictions api (must pass string, not Path)
+                    val = COCOeval(anno, pred, "bbox")
+                else:
+                    from lvis import LVIS, LVISEval
+
+                    anno = LVIS(str(anno_json))  # init annotations api
+                    pred = anno._load_json(str(pred_json))  # init predictions api (must pass string, not Path)
+                    val = LVISEval(anno, pred, "bbox")
+                val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files]  # images to eval
+                val.evaluate()
+                val.accumulate()
+                val.summarize()
+                if self.is_lvis:
+                    val.print_results()  # explicitly call print_results
+                # update mAP50-95 and mAP50
+                stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
+                    val.stats[:2] if self.is_coco else [val.results["AP50"], val.results["AP"]]
+                )
+            except Exception as e:
+                LOGGER.warning(f"{pkg} unable to run: {e}")
+        return stats

+ 111 - 0
ultralytics/models/yolo/model.py

@@ -0,0 +1,111 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from pathlib import Path
+
+from ultralytics.engine.model import Model
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
+from ultralytics.utils import ROOT, yaml_load
+
+
+class YOLO(Model):
+    """YOLO (You Only Look Once) object detection model."""
+
+    def __init__(self, model="yolo11n.pt", task=None, verbose=False):
+        """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
+        path = Path(model)
+        if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}:  # if YOLOWorld PyTorch model
+            new_instance = YOLOWorld(path, verbose=verbose)
+            self.__class__ = type(new_instance)
+            self.__dict__ = new_instance.__dict__
+        else:
+            # Continue with default YOLO initialization
+            super().__init__(model=model, task=task, verbose=verbose)
+
+    @property
+    def task_map(self):
+        """Map head to model, trainer, validator, and predictor classes."""
+        return {
+            "classify": {
+                "model": ClassificationModel,
+                "trainer": yolo.classify.ClassificationTrainer,
+                "validator": yolo.classify.ClassificationValidator,
+                "predictor": yolo.classify.ClassificationPredictor,
+            },
+            "detect": {
+                "model": DetectionModel,
+                "trainer": yolo.detect.DetectionTrainer,
+                "validator": yolo.detect.DetectionValidator,
+                "predictor": yolo.detect.DetectionPredictor,
+            },
+            "segment": {
+                "model": SegmentationModel,
+                "trainer": yolo.segment.SegmentationTrainer,
+                "validator": yolo.segment.SegmentationValidator,
+                "predictor": yolo.segment.SegmentationPredictor,
+            },
+            "pose": {
+                "model": PoseModel,
+                "trainer": yolo.pose.PoseTrainer,
+                "validator": yolo.pose.PoseValidator,
+                "predictor": yolo.pose.PosePredictor,
+            },
+            "obb": {
+                "model": OBBModel,
+                "trainer": yolo.obb.OBBTrainer,
+                "validator": yolo.obb.OBBValidator,
+                "predictor": yolo.obb.OBBPredictor,
+            },
+        }
+
+
+class YOLOWorld(Model):
+    """YOLO-World object detection model."""
+
+    def __init__(self, model="yolov8s-world.pt", verbose=False) -> None:
+        """
+        Initialize YOLOv8-World model with a pre-trained model file.
+
+        Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
+        COCO class names.
+
+        Args:
+            model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
+            verbose (bool): If True, prints additional information during initialization.
+        """
+        super().__init__(model=model, task="detect", verbose=verbose)
+
+        # Assign default COCO class names when there are no custom names
+        if not hasattr(self.model, "names"):
+            self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")
+
+    @property
+    def task_map(self):
+        """Map head to model, validator, and predictor classes."""
+        return {
+            "detect": {
+                "model": WorldModel,
+                "validator": yolo.detect.DetectionValidator,
+                "predictor": yolo.detect.DetectionPredictor,
+                "trainer": yolo.world.WorldTrainer,
+            }
+        }
+
+    def set_classes(self, classes):
+        """
+        Set classes.
+
+        Args:
+            classes (List(str)): A list of categories i.e. ["person"].
+        """
+        self.model.set_classes(classes)
+        # Remove background if it's given
+        background = " "
+        if background in classes:
+            classes.remove(background)
+        self.model.names = classes
+
+        # Reset method class names
+        # self.predictor = None  # reset predictor otherwise old names remain
+        if self.predictor:
+            self.predictor.model.names = classes

+ 7 - 0
ultralytics/models/yolo/obb/__init__.py

@@ -0,0 +1,7 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from .predict import OBBPredictor
+from .train import OBBTrainer
+from .val import OBBValidator
+
+__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"

+ 53 - 0
ultralytics/models/yolo/obb/predict.py

@@ -0,0 +1,53 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import torch
+
+from ultralytics.engine.results import Results
+from ultralytics.models.yolo.detect.predict import DetectionPredictor
+from ultralytics.utils import DEFAULT_CFG, ops
+
+
+class OBBPredictor(DetectionPredictor):
+    """
+    A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
+
+    Example:
+        ```python
+        from ultralytics.utils import ASSETS
+        from ultralytics.models.yolo.obb import OBBPredictor
+
+        args = dict(model="yolov8n-obb.pt", source=ASSETS)
+        predictor = OBBPredictor(overrides=args)
+        predictor.predict_cli()
+        ```
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """Initializes OBBPredictor with optional model and data configuration overrides."""
+        super().__init__(cfg, overrides, _callbacks)
+        self.args.task = "obb"
+
+    def postprocess(self, preds, img, orig_imgs):
+        """Post-processes predictions and returns a list of Results objects."""
+        preds = ops.non_max_suppression(
+            preds,
+            self.args.conf,
+            self.args.iou,
+            agnostic=self.args.agnostic_nms,
+            max_det=self.args.max_det,
+            nc=len(self.model.names),
+            classes=self.args.classes,
+            rotated=True,
+        )
+
+        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
+            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+
+        results = []
+        for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
+            rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
+            rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
+            # xywh, r, conf, cls
+            obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
+            results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
+        return results

+ 44 - 0
ultralytics/models/yolo/obb/train.py

@@ -0,0 +1,44 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from copy import copy
+
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import OBBModel
+from ultralytics.utils import DEFAULT_CFG, RANK
+
+
+class OBBTrainer(yolo.detect.DetectionTrainer):
+    """
+    A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.obb import OBBTrainer
+
+        args = dict(model="yolov8n-obb.pt", data="dota8.yaml", epochs=3)
+        trainer = OBBTrainer(overrides=args)
+        trainer.train()
+        ```
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """Initialize a OBBTrainer object with given arguments."""
+        if overrides is None:
+            overrides = {}
+        overrides["task"] = "obb"
+        super().__init__(cfg, overrides, _callbacks)
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Return OBBModel initialized with specified config and weights."""
+        model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
+        if weights:
+            model.load(weights)
+
+        return model
+
+    def get_validator(self):
+        """Return an instance of OBBValidator for validation of YOLO model."""
+        self.loss_names = "box_loss", "cls_loss", "dfl_loss"
+        return yolo.obb.OBBValidator(
+            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+        )

+ 203 - 0
ultralytics/models/yolo/obb/val.py

@@ -0,0 +1,203 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from pathlib import Path
+
+import torch
+
+from ultralytics.models.yolo.detect import DetectionValidator
+from ultralytics.utils import LOGGER, ops
+from ultralytics.utils.metrics import OBBMetrics, batch_probiou
+from ultralytics.utils.plotting import output_to_rotated_target, plot_images
+
+
+class OBBValidator(DetectionValidator):
+    """
+    A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.obb import OBBValidator
+
+        args = dict(model="yolov8n-obb.pt", data="dota8.yaml")
+        validator = OBBValidator(args=args)
+        validator(model=args["model"])
+        ```
+    """
+
+    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
+        """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
+        super().__init__(dataloader, save_dir, pbar, args, _callbacks)
+        self.args.task = "obb"
+        self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
+
+    def init_metrics(self, model):
+        """Initialize evaluation metrics for YOLO."""
+        super().init_metrics(model)
+        val = self.data.get(self.args.split, "")  # validation path
+        self.is_dota = isinstance(val, str) and "DOTA" in val  # is COCO
+
+    def postprocess(self, preds):
+        """Apply Non-maximum suppression to prediction outputs."""
+        return ops.non_max_suppression(
+            preds,
+            self.args.conf,
+            self.args.iou,
+            labels=self.lb,
+            nc=self.nc,
+            multi_label=True,
+            agnostic=self.args.single_cls or self.args.agnostic_nms,
+            max_det=self.args.max_det,
+            rotated=True,
+        )
+
+    def _process_batch(self, detections, gt_bboxes, gt_cls):
+        """
+        Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
+
+        Args:
+            detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated
+                data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle).
+            gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is
+                represented as (x1, y1, x2, y2, angle).
+            gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes.
+
+        Returns:
+            (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
+                Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.
+
+        Example:
+            ```python
+            detections = torch.rand(100, 7)  # 100 sample detections
+            gt_bboxes = torch.rand(50, 5)  # 50 sample ground truth boxes
+            gt_cls = torch.randint(0, 5, (50,))  # 50 ground truth class labels
+            correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
+            ```
+
+        Note:
+            This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
+        """
+        iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
+        return self.match_predictions(detections[:, 5], gt_cls, iou)
+
+    def _prepare_batch(self, si, batch):
+        """Prepares and returns a batch for OBB validation."""
+        idx = batch["batch_idx"] == si
+        cls = batch["cls"][idx].squeeze(-1)
+        bbox = batch["bboxes"][idx]
+        ori_shape = batch["ori_shape"][si]
+        imgsz = batch["img"].shape[2:]
+        ratio_pad = batch["ratio_pad"][si]
+        if len(cls):
+            bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]])  # target boxes
+            ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True)  # native-space labels
+        return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
+
+    def _prepare_pred(self, pred, pbatch):
+        """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
+        predn = pred.clone()
+        ops.scale_boxes(
+            pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
+        )  # native-space pred
+        return predn
+
+    def plot_predictions(self, batch, preds, ni):
+        """Plots predicted bounding boxes on input images and saves the result."""
+        plot_images(
+            batch["img"],
+            *output_to_rotated_target(preds, max_det=self.args.max_det),
+            paths=batch["im_file"],
+            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
+            names=self.names,
+            on_plot=self.on_plot,
+        )  # pred
+
+    def pred_to_json(self, predn, filename):
+        """Serialize YOLO predictions to COCO json format."""
+        stem = Path(filename).stem
+        image_id = int(stem) if stem.isnumeric() else stem
+        rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
+        poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
+        for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
+            self.jdict.append(
+                {
+                    "image_id": image_id,
+                    "category_id": self.class_map[int(predn[i, 5].item())],
+                    "score": round(predn[i, 4].item(), 5),
+                    "rbox": [round(x, 3) for x in r],
+                    "poly": [round(x, 3) for x in b],
+                }
+            )
+
+    def save_one_txt(self, predn, save_conf, shape, file):
+        """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
+        import numpy as np
+
+        from ultralytics.engine.results import Results
+
+        rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
+        # xywh, r, conf, cls
+        obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
+        Results(
+            np.zeros((shape[0], shape[1]), dtype=np.uint8),
+            path=None,
+            names=self.names,
+            obb=obb,
+        ).save_txt(file, save_conf=save_conf)
+
+    def eval_json(self, stats):
+        """Evaluates YOLO output in JSON format and returns performance statistics."""
+        if self.args.save_json and self.is_dota and len(self.jdict):
+            import json
+            import re
+            from collections import defaultdict
+
+            pred_json = self.save_dir / "predictions.json"  # predictions
+            pred_txt = self.save_dir / "predictions_txt"  # predictions
+            pred_txt.mkdir(parents=True, exist_ok=True)
+            data = json.load(open(pred_json))
+            # Save split results
+            LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
+            for d in data:
+                image_id = d["image_id"]
+                score = d["score"]
+                classname = self.names[d["category_id"] - 1].replace(" ", "-")
+                p = d["poly"]
+
+                with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a") as f:
+                    f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
+            # Save merged results, this could result slightly lower map than using official merging script,
+            # because of the probiou calculation.
+            pred_merged_txt = self.save_dir / "predictions_merged_txt"  # predictions
+            pred_merged_txt.mkdir(parents=True, exist_ok=True)
+            merged_results = defaultdict(list)
+            LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
+            for d in data:
+                image_id = d["image_id"].split("__")[0]
+                pattern = re.compile(r"\d+___\d+")
+                x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
+                bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
+                bbox[0] += x
+                bbox[1] += y
+                bbox.extend([score, cls])
+                merged_results[image_id].append(bbox)
+            for image_id, bbox in merged_results.items():
+                bbox = torch.tensor(bbox)
+                max_wh = torch.max(bbox[:, :2]).item() * 2
+                c = bbox[:, 6:7] * max_wh  # classes
+                scores = bbox[:, 5]  # scores
+                b = bbox[:, :5].clone()
+                b[:, :2] += c
+                # 0.3 could get results close to the ones from official merging script, even slightly better.
+                i = ops.nms_rotated(b, scores, 0.3)
+                bbox = bbox[i]
+
+                b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
+                for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
+                    classname = self.names[int(x[-1])].replace(" ", "-")
+                    p = [round(i, 3) for i in x[:-2]]  # poly
+                    score = round(x[-2], 3)
+
+                    with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a") as f:
+                        f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
+
+        return stats

+ 7 - 0
ultralytics/models/yolo/pose/__init__.py

@@ -0,0 +1,7 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from .predict import PosePredictor
+from .train import PoseTrainer
+from .val import PoseValidator
+
+__all__ = "PoseTrainer", "PoseValidator", "PosePredictor"

+ 56 - 0
ultralytics/models/yolo/pose/predict.py

@@ -0,0 +1,56 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.engine.results import Results
+from ultralytics.models.yolo.detect.predict import DetectionPredictor
+from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
+
+
+class PosePredictor(DetectionPredictor):
+    """
+    A class extending the DetectionPredictor class for prediction based on a pose model.
+
+    Example:
+        ```python
+        from ultralytics.utils import ASSETS
+        from ultralytics.models.yolo.pose import PosePredictor
+
+        args = dict(model="yolov8n-pose.pt", source=ASSETS)
+        predictor = PosePredictor(overrides=args)
+        predictor.predict_cli()
+        ```
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
+        super().__init__(cfg, overrides, _callbacks)
+        self.args.task = "pose"
+        if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
+            LOGGER.warning(
+                "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
+                "See https://github.com/ultralytics/ultralytics/issues/4031."
+            )
+
+    def postprocess(self, preds, img, orig_imgs):
+        """Return detection results for a given input image or list of images."""
+        preds = ops.non_max_suppression(
+            preds,
+            self.args.conf,
+            self.args.iou,
+            agnostic=self.args.agnostic_nms,
+            max_det=self.args.max_det,
+            classes=self.args.classes,
+            nc=len(self.model.names),
+        )
+
+        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
+            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+
+        results = []
+        for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
+            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
+            pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
+            pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
+            results.append(
+                Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
+            )
+        return results

+ 79 - 0
ultralytics/models/yolo/pose/train.py

@@ -0,0 +1,79 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from copy import copy
+
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import PoseModel
+from ultralytics.utils import DEFAULT_CFG, LOGGER
+from ultralytics.utils.plotting import plot_images, plot_results
+
+
+class PoseTrainer(yolo.detect.DetectionTrainer):
+    """
+    A class extending the DetectionTrainer class for training based on a pose model.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.pose import PoseTrainer
+
+        args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml", epochs=3)
+        trainer = PoseTrainer(overrides=args)
+        trainer.train()
+        ```
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """Initialize a PoseTrainer object with specified configurations and overrides."""
+        if overrides is None:
+            overrides = {}
+        overrides["task"] = "pose"
+        super().__init__(cfg, overrides, _callbacks)
+
+        if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
+            LOGGER.warning(
+                "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
+                "See https://github.com/ultralytics/ultralytics/issues/4031."
+            )
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Get pose estimation model with specified configuration and weights."""
+        model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
+        if weights:
+            model.load(weights)
+
+        return model
+
+    def set_model_attributes(self):
+        """Sets keypoints shape attribute of PoseModel."""
+        super().set_model_attributes()
+        self.model.kpt_shape = self.data["kpt_shape"]
+
+    def get_validator(self):
+        """Returns an instance of the PoseValidator class for validation."""
+        self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
+        return yolo.pose.PoseValidator(
+            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+        )
+
+    def plot_training_samples(self, batch, ni):
+        """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
+        images = batch["img"]
+        kpts = batch["keypoints"]
+        cls = batch["cls"].squeeze(-1)
+        bboxes = batch["bboxes"]
+        paths = batch["im_file"]
+        batch_idx = batch["batch_idx"]
+        plot_images(
+            images,
+            batch_idx,
+            cls,
+            bboxes,
+            kpts=kpts,
+            paths=paths,
+            fname=self.save_dir / f"train_batch{ni}.jpg",
+            on_plot=self.on_plot,
+        )
+
+    def plot_metrics(self):
+        """Plots training/val metrics."""
+        plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png

+ 282 - 0
ultralytics/models/yolo/pose/val.py

@@ -0,0 +1,282 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from ultralytics.models.yolo.detect import DetectionValidator
+from ultralytics.utils import LOGGER, ops
+from ultralytics.utils.checks import check_requirements
+from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
+from ultralytics.utils.plotting import output_to_target, plot_images
+
+
+class PoseValidator(DetectionValidator):
+    """
+    A class extending the DetectionValidator class for validation based on a pose model.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.pose import PoseValidator
+
+        args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml")
+        validator = PoseValidator(args=args)
+        validator()
+        ```
+    """
+
+    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
+        """Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
+        super().__init__(dataloader, save_dir, pbar, args, _callbacks)
+        self.sigma = None
+        self.kpt_shape = None
+        self.args.task = "pose"
+        self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
+        if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
+            LOGGER.warning(
+                "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
+                "See https://github.com/ultralytics/ultralytics/issues/4031."
+            )
+
+    def preprocess(self, batch):
+        """Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
+        batch = super().preprocess(batch)
+        batch["keypoints"] = batch["keypoints"].to(self.device).float()
+        return batch
+
+    def get_desc(self):
+        """Returns description of evaluation metrics in string format."""
+        return ("%22s" + "%11s" * 10) % (
+            "Class",
+            "Images",
+            "Instances",
+            "Box(P",
+            "R",
+            "mAP50",
+            "mAP50-95)",
+            "Pose(P",
+            "R",
+            "mAP50",
+            "mAP50-95)",
+        )
+
+    def postprocess(self, preds):
+        """Apply non-maximum suppression and return detections with high confidence scores."""
+        return ops.non_max_suppression(
+            preds,
+            self.args.conf,
+            self.args.iou,
+            labels=self.lb,
+            multi_label=True,
+            agnostic=self.args.single_cls or self.args.agnostic_nms,
+            max_det=self.args.max_det,
+            nc=self.nc,
+        )
+
+    def init_metrics(self, model):
+        """Initiate pose estimation metrics for YOLO model."""
+        super().init_metrics(model)
+        self.kpt_shape = self.data["kpt_shape"]
+        is_pose = self.kpt_shape == [17, 3]
+        nkpt = self.kpt_shape[0]
+        self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
+        self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
+
+    def _prepare_batch(self, si, batch):
+        """Prepares a batch for processing by converting keypoints to float and moving to device."""
+        pbatch = super()._prepare_batch(si, batch)
+        kpts = batch["keypoints"][batch["batch_idx"] == si]
+        h, w = pbatch["imgsz"]
+        kpts = kpts.clone()
+        kpts[..., 0] *= w
+        kpts[..., 1] *= h
+        kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
+        pbatch["kpts"] = kpts
+        return pbatch
+
+    def _prepare_pred(self, pred, pbatch):
+        """Prepares and scales keypoints in a batch for pose processing."""
+        predn = super()._prepare_pred(pred, pbatch)
+        nk = pbatch["kpts"].shape[1]
+        pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
+        ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
+        return predn, pred_kpts
+
+    def update_metrics(self, preds, batch):
+        """Metrics."""
+        for si, pred in enumerate(preds):
+            self.seen += 1
+            npr = len(pred)
+            stat = dict(
+                conf=torch.zeros(0, device=self.device),
+                pred_cls=torch.zeros(0, device=self.device),
+                tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+                tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+            )
+            pbatch = self._prepare_batch(si, batch)
+            cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
+            nl = len(cls)
+            stat["target_cls"] = cls
+            stat["target_img"] = cls.unique()
+            if npr == 0:
+                if nl:
+                    for k in self.stats.keys():
+                        self.stats[k].append(stat[k])
+                    if self.args.plots:
+                        self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
+                continue
+
+            # Predictions
+            if self.args.single_cls:
+                pred[:, 5] = 0
+            predn, pred_kpts = self._prepare_pred(pred, pbatch)
+            stat["conf"] = predn[:, 4]
+            stat["pred_cls"] = predn[:, 5]
+
+            # Evaluate
+            if nl:
+                stat["tp"] = self._process_batch(predn, bbox, cls)
+                stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
+            if self.args.plots:
+                self.confusion_matrix.process_batch(predn, bbox, cls)
+
+            for k in self.stats.keys():
+                self.stats[k].append(stat[k])
+
+            # Save
+            if self.args.save_json:
+                self.pred_to_json(predn, batch["im_file"][si])
+            if self.args.save_txt:
+                self.save_one_txt(
+                    predn,
+                    pred_kpts,
+                    self.args.save_conf,
+                    pbatch["ori_shape"],
+                    self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
+                )
+
+    def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
+        """
+        Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
+
+        Args:
+            detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each
+                detection is of the format (x1, y1, x2, y2, conf, class).
+            gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each
+                box is of the format (x1, y1, x2, y2).
+            gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
+            pred_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing predicted keypoints, where
+                51 corresponds to 17 keypoints each having 3 values.
+            gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints.
+
+        Returns:
+            torch.Tensor: A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
+                where N is the number of detections.
+
+        Example:
+            ```python
+            detections = torch.rand(100, 6)  # 100 predictions: (x1, y1, x2, y2, conf, class)
+            gt_bboxes = torch.rand(50, 4)  # 50 ground truth boxes: (x1, y1, x2, y2)
+            gt_cls = torch.randint(0, 2, (50,))  # 50 ground truth class indices
+            pred_kpts = torch.rand(100, 51)  # 100 predicted keypoints
+            gt_kpts = torch.rand(50, 51)  # 50 ground truth keypoints
+            correct_preds = _process_batch(detections, gt_bboxes, gt_cls, pred_kpts, gt_kpts)
+            ```
+
+        Note:
+            `0.53` scale factor used in area computation is referenced from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
+        """
+        if pred_kpts is not None and gt_kpts is not None:
+            # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
+            area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53
+            iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
+        else:  # boxes
+            iou = box_iou(gt_bboxes, detections[:, :4])
+
+        return self.match_predictions(detections[:, 5], gt_cls, iou)
+
+    def plot_val_samples(self, batch, ni):
+        """Plots and saves validation set samples with predicted bounding boxes and keypoints."""
+        plot_images(
+            batch["img"],
+            batch["batch_idx"],
+            batch["cls"].squeeze(-1),
+            batch["bboxes"],
+            kpts=batch["keypoints"],
+            paths=batch["im_file"],
+            fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+            names=self.names,
+            on_plot=self.on_plot,
+        )
+
+    def plot_predictions(self, batch, preds, ni):
+        """Plots predictions for YOLO model."""
+        pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
+        plot_images(
+            batch["img"],
+            *output_to_target(preds, max_det=self.args.max_det),
+            kpts=pred_kpts,
+            paths=batch["im_file"],
+            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
+            names=self.names,
+            on_plot=self.on_plot,
+        )  # pred
+
+    def save_one_txt(self, predn, pred_kpts, save_conf, shape, file):
+        """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
+        from ultralytics.engine.results import Results
+
+        Results(
+            np.zeros((shape[0], shape[1]), dtype=np.uint8),
+            path=None,
+            names=self.names,
+            boxes=predn[:, :6],
+            keypoints=pred_kpts,
+        ).save_txt(file, save_conf=save_conf)
+
+    def pred_to_json(self, predn, filename):
+        """Converts YOLO predictions to COCO JSON format."""
+        stem = Path(filename).stem
+        image_id = int(stem) if stem.isnumeric() else stem
+        box = ops.xyxy2xywh(predn[:, :4])  # xywh
+        box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
+        for p, b in zip(predn.tolist(), box.tolist()):
+            self.jdict.append(
+                {
+                    "image_id": image_id,
+                    "category_id": self.class_map[int(p[5])],
+                    "bbox": [round(x, 3) for x in b],
+                    "keypoints": p[6:],
+                    "score": round(p[4], 5),
+                }
+            )
+
+    def eval_json(self, stats):
+        """Evaluates object detection model using COCO JSON format."""
+        if self.args.save_json and self.is_coco and len(self.jdict):
+            anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json"  # annotations
+            pred_json = self.save_dir / "predictions.json"  # predictions
+            LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
+            try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
+                check_requirements("pycocotools>=2.0.6")
+                from pycocotools.coco import COCO  # noqa
+                from pycocotools.cocoeval import COCOeval  # noqa
+
+                for x in anno_json, pred_json:
+                    assert x.is_file(), f"{x} file not found"
+                anno = COCO(str(anno_json))  # init annotations api
+                pred = anno.loadRes(str(pred_json))  # init predictions api (must pass string, not Path)
+                for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
+                    if self.is_coco:
+                        eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files]  # im to eval
+                    eval.evaluate()
+                    eval.accumulate()
+                    eval.summarize()
+                    idx = i * 4 + 2
+                    stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
+                        :2
+                    ]  # update mAP50-95 and mAP50
+            except Exception as e:
+                LOGGER.warning(f"pycocotools unable to run: {e}")
+        return stats

+ 7 - 0
ultralytics/models/yolo/segment/__init__.py

@@ -0,0 +1,7 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from .predict import SegmentationPredictor
+from .train import SegmentationTrainer
+from .val import SegmentationValidator
+
+__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"

+ 55 - 0
ultralytics/models/yolo/segment/predict.py

@@ -0,0 +1,55 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.engine.results import Results
+from ultralytics.models.yolo.detect.predict import DetectionPredictor
+from ultralytics.utils import DEFAULT_CFG, ops
+
+
+class SegmentationPredictor(DetectionPredictor):
+    """
+    A class extending the DetectionPredictor class for prediction based on a segmentation model.
+
+    Example:
+        ```python
+        from ultralytics.utils import ASSETS
+        from ultralytics.models.yolo.segment import SegmentationPredictor
+
+        args = dict(model="yolov8n-seg.pt", source=ASSETS)
+        predictor = SegmentationPredictor(overrides=args)
+        predictor.predict_cli()
+        ```
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks."""
+        super().__init__(cfg, overrides, _callbacks)
+        self.args.task = "segment"
+
+    def postprocess(self, preds, img, orig_imgs):
+        """Applies non-max suppression and processes detections for each image in an input batch."""
+        p = ops.non_max_suppression(
+            preds[0],
+            self.args.conf,
+            self.args.iou,
+            agnostic=self.args.agnostic_nms,
+            max_det=self.args.max_det,
+            nc=len(self.model.names),
+            classes=self.args.classes,
+        )
+
+        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
+            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+
+        results = []
+        proto = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]  # tuple if PyTorch model or array if exported
+        for i, (pred, orig_img, img_path) in enumerate(zip(p, orig_imgs, self.batch[0])):
+            if not len(pred):  # save empty boxes
+                masks = None
+            elif self.args.retina_masks:
+                pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
+                masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2])  # HWC
+            else:
+                masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
+                pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
+            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
+        return results

+ 62 - 0
ultralytics/models/yolo/segment/train.py

@@ -0,0 +1,62 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from copy import copy
+
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import SegmentationModel
+from ultralytics.utils import DEFAULT_CFG, RANK
+from ultralytics.utils.plotting import plot_images, plot_results
+
+
+class SegmentationTrainer(yolo.detect.DetectionTrainer):
+    """
+    A class extending the DetectionTrainer class for training based on a segmentation model.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.segment import SegmentationTrainer
+
+        args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml", epochs=3)
+        trainer = SegmentationTrainer(overrides=args)
+        trainer.train()
+        ```
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """Initialize a SegmentationTrainer object with given arguments."""
+        if overrides is None:
+            overrides = {}
+        overrides["task"] = "segment"
+        super().__init__(cfg, overrides, _callbacks)
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Return SegmentationModel initialized with specified config and weights."""
+        model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
+        if weights:
+            model.load(weights)
+
+        return model
+
+    def get_validator(self):
+        """Return an instance of SegmentationValidator for validation of YOLO model."""
+        self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
+        return yolo.segment.SegmentationValidator(
+            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
+        )
+
+    def plot_training_samples(self, batch, ni):
+        """Creates a plot of training sample images with labels and box coordinates."""
+        plot_images(
+            batch["img"],
+            batch["batch_idx"],
+            batch["cls"].squeeze(-1),
+            batch["bboxes"],
+            masks=batch["masks"],
+            paths=batch["im_file"],
+            fname=self.save_dir / f"train_batch{ni}.jpg",
+            on_plot=self.on_plot,
+        )
+
+    def plot_metrics(self):
+        """Plots training/val metrics."""
+        plot_results(file=self.csv, segment=True, on_plot=self.on_plot)  # save results.png

+ 318 - 0
ultralytics/models/yolo/segment/val.py

@@ -0,0 +1,318 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from multiprocessing.pool import ThreadPool
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from ultralytics.models.yolo.detect import DetectionValidator
+from ultralytics.utils import LOGGER, NUM_THREADS, ops
+from ultralytics.utils.checks import check_requirements
+from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
+from ultralytics.utils.plotting import output_to_target, plot_images
+
+
+class SegmentationValidator(DetectionValidator):
+    """
+    A class extending the DetectionValidator class for validation based on a segmentation model.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.segment import SegmentationValidator
+
+        args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml")
+        validator = SegmentationValidator(args=args)
+        validator()
+        ```
+    """
+
+    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
+        """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
+        super().__init__(dataloader, save_dir, pbar, args, _callbacks)
+        self.plot_masks = None
+        self.process = None
+        self.args.task = "segment"
+        self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
+
+    def preprocess(self, batch):
+        """Preprocesses batch by converting masks to float and sending to device."""
+        batch = super().preprocess(batch)
+        batch["masks"] = batch["masks"].to(self.device).float()
+        return batch
+
+    def init_metrics(self, model):
+        """Initialize metrics and select mask processing function based on save_json flag."""
+        super().init_metrics(model)
+        self.plot_masks = []
+        if self.args.save_json:
+            check_requirements("pycocotools>=2.0.6")
+        # more accurate vs faster
+        self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
+        self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
+
+    def get_desc(self):
+        """Return a formatted description of evaluation metrics."""
+        return ("%22s" + "%11s" * 10) % (
+            "Class",
+            "Images",
+            "Instances",
+            "Box(P",
+            "R",
+            "mAP50",
+            "mAP50-95)",
+            "Mask(P",
+            "R",
+            "mAP50",
+            "mAP50-95)",
+        )
+
+    def postprocess(self, preds):
+        """Post-processes YOLO predictions and returns output detections with proto."""
+        p = ops.non_max_suppression(
+            preds[0],
+            self.args.conf,
+            self.args.iou,
+            labels=self.lb,
+            multi_label=True,
+            agnostic=self.args.single_cls or self.args.agnostic_nms,
+            max_det=self.args.max_det,
+            nc=self.nc,
+        )
+        proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported
+        return p, proto
+
+    def _prepare_batch(self, si, batch):
+        """Prepares a batch for training or inference by processing images and targets."""
+        prepared_batch = super()._prepare_batch(si, batch)
+        midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
+        prepared_batch["masks"] = batch["masks"][midx]
+        return prepared_batch
+
+    def _prepare_pred(self, pred, pbatch, proto):
+        """Prepares a batch for training or inference by processing images and targets."""
+        predn = super()._prepare_pred(pred, pbatch)
+        pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
+        return predn, pred_masks
+
+    def update_metrics(self, preds, batch):
+        """Metrics."""
+        for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
+            self.seen += 1
+            npr = len(pred)
+            stat = dict(
+                conf=torch.zeros(0, device=self.device),
+                pred_cls=torch.zeros(0, device=self.device),
+                tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+                tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
+            )
+            pbatch = self._prepare_batch(si, batch)
+            cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
+            nl = len(cls)
+            stat["target_cls"] = cls
+            stat["target_img"] = cls.unique()
+            if npr == 0:
+                if nl:
+                    for k in self.stats.keys():
+                        self.stats[k].append(stat[k])
+                    if self.args.plots:
+                        self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
+                continue
+
+            # Masks
+            gt_masks = pbatch.pop("masks")
+            # Predictions
+            if self.args.single_cls:
+                pred[:, 5] = 0
+            predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
+            stat["conf"] = predn[:, 4]
+            stat["pred_cls"] = predn[:, 5]
+
+            # Evaluate
+            if nl:
+                stat["tp"] = self._process_batch(predn, bbox, cls)
+                stat["tp_m"] = self._process_batch(
+                    predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
+                )
+            if self.args.plots:
+                self.confusion_matrix.process_batch(predn, bbox, cls)
+
+            for k in self.stats.keys():
+                self.stats[k].append(stat[k])
+
+            pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
+            if self.args.plots and self.batch_i < 3:
+                self.plot_masks.append(pred_masks[:15].cpu())  # filter top 15 to plot
+
+            # Save
+            if self.args.save_json:
+                self.pred_to_json(
+                    predn,
+                    batch["im_file"][si],
+                    ops.scale_image(
+                        pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
+                        pbatch["ori_shape"],
+                        ratio_pad=batch["ratio_pad"][si],
+                    ),
+                )
+            if self.args.save_txt:
+                self.save_one_txt(
+                    predn,
+                    pred_masks,
+                    self.args.save_conf,
+                    pbatch["ori_shape"],
+                    self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
+                )
+
+    def finalize_metrics(self, *args, **kwargs):
+        """Sets speed and confusion matrix for evaluation metrics."""
+        self.metrics.speed = self.speed
+        self.metrics.confusion_matrix = self.confusion_matrix
+
+    def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False):
+        """
+        Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
+
+        Args:
+            detections (torch.Tensor): Tensor of shape (N, 6) representing detected bounding boxes and
+                associated confidence scores and class indices. Each row is of the format [x1, y1, x2, y2, conf, class].
+            gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
+                Each row is of the format [x1, y1, x2, y2].
+            gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
+            pred_masks (torch.Tensor | None): Tensor representing predicted masks, if available. The shape should
+                match the ground truth masks.
+            gt_masks (torch.Tensor | None): Tensor of shape (M, H, W) representing ground truth masks, if available.
+            overlap (bool): Flag indicating if overlapping masks should be considered.
+            masks (bool): Flag indicating if the batch contains mask data.
+
+        Returns:
+            (torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels.
+
+        Note:
+            - If `masks` is True, the function computes IoU between predicted and ground truth masks.
+            - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
+
+        Example:
+            ```python
+            detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]])
+            gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]])
+            gt_cls = torch.tensor([1, 0])
+            correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
+            ```
+        """
+        if masks:
+            if overlap:
+                nl = len(gt_cls)
+                index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
+                gt_masks = gt_masks.repeat(nl, 1, 1)  # shape(1,640,640) -> (n,640,640)
+                gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
+            if gt_masks.shape[1:] != pred_masks.shape[1:]:
+                gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
+                gt_masks = gt_masks.gt_(0.5)
+            iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
+        else:  # boxes
+            iou = box_iou(gt_bboxes, detections[:, :4])
+
+        return self.match_predictions(detections[:, 5], gt_cls, iou)
+
+    def plot_val_samples(self, batch, ni):
+        """Plots validation samples with bounding box labels."""
+        plot_images(
+            batch["img"],
+            batch["batch_idx"],
+            batch["cls"].squeeze(-1),
+            batch["bboxes"],
+            masks=batch["masks"],
+            paths=batch["im_file"],
+            fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+            names=self.names,
+            on_plot=self.on_plot,
+        )
+
+    def plot_predictions(self, batch, preds, ni):
+        """Plots batch predictions with masks and bounding boxes."""
+        plot_images(
+            batch["img"],
+            *output_to_target(preds[0], max_det=15),  # not set to self.args.max_det due to slow plotting speed
+            torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
+            paths=batch["im_file"],
+            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
+            names=self.names,
+            on_plot=self.on_plot,
+        )  # pred
+        self.plot_masks.clear()
+
+    def save_one_txt(self, predn, pred_masks, save_conf, shape, file):
+        """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
+        from ultralytics.engine.results import Results
+
+        Results(
+            np.zeros((shape[0], shape[1]), dtype=np.uint8),
+            path=None,
+            names=self.names,
+            boxes=predn[:, :6],
+            masks=pred_masks,
+        ).save_txt(file, save_conf=save_conf)
+
+    def pred_to_json(self, predn, filename, pred_masks):
+        """
+        Save one JSON result.
+
+        Examples:
+             >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
+        """
+        from pycocotools.mask import encode  # noqa
+
+        def single_encode(x):
+            """Encode predicted masks as RLE and append results to jdict."""
+            rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
+            rle["counts"] = rle["counts"].decode("utf-8")
+            return rle
+
+        stem = Path(filename).stem
+        image_id = int(stem) if stem.isnumeric() else stem
+        box = ops.xyxy2xywh(predn[:, :4])  # xywh
+        box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner
+        pred_masks = np.transpose(pred_masks, (2, 0, 1))
+        with ThreadPool(NUM_THREADS) as pool:
+            rles = pool.map(single_encode, pred_masks)
+        for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
+            self.jdict.append(
+                {
+                    "image_id": image_id,
+                    "category_id": self.class_map[int(p[5])],
+                    "bbox": [round(x, 3) for x in b],
+                    "score": round(p[4], 5),
+                    "segmentation": rles[i],
+                }
+            )
+
+    def eval_json(self, stats):
+        """Return COCO-style object detection evaluation metrics."""
+        if self.args.save_json and self.is_coco and len(self.jdict):
+            anno_json = self.data["path"] / "annotations/instances_val2017.json"  # annotations
+            pred_json = self.save_dir / "predictions.json"  # predictions
+            LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
+            try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
+                check_requirements("pycocotools>=2.0.6")
+                from pycocotools.coco import COCO  # noqa
+                from pycocotools.cocoeval import COCOeval  # noqa
+
+                for x in anno_json, pred_json:
+                    assert x.is_file(), f"{x} file not found"
+                anno = COCO(str(anno_json))  # init annotations api
+                pred = anno.loadRes(str(pred_json))  # init predictions api (must pass string, not Path)
+                for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]):
+                    if self.is_coco:
+                        eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files]  # im to eval
+                    eval.evaluate()
+                    eval.accumulate()
+                    eval.summarize()
+                    idx = i * 4 + 2
+                    stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
+                        :2
+                    ]  # update mAP50-95 and mAP50
+            except Exception as e:
+                LOGGER.warning(f"pycocotools unable to run: {e}")
+        return stats

+ 5 - 0
ultralytics/models/yolo/world/__init__.py

@@ -0,0 +1,5 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from .train import WorldTrainer
+
+__all__ = ["WorldTrainer"]

+ 92 - 0
ultralytics/models/yolo/world/train.py

@@ -0,0 +1,92 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import itertools
+
+from ultralytics.data import build_yolo_dataset
+from ultralytics.models import yolo
+from ultralytics.nn.tasks import WorldModel
+from ultralytics.utils import DEFAULT_CFG, RANK, checks
+from ultralytics.utils.torch_utils import de_parallel
+
+
+def on_pretrain_routine_end(trainer):
+    """Callback."""
+    if RANK in {-1, 0}:
+        # NOTE: for evaluation
+        names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
+        de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
+    device = next(trainer.model.parameters()).device
+    trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device)
+    for p in trainer.text_model.parameters():
+        p.requires_grad_(False)
+
+
+class WorldTrainer(yolo.detect.DetectionTrainer):
+    """
+    A class to fine-tune a world model on a close-set dataset.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.world import WorldModel
+
+        args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
+        trainer = WorldTrainer(overrides=args)
+        trainer.train()
+        ```
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """Initialize a WorldTrainer object with given arguments."""
+        if overrides is None:
+            overrides = {}
+        super().__init__(cfg, overrides, _callbacks)
+
+        # Import and assign clip
+        try:
+            import clip
+        except ImportError:
+            checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
+            import clip
+        self.clip = clip
+
+    def get_model(self, cfg=None, weights=None, verbose=True):
+        """Return WorldModel initialized with specified config and weights."""
+        # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
+        # NOTE: Following the official config, nc hard-coded to 80 for now.
+        model = WorldModel(
+            cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
+            ch=3,
+            nc=min(self.data["nc"], 80),
+            verbose=verbose and RANK == -1,
+        )
+        if weights:
+            model.load(weights)
+        self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
+
+        return model
+
+    def build_dataset(self, img_path, mode="train", batch=None):
+        """
+        Build YOLO Dataset.
+
+        Args:
+            img_path (str): Path to the folder containing images.
+            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
+            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
+        """
+        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
+        return build_yolo_dataset(
+            self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
+        )
+
+    def preprocess_batch(self, batch):
+        """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
+        batch = super().preprocess_batch(batch)
+
+        # NOTE: add text features
+        texts = list(itertools.chain(*batch["texts"]))
+        text_token = self.clip.tokenize(texts).to(batch["img"].device)
+        txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype)  # torch.float32
+        txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
+        batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
+        return batch

+ 109 - 0
ultralytics/models/yolo/world/train_world.py

@@ -0,0 +1,109 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset
+from ultralytics.data.utils import check_det_dataset
+from ultralytics.models.yolo.world import WorldTrainer
+from ultralytics.utils import DEFAULT_CFG
+from ultralytics.utils.torch_utils import de_parallel
+
+
+class WorldTrainerFromScratch(WorldTrainer):
+    """
+    A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
+
+    Example:
+        ```python
+        from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
+        from ultralytics import YOLOWorld
+
+        data = dict(
+            train=dict(
+                yolo_data=["Objects365.yaml"],
+                grounding_data=[
+                    dict(
+                        img_path="../datasets/flickr30k/images",
+                        json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
+                    ),
+                    dict(
+                        img_path="../datasets/GQA/images",
+                        json_file="../datasets/GQA/final_mixed_train_no_coco.json",
+                    ),
+                ],
+            ),
+            val=dict(yolo_data=["lvis.yaml"]),
+        )
+
+        model = YOLOWorld("yolov8s-worldv2.yaml")
+        model.train(data=data, trainer=WorldTrainerFromScratch)
+        ```
+    """
+
+    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+        """Initialize a WorldTrainer object with given arguments."""
+        if overrides is None:
+            overrides = {}
+        super().__init__(cfg, overrides, _callbacks)
+
+    def build_dataset(self, img_path, mode="train", batch=None):
+        """
+        Build YOLO Dataset.
+
+        Args:
+            img_path (List[str] | str): Path to the folder containing images.
+            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
+            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
+        """
+        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
+        if mode != "train":
+            return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
+        dataset = [
+            build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
+            if isinstance(im_path, str)
+            else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
+            for im_path in img_path
+        ]
+        return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
+
+    def get_dataset(self):
+        """
+        Get train, val path from data dict if it exists.
+
+        Returns None if data format is not recognized.
+        """
+        final_data = {}
+        data_yaml = self.args.data
+        assert data_yaml.get("train", False), "train dataset not found"  # object365.yaml
+        assert data_yaml.get("val", False), "validation dataset not found"  # lvis.yaml
+        data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
+        assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
+        val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
+        for d in data["val"]:
+            if d.get("minival") is None:  # for lvis dataset
+                continue
+            d["minival"] = str(d["path"] / d["minival"])
+        for s in ["train", "val"]:
+            final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
+            # save grounding data if there's one
+            grounding_data = data_yaml[s].get("grounding_data")
+            if grounding_data is None:
+                continue
+            grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]
+            for g in grounding_data:
+                assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
+            final_data[s] += grounding_data
+        # NOTE: to make training work properly, set `nc` and `names`
+        final_data["nc"] = data["val"][0]["nc"]
+        final_data["names"] = data["val"][0]["names"]
+        self.data = final_data
+        return final_data["train"], final_data["val"][0]
+
+    def plot_training_labels(self):
+        """DO NOT plot labels."""
+        pass
+
+    def final_eval(self):
+        """Performs final evaluation and validation for object detection YOLO-World model."""
+        val = self.args.data["val"]["yolo_data"][0]
+        self.validator.args.data = val
+        self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
+        return super().final_eval()