Jelajahi Sumber

Add files via upload

Mengqi Lei 2 bulan lalu
induk
melakukan
7acd8396ff

+ 1331 - 0
ultralytics/utils/__init__.py

@@ -0,0 +1,1331 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import contextlib
+import importlib.metadata
+import inspect
+import json
+import logging.config
+import os
+import platform
+import re
+import subprocess
+import sys
+import threading
+import time
+import uuid
+from pathlib import Path
+from threading import Lock
+from types import SimpleNamespace
+from typing import Union
+from urllib.parse import unquote
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import yaml
+from tqdm import tqdm as tqdm_original
+
+from ultralytics import __version__
+
+# PyTorch Multi-GPU DDP Constants
+RANK = int(os.getenv("RANK", -1))
+LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1))  # https://pytorch.org/docs/stable/elastic/run.html
+
+# Other Constants
+ARGV = sys.argv or ["", ""]  # sometimes sys.argv = []
+FILE = Path(__file__).resolve()
+ROOT = FILE.parents[1]  # YOLO
+ASSETS = ROOT / "assets"  # default images
+ASSETS_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0"  # assets GitHub URL
+DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml"
+DEFAULT_SOL_CFG_PATH = ROOT / "cfg/solutions/default.yaml"  # Ultralytics solutions yaml path
+NUM_THREADS = min(8, max(1, os.cpu_count() - 1))  # number of YOLO multiprocessing threads
+AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true"  # global auto-install mode
+VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true"  # global verbose mode
+TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None  # tqdm bar format
+LOGGING_NAME = "ultralytics"
+MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"])  # environment booleans
+ARM64 = platform.machine() in {"arm64", "aarch64"}  # ARM64 booleans
+PYTHON_VERSION = platform.python_version()
+TORCH_VERSION = torch.__version__
+TORCHVISION_VERSION = importlib.metadata.version("torchvision")  # faster than importing torchvision
+IS_VSCODE = os.environ.get("TERM_PROGRAM", False) == "vscode"
+HELP_MSG = """
+    Examples for running Ultralytics:
+
+    1. Install the ultralytics package:
+
+        pip install ultralytics
+
+    2. Use the Python SDK:
+
+        from ultralytics import YOLO
+
+        # Load a model
+        model = YOLO("yolo11n.yaml")  # build a new model from scratch
+        model = YOLO("yolo11n.pt")  # load a pretrained model (recommended for training)
+
+        # Use the model
+        results = model.train(data="coco8.yaml", epochs=3)  # train the model
+        results = model.val()  # evaluate model performance on the validation set
+        results = model("https://ultralytics.com/images/bus.jpg")  # predict on an image
+        success = model.export(format="onnx")  # export the model to ONNX format
+
+    3. Use the command line interface (CLI):
+
+        Ultralytics 'yolo' CLI commands use the following syntax:
+
+            yolo TASK MODE ARGS
+
+            Where   TASK (optional) is one of [detect, segment, classify, pose, obb]
+                    MODE (required) is one of [train, val, predict, export, track, benchmark]
+                    ARGS (optional) are any number of custom "arg=value" pairs like "imgsz=320" that override defaults.
+                        See all ARGS at https://docs.ultralytics.com/usage/cfg or with "yolo cfg"
+
+        - Train a detection model for 10 epochs with an initial learning_rate of 0.01
+            yolo detect train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01
+
+        - Predict a YouTube video using a pretrained segmentation model at image size 320:
+            yolo segment predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320
+
+        - Val a pretrained detection model at batch-size 1 and image size 640:
+            yolo detect val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640
+
+        - Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required)
+            yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128
+
+        - Run special commands:
+            yolo help
+            yolo checks
+            yolo version
+            yolo settings
+            yolo copy-cfg
+            yolo cfg
+
+    Docs: https://docs.ultralytics.com
+    Community: https://community.ultralytics.com
+    GitHub: https://github.com/ultralytics/ultralytics
+    """
+
+# Settings and Environment Variables
+torch.set_printoptions(linewidth=320, precision=4, profile="default")
+np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format})  # format short g, %precision=5
+cv2.setNumThreads(0)  # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
+os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS)  # NumExpr max threads
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # for deterministic training to avoid CUDA warning
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # suppress verbose TF compiler warnings in Colab
+os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"  # suppress "NNPACK.cpp could not initialize NNPACK" warnings
+os.environ["KINETO_LOG_LEVEL"] = "5"  # suppress verbose PyTorch profiler output when computing FLOPs
+
+
+class TQDM(tqdm_original):
+    """
+    A custom TQDM progress bar class that extends the original tqdm functionality.
+
+    This class modifies the behavior of the original tqdm progress bar based on global settings and provides
+    additional customization options.
+
+    Attributes:
+        disable (bool): Whether to disable the progress bar. Determined by the global VERBOSE setting and
+            any passed 'disable' argument.
+        bar_format (str): The format string for the progress bar. Uses the global TQDM_BAR_FORMAT if not
+            explicitly set.
+
+    Methods:
+        __init__: Initializes the TQDM object with custom settings.
+
+    Examples:
+        >>> from ultralytics.utils import TQDM
+        >>> for i in TQDM(range(100)):
+        ...     # Your processing code here
+        ...     pass
+    """
+
+    def __init__(self, *args, **kwargs):
+        """
+        Initializes a custom TQDM progress bar.
+
+        This class extends the original tqdm class to provide customized behavior for Ultralytics projects.
+
+        Args:
+            *args (Any): Variable length argument list to be passed to the original tqdm constructor.
+            **kwargs (Any): Arbitrary keyword arguments to be passed to the original tqdm constructor.
+
+        Notes:
+            - The progress bar is disabled if VERBOSE is False or if 'disable' is explicitly set to True in kwargs.
+            - The default bar format is set to TQDM_BAR_FORMAT unless overridden in kwargs.
+
+        Examples:
+            >>> from ultralytics.utils import TQDM
+            >>> for i in TQDM(range(100)):
+            ...     # Your code here
+            ...     pass
+        """
+        kwargs["disable"] = not VERBOSE or kwargs.get("disable", False)  # logical 'and' with default value if passed
+        kwargs.setdefault("bar_format", TQDM_BAR_FORMAT)  # override default value if passed
+        super().__init__(*args, **kwargs)
+
+
+class SimpleClass:
+    """
+    A simple base class for creating objects with string representations of their attributes.
+
+    This class provides a foundation for creating objects that can be easily printed or represented as strings,
+    showing all their non-callable attributes. It's useful for debugging and introspection of object states.
+
+    Methods:
+        __str__: Returns a human-readable string representation of the object.
+        __repr__: Returns a machine-readable string representation of the object.
+        __getattr__: Provides a custom attribute access error message with helpful information.
+
+    Examples:
+        >>> class MyClass(SimpleClass):
+        ...     def __init__(self):
+        ...         self.x = 10
+        ...         self.y = "hello"
+        >>> obj = MyClass()
+        >>> print(obj)
+        __main__.MyClass object with attributes:
+
+        x: 10
+        y: 'hello'
+
+    Notes:
+        - This class is designed to be subclassed. It provides a convenient way to inspect object attributes.
+        - The string representation includes the module and class name of the object.
+        - Callable attributes and attributes starting with an underscore are excluded from the string representation.
+    """
+
+    def __str__(self):
+        """Return a human-readable string representation of the object."""
+        attr = []
+        for a in dir(self):
+            v = getattr(self, a)
+            if not callable(v) and not a.startswith("_"):
+                if isinstance(v, SimpleClass):
+                    # Display only the module and class name for subclasses
+                    s = f"{a}: {v.__module__}.{v.__class__.__name__} object"
+                else:
+                    s = f"{a}: {repr(v)}"
+                attr.append(s)
+        return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr)
+
+    def __repr__(self):
+        """Return a machine-readable string representation of the object."""
+        return self.__str__()
+
+    def __getattr__(self, attr):
+        """Custom attribute access error message with helpful information."""
+        name = self.__class__.__name__
+        raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
+
+
+class IterableSimpleNamespace(SimpleNamespace):
+    """
+    An iterable SimpleNamespace class that provides enhanced functionality for attribute access and iteration.
+
+    This class extends the SimpleNamespace class with additional methods for iteration, string representation,
+    and attribute access. It is designed to be used as a convenient container for storing and accessing
+    configuration parameters.
+
+    Methods:
+        __iter__: Returns an iterator of key-value pairs from the namespace's attributes.
+        __str__: Returns a human-readable string representation of the object.
+        __getattr__: Provides a custom attribute access error message with helpful information.
+        get: Retrieves the value of a specified key, or a default value if the key doesn't exist.
+
+    Examples:
+        >>> cfg = IterableSimpleNamespace(a=1, b=2, c=3)
+        >>> for k, v in cfg:
+        ...     print(f"{k}: {v}")
+        a: 1
+        b: 2
+        c: 3
+        >>> print(cfg)
+        a=1
+        b=2
+        c=3
+        >>> cfg.get("b")
+        2
+        >>> cfg.get("d", "default")
+        'default'
+
+    Notes:
+        This class is particularly useful for storing configuration parameters in a more accessible
+        and iterable format compared to a standard dictionary.
+    """
+
+    def __iter__(self):
+        """Return an iterator of key-value pairs from the namespace's attributes."""
+        return iter(vars(self).items())
+
+    def __str__(self):
+        """Return a human-readable string representation of the object."""
+        return "\n".join(f"{k}={v}" for k, v in vars(self).items())
+
+    def __getattr__(self, attr):
+        """Custom attribute access error message with helpful information."""
+        name = self.__class__.__name__
+        raise AttributeError(
+            f"""
+            '{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics
+            'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace
+            {DEFAULT_CFG_PATH} with the latest version from
+            https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml
+            """
+        )
+
+    def get(self, key, default=None):
+        """Return the value of the specified key if it exists; otherwise, return the default value."""
+        return getattr(self, key, default)
+
+
+def plt_settings(rcparams=None, backend="Agg"):
+    """
+    Decorator to temporarily set rc parameters and the backend for a plotting function.
+
+    Example:
+        decorator: @plt_settings({"font.size": 12})
+        context manager: with plt_settings({"font.size": 12}):
+
+    Args:
+        rcparams (dict): Dictionary of rc parameters to set.
+        backend (str, optional): Name of the backend to use. Defaults to 'Agg'.
+
+    Returns:
+        (Callable): Decorated function with temporarily set rc parameters and backend. This decorator can be
+            applied to any function that needs to have specific matplotlib rc parameters and backend for its execution.
+    """
+    if rcparams is None:
+        rcparams = {"font.size": 11}
+
+    def decorator(func):
+        """Decorator to apply temporary rc parameters and backend to a function."""
+
+        def wrapper(*args, **kwargs):
+            """Sets rc parameters and backend, calls the original function, and restores the settings."""
+            original_backend = plt.get_backend()
+            switch = backend.lower() != original_backend.lower()
+            if switch:
+                plt.close("all")  # auto-close()ing of figures upon backend switching is deprecated since 3.8
+                plt.switch_backend(backend)
+
+            # Plot with backend and always revert to original backend
+            try:
+                with plt.rc_context(rcparams):
+                    result = func(*args, **kwargs)
+            finally:
+                if switch:
+                    plt.close("all")
+                    plt.switch_backend(original_backend)
+            return result
+
+        return wrapper
+
+    return decorator
+
+
+def set_logging(name="LOGGING_NAME", verbose=True):
+    """
+    Sets up logging with UTF-8 encoding and configurable verbosity.
+
+    This function configures logging for the Ultralytics library, setting the appropriate logging level and
+    formatter based on the verbosity flag and the current process rank. It handles special cases for Windows
+    environments where UTF-8 encoding might not be the default.
+
+    Args:
+        name (str): Name of the logger. Defaults to "LOGGING_NAME".
+        verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise. Defaults to True.
+
+    Examples:
+        >>> set_logging(name="ultralytics", verbose=True)
+        >>> logger = logging.getLogger("ultralytics")
+        >>> logger.info("This is an info message")
+
+    Notes:
+        - On Windows, this function attempts to reconfigure stdout to use UTF-8 encoding if possible.
+        - If reconfiguration is not possible, it falls back to a custom formatter that handles non-UTF-8 environments.
+        - The function sets up a StreamHandler with the appropriate formatter and level.
+        - The logger's propagate flag is set to False to prevent duplicate logging in parent loggers.
+    """
+    level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR  # rank in world for Multi-GPU trainings
+
+    # Configure the console (stdout) encoding to UTF-8, with checks for compatibility
+    formatter = logging.Formatter("%(message)s")  # Default formatter
+    if WINDOWS and hasattr(sys.stdout, "encoding") and sys.stdout.encoding != "utf-8":
+
+        class CustomFormatter(logging.Formatter):
+            def format(self, record):
+                """Sets up logging with UTF-8 encoding and configurable verbosity."""
+                return emojis(super().format(record))
+
+        try:
+            # Attempt to reconfigure stdout to use UTF-8 encoding if possible
+            if hasattr(sys.stdout, "reconfigure"):
+                sys.stdout.reconfigure(encoding="utf-8")
+            # For environments where reconfigure is not available, wrap stdout in a TextIOWrapper
+            elif hasattr(sys.stdout, "buffer"):
+                import io
+
+                sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
+            else:
+                formatter = CustomFormatter("%(message)s")
+        except Exception as e:
+            print(f"Creating custom formatter for non UTF-8 environments due to {e}")
+            formatter = CustomFormatter("%(message)s")
+
+    # Create and configure the StreamHandler with the appropriate formatter and level
+    stream_handler = logging.StreamHandler(sys.stdout)
+    stream_handler.setFormatter(formatter)
+    stream_handler.setLevel(level)
+
+    # Set up the logger
+    logger = logging.getLogger(name)
+    logger.setLevel(level)
+    logger.addHandler(stream_handler)
+    logger.propagate = False
+    return logger
+
+
+# Set logger
+LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE)  # define globally (used in train.py, val.py, predict.py, etc.)
+for logger in "sentry_sdk", "urllib3.connectionpool":
+    logging.getLogger(logger).setLevel(logging.CRITICAL + 1)
+
+
+def emojis(string=""):
+    """Return platform-dependent emoji-safe version of string."""
+    return string.encode().decode("ascii", "ignore") if WINDOWS else string
+
+
+class ThreadingLocked:
+    """
+    A decorator class for ensuring thread-safe execution of a function or method. This class can be used as a decorator
+    to make sure that if the decorated function is called from multiple threads, only one thread at a time will be able
+    to execute the function.
+
+    Attributes:
+        lock (threading.Lock): A lock object used to manage access to the decorated function.
+
+    Example:
+        ```python
+        from ultralytics.utils import ThreadingLocked
+
+        @ThreadingLocked()
+        def my_function():
+            # Your code here
+        ```
+    """
+
+    def __init__(self):
+        """Initializes the decorator class for thread-safe execution of a function or method."""
+        self.lock = threading.Lock()
+
+    def __call__(self, f):
+        """Run thread-safe execution of function or method."""
+        from functools import wraps
+
+        @wraps(f)
+        def decorated(*args, **kwargs):
+            """Applies thread-safety to the decorated function or method."""
+            with self.lock:
+                return f(*args, **kwargs)
+
+        return decorated
+
+
+def yaml_save(file="data.yaml", data=None, header=""):
+    """
+    Save YAML data to a file.
+
+    Args:
+        file (str, optional): File name. Default is 'data.yaml'.
+        data (dict): Data to save in YAML format.
+        header (str, optional): YAML header to add.
+
+    Returns:
+        (None): Data is saved to the specified file.
+    """
+    if data is None:
+        data = {}
+    file = Path(file)
+    if not file.parent.exists():
+        # Create parent directories if they don't exist
+        file.parent.mkdir(parents=True, exist_ok=True)
+
+    # Convert Path objects to strings
+    valid_types = int, float, str, bool, list, tuple, dict, type(None)
+    for k, v in data.items():
+        if not isinstance(v, valid_types):
+            data[k] = str(v)
+
+    # Dump data to file in YAML format
+    with open(file, "w", errors="ignore", encoding="utf-8") as f:
+        if header:
+            f.write(header)
+        yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)
+
+
+def yaml_load(file="data.yaml", append_filename=False):
+    """
+    Load YAML data from a file.
+
+    Args:
+        file (str, optional): File name. Default is 'data.yaml'.
+        append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
+
+    Returns:
+        (dict): YAML data and file name.
+    """
+    assert Path(file).suffix in {".yaml", ".yml"}, f"Attempting to load non-YAML file {file} with yaml_load()"
+    with open(file, errors="ignore", encoding="utf-8") as f:
+        s = f.read()  # string
+
+        # Remove special characters
+        if not s.isprintable():
+            s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s)
+
+        # Add YAML filename to dict and return
+        data = yaml.safe_load(s) or {}  # always return a dict (yaml.safe_load() may return None for empty files)
+        if append_filename:
+            data["yaml_file"] = str(file)
+        return data
+
+
+def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
+    """
+    Pretty prints a YAML file or a YAML-formatted dictionary.
+
+    Args:
+        yaml_file: The file path of the YAML file or a YAML-formatted dictionary.
+
+    Returns:
+        (None)
+    """
+    yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
+    dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True, width=float("inf"))
+    LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
+
+
+# Default configuration
+DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
+DEFAULT_SOL_DICT = yaml_load(DEFAULT_SOL_CFG_PATH)  # Ultralytics solutions configuration
+for k, v in DEFAULT_CFG_DICT.items():
+    if isinstance(v, str) and v.lower() == "none":
+        DEFAULT_CFG_DICT[k] = None
+DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
+DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
+
+
+def read_device_model() -> str:
+    """
+    Reads the device model information from the system and caches it for quick access. Used by is_jetson() and
+    is_raspberrypi().
+
+    Returns:
+        (str): Kernel release information.
+    """
+    return platform.release().lower()
+
+
+def is_ubuntu() -> bool:
+    """
+    Check if the OS is Ubuntu.
+
+    Returns:
+        (bool): True if OS is Ubuntu, False otherwise.
+    """
+    try:
+        with open("/etc/os-release") as f:
+            return "ID=ubuntu" in f.read()
+    except FileNotFoundError:
+        return False
+
+
+def is_colab():
+    """
+    Check if the current script is running inside a Google Colab notebook.
+
+    Returns:
+        (bool): True if running inside a Colab notebook, False otherwise.
+    """
+    return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ
+
+
+def is_kaggle():
+    """
+    Check if the current script is running inside a Kaggle kernel.
+
+    Returns:
+        (bool): True if running inside a Kaggle kernel, False otherwise.
+    """
+    return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
+
+
+def is_jupyter():
+    """
+    Check if the current script is running inside a Jupyter Notebook.
+
+    Returns:
+        (bool): True if running inside a Jupyter Notebook, False otherwise.
+
+    Note:
+        - Only works on Colab and Kaggle, other environments like Jupyterlab and Paperspace are not reliably detectable.
+        - "get_ipython" in globals() method suffers false positives when IPython package installed manually.
+    """
+    return IS_COLAB or IS_KAGGLE
+
+
+def is_runpod():
+    """
+    Check if the current script is running inside a RunPod container.
+
+    Returns:
+        (bool): True if running in RunPod, False otherwise.
+    """
+    return "RUNPOD_POD_ID" in os.environ
+
+
+def is_docker() -> bool:
+    """
+    Determine if the script is running inside a Docker container.
+
+    Returns:
+        (bool): True if the script is running inside a Docker container, False otherwise.
+    """
+    try:
+        with open("/proc/self/cgroup") as f:
+            return "docker" in f.read()
+    except Exception:
+        return False
+
+
+def is_raspberrypi() -> bool:
+    """
+    Determines if the Python environment is running on a Raspberry Pi by checking the device model information.
+
+    Returns:
+        (bool): True if running on a Raspberry Pi, False otherwise.
+    """
+    return "rpi" in DEVICE_MODEL
+
+
+def is_jetson() -> bool:
+    """
+    Determines if the Python environment is running on an NVIDIA Jetson device by checking the device model information.
+
+    Returns:
+        (bool): True if running on an NVIDIA Jetson device, False otherwise.
+    """
+    return "tegra" in DEVICE_MODEL
+
+
+def is_online() -> bool:
+    """
+    Check internet connectivity by attempting to connect to a known online host.
+
+    Returns:
+        (bool): True if connection is successful, False otherwise.
+    """
+    try:
+        assert str(os.getenv("YOLO_OFFLINE", "")).lower() != "true"  # check if ENV var YOLO_OFFLINE="True"
+        import socket
+
+        for dns in ("1.1.1.1", "8.8.8.8"):  # check Cloudflare and Google DNS
+            socket.create_connection(address=(dns, 80), timeout=2.0).close()
+            return True
+    except Exception:
+        return False
+
+
+def is_pip_package(filepath: str = __name__) -> bool:
+    """
+    Determines if the file at the given filepath is part of a pip package.
+
+    Args:
+        filepath (str): The filepath to check.
+
+    Returns:
+        (bool): True if the file is part of a pip package, False otherwise.
+    """
+    import importlib.util
+
+    # Get the spec for the module
+    spec = importlib.util.find_spec(filepath)
+
+    # Return whether the spec is not None and the origin is not None (indicating it is a package)
+    return spec is not None and spec.origin is not None
+
+
+def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
+    """
+    Check if a directory is writeable.
+
+    Args:
+        dir_path (str | Path): The path to the directory.
+
+    Returns:
+        (bool): True if the directory is writeable, False otherwise.
+    """
+    return os.access(str(dir_path), os.W_OK)
+
+
+def is_pytest_running():
+    """
+    Determines whether pytest is currently running or not.
+
+    Returns:
+        (bool): True if pytest is running, False otherwise.
+    """
+    return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(ARGV[0]).stem)
+
+
+def is_github_action_running() -> bool:
+    """
+    Determine if the current environment is a GitHub Actions runner.
+
+    Returns:
+        (bool): True if the current environment is a GitHub Actions runner, False otherwise.
+    """
+    return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ
+
+
+def get_git_dir():
+    """
+    Determines whether the current file is part of a git repository and if so, returns the repository root directory. If
+    the current file is not part of a git repository, returns None.
+
+    Returns:
+        (Path | None): Git root directory if found or None if not found.
+    """
+    for d in Path(__file__).parents:
+        if (d / ".git").is_dir():
+            return d
+
+
+def is_git_dir():
+    """
+    Determines whether the current file is part of a git repository. If the current file is not part of a git
+    repository, returns None.
+
+    Returns:
+        (bool): True if current file is part of a git repository.
+    """
+    return GIT_DIR is not None
+
+
+def get_git_origin_url():
+    """
+    Retrieves the origin URL of a git repository.
+
+    Returns:
+        (str | None): The origin URL of the git repository or None if not git directory.
+    """
+    if IS_GIT_DIR:
+        try:
+            origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
+            return origin.decode().strip()
+        except subprocess.CalledProcessError:
+            return None
+
+
+def get_git_branch():
+    """
+    Returns the current git branch name. If not in a git repository, returns None.
+
+    Returns:
+        (str | None): The current git branch name or None if not a git directory.
+    """
+    if IS_GIT_DIR:
+        try:
+            origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
+            return origin.decode().strip()
+        except subprocess.CalledProcessError:
+            return None
+
+
+def get_default_args(func):
+    """
+    Returns a dictionary of default arguments for a function.
+
+    Args:
+        func (callable): The function to inspect.
+
+    Returns:
+        (dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter.
+    """
+    signature = inspect.signature(func)
+    return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
+
+
+def get_ubuntu_version():
+    """
+    Retrieve the Ubuntu version if the OS is Ubuntu.
+
+    Returns:
+        (str): Ubuntu version or None if not an Ubuntu OS.
+    """
+    if is_ubuntu():
+        try:
+            with open("/etc/os-release") as f:
+                return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]
+        except (FileNotFoundError, AttributeError):
+            return None
+
+
+def get_user_config_dir(sub_dir="yolov13"):
+    """
+    Return the appropriate config directory based on the environment operating system.
+
+    Args:
+        sub_dir (str): The name of the subdirectory to create.
+
+    Returns:
+        (Path): The path to the user config directory.
+    """
+    if WINDOWS:
+        path = Path.home() / "AppData" / "Roaming" / sub_dir
+    elif MACOS:  # macOS
+        path = Path.home() / "Library" / "Application Support" / sub_dir
+    elif LINUX:
+        path = Path.home() / ".config" / sub_dir
+    else:
+        raise ValueError(f"Unsupported operating system: {platform.system()}")
+
+    # GCP and AWS lambda fix, only /tmp is writeable
+    if not is_dir_writeable(path.parent):
+        LOGGER.warning(
+            f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
+            "Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path."
+        )
+        path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir
+
+    # Create the subdirectory if it does not exist
+    path.mkdir(parents=True, exist_ok=True)
+
+    return path
+
+
+# Define constants (required below)
+DEVICE_MODEL = read_device_model()  # is_jetson() and is_raspberrypi() depend on this constant
+ONLINE = is_online()
+IS_COLAB = is_colab()
+IS_KAGGLE = is_kaggle()
+IS_DOCKER = is_docker()
+IS_JETSON = is_jetson()
+IS_JUPYTER = is_jupyter()
+IS_PIP_PACKAGE = is_pip_package()
+IS_RASPBERRYPI = is_raspberrypi()
+GIT_DIR = get_git_dir()
+IS_GIT_DIR = is_git_dir()
+USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir())  # Ultralytics settings dir
+SETTINGS_FILE = USER_CONFIG_DIR / "settings.json"
+
+
+def colorstr(*input):
+    r"""
+    Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes.
+    See https://en.wikipedia.org/wiki/ANSI_escape_code for more details.
+
+    This function can be called in two ways:
+        - colorstr('color', 'style', 'your string')
+        - colorstr('your string')
+
+    In the second form, 'blue' and 'bold' will be applied by default.
+
+    Args:
+        *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments,
+                      and the last string is the one to be colored.
+
+    Supported Colors and Styles:
+        Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'
+        Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',
+                       'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'
+        Misc: 'end', 'bold', 'underline'
+
+    Returns:
+        (str): The input string wrapped with ANSI escape codes for the specified color and style.
+
+    Examples:
+        >>> colorstr("blue", "bold", "hello world")
+        >>> "\033[34m\033[1mhello world\033[0m"
+    """
+    *args, string = input if len(input) > 1 else ("blue", "bold", input[0])  # color arguments, string
+    colors = {
+        "black": "\033[30m",  # basic colors
+        "red": "\033[31m",
+        "green": "\033[32m",
+        "yellow": "\033[33m",
+        "blue": "\033[34m",
+        "magenta": "\033[35m",
+        "cyan": "\033[36m",
+        "white": "\033[37m",
+        "bright_black": "\033[90m",  # bright colors
+        "bright_red": "\033[91m",
+        "bright_green": "\033[92m",
+        "bright_yellow": "\033[93m",
+        "bright_blue": "\033[94m",
+        "bright_magenta": "\033[95m",
+        "bright_cyan": "\033[96m",
+        "bright_white": "\033[97m",
+        "end": "\033[0m",  # misc
+        "bold": "\033[1m",
+        "underline": "\033[4m",
+    }
+    return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
+
+
+def remove_colorstr(input_string):
+    """
+    Removes ANSI escape codes from a string, effectively un-coloring it.
+
+    Args:
+        input_string (str): The string to remove color and style from.
+
+    Returns:
+        (str): A new string with all ANSI escape codes removed.
+
+    Examples:
+        >>> remove_colorstr(colorstr("blue", "bold", "hello world"))
+        >>> "hello world"
+    """
+    ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]")
+    return ansi_escape.sub("", input_string)
+
+
+class TryExcept(contextlib.ContextDecorator):
+    """
+    Ultralytics TryExcept class. Use as @TryExcept() decorator or 'with TryExcept():' context manager.
+
+    Examples:
+        As a decorator:
+        >>> @TryExcept(msg="Error occurred in func", verbose=True)
+        >>> def func():
+        >>> # Function logic here
+        >>>     pass
+
+        As a context manager:
+        >>> with TryExcept(msg="Error occurred in block", verbose=True):
+        >>> # Code block here
+        >>>     pass
+    """
+
+    def __init__(self, msg="", verbose=True):
+        """Initialize TryExcept class with optional message and verbosity settings."""
+        self.msg = msg
+        self.verbose = verbose
+
+    def __enter__(self):
+        """Executes when entering TryExcept context, initializes instance."""
+        pass
+
+    def __exit__(self, exc_type, value, traceback):
+        """Defines behavior when exiting a 'with' block, prints error message if necessary."""
+        if self.verbose and value:
+            print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
+        return True
+
+
+class Retry(contextlib.ContextDecorator):
+    """
+    Retry class for function execution with exponential backoff.
+
+    Can be used as a decorator to retry a function on exceptions, up to a specified number of times with an
+    exponentially increasing delay between retries.
+
+    Examples:
+        Example usage as a decorator:
+        >>> @Retry(times=3, delay=2)
+        >>> def test_func():
+        >>> # Replace with function logic that may raise exceptions
+        >>>     return True
+    """
+
+    def __init__(self, times=3, delay=2):
+        """Initialize Retry class with specified number of retries and delay."""
+        self.times = times
+        self.delay = delay
+        self._attempts = 0
+
+    def __call__(self, func):
+        """Decorator implementation for Retry with exponential backoff."""
+
+        def wrapped_func(*args, **kwargs):
+            """Applies retries to the decorated function or method."""
+            self._attempts = 0
+            while self._attempts < self.times:
+                try:
+                    return func(*args, **kwargs)
+                except Exception as e:
+                    self._attempts += 1
+                    print(f"Retry {self._attempts}/{self.times} failed: {e}")
+                    if self._attempts >= self.times:
+                        raise e
+                    time.sleep(self.delay * (2**self._attempts))  # exponential backoff delay
+
+        return wrapped_func
+
+
+def threaded(func):
+    """
+    Multi-threads a target function by default and returns the thread or function result.
+
+    Use as @threaded decorator. The function runs in a separate thread unless 'threaded=False' is passed.
+    """
+
+    def wrapper(*args, **kwargs):
+        """Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result."""
+        if kwargs.pop("threaded", True):  # run in thread
+            thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
+            thread.start()
+            return thread
+        else:
+            return func(*args, **kwargs)
+
+    return wrapper
+
+
+def set_sentry():
+    """
+    Initialize the Sentry SDK for error tracking and reporting. Only used if sentry_sdk package is installed and
+    sync=True in settings. Run 'yolo settings' to see and update settings.
+
+    Conditions required to send errors (ALL conditions must be met or no errors will be reported):
+        - sentry_sdk package is installed
+        - sync=True in YOLO settings
+        - pytest is not running
+        - running in a pip package installation
+        - running in a non-git directory
+        - running with rank -1 or 0
+        - online environment
+        - CLI used to run package (checked with 'yolo' as the name of the main CLI command)
+
+    The function also configures Sentry SDK to ignore KeyboardInterrupt and FileNotFoundError exceptions and to exclude
+    events with 'out of memory' in their exception message.
+
+    Additionally, the function sets custom tags and user information for Sentry events.
+    """
+    if (
+        not SETTINGS["sync"]
+        or RANK not in {-1, 0}
+        or Path(ARGV[0]).name != "yolo"
+        or TESTS_RUNNING
+        or not ONLINE
+        or not IS_PIP_PACKAGE
+        or IS_GIT_DIR
+    ):
+        return
+    # If sentry_sdk package is not installed then return and do not use Sentry
+    try:
+        import sentry_sdk  # noqa
+    except ImportError:
+        return
+
+    def before_send(event, hint):
+        """
+        Modify the event before sending it to Sentry based on specific exception types and messages.
+
+        Args:
+            event (dict): The event dictionary containing information about the error.
+            hint (dict): A dictionary containing additional information about the error.
+
+        Returns:
+            dict: The modified event or None if the event should not be sent to Sentry.
+        """
+        if "exc_info" in hint:
+            exc_type, exc_value, _ = hint["exc_info"]
+            if exc_type in {KeyboardInterrupt, FileNotFoundError} or "out of memory" in str(exc_value):
+                return None  # do not send event
+
+        event["tags"] = {
+            "sys_argv": ARGV[0],
+            "sys_argv_name": Path(ARGV[0]).name,
+            "install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other",
+            "os": ENVIRONMENT,
+        }
+        return event
+
+    sentry_sdk.init(
+        dsn="https://888e5a0778212e1d0314c37d4b9aae5d@o4504521589325824.ingest.us.sentry.io/4504521592406016",
+        debug=False,
+        auto_enabling_integrations=False,
+        traces_sample_rate=1.0,
+        release=__version__,
+        environment="runpod" if is_runpod() else "production",
+        before_send=before_send,
+        ignore_errors=[KeyboardInterrupt, FileNotFoundError],
+    )
+    sentry_sdk.set_user({"id": SETTINGS["uuid"]})  # SHA-256 anonymized UUID hash
+
+
+class JSONDict(dict):
+    """
+    A dictionary-like class that provides JSON persistence for its contents.
+
+    This class extends the built-in dictionary to automatically save its contents to a JSON file whenever they are
+    modified. It ensures thread-safe operations using a lock.
+
+    Attributes:
+        file_path (Path): The path to the JSON file used for persistence.
+        lock (threading.Lock): A lock object to ensure thread-safe operations.
+
+    Methods:
+        _load: Loads the data from the JSON file into the dictionary.
+        _save: Saves the current state of the dictionary to the JSON file.
+        __setitem__: Stores a key-value pair and persists it to disk.
+        __delitem__: Removes an item and updates the persistent storage.
+        update: Updates the dictionary and persists changes.
+        clear: Clears all entries and updates the persistent storage.
+
+    Examples:
+        >>> json_dict = JSONDict("data.json")
+        >>> json_dict["key"] = "value"
+        >>> print(json_dict["key"])
+        value
+        >>> del json_dict["key"]
+        >>> json_dict.update({"new_key": "new_value"})
+        >>> json_dict.clear()
+    """
+
+    def __init__(self, file_path: Union[str, Path] = "data.json"):
+        """Initialize a JSONDict object with a specified file path for JSON persistence."""
+        super().__init__()
+        self.file_path = Path(file_path)
+        self.lock = Lock()
+        self._load()
+
+    def _load(self):
+        """Load the data from the JSON file into the dictionary."""
+        try:
+            if self.file_path.exists():
+                with open(self.file_path) as f:
+                    self.update(json.load(f))
+        except json.JSONDecodeError:
+            print(f"Error decoding JSON from {self.file_path}. Starting with an empty dictionary.")
+        except Exception as e:
+            print(f"Error reading from {self.file_path}: {e}")
+
+    def _save(self):
+        """Save the current state of the dictionary to the JSON file."""
+        try:
+            self.file_path.parent.mkdir(parents=True, exist_ok=True)
+            with open(self.file_path, "w") as f:
+                json.dump(dict(self), f, indent=2, default=self._json_default)
+        except Exception as e:
+            print(f"Error writing to {self.file_path}: {e}")
+
+    @staticmethod
+    def _json_default(obj):
+        """Handle JSON serialization of Path objects."""
+        if isinstance(obj, Path):
+            return str(obj)
+        raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
+
+    def __setitem__(self, key, value):
+        """Store a key-value pair and persist to disk."""
+        with self.lock:
+            super().__setitem__(key, value)
+            self._save()
+
+    def __delitem__(self, key):
+        """Remove an item and update the persistent storage."""
+        with self.lock:
+            super().__delitem__(key)
+            self._save()
+
+    def __str__(self):
+        """Return a pretty-printed JSON string representation of the dictionary."""
+        contents = json.dumps(dict(self), indent=2, ensure_ascii=False, default=self._json_default)
+        return f'JSONDict("{self.file_path}"):\n{contents}'
+
+    def update(self, *args, **kwargs):
+        """Update the dictionary and persist changes."""
+        with self.lock:
+            super().update(*args, **kwargs)
+            self._save()
+
+    def clear(self):
+        """Clear all entries and update the persistent storage."""
+        with self.lock:
+            super().clear()
+            self._save()
+
+
+class SettingsManager(JSONDict):
+    """
+    SettingsManager class for managing and persisting Ultralytics settings.
+
+    This class extends JSONDict to provide JSON persistence for settings, ensuring thread-safe operations and default
+    values. It validates settings on initialization and provides methods to update or reset settings.
+
+    Attributes:
+        file (Path): The path to the JSON file used for persistence.
+        version (str): The version of the settings schema.
+        defaults (Dict): A dictionary containing default settings.
+        help_msg (str): A help message for users on how to view and update settings.
+
+    Methods:
+        _validate_settings: Validates the current settings and resets if necessary.
+        update: Updates settings, validating keys and types.
+        reset: Resets the settings to default and saves them.
+
+    Examples:
+        Initialize and update settings:
+        >>> settings = SettingsManager()
+        >>> settings.update(runs_dir="/new/runs/dir")
+        >>> print(settings["runs_dir"])
+        /new/runs/dir
+    """
+
+    def __init__(self, file=SETTINGS_FILE, version="0.0.6"):
+        """Initializes the SettingsManager with default settings and loads user settings."""
+        import hashlib
+
+        from ultralytics.utils.torch_utils import torch_distributed_zero_first
+
+        root = GIT_DIR or Path()
+        datasets_root = (root.parent if GIT_DIR and is_dir_writeable(root.parent) else root).resolve()
+
+        self.file = Path(file)
+        self.version = version
+        self.defaults = {
+            "settings_version": version,  # Settings schema version
+            "datasets_dir": str(datasets_root / "datasets"),  # Datasets directory
+            "weights_dir": str(root / "weights"),  # Model weights directory
+            "runs_dir": str(root / "runs"),  # Experiment runs directory
+            "uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),  # SHA-256 anonymized UUID hash
+            "sync": True,  # Enable synchronization
+            "api_key": "",  # Ultralytics API Key
+            "openai_api_key": "",  # OpenAI API Key
+            "clearml": True,  # ClearML integration
+            "comet": True,  # Comet integration
+            "dvc": True,  # DVC integration
+            "hub": True,  # Ultralytics HUB integration
+            "mlflow": True,  # MLflow integration
+            "neptune": True,  # Neptune integration
+            "raytune": True,  # Ray Tune integration
+            "tensorboard": True,  # TensorBoard logging
+            "wandb": False,  # Weights & Biases logging
+            "vscode_msg": True,  # VSCode messaging
+        }
+
+        self.help_msg = (
+            f"\nView Ultralytics Settings with 'yolo settings' or at '{self.file}'"
+            "\nUpdate Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. "
+            "For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings."
+        )
+
+        with torch_distributed_zero_first(RANK):
+            super().__init__(self.file)
+
+            if not self.file.exists() or not self:  # Check if file doesn't exist or is empty
+                LOGGER.info(f"Creating new Ultralytics Settings v{version} file ✅ {self.help_msg}")
+                self.reset()
+
+            self._validate_settings()
+
+    def _validate_settings(self):
+        """Validate the current settings and reset if necessary."""
+        correct_keys = set(self.keys()) == set(self.defaults.keys())
+        correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items())
+        correct_version = self.get("settings_version", "") == self.version
+
+        if not (correct_keys and correct_types and correct_version):
+            LOGGER.warning(
+                "WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem "
+                f"with your settings or a recent ultralytics package update. {self.help_msg}"
+            )
+            self.reset()
+
+        if self.get("datasets_dir") == self.get("runs_dir"):
+            LOGGER.warning(
+                f"WARNING ⚠️ Ultralytics setting 'datasets_dir: {self.get('datasets_dir')}' "
+                f"must be different than 'runs_dir: {self.get('runs_dir')}'. "
+                f"Please change one to avoid possible issues during training. {self.help_msg}"
+            )
+
+    def __setitem__(self, key, value):
+        """Updates one key: value pair."""
+        self.update({key: value})
+
+    def update(self, *args, **kwargs):
+        """Updates settings, validating keys and types."""
+        for arg in args:
+            if isinstance(arg, dict):
+                kwargs.update(arg)
+        for k, v in kwargs.items():
+            if k not in self.defaults:
+                raise KeyError(f"No Ultralytics setting '{k}'. {self.help_msg}")
+            t = type(self.defaults[k])
+            if not isinstance(v, t):
+                raise TypeError(
+                    f"Ultralytics setting '{k}' must be '{t.__name__}' type, not '{type(v).__name__}'. {self.help_msg}"
+                )
+        super().update(*args, **kwargs)
+
+    def reset(self):
+        """Resets the settings to default and saves them."""
+        self.clear()
+        self.update(self.defaults)
+
+
+def deprecation_warn(arg, new_arg=None):
+    """Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
+    msg = f"WARNING ⚠️ '{arg}' is deprecated and will be removed in in the future."
+    if new_arg is not None:
+        msg += f" Use '{new_arg}' instead."
+    LOGGER.warning(msg)
+
+
+def clean_url(url):
+    """Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
+    url = Path(url).as_posix().replace(":/", "://")  # Pathlib turns :// -> :/, as_posix() for Windows
+    return unquote(url).split("?")[0]  # '%2F' to '/', split https://url.com/file.txt?auth
+
+
+def url2file(url):
+    """Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt."""
+    return Path(clean_url(url)).name
+
+
+def vscode_msg(ext="ultralytics.ultralytics-snippets") -> str:
+    """Display a message to install Ultralytics-Snippets for VS Code if not already installed."""
+    path = (USER_CONFIG_DIR.parents[2] if WINDOWS else USER_CONFIG_DIR.parents[1]) / ".vscode/extensions"
+    obs_file = path / ".obsolete"  # file tracks uninstalled extensions, while source directory remains
+    installed = any(path.glob(f"{ext}*")) and ext not in (obs_file.read_text("utf-8") if obs_file.exists() else "")
+    url = "https://docs.ultralytics.com/integrations/vscode"
+    return "" if installed else f"{colorstr('VS Code:')} view Ultralytics VS Code Extension ⚡ at {url}"
+
+
+# Run below code on utils init ------------------------------------------------------------------------------------
+
+# Check first-install steps
+PREFIX = colorstr("Ultralytics: ")
+SETTINGS = SettingsManager()  # initialize settings
+PERSISTENT_CACHE = JSONDict(USER_CONFIG_DIR / "persistent_cache.json")  # initialize persistent cache
+DATASETS_DIR = Path(SETTINGS["datasets_dir"])  # global datasets directory
+WEIGHTS_DIR = Path(SETTINGS["weights_dir"])  # global weights directory
+RUNS_DIR = Path(SETTINGS["runs_dir"])  # global runs directory
+ENVIRONMENT = (
+    "Colab"
+    if IS_COLAB
+    else "Kaggle"
+    if IS_KAGGLE
+    else "Jupyter"
+    if IS_JUPYTER
+    else "Docker"
+    if IS_DOCKER
+    else platform.system()
+)
+TESTS_RUNNING = is_pytest_running() or is_github_action_running()
+set_sentry()
+
+# Apply monkey patches
+from ultralytics.utils.patches import imread, imshow, imwrite, torch_load, torch_save
+
+torch.load = torch_load
+torch.save = torch_save
+if WINDOWS:
+    # Apply cv2 patches for non-ASCII and non-UTF characters in image paths
+    cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow

+ 106 - 0
ultralytics/utils/autobatch.py

@@ -0,0 +1,106 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
+
+import os
+from copy import deepcopy
+
+import numpy as np
+import torch
+
+from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
+from ultralytics.utils.torch_utils import autocast, profile
+
+
+def check_train_batch_size(model, imgsz=640, amp=True, batch=-1, max_num_obj=1):
+    """
+    Compute optimal YOLO training batch size using the autobatch() function.
+
+    Args:
+        model (torch.nn.Module): YOLO model to check batch size for.
+        imgsz (int, optional): Image size used for training.
+        amp (bool, optional): Use automatic mixed precision if True.
+        batch (float, optional): Fraction of GPU memory to use. If -1, use default.
+        max_num_obj (int, optional): The maximum number of objects from dataset.
+
+    Returns:
+        (int): Optimal batch size computed using the autobatch() function.
+
+    Note:
+        If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.
+        Otherwise, a default fraction of 0.6 is used.
+    """
+    with autocast(enabled=amp):
+        return autobatch(
+            deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj
+        )
+
+
+def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch, max_num_obj=1):
+    """
+    Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
+
+    Args:
+        model (torch.nn.module): YOLO model to compute batch size for.
+        imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
+        fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.60.
+        batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
+        max_num_obj (int, optional): The maximum number of objects from dataset.
+
+    Returns:
+        (int): The optimal batch size.
+    """
+    # Check device
+    prefix = colorstr("AutoBatch: ")
+    LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")
+    device = next(model.parameters()).device  # get model device
+    if device.type in {"cpu", "mps"}:
+        LOGGER.info(f"{prefix} ⚠️ intended for CUDA devices, using default batch-size {batch_size}")
+        return batch_size
+    if torch.backends.cudnn.benchmark:
+        LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
+        return batch_size
+
+    # Inspect CUDA memory
+    gb = 1 << 30  # bytes to GiB (1024 ** 3)
+    d = f"CUDA:{os.getenv('CUDA_VISIBLE_DEVICES', '0').strip()[0]}"  # 'CUDA:0'
+    properties = torch.cuda.get_device_properties(device)  # device properties
+    t = properties.total_memory / gb  # GiB total
+    r = torch.cuda.memory_reserved(device) / gb  # GiB reserved
+    a = torch.cuda.memory_allocated(device) / gb  # GiB allocated
+    f = t - (r + a)  # GiB free
+    LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
+
+    # Profile batch sizes
+    batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64]
+    try:
+        img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
+        results = profile(img, model, n=1, device=device, max_num_obj=max_num_obj)
+
+        # Fit a solution
+        xy = [
+            [x, y[2]]
+            for i, (x, y) in enumerate(zip(batch_sizes, results))
+            if y  # valid result
+            and isinstance(y[2], (int, float))  # is numeric
+            and 0 < y[2] < t  # between 0 and GPU limit
+            and (i == 0 or not results[i - 1] or y[2] > results[i - 1][2])  # first item or increasing memory
+        ]
+        fit_x, fit_y = zip(*xy) if xy else ([], [])
+        p = np.polyfit(np.log(fit_x), np.log(fit_y), deg=1)  # first-degree polynomial fit in log space
+        b = int(round(np.exp((np.log(f * fraction) - p[1]) / p[0])))  # y intercept (optimal batch size)
+        if None in results:  # some sizes failed
+            i = results.index(None)  # first fail index
+            if b >= batch_sizes[i]:  # y intercept above failure point
+                b = batch_sizes[max(i - 1, 0)]  # select prior safe point
+        if b < 1 or b > 1024:  # b outside of safe range
+            LOGGER.info(f"{prefix}WARNING ⚠️ batch={b} outside safe range, using default batch-size {batch_size}.")
+            b = batch_size
+
+        fraction = (np.exp(np.polyval(p, np.log(b))) + r + a) / t  # predicted fraction
+        LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
+        return b
+    except Exception as e:
+        LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e},  using default batch-size {batch_size}.")
+        return batch_size
+    finally:
+        torch.cuda.empty_cache()

+ 583 - 0
ultralytics/utils/benchmarks.py

@@ -0,0 +1,583 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""
+Benchmark a YOLO model formats for speed and accuracy.
+
+Usage:
+    from ultralytics.utils.benchmarks import ProfileModels, benchmark
+    ProfileModels(['yolov8n.yaml', 'yolov8s.yaml']).profile()
+    benchmark(model='yolov8n.pt', imgsz=160)
+
+Format                  | `format=argument`         | Model
+---                     | ---                       | ---
+PyTorch                 | -                         | yolov8n.pt
+TorchScript             | `torchscript`             | yolov8n.torchscript
+ONNX                    | `onnx`                    | yolov8n.onnx
+OpenVINO                | `openvino`                | yolov8n_openvino_model/
+TensorRT                | `engine`                  | yolov8n.engine
+CoreML                  | `coreml`                  | yolov8n.mlpackage
+TensorFlow SavedModel   | `saved_model`             | yolov8n_saved_model/
+TensorFlow GraphDef     | `pb`                      | yolov8n.pb
+TensorFlow Lite         | `tflite`                  | yolov8n.tflite
+TensorFlow Edge TPU     | `edgetpu`                 | yolov8n_edgetpu.tflite
+TensorFlow.js           | `tfjs`                    | yolov8n_web_model/
+PaddlePaddle            | `paddle`                  | yolov8n_paddle_model/
+MNN                     | `mnn`                     | yolov8n.mnn
+NCNN                    | `ncnn`                    | yolov8n_ncnn_model/
+"""
+
+import glob
+import os
+import platform
+import re
+import shutil
+import time
+from pathlib import Path
+
+import numpy as np
+import torch.cuda
+import yaml
+
+from ultralytics import YOLO, YOLOWorld
+from ultralytics.cfg import TASK2DATA, TASK2METRIC
+from ultralytics.engine.exporter import export_formats
+from ultralytics.utils import ARM64, ASSETS, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
+from ultralytics.utils.checks import IS_PYTHON_3_12, check_requirements, check_yolo
+from ultralytics.utils.downloads import safe_download
+from ultralytics.utils.files import file_size
+from ultralytics.utils.torch_utils import get_cpu_info, select_device
+
+
+def benchmark(
+    model=WEIGHTS_DIR / "yolo11n.pt",
+    data=None,
+    imgsz=160,
+    half=False,
+    int8=False,
+    device="cpu",
+    verbose=False,
+    eps=1e-3,
+):
+    """
+    Benchmark a YOLO model across different formats for speed and accuracy.
+
+    Args:
+        model (str | Path): Path to the model file or directory.
+        data (str | None): Dataset to evaluate on, inherited from TASK2DATA if not passed.
+        imgsz (int): Image size for the benchmark.
+        half (bool): Use half-precision for the model if True.
+        int8 (bool): Use int8-precision for the model if True.
+        device (str): Device to run the benchmark on, either 'cpu' or 'cuda'.
+        verbose (bool | float): If True or a float, assert benchmarks pass with given metric.
+        eps (float): Epsilon value for divide by zero prevention.
+
+    Returns:
+        (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size, metric,
+            and inference time.
+
+    Examples:
+        Benchmark a YOLO model with default settings:
+        >>> from ultralytics.utils.benchmarks import benchmark
+        >>> benchmark(model="yolo11n.pt", imgsz=640)
+    """
+    import pandas as pd  # scope for faster 'import ultralytics'
+
+    pd.options.display.max_columns = 10
+    pd.options.display.width = 120
+    device = select_device(device, verbose=False)
+    if isinstance(model, (str, Path)):
+        model = YOLO(model)
+    is_end2end = getattr(model.model.model[-1], "end2end", False)
+
+    y = []
+    t0 = time.time()
+    for i, (name, format, suffix, cpu, gpu, _) in enumerate(zip(*export_formats().values())):
+        emoji, filename = "❌", None  # export defaults
+        try:
+            # Checks
+            if i == 7:  # TF GraphDef
+                assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
+            elif i == 9:  # Edge TPU
+                assert LINUX and not ARM64, "Edge TPU export only supported on non-aarch64 Linux"
+            elif i in {5, 10}:  # CoreML and TF.js
+                assert MACOS or LINUX, "CoreML and TF.js export only supported on macOS and Linux"
+                assert not IS_RASPBERRYPI, "CoreML and TF.js export not supported on Raspberry Pi"
+                assert not IS_JETSON, "CoreML and TF.js export not supported on NVIDIA Jetson"
+            if i in {5}:  # CoreML
+                assert not IS_PYTHON_3_12, "CoreML not supported on Python 3.12"
+            if i in {6, 7, 8}:  # TF SavedModel, TF GraphDef, and TFLite
+                assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
+            if i in {9, 10}:  # TF EdgeTPU and TF.js
+                assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
+            if i == 11:  # Paddle
+                assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
+                assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
+                assert LINUX or MACOS, "Windows Paddle exports not supported yet"
+            if i == 12:  # MNN
+                assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN exports not supported yet"
+            if i == 13:  # NCNN
+                assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
+            if i == 14:  # IMX
+                assert not is_end2end
+                assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported"
+                assert model.task == "detect", "IMX only supported for detection task"
+                assert "C2f" in model.__str__(), "IMX only supported for YOLOv8"
+            if "cpu" in device.type:
+                assert cpu, "inference not supported on CPU"
+            if "cuda" in device.type:
+                assert gpu, "inference not supported on GPU"
+
+            # Export
+            if format == "-":
+                filename = model.ckpt_path or model.cfg
+                exported_model = model  # PyTorch format
+            else:
+                filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False)
+                exported_model = YOLO(filename, task=model.task)
+                assert suffix in str(filename), "export failed"
+            emoji = "❎"  # indicates export succeeded
+
+            # Predict
+            assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
+            assert i not in {9, 10}, "inference not supported"  # Edge TPU and TF.js are unsupported
+            assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13"  # CoreML
+            if i in {13}:
+                assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet"
+            exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
+
+            # Validate
+            data = data or TASK2DATA[model.task]  # task to dataset, i.e. coco8.yaml for task=detect
+            key = TASK2METRIC[model.task]  # task to metric, i.e. metrics/mAP50-95(B) for task=detect
+            results = exported_model.val(
+                data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, int8=int8, verbose=False
+            )
+            metric, speed = results.results_dict[key], results.speed["inference"]
+            fps = round(1000 / (speed + eps), 2)  # frames per second
+            y.append([name, "✅", round(file_size(filename), 1), round(metric, 4), round(speed, 2), fps])
+        except Exception as e:
+            if verbose:
+                assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}"
+            LOGGER.warning(f"ERROR ❌️ Benchmark failure for {name}: {e}")
+            y.append([name, emoji, round(file_size(filename), 1), None, None, None])  # mAP, t_inference
+
+    # Print results
+    check_yolo(device=device)  # print system info
+    df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"])
+
+    name = Path(model.ckpt_path).name
+    s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n"
+    LOGGER.info(s)
+    with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
+        f.write(s)
+
+    if verbose and isinstance(verbose, float):
+        metrics = df[key].array  # values to compare to floor
+        floor = verbose  # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
+        assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}"
+
+    return df
+
+
+class RF100Benchmark:
+    """Benchmark YOLO model performance across various formats for speed and accuracy."""
+
+    def __init__(self):
+        """Initialize the RF100Benchmark class for benchmarking YOLO model performance across various formats."""
+        self.ds_names = []
+        self.ds_cfg_list = []
+        self.rf = None
+        self.val_metrics = ["class", "images", "targets", "precision", "recall", "map50", "map95"]
+
+    def set_key(self, api_key):
+        """
+        Set Roboflow API key for processing.
+
+        Args:
+            api_key (str): The API key.
+
+        Examples:
+            Set the Roboflow API key for accessing datasets:
+            >>> benchmark = RF100Benchmark()
+            >>> benchmark.set_key("your_roboflow_api_key")
+        """
+        check_requirements("roboflow")
+        from roboflow import Roboflow
+
+        self.rf = Roboflow(api_key=api_key)
+
+    def parse_dataset(self, ds_link_txt="datasets_links.txt"):
+        """
+        Parse dataset links and download datasets.
+
+        Args:
+            ds_link_txt (str): Path to the file containing dataset links.
+
+        Examples:
+            >>> benchmark = RF100Benchmark()
+            >>> benchmark.set_key("api_key")
+            >>> benchmark.parse_dataset("datasets_links.txt")
+        """
+        (shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
+        os.chdir("rf-100")
+        os.mkdir("ultralytics-benchmarks")
+        safe_download("https://github.com/ultralytics/assets/releases/download/v0.0.0/datasets_links.txt")
+
+        with open(ds_link_txt) as file:
+            for line in file:
+                try:
+                    _, url, workspace, project, version = re.split("/+", line.strip())
+                    self.ds_names.append(project)
+                    proj_version = f"{project}-{version}"
+                    if not Path(proj_version).exists():
+                        self.rf.workspace(workspace).project(project).version(version).download("yolov8")
+                    else:
+                        print("Dataset already downloaded.")
+                    self.ds_cfg_list.append(Path.cwd() / proj_version / "data.yaml")
+                except Exception:
+                    continue
+
+        return self.ds_names, self.ds_cfg_list
+
+    @staticmethod
+    def fix_yaml(path):
+        """
+        Fixes the train and validation paths in a given YAML file.
+
+        Args:
+            path (str): Path to the YAML file to be fixed.
+
+        Examples:
+            >>> RF100Benchmark.fix_yaml("path/to/data.yaml")
+        """
+        with open(path) as file:
+            yaml_data = yaml.safe_load(file)
+        yaml_data["train"] = "train/images"
+        yaml_data["val"] = "valid/images"
+        with open(path, "w") as file:
+            yaml.safe_dump(yaml_data, file)
+
+    def evaluate(self, yaml_path, val_log_file, eval_log_file, list_ind):
+        """
+        Evaluate model performance on validation results.
+
+        Args:
+            yaml_path (str): Path to the YAML configuration file.
+            val_log_file (str): Path to the validation log file.
+            eval_log_file (str): Path to the evaluation log file.
+            list_ind (int): Index of the current dataset in the list.
+
+        Returns:
+            (float): The mean average precision (mAP) value for the evaluated model.
+
+        Examples:
+            Evaluate a model on a specific dataset
+            >>> benchmark = RF100Benchmark()
+            >>> benchmark.evaluate("path/to/data.yaml", "path/to/val_log.txt", "path/to/eval_log.txt", 0)
+        """
+        skip_symbols = ["🚀", "⚠️", "💡", "❌"]
+        with open(yaml_path) as stream:
+            class_names = yaml.safe_load(stream)["names"]
+        with open(val_log_file, encoding="utf-8") as f:
+            lines = f.readlines()
+            eval_lines = []
+            for line in lines:
+                if any(symbol in line for symbol in skip_symbols):
+                    continue
+                entries = line.split(" ")
+                entries = list(filter(lambda val: val != "", entries))
+                entries = [e.strip("\n") for e in entries]
+                eval_lines.extend(
+                    {
+                        "class": entries[0],
+                        "images": entries[1],
+                        "targets": entries[2],
+                        "precision": entries[3],
+                        "recall": entries[4],
+                        "map50": entries[5],
+                        "map95": entries[6],
+                    }
+                    for e in entries
+                    if e in class_names or (e == "all" and "(AP)" not in entries and "(AR)" not in entries)
+                )
+        map_val = 0.0
+        if len(eval_lines) > 1:
+            print("There's more dicts")
+            for lst in eval_lines:
+                if lst["class"] == "all":
+                    map_val = lst["map50"]
+        else:
+            print("There's only one dict res")
+            map_val = [res["map50"] for res in eval_lines][0]
+
+        with open(eval_log_file, "a") as f:
+            f.write(f"{self.ds_names[list_ind]}: {map_val}\n")
+
+
+class ProfileModels:
+    """
+    ProfileModels class for profiling different models on ONNX and TensorRT.
+
+    This class profiles the performance of different models, returning results such as model speed and FLOPs.
+
+    Attributes:
+        paths (List[str]): Paths of the models to profile.
+        num_timed_runs (int): Number of timed runs for the profiling.
+        num_warmup_runs (int): Number of warmup runs before profiling.
+        min_time (float): Minimum number of seconds to profile for.
+        imgsz (int): Image size used in the models.
+        half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.
+        trt (bool): Flag to indicate whether to profile using TensorRT.
+        device (torch.device): Device used for profiling.
+
+    Methods:
+        profile: Profiles the models and prints the result.
+
+    Examples:
+        Profile models and print results
+        >>> from ultralytics.utils.benchmarks import ProfileModels
+        >>> profiler = ProfileModels(["yolov8n.yaml", "yolov8s.yaml"], imgsz=640)
+        >>> profiler.profile()
+    """
+
+    def __init__(
+        self,
+        paths: list,
+        num_timed_runs=100,
+        num_warmup_runs=10,
+        min_time=60,
+        imgsz=640,
+        half=True,
+        trt=True,
+        device=None,
+    ):
+        """
+        Initialize the ProfileModels class for profiling models.
+
+        Args:
+            paths (List[str]): List of paths of the models to be profiled.
+            num_timed_runs (int): Number of timed runs for the profiling.
+            num_warmup_runs (int): Number of warmup runs before the actual profiling starts.
+            min_time (float): Minimum time in seconds for profiling a model.
+            imgsz (int): Size of the image used during profiling.
+            half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.
+            trt (bool): Flag to indicate whether to profile using TensorRT.
+            device (torch.device | None): Device used for profiling. If None, it is determined automatically.
+
+        Notes:
+            FP16 'half' argument option removed for ONNX as slower on CPU than FP32.
+
+        Examples:
+            Initialize and profile models
+            >>> from ultralytics.utils.benchmarks import ProfileModels
+            >>> profiler = ProfileModels(["yolov8n.yaml", "yolov8s.yaml"], imgsz=640)
+            >>> profiler.profile()
+        """
+        self.paths = paths
+        self.num_timed_runs = num_timed_runs
+        self.num_warmup_runs = num_warmup_runs
+        self.min_time = min_time
+        self.imgsz = imgsz
+        self.half = half
+        self.trt = trt  # run TensorRT profiling
+        self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu")
+
+    def profile(self):
+        """Profiles YOLO models for speed and accuracy across various formats including ONNX and TensorRT."""
+        files = self.get_files()
+
+        if not files:
+            print("No matching *.pt or *.onnx files found.")
+            return
+
+        table_rows = []
+        output = []
+        for file in files:
+            engine_file = file.with_suffix(".engine")
+            if file.suffix in {".pt", ".yaml", ".yml"}:
+                model = YOLO(str(file))
+                model.fuse()  # to report correct params and GFLOPs in model.info()
+                model_info = model.info()
+                if self.trt and self.device.type != "cpu" and not engine_file.is_file():
+                    engine_file = model.export(
+                        format="engine",
+                        half=self.half,
+                        imgsz=self.imgsz,
+                        device=self.device,
+                        verbose=False,
+                    )
+                onnx_file = model.export(
+                    format="onnx",
+                    imgsz=self.imgsz,
+                    device=self.device,
+                    verbose=False,
+                )
+            elif file.suffix == ".onnx":
+                model_info = self.get_onnx_model_info(file)
+                onnx_file = file
+            else:
+                continue
+
+            t_engine = self.profile_tensorrt_model(str(engine_file))
+            t_onnx = self.profile_onnx_model(str(onnx_file))
+            table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info))
+            output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info))
+
+        self.print_table(table_rows)
+        return output
+
+    def get_files(self):
+        """Returns a list of paths for all relevant model files given by the user."""
+        files = []
+        for path in self.paths:
+            path = Path(path)
+            if path.is_dir():
+                extensions = ["*.pt", "*.onnx", "*.yaml"]
+                files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
+            elif path.suffix in {".pt", ".yaml", ".yml"}:  # add non-existing
+                files.append(str(path))
+            else:
+                files.extend(glob.glob(str(path)))
+
+        print(f"Profiling: {sorted(files)}")
+        return [Path(file) for file in sorted(files)]
+
+    @staticmethod
+    def get_onnx_model_info(onnx_file: str):
+        """Extracts metadata from an ONNX model file including parameters, GFLOPs, and input shape."""
+        return 0.0, 0.0, 0.0, 0.0  # return (num_layers, num_params, num_gradients, num_flops)
+
+    @staticmethod
+    def iterative_sigma_clipping(data, sigma=2, max_iters=3):
+        """Applies iterative sigma clipping to data to remove outliers based on specified sigma and iteration count."""
+        data = np.array(data)
+        for _ in range(max_iters):
+            mean, std = np.mean(data), np.std(data)
+            clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)]
+            if len(clipped_data) == len(data):
+                break
+            data = clipped_data
+        return data
+
+    def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
+        """Profiles YOLO model performance with TensorRT, measuring average run time and standard deviation."""
+        if not self.trt or not Path(engine_file).is_file():
+            return 0.0, 0.0
+
+        # Model and input
+        model = YOLO(engine_file)
+        input_data = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8)  # use uint8 for Classify
+
+        # Warmup runs
+        elapsed = 0.0
+        for _ in range(3):
+            start_time = time.time()
+            for _ in range(self.num_warmup_runs):
+                model(input_data, imgsz=self.imgsz, verbose=False)
+            elapsed = time.time() - start_time
+
+        # Compute number of runs as higher of min_time or num_timed_runs
+        num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50)
+
+        # Timed runs
+        run_times = []
+        for _ in TQDM(range(num_runs), desc=engine_file):
+            results = model(input_data, imgsz=self.imgsz, verbose=False)
+            run_times.append(results[0].speed["inference"])  # Convert to milliseconds
+
+        run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3)  # sigma clipping
+        return np.mean(run_times), np.std(run_times)
+
+    def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
+        """Profiles an ONNX model, measuring average inference time and standard deviation across multiple runs."""
+        check_requirements("onnxruntime")
+        import onnxruntime as ort
+
+        # Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
+        sess_options = ort.SessionOptions()
+        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
+        sess_options.intra_op_num_threads = 8  # Limit the number of threads
+        sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"])
+
+        input_tensor = sess.get_inputs()[0]
+        input_type = input_tensor.type
+        dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape)  # dynamic input shape
+        input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape
+
+        # Mapping ONNX datatype to numpy datatype
+        if "float16" in input_type:
+            input_dtype = np.float16
+        elif "float" in input_type:
+            input_dtype = np.float32
+        elif "double" in input_type:
+            input_dtype = np.float64
+        elif "int64" in input_type:
+            input_dtype = np.int64
+        elif "int32" in input_type:
+            input_dtype = np.int32
+        else:
+            raise ValueError(f"Unsupported ONNX datatype {input_type}")
+
+        input_data = np.random.rand(*input_shape).astype(input_dtype)
+        input_name = input_tensor.name
+        output_name = sess.get_outputs()[0].name
+
+        # Warmup runs
+        elapsed = 0.0
+        for _ in range(3):
+            start_time = time.time()
+            for _ in range(self.num_warmup_runs):
+                sess.run([output_name], {input_name: input_data})
+            elapsed = time.time() - start_time
+
+        # Compute number of runs as higher of min_time or num_timed_runs
+        num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs)
+
+        # Timed runs
+        run_times = []
+        for _ in TQDM(range(num_runs), desc=onnx_file):
+            start_time = time.time()
+            sess.run([output_name], {input_name: input_data})
+            run_times.append((time.time() - start_time) * 1000)  # Convert to milliseconds
+
+        run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5)  # sigma clipping
+        return np.mean(run_times), np.std(run_times)
+
+    def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
+        """Generates a table row string with model performance metrics including inference times and model details."""
+        layers, params, gradients, flops = model_info
+        return (
+            f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.1f}±{t_onnx[1]:.1f} ms | {t_engine[0]:.1f}±"
+            f"{t_engine[1]:.1f} ms | {params / 1e6:.1f} | {flops:.1f} |"
+        )
+
+    @staticmethod
+    def generate_results_dict(model_name, t_onnx, t_engine, model_info):
+        """Generates a dictionary of profiling results including model name, parameters, GFLOPs, and speed metrics."""
+        layers, params, gradients, flops = model_info
+        return {
+            "model/name": model_name,
+            "model/parameters": params,
+            "model/GFLOPs": round(flops, 3),
+            "model/speed_ONNX(ms)": round(t_onnx[0], 3),
+            "model/speed_TensorRT(ms)": round(t_engine[0], 3),
+        }
+
+    @staticmethod
+    def print_table(table_rows):
+        """Prints a formatted table of model profiling results, including speed and accuracy metrics."""
+        gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
+        headers = [
+            "Model",
+            "size<br><sup>(pixels)",
+            "mAP<sup>val<br>50-95",
+            f"Speed<br><sup>CPU ({get_cpu_info()}) ONNX<br>(ms)",
+            f"Speed<br><sup>{gpu} TensorRT<br>(ms)",
+            "params<br><sup>(M)",
+            "FLOPs<br><sup>(B)",
+        ]
+        header = "|" + "|".join(f" {h} " for h in headers) + "|"
+        separator = "|" + "|".join("-" * (len(h) + 2) for h in headers) + "|"
+
+        print(f"\n\n{header}")
+        print(separator)
+        for row in table_rows:
+            print(row)

+ 5 - 0
ultralytics/utils/callbacks/__init__.py

@@ -0,0 +1,5 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
+
+__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"

+ 217 - 0
ultralytics/utils/callbacks/base.py

@@ -0,0 +1,217 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""Base callbacks."""
+
+from collections import defaultdict
+from copy import deepcopy
+
+# Trainer callbacks ----------------------------------------------------------------------------------------------------
+
+
+def on_pretrain_routine_start(trainer):
+    """Called before the pretraining routine starts."""
+    pass
+
+
+def on_pretrain_routine_end(trainer):
+    """Called after the pretraining routine ends."""
+    pass
+
+
+def on_train_start(trainer):
+    """Called when the training starts."""
+    pass
+
+
+def on_train_epoch_start(trainer):
+    """Called at the start of each training epoch."""
+    pass
+
+
+def on_train_batch_start(trainer):
+    """Called at the start of each training batch."""
+    pass
+
+
+def optimizer_step(trainer):
+    """Called when the optimizer takes a step."""
+    pass
+
+
+def on_before_zero_grad(trainer):
+    """Called before the gradients are set to zero."""
+    pass
+
+
+def on_train_batch_end(trainer):
+    """Called at the end of each training batch."""
+    pass
+
+
+def on_train_epoch_end(trainer):
+    """Called at the end of each training epoch."""
+    pass
+
+
+def on_fit_epoch_end(trainer):
+    """Called at the end of each fit epoch (train + val)."""
+    pass
+
+
+def on_model_save(trainer):
+    """Called when the model is saved."""
+    pass
+
+
+def on_train_end(trainer):
+    """Called when the training ends."""
+    pass
+
+
+def on_params_update(trainer):
+    """Called when the model parameters are updated."""
+    pass
+
+
+def teardown(trainer):
+    """Called during the teardown of the training process."""
+    pass
+
+
+# Validator callbacks --------------------------------------------------------------------------------------------------
+
+
+def on_val_start(validator):
+    """Called when the validation starts."""
+    pass
+
+
+def on_val_batch_start(validator):
+    """Called at the start of each validation batch."""
+    pass
+
+
+def on_val_batch_end(validator):
+    """Called at the end of each validation batch."""
+    pass
+
+
+def on_val_end(validator):
+    """Called when the validation ends."""
+    pass
+
+
+# Predictor callbacks --------------------------------------------------------------------------------------------------
+
+
+def on_predict_start(predictor):
+    """Called when the prediction starts."""
+    pass
+
+
+def on_predict_batch_start(predictor):
+    """Called at the start of each prediction batch."""
+    pass
+
+
+def on_predict_batch_end(predictor):
+    """Called at the end of each prediction batch."""
+    pass
+
+
+def on_predict_postprocess_end(predictor):
+    """Called after the post-processing of the prediction ends."""
+    pass
+
+
+def on_predict_end(predictor):
+    """Called when the prediction ends."""
+    pass
+
+
+# Exporter callbacks ---------------------------------------------------------------------------------------------------
+
+
+def on_export_start(exporter):
+    """Called when the model export starts."""
+    pass
+
+
+def on_export_end(exporter):
+    """Called when the model export ends."""
+    pass
+
+
+default_callbacks = {
+    # Run in trainer
+    "on_pretrain_routine_start": [on_pretrain_routine_start],
+    "on_pretrain_routine_end": [on_pretrain_routine_end],
+    "on_train_start": [on_train_start],
+    "on_train_epoch_start": [on_train_epoch_start],
+    "on_train_batch_start": [on_train_batch_start],
+    "optimizer_step": [optimizer_step],
+    "on_before_zero_grad": [on_before_zero_grad],
+    "on_train_batch_end": [on_train_batch_end],
+    "on_train_epoch_end": [on_train_epoch_end],
+    "on_fit_epoch_end": [on_fit_epoch_end],  # fit = train + val
+    "on_model_save": [on_model_save],
+    "on_train_end": [on_train_end],
+    "on_params_update": [on_params_update],
+    "teardown": [teardown],
+    # Run in validator
+    "on_val_start": [on_val_start],
+    "on_val_batch_start": [on_val_batch_start],
+    "on_val_batch_end": [on_val_batch_end],
+    "on_val_end": [on_val_end],
+    # Run in predictor
+    "on_predict_start": [on_predict_start],
+    "on_predict_batch_start": [on_predict_batch_start],
+    "on_predict_postprocess_end": [on_predict_postprocess_end],
+    "on_predict_batch_end": [on_predict_batch_end],
+    "on_predict_end": [on_predict_end],
+    # Run in exporter
+    "on_export_start": [on_export_start],
+    "on_export_end": [on_export_end],
+}
+
+
+def get_default_callbacks():
+    """
+    Return a copy of the default_callbacks dictionary with lists as default values.
+
+    Returns:
+        (defaultdict): A defaultdict with keys from default_callbacks and empty lists as default values.
+    """
+    return defaultdict(list, deepcopy(default_callbacks))
+
+
+def add_integration_callbacks(instance):
+    """
+    Add integration callbacks from various sources to the instance's callbacks.
+
+    Args:
+        instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
+            of callback lists.
+    """
+    # Load HUB callbacks
+    from .hub import callbacks as hub_cb
+
+    callbacks_list = [hub_cb]
+
+    # Load training callbacks
+    if "Trainer" in instance.__class__.__name__:
+        from .clearml import callbacks as clear_cb
+        from .comet import callbacks as comet_cb
+        from .dvc import callbacks as dvc_cb
+        from .mlflow import callbacks as mlflow_cb
+        from .neptune import callbacks as neptune_cb
+        from .raytune import callbacks as tune_cb
+        from .tensorboard import callbacks as tb_cb
+        from .wb import callbacks as wb_cb
+
+        callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
+
+    # Add the callbacks to the callbacks dictionary
+    for callbacks in callbacks_list:
+        for k, v in callbacks.items():
+            if v not in instance.callbacks[k]:
+                instance.callbacks[k].append(v)

+ 153 - 0
ultralytics/utils/callbacks/clearml.py

@@ -0,0 +1,153 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
+
+try:
+    assert not TESTS_RUNNING  # do not log pytest
+    assert SETTINGS["clearml"] is True  # verify integration is enabled
+    import clearml
+    from clearml import Task
+
+    assert hasattr(clearml, "__version__")  # verify package is not directory
+
+except (ImportError, AssertionError):
+    clearml = None
+
+
+def _log_debug_samples(files, title="Debug Samples") -> None:
+    """
+    Log files (images) as debug samples in the ClearML task.
+
+    Args:
+        files (list): A list of file paths in PosixPath format.
+        title (str): A title that groups together images with the same values.
+    """
+    import re
+
+    if task := Task.current_task():
+        for f in files:
+            if f.exists():
+                it = re.search(r"_batch(\d+)", f.name)
+                iteration = int(it.groups()[0]) if it else 0
+                task.get_logger().report_image(
+                    title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
+                )
+
+
+def _log_plot(title, plot_path) -> None:
+    """
+    Log an image as a plot in the plot section of ClearML.
+
+    Args:
+        title (str): The title of the plot.
+        plot_path (str): The path to the saved image file.
+    """
+    import matplotlib.image as mpimg
+    import matplotlib.pyplot as plt
+
+    img = mpimg.imread(plot_path)
+    fig = plt.figure()
+    ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[])  # no ticks
+    ax.imshow(img)
+
+    Task.current_task().get_logger().report_matplotlib_figure(
+        title=title, series="", figure=fig, report_interactive=False
+    )
+
+
+def on_pretrain_routine_start(trainer):
+    """Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
+    try:
+        if task := Task.current_task():
+            # WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
+            # We are logging these plots and model files manually in the integration
+            from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
+            from clearml.binding.matplotlib_bind import PatchedMatplotlib
+
+            PatchPyTorchModelIO.update_current_task(None)
+            PatchedMatplotlib.update_current_task(None)
+        else:
+            task = Task.init(
+                project_name=trainer.args.project or "Ultralytics",
+                task_name=trainer.args.name,
+                tags=["Ultralytics"],
+                output_uri=True,
+                reuse_last_task_id=False,
+                auto_connect_frameworks={"pytorch": False, "matplotlib": False},
+            )
+            LOGGER.warning(
+                "ClearML Initialized a new task. If you want to run remotely, "
+                "please add clearml-init and connect your arguments before initializing YOLO."
+            )
+        task.connect(vars(trainer.args), name="General")
+    except Exception as e:
+        LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}")
+
+
+def on_train_epoch_end(trainer):
+    """Logs debug samples for the first epoch of YOLO training and report current training progress."""
+    if task := Task.current_task():
+        # Log debug samples
+        if trainer.epoch == 1:
+            _log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
+        # Report the current training progress
+        for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
+            task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
+        for k, v in trainer.lr.items():
+            task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
+
+
+def on_fit_epoch_end(trainer):
+    """Reports model information to logger at the end of an epoch."""
+    if task := Task.current_task():
+        # You should have access to the validation bboxes under jdict
+        task.get_logger().report_scalar(
+            title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
+        )
+        for k, v in trainer.metrics.items():
+            task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch)
+        if trainer.epoch == 0:
+            from ultralytics.utils.torch_utils import model_info_for_loggers
+
+            for k, v in model_info_for_loggers(trainer).items():
+                task.get_logger().report_single_value(k, v)
+
+
+def on_val_end(validator):
+    """Logs validation results including labels and predictions."""
+    if Task.current_task():
+        # Log val_labels and val_pred
+        _log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
+
+
+def on_train_end(trainer):
+    """Logs final model and its name on training completion."""
+    if task := Task.current_task():
+        # Log final results, CM matrix + PR plots
+        files = [
+            "results.png",
+            "confusion_matrix.png",
+            "confusion_matrix_normalized.png",
+            *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
+        ]
+        files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()]  # filter
+        for f in files:
+            _log_plot(title=f.stem, plot_path=f)
+        # Report final metrics
+        for k, v in trainer.validator.metrics.results_dict.items():
+            task.get_logger().report_single_value(k, v)
+        # Log the final model
+        task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
+
+
+callbacks = (
+    {
+        "on_pretrain_routine_start": on_pretrain_routine_start,
+        "on_train_epoch_end": on_train_epoch_end,
+        "on_fit_epoch_end": on_fit_epoch_end,
+        "on_val_end": on_val_end,
+        "on_train_end": on_train_end,
+    }
+    if clearml
+    else {}
+)

+ 397 - 0
ultralytics/utils/callbacks/comet.py

@@ -0,0 +1,397 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
+from ultralytics.utils.metrics import ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics
+
+try:
+    assert not TESTS_RUNNING  # do not log pytest
+    assert SETTINGS["comet"] is True  # verify integration is enabled
+    import comet_ml
+
+    assert hasattr(comet_ml, "__version__")  # verify package is not directory
+
+    import os
+    from pathlib import Path
+
+    # Ensures certain logging functions only run for supported tasks
+    COMET_SUPPORTED_TASKS = ["detect"]
+
+    # Names of plots created by Ultralytics that are logged to Comet
+    CONFUSION_MATRIX_PLOT_NAMES = "confusion_matrix", "confusion_matrix_normalized"
+    EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve"
+    LABEL_PLOT_NAMES = "labels", "labels_correlogram"
+    SEGMENT_METRICS_PLOT_PREFIX = "Box", "Mask"
+    POSE_METRICS_PLOT_PREFIX = "Box", "Pose"
+
+    _comet_image_prediction_count = 0
+
+except (ImportError, AssertionError):
+    comet_ml = None
+
+
+def _get_comet_mode():
+    """Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
+    return os.getenv("COMET_MODE", "online")
+
+
+def _get_comet_model_name():
+    """Returns the model name for Comet from the environment variable COMET_MODEL_NAME or defaults to 'Ultralytics'."""
+    return os.getenv("COMET_MODEL_NAME", "Ultralytics")
+
+
+def _get_eval_batch_logging_interval():
+    """Get the evaluation batch logging interval from environment variable or use default value 1."""
+    return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
+
+
+def _get_max_image_predictions_to_log():
+    """Get the maximum number of image predictions to log from the environment variables."""
+    return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
+
+
+def _scale_confidence_score(score):
+    """Scales the given confidence score by a factor specified in an environment variable."""
+    scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
+    return score * scale
+
+
+def _should_log_confusion_matrix():
+    """Determines if the confusion matrix should be logged based on the environment variable settings."""
+    return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
+
+
+def _should_log_image_predictions():
+    """Determines whether to log image predictions based on a specified environment variable."""
+    return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
+
+
+def _get_experiment_type(mode, project_name):
+    """Return an experiment based on mode and project name."""
+    if mode == "offline":
+        return comet_ml.OfflineExperiment(project_name=project_name)
+
+    return comet_ml.Experiment(project_name=project_name)
+
+
+def _create_experiment(args):
+    """Ensures that the experiment object is only created in a single process during distributed training."""
+    if RANK not in {-1, 0}:
+        return
+    try:
+        comet_mode = _get_comet_mode()
+        _project_name = os.getenv("COMET_PROJECT_NAME", args.project)
+        experiment = _get_experiment_type(comet_mode, _project_name)
+        experiment.log_parameters(vars(args))
+        experiment.log_others(
+            {
+                "eval_batch_logging_interval": _get_eval_batch_logging_interval(),
+                "log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
+                "log_image_predictions": _should_log_image_predictions(),
+                "max_image_predictions": _get_max_image_predictions_to_log(),
+            }
+        )
+        experiment.log_other("Created from", "ultralytics")
+
+    except Exception as e:
+        LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}")
+
+
+def _fetch_trainer_metadata(trainer):
+    """Returns metadata for YOLO training including epoch and asset saving status."""
+    curr_epoch = trainer.epoch + 1
+
+    train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
+    curr_step = curr_epoch * train_num_steps_per_epoch
+    final_epoch = curr_epoch == trainer.epochs
+
+    save = trainer.args.save
+    save_period = trainer.args.save_period
+    save_interval = curr_epoch % save_period == 0
+    save_assets = save and save_period > 0 and save_interval and not final_epoch
+
+    return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
+
+
+def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
+    """
+    YOLO resizes images during training and the label values are normalized based on this resized shape.
+
+    This function rescales the bounding box labels to the original image shape.
+    """
+    resized_image_height, resized_image_width = resized_image_shape
+
+    # Convert normalized xywh format predictions to xyxy in resized scale format
+    box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
+    # Scale box predictions from resized image scale back to original image scale
+    box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
+    # Convert bounding box format from xyxy to xywh for Comet logging
+    box = ops.xyxy2xywh(box)
+    # Adjust xy center to correspond top-left corner
+    box[:2] -= box[2:] / 2
+    box = box.tolist()
+
+    return box
+
+
+def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
+    """Format ground truth annotations for detection."""
+    indices = batch["batch_idx"] == img_idx
+    bboxes = batch["bboxes"][indices]
+    if len(bboxes) == 0:
+        LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels")
+        return None
+
+    cls_labels = batch["cls"][indices].squeeze(1).tolist()
+    if class_name_map:
+        cls_labels = [str(class_name_map[label]) for label in cls_labels]
+
+    original_image_shape = batch["ori_shape"][img_idx]
+    resized_image_shape = batch["resized_shape"][img_idx]
+    ratio_pad = batch["ratio_pad"][img_idx]
+
+    data = []
+    for box, label in zip(bboxes, cls_labels):
+        box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
+        data.append(
+            {
+                "boxes": [box],
+                "label": f"gt_{label}",
+                "score": _scale_confidence_score(1.0),
+            }
+        )
+
+    return {"name": "ground_truth", "data": data}
+
+
+def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
+    """Format YOLO predictions for object detection visualization."""
+    stem = image_path.stem
+    image_id = int(stem) if stem.isnumeric() else stem
+
+    predictions = metadata.get(image_id)
+    if not predictions:
+        LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions")
+        return None
+
+    data = []
+    for prediction in predictions:
+        boxes = prediction["bbox"]
+        score = _scale_confidence_score(prediction["score"])
+        cls_label = prediction["category_id"]
+        if class_label_map:
+            cls_label = str(class_label_map[cls_label])
+
+        data.append({"boxes": [boxes], "label": cls_label, "score": score})
+
+    return {"name": "prediction", "data": data}
+
+
+def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
+    """Join the ground truth and prediction annotations if they exist."""
+    ground_truth_annotations = _format_ground_truth_annotations_for_detection(
+        img_idx, image_path, batch, class_label_map
+    )
+    prediction_annotations = _format_prediction_annotations_for_detection(
+        image_path, prediction_metadata_map, class_label_map
+    )
+
+    annotations = [
+        annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
+    ]
+    return [annotations] if annotations else None
+
+
+def _create_prediction_metadata_map(model_predictions):
+    """Create metadata map for model predictions by groupings them based on image ID."""
+    pred_metadata_map = {}
+    for prediction in model_predictions:
+        pred_metadata_map.setdefault(prediction["image_id"], [])
+        pred_metadata_map[prediction["image_id"]].append(prediction)
+
+    return pred_metadata_map
+
+
+def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
+    """Log the confusion matrix to Comet experiment."""
+    conf_mat = trainer.validator.confusion_matrix.matrix
+    names = list(trainer.data["names"].values()) + ["background"]
+    experiment.log_confusion_matrix(
+        matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
+    )
+
+
+def _log_images(experiment, image_paths, curr_step, annotations=None):
+    """Logs images to the experiment with optional annotations."""
+    if annotations:
+        for image_path, annotation in zip(image_paths, annotations):
+            experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)
+
+    else:
+        for image_path in image_paths:
+            experiment.log_image(image_path, name=image_path.stem, step=curr_step)
+
+
+def _log_image_predictions(experiment, validator, curr_step):
+    """Logs predicted boxes for a single image during training."""
+    global _comet_image_prediction_count
+
+    task = validator.args.task
+    if task not in COMET_SUPPORTED_TASKS:
+        return
+
+    jdict = validator.jdict
+    if not jdict:
+        return
+
+    predictions_metadata_map = _create_prediction_metadata_map(jdict)
+    dataloader = validator.dataloader
+    class_label_map = validator.names
+
+    batch_logging_interval = _get_eval_batch_logging_interval()
+    max_image_predictions = _get_max_image_predictions_to_log()
+
+    for batch_idx, batch in enumerate(dataloader):
+        if (batch_idx + 1) % batch_logging_interval != 0:
+            continue
+
+        image_paths = batch["im_file"]
+        for img_idx, image_path in enumerate(image_paths):
+            if _comet_image_prediction_count >= max_image_predictions:
+                return
+
+            image_path = Path(image_path)
+            annotations = _fetch_annotations(
+                img_idx,
+                image_path,
+                batch,
+                predictions_metadata_map,
+                class_label_map,
+            )
+            _log_images(
+                experiment,
+                [image_path],
+                curr_step,
+                annotations=annotations,
+            )
+            _comet_image_prediction_count += 1
+
+
+def _log_plots(experiment, trainer):
+    """Logs evaluation plots and label plots for the experiment."""
+    plot_filenames = None
+    if isinstance(trainer.validator.metrics, SegmentMetrics) and trainer.validator.metrics.task == "segment":
+        plot_filenames = [
+            trainer.save_dir / f"{prefix}{plots}.png"
+            for plots in EVALUATION_PLOT_NAMES
+            for prefix in SEGMENT_METRICS_PLOT_PREFIX
+        ]
+    elif isinstance(trainer.validator.metrics, PoseMetrics):
+        plot_filenames = [
+            trainer.save_dir / f"{prefix}{plots}.png"
+            for plots in EVALUATION_PLOT_NAMES
+            for prefix in POSE_METRICS_PLOT_PREFIX
+        ]
+    elif isinstance(trainer.validator.metrics, (DetMetrics, OBBMetrics)):
+        plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
+
+    if plot_filenames is not None:
+        _log_images(experiment, plot_filenames, None)
+
+    confusion_matrix_filenames = [trainer.save_dir / f"{plots}.png" for plots in CONFUSION_MATRIX_PLOT_NAMES]
+    _log_images(experiment, confusion_matrix_filenames, None)
+
+    if not isinstance(trainer.validator.metrics, ClassifyMetrics):
+        label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
+        _log_images(experiment, label_plot_filenames, None)
+
+
+def _log_model(experiment, trainer):
+    """Log the best-trained model to Comet.ml."""
+    model_name = _get_comet_model_name()
+    experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
+
+
+def on_pretrain_routine_start(trainer):
+    """Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
+    experiment = comet_ml.get_global_experiment()
+    is_alive = getattr(experiment, "alive", False)
+    if not experiment or not is_alive:
+        _create_experiment(trainer.args)
+
+
+def on_train_epoch_end(trainer):
+    """Log metrics and save batch images at the end of training epochs."""
+    experiment = comet_ml.get_global_experiment()
+    if not experiment:
+        return
+
+    metadata = _fetch_trainer_metadata(trainer)
+    curr_epoch = metadata["curr_epoch"]
+    curr_step = metadata["curr_step"]
+
+    experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
+
+
+def on_fit_epoch_end(trainer):
+    """Logs model assets at the end of each epoch."""
+    experiment = comet_ml.get_global_experiment()
+    if not experiment:
+        return
+
+    metadata = _fetch_trainer_metadata(trainer)
+    curr_epoch = metadata["curr_epoch"]
+    curr_step = metadata["curr_step"]
+    save_assets = metadata["save_assets"]
+
+    experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
+    experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
+    if curr_epoch == 1:
+        from ultralytics.utils.torch_utils import model_info_for_loggers
+
+        experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
+
+    if not save_assets:
+        return
+
+    _log_model(experiment, trainer)
+    if _should_log_confusion_matrix():
+        _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
+    if _should_log_image_predictions():
+        _log_image_predictions(experiment, trainer.validator, curr_step)
+
+
+def on_train_end(trainer):
+    """Perform operations at the end of training."""
+    experiment = comet_ml.get_global_experiment()
+    if not experiment:
+        return
+
+    metadata = _fetch_trainer_metadata(trainer)
+    curr_epoch = metadata["curr_epoch"]
+    curr_step = metadata["curr_step"]
+    plots = trainer.args.plots
+
+    _log_model(experiment, trainer)
+    if plots:
+        _log_plots(experiment, trainer)
+
+    _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
+    _log_image_predictions(experiment, trainer.validator, curr_step)
+    _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
+    _log_images(experiment, trainer.save_dir.glob("val_batch*.jpg"), curr_step)
+    experiment.end()
+
+    global _comet_image_prediction_count
+    _comet_image_prediction_count = 0
+
+
+callbacks = (
+    {
+        "on_pretrain_routine_start": on_pretrain_routine_start,
+        "on_train_epoch_end": on_train_epoch_end,
+        "on_fit_epoch_end": on_fit_epoch_end,
+        "on_train_end": on_train_end,
+    }
+    if comet_ml
+    else {}
+)

+ 145 - 0
ultralytics/utils/callbacks/dvc.py

@@ -0,0 +1,145 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
+
+try:
+    assert not TESTS_RUNNING  # do not log pytest
+    assert SETTINGS["dvc"] is True  # verify integration is enabled
+    import dvclive
+
+    assert checks.check_version("dvclive", "2.11.0", verbose=True)
+
+    import os
+    import re
+    from pathlib import Path
+
+    # DVCLive logger instance
+    live = None
+    _processed_plots = {}
+
+    # `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we
+    # distinguish final evaluation of the best model vs last epoch validation
+    _training_epoch = False
+
+except (ImportError, AssertionError, TypeError):
+    dvclive = None
+
+
+def _log_images(path, prefix=""):
+    """Logs images at specified path with an optional prefix using DVCLive."""
+    if live:
+        name = path.name
+
+        # Group images by batch to enable sliders in UI
+        if m := re.search(r"_batch(\d+)", name):
+            ni = m[1]
+            new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
+            name = (Path(new_stem) / ni).with_suffix(path.suffix)
+
+        live.log_image(os.path.join(prefix, name), path)
+
+
+def _log_plots(plots, prefix=""):
+    """Logs plot images for training progress if they have not been previously processed."""
+    for name, params in plots.items():
+        timestamp = params["timestamp"]
+        if _processed_plots.get(name) != timestamp:
+            _log_images(name, prefix)
+            _processed_plots[name] = timestamp
+
+
+def _log_confusion_matrix(validator):
+    """Logs the confusion matrix for the given validator using DVCLive."""
+    targets = []
+    preds = []
+    matrix = validator.confusion_matrix.matrix
+    names = list(validator.names.values())
+    if validator.confusion_matrix.task == "detect":
+        names += ["background"]
+
+    for ti, pred in enumerate(matrix.T.astype(int)):
+        for pi, num in enumerate(pred):
+            targets.extend([names[ti]] * num)
+            preds.extend([names[pi]] * num)
+
+    live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
+
+
+def on_pretrain_routine_start(trainer):
+    """Initializes DVCLive logger for training metadata during pre-training routine."""
+    try:
+        global live
+        live = dvclive.Live(save_dvc_exp=True, cache_images=True)
+        LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
+    except Exception as e:
+        LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
+
+
+def on_pretrain_routine_end(trainer):
+    """Logs plots related to the training process at the end of the pretraining routine."""
+    _log_plots(trainer.plots, "train")
+
+
+def on_train_start(trainer):
+    """Logs the training parameters if DVCLive logging is active."""
+    if live:
+        live.log_params(trainer.args)
+
+
+def on_train_epoch_start(trainer):
+    """Sets the global variable _training_epoch value to True at the start of training each epoch."""
+    global _training_epoch
+    _training_epoch = True
+
+
+def on_fit_epoch_end(trainer):
+    """Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
+    global _training_epoch
+    if live and _training_epoch:
+        all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
+        for metric, value in all_metrics.items():
+            live.log_metric(metric, value)
+
+        if trainer.epoch == 0:
+            from ultralytics.utils.torch_utils import model_info_for_loggers
+
+            for metric, value in model_info_for_loggers(trainer).items():
+                live.log_metric(metric, value, plot=False)
+
+        _log_plots(trainer.plots, "train")
+        _log_plots(trainer.validator.plots, "val")
+
+        live.next_step()
+        _training_epoch = False
+
+
+def on_train_end(trainer):
+    """Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
+    if live:
+        # At the end log the best metrics. It runs validator on the best model internally.
+        all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
+        for metric, value in all_metrics.items():
+            live.log_metric(metric, value, plot=False)
+
+        _log_plots(trainer.plots, "val")
+        _log_plots(trainer.validator.plots, "val")
+        _log_confusion_matrix(trainer.validator)
+
+        if trainer.best.exists():
+            live.log_artifact(trainer.best, copy=True, type="model")
+
+        live.end()
+
+
+callbacks = (
+    {
+        "on_pretrain_routine_start": on_pretrain_routine_start,
+        "on_pretrain_routine_end": on_pretrain_routine_end,
+        "on_train_start": on_train_start,
+        "on_train_epoch_start": on_train_epoch_start,
+        "on_fit_epoch_end": on_fit_epoch_end,
+        "on_train_end": on_train_end,
+    }
+    if dvclive
+    else {}
+)

+ 108 - 0
ultralytics/utils/callbacks/hub.py

@@ -0,0 +1,108 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import json
+from time import time
+
+from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events
+from ultralytics.utils import LOGGER, RANK, SETTINGS
+
+
+def on_pretrain_routine_start(trainer):
+    """Create a remote Ultralytics HUB session to log local model training."""
+    if RANK in {-1, 0} and SETTINGS["hub"] is True and SETTINGS["api_key"] and trainer.hub_session is None:
+        trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args)
+
+
+def on_pretrain_routine_end(trainer):
+    """Logs info before starting timer for upload rate limit."""
+    if session := getattr(trainer, "hub_session", None):
+        # Start timer for upload rate limit
+        session.timers = {"metrics": time(), "ckpt": time()}  # start timer on session.rate_limit
+
+
+def on_fit_epoch_end(trainer):
+    """Uploads training progress metrics at the end of each epoch."""
+    if session := getattr(trainer, "hub_session", None):
+        # Upload metrics after val end
+        all_plots = {
+            **trainer.label_loss_items(trainer.tloss, prefix="train"),
+            **trainer.metrics,
+        }
+        if trainer.epoch == 0:
+            from ultralytics.utils.torch_utils import model_info_for_loggers
+
+            all_plots = {**all_plots, **model_info_for_loggers(trainer)}
+
+        session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
+
+        # If any metrics fail to upload, add them to the queue to attempt uploading again.
+        if session.metrics_upload_failed_queue:
+            session.metrics_queue.update(session.metrics_upload_failed_queue)
+
+        if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
+            session.upload_metrics()
+            session.timers["metrics"] = time()  # reset timer
+            session.metrics_queue = {}  # reset queue
+
+
+def on_model_save(trainer):
+    """Saves checkpoints to Ultralytics HUB with rate limiting."""
+    if session := getattr(trainer, "hub_session", None):
+        # Upload checkpoints with rate limiting
+        is_best = trainer.best_fitness == trainer.fitness
+        if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
+            LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}")
+            session.upload_model(trainer.epoch, trainer.last, is_best)
+            session.timers["ckpt"] = time()  # reset timer
+
+
+def on_train_end(trainer):
+    """Upload final model and metrics to Ultralytics HUB at the end of training."""
+    if session := getattr(trainer, "hub_session", None):
+        # Upload final model and metrics with exponential standoff
+        LOGGER.info(f"{PREFIX}Syncing final model...")
+        session.upload_model(
+            trainer.epoch,
+            trainer.best,
+            map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
+            final=True,
+        )
+        session.alive = False  # stop heartbeats
+        LOGGER.info(f"{PREFIX}Done ✅\n{PREFIX}View model at {session.model_url} 🚀")
+
+
+def on_train_start(trainer):
+    """Run events on train start."""
+    events(trainer.args)
+
+
+def on_val_start(validator):
+    """Runs events on validation start."""
+    events(validator.args)
+
+
+def on_predict_start(predictor):
+    """Run events on predict start."""
+    events(predictor.args)
+
+
+def on_export_start(exporter):
+    """Run events on export start."""
+    events(exporter.args)
+
+
+callbacks = (
+    {
+        "on_pretrain_routine_start": on_pretrain_routine_start,
+        "on_pretrain_routine_end": on_pretrain_routine_end,
+        "on_fit_epoch_end": on_fit_epoch_end,
+        "on_model_save": on_model_save,
+        "on_train_end": on_train_end,
+        "on_train_start": on_train_start,
+        "on_val_start": on_val_start,
+        "on_predict_start": on_predict_start,
+        "on_export_start": on_export_start,
+    }
+    if SETTINGS["hub"] is True
+    else {}
+)  # verify enabled

+ 137 - 0
ultralytics/utils/callbacks/mlflow.py

@@ -0,0 +1,137 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""
+MLflow Logging for Ultralytics YOLO.
+
+This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts.
+For setting up, a tracking URI should be specified. The logging can be customized using environment variables.
+
+Commands:
+    1. To set a project name:
+        `export MLFLOW_EXPERIMENT_NAME=<your_experiment_name>` or use the project=<project> argument
+
+    2. To set a run name:
+        `export MLFLOW_RUN=<your_run_name>` or use the name=<name> argument
+
+    3. To start a local MLflow server:
+        mlflow server --backend-store-uri runs/mlflow
+       It will by default start a local server at http://127.0.0.1:5000.
+       To specify a different URI, set the MLFLOW_TRACKING_URI environment variable.
+
+    4. To kill all running MLflow server instances:
+        ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9
+"""
+
+from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr
+
+try:
+    import os
+
+    assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "")  # do not log pytest
+    assert SETTINGS["mlflow"] is True  # verify integration is enabled
+    import mlflow
+
+    assert hasattr(mlflow, "__version__")  # verify package is not directory
+    from pathlib import Path
+
+    PREFIX = colorstr("MLflow: ")
+
+except (ImportError, AssertionError):
+    mlflow = None
+
+
+def sanitize_dict(x):
+    """Sanitize dictionary keys by removing parentheses and converting values to floats."""
+    return {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
+
+
+def on_pretrain_routine_end(trainer):
+    """
+    Log training parameters to MLflow at the end of the pretraining routine.
+
+    This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
+    experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
+    from the trainer.
+
+    Args:
+        trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
+
+    Global:
+        mlflow: The imported mlflow module to use for logging.
+
+    Environment Variables:
+        MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
+        MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
+        MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
+        MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of training.
+    """
+    global mlflow
+
+    uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
+    LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
+    mlflow.set_tracking_uri(uri)
+
+    # Set experiment and run names
+    experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/Ultralytics"
+    run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
+    mlflow.set_experiment(experiment_name)
+
+    mlflow.autolog()
+    try:
+        active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
+        LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
+        if Path(uri).is_dir():
+            LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
+        LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
+        mlflow.log_params(dict(trainer.args))
+    except Exception as e:
+        LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n{PREFIX}WARNING ⚠️ Not tracking this run")
+
+
+def on_train_epoch_end(trainer):
+    """Log training metrics at the end of each train epoch to MLflow."""
+    if mlflow:
+        mlflow.log_metrics(
+            metrics={
+                **sanitize_dict(trainer.lr),
+                **sanitize_dict(trainer.label_loss_items(trainer.tloss, prefix="train")),
+            },
+            step=trainer.epoch,
+        )
+
+
+def on_fit_epoch_end(trainer):
+    """Log training metrics at the end of each fit epoch to MLflow."""
+    if mlflow:
+        mlflow.log_metrics(metrics=sanitize_dict(trainer.metrics), step=trainer.epoch)
+
+
+def on_train_end(trainer):
+    """Log model artifacts at the end of the training."""
+    if not mlflow:
+        return
+    mlflow.log_artifact(str(trainer.best.parent))  # log save_dir/weights directory with best.pt and last.pt
+    for f in trainer.save_dir.glob("*"):  # log all other files in save_dir
+        if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
+            mlflow.log_artifact(str(f))
+    keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true"
+    if keep_run_active:
+        LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
+    else:
+        mlflow.end_run()
+        LOGGER.debug(f"{PREFIX}mlflow run ended")
+
+    LOGGER.info(
+        f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n{PREFIX}disable with 'yolo settings mlflow=False'"
+    )
+
+
+callbacks = (
+    {
+        "on_pretrain_routine_end": on_pretrain_routine_end,
+        "on_train_epoch_end": on_train_epoch_end,
+        "on_fit_epoch_end": on_fit_epoch_end,
+        "on_train_end": on_train_end,
+    }
+    if mlflow
+    else {}
+)

+ 116 - 0
ultralytics/utils/callbacks/neptune.py

@@ -0,0 +1,116 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
+
+try:
+    assert not TESTS_RUNNING  # do not log pytest
+    assert SETTINGS["neptune"] is True  # verify integration is enabled
+    import neptune
+    from neptune.types import File
+
+    assert hasattr(neptune, "__version__")
+
+    run = None  # NeptuneAI experiment logger instance
+
+except (ImportError, AssertionError):
+    neptune = None
+
+
+def _log_scalars(scalars, step=0):
+    """Log scalars to the NeptuneAI experiment logger."""
+    if run:
+        for k, v in scalars.items():
+            run[k].append(value=v, step=step)
+
+
+def _log_images(imgs_dict, group=""):
+    """Log scalars to the NeptuneAI experiment logger."""
+    if run:
+        for k, v in imgs_dict.items():
+            run[f"{group}/{k}"].upload(File(v))
+
+
+def _log_plot(title, plot_path):
+    """
+    Log plots to the NeptuneAI experiment logger.
+
+    Args:
+        title (str): Title of the plot.
+        plot_path (PosixPath | str): Path to the saved image file.
+    """
+    import matplotlib.image as mpimg
+    import matplotlib.pyplot as plt
+
+    img = mpimg.imread(plot_path)
+    fig = plt.figure()
+    ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[])  # no ticks
+    ax.imshow(img)
+    run[f"Plots/{title}"].upload(fig)
+
+
+def on_pretrain_routine_start(trainer):
+    """Callback function called before the training routine starts."""
+    try:
+        global run
+        run = neptune.init_run(
+            project=trainer.args.project or "Ultralytics",
+            name=trainer.args.name,
+            tags=["Ultralytics"],
+        )
+        run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
+    except Exception as e:
+        LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}")
+
+
+def on_train_epoch_end(trainer):
+    """Callback function called at end of each training epoch."""
+    _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
+    _log_scalars(trainer.lr, trainer.epoch + 1)
+    if trainer.epoch == 1:
+        _log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
+
+
+def on_fit_epoch_end(trainer):
+    """Callback function called at end of each fit (train+val) epoch."""
+    if run and trainer.epoch == 0:
+        from ultralytics.utils.torch_utils import model_info_for_loggers
+
+        run["Configuration/Model"] = model_info_for_loggers(trainer)
+    _log_scalars(trainer.metrics, trainer.epoch + 1)
+
+
+def on_val_end(validator):
+    """Callback function called at end of each validation."""
+    if run:
+        # Log val_labels and val_pred
+        _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
+
+
+def on_train_end(trainer):
+    """Callback function called at end of training."""
+    if run:
+        # Log final results, CM matrix + PR plots
+        files = [
+            "results.png",
+            "confusion_matrix.png",
+            "confusion_matrix_normalized.png",
+            *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
+        ]
+        files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()]  # filter
+        for f in files:
+            _log_plot(title=f.stem, plot_path=f)
+        # Log the final model
+        run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best)))
+
+
+callbacks = (
+    {
+        "on_pretrain_routine_start": on_pretrain_routine_start,
+        "on_train_epoch_end": on_train_epoch_end,
+        "on_fit_epoch_end": on_fit_epoch_end,
+        "on_val_end": on_val_end,
+        "on_train_end": on_train_end,
+    }
+    if neptune
+    else {}
+)

+ 28 - 0
ultralytics/utils/callbacks/raytune.py

@@ -0,0 +1,28 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.utils import SETTINGS
+
+try:
+    assert SETTINGS["raytune"] is True  # verify integration is enabled
+    import ray
+    from ray import tune
+    from ray.air import session
+
+except (ImportError, AssertionError):
+    tune = None
+
+
+def on_fit_epoch_end(trainer):
+    """Sends training metrics to Ray Tune at end of each epoch."""
+    if ray.train._internal.session.get_session():  # replacement for deprecated ray.tune.is_session_enabled()
+        metrics = trainer.metrics
+        session.report({**metrics, **{"epoch": trainer.epoch + 1}})
+
+
+callbacks = (
+    {
+        "on_fit_epoch_end": on_fit_epoch_end,
+    }
+    if tune
+    else {}
+)

+ 106 - 0
ultralytics/utils/callbacks/tensorboard.py

@@ -0,0 +1,106 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
+
+try:
+    # WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674
+    from torch.utils.tensorboard import SummaryWriter
+
+    assert not TESTS_RUNNING  # do not log pytest
+    assert SETTINGS["tensorboard"] is True  # verify integration is enabled
+    WRITER = None  # TensorBoard SummaryWriter instance
+    PREFIX = colorstr("TensorBoard: ")
+
+    # Imports below only required if TensorBoard enabled
+    import warnings
+    from copy import deepcopy
+
+    from ultralytics.utils.torch_utils import de_parallel, torch
+
+except (ImportError, AssertionError, TypeError, AttributeError):
+    # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
+    # AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
+    SummaryWriter = None
+
+
+def _log_scalars(scalars, step=0):
+    """Logs scalar values to TensorBoard."""
+    if WRITER:
+        for k, v in scalars.items():
+            WRITER.add_scalar(k, v, step)
+
+
+def _log_tensorboard_graph(trainer):
+    """Log model graph to TensorBoard."""
+    # Input image
+    imgsz = trainer.args.imgsz
+    imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
+    p = next(trainer.model.parameters())  # for device, type
+    im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype)  # input image (must be zeros, not empty)
+
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore", category=UserWarning)  # suppress jit trace warning
+        warnings.simplefilter("ignore", category=torch.jit.TracerWarning)  # suppress jit trace warning
+
+        # Try simple method first (YOLO)
+        try:
+            trainer.model.eval()  # place in .eval() mode to avoid BatchNorm statistics changes
+            WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
+            LOGGER.info(f"{PREFIX}model graph visualization added ✅")
+            return
+
+        except Exception:
+            # Fallback to TorchScript export steps (RTDETR)
+            try:
+                model = deepcopy(de_parallel(trainer.model))
+                model.eval()
+                model = model.fuse(verbose=False)
+                for m in model.modules():
+                    if hasattr(m, "export"):  # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
+                        m.export = True
+                        m.format = "torchscript"
+                model(im)  # dry run
+                WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
+                LOGGER.info(f"{PREFIX}model graph visualization added ✅")
+            except Exception as e:
+                LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
+
+
+def on_pretrain_routine_start(trainer):
+    """Initialize TensorBoard logging with SummaryWriter."""
+    if SummaryWriter:
+        try:
+            global WRITER
+            WRITER = SummaryWriter(str(trainer.save_dir))
+            LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
+        except Exception as e:
+            LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
+
+
+def on_train_start(trainer):
+    """Log TensorBoard graph."""
+    if WRITER:
+        _log_tensorboard_graph(trainer)
+
+
+def on_train_epoch_end(trainer):
+    """Logs scalar statistics at the end of a training epoch."""
+    _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
+    _log_scalars(trainer.lr, trainer.epoch + 1)
+
+
+def on_fit_epoch_end(trainer):
+    """Logs epoch metrics at end of training epoch."""
+    _log_scalars(trainer.metrics, trainer.epoch + 1)
+
+
+callbacks = (
+    {
+        "on_pretrain_routine_start": on_pretrain_routine_start,
+        "on_train_start": on_train_start,
+        "on_fit_epoch_end": on_fit_epoch_end,
+        "on_train_epoch_end": on_train_epoch_end,
+    }
+    if SummaryWriter
+    else {}
+)

+ 170 - 0
ultralytics/utils/callbacks/wb.py

@@ -0,0 +1,170 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.utils import SETTINGS, TESTS_RUNNING
+from ultralytics.utils.torch_utils import model_info_for_loggers
+
+try:
+    assert not TESTS_RUNNING  # do not log pytest
+    assert SETTINGS["wandb"] is True  # verify integration is enabled
+    import wandb as wb
+
+    assert hasattr(wb, "__version__")  # verify package is not directory
+    _processed_plots = {}
+
+except (ImportError, AssertionError):
+    wb = None
+
+
+def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
+    """
+    Create and log a custom metric visualization to wandb.plot.pr_curve.
+
+    This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
+    curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
+    different classes.
+
+    Args:
+        x (List): Values for the x-axis; expected to have length N.
+        y (List): Corresponding values for the y-axis; also expected to have length N.
+        classes (List): Labels identifying the class of each point; length N.
+        title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
+        x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
+        y_title (str, optional): Label for the y-axis; defaults to 'Precision'.
+
+    Returns:
+        (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
+    """
+    import pandas  # scope for faster 'import ultralytics'
+
+    df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
+    fields = {"x": "x", "y": "y", "class": "class"}
+    string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
+    return wb.plot_table(
+        "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
+    )
+
+
+def _plot_curve(
+    x,
+    y,
+    names=None,
+    id="precision-recall",
+    title="Precision Recall Curve",
+    x_title="Recall",
+    y_title="Precision",
+    num_x=100,
+    only_mean=False,
+):
+    """
+    Log a metric curve visualization.
+
+    This function generates a metric curve based on input data and logs the visualization to wandb.
+    The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
+
+    Args:
+        x (np.ndarray): Data points for the x-axis with length N.
+        y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
+        names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
+        id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
+        title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
+        x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
+        y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
+        num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
+        only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.
+
+    Note:
+        The function leverages the '_custom_table' function to generate the actual visualization.
+    """
+    import numpy as np
+
+    # Create new x
+    if names is None:
+        names = []
+    x_new = np.linspace(x[0], x[-1], num_x).round(5)
+
+    # Create arrays for logging
+    x_log = x_new.tolist()
+    y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()
+
+    if only_mean:
+        table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
+        wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
+    else:
+        classes = ["mean"] * len(x_log)
+        for i, yi in enumerate(y):
+            x_log.extend(x_new)  # add new x
+            y_log.extend(np.interp(x_new, x, yi))  # interpolate y to new x
+            classes.extend([names[i]] * len(x_new))  # add class names
+        wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)
+
+
+def _log_plots(plots, step):
+    """Logs plots from the input dictionary if they haven't been logged already at the specified step."""
+    for name, params in plots.copy().items():  # shallow copy to prevent plots dict changing during iteration
+        timestamp = params["timestamp"]
+        if _processed_plots.get(name) != timestamp:
+            wb.run.log({name.stem: wb.Image(str(name))}, step=step)
+            _processed_plots[name] = timestamp
+
+
+def on_pretrain_routine_start(trainer):
+    """Initiate and start project if module is present."""
+    if not wb.run:
+        wb.init(
+            project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics",
+            name=str(trainer.args.name).replace("/", "-"),
+            config=vars(trainer.args),
+        )
+
+
+def on_fit_epoch_end(trainer):
+    """Logs training metrics and model information at the end of an epoch."""
+    wb.run.log(trainer.metrics, step=trainer.epoch + 1)
+    _log_plots(trainer.plots, step=trainer.epoch + 1)
+    _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
+    if trainer.epoch == 0:
+        wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
+
+
+def on_train_epoch_end(trainer):
+    """Log metrics and save images at the end of each training epoch."""
+    wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
+    wb.run.log(trainer.lr, step=trainer.epoch + 1)
+    if trainer.epoch == 1:
+        _log_plots(trainer.plots, step=trainer.epoch + 1)
+
+
+def on_train_end(trainer):
+    """Save the best model as an artifact at end of training."""
+    _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
+    _log_plots(trainer.plots, step=trainer.epoch + 1)
+    art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
+    if trainer.best.exists():
+        art.add_file(trainer.best)
+        wb.run.log_artifact(art, aliases=["best"])
+    # Check if we actually have plots to save
+    if trainer.args.plots and hasattr(trainer.validator.metrics, "curves_results"):
+        for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
+            x, y, x_title, y_title = curve_values
+            _plot_curve(
+                x,
+                y,
+                names=list(trainer.validator.metrics.names.values()),
+                id=f"curves/{curve_name}",
+                title=curve_name,
+                x_title=x_title,
+                y_title=y_title,
+            )
+    wb.run.finish()  # required or run continues on dashboard
+
+
+callbacks = (
+    {
+        "on_pretrain_routine_start": on_pretrain_routine_start,
+        "on_train_epoch_end": on_train_epoch_end,
+        "on_fit_epoch_end": on_fit_epoch_end,
+        "on_train_end": on_train_end,
+    }
+    if wb
+    else {}
+)

+ 803 - 0
ultralytics/utils/checks.py

@@ -0,0 +1,803 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import glob
+import inspect
+import math
+import os
+import platform
+import re
+import shutil
+import subprocess
+import time
+from importlib import metadata
+from pathlib import Path
+from typing import Optional
+
+import cv2
+import numpy as np
+import requests
+import torch
+
+from ultralytics.utils import (
+    ASSETS,
+    AUTOINSTALL,
+    IS_COLAB,
+    IS_GIT_DIR,
+    IS_KAGGLE,
+    IS_PIP_PACKAGE,
+    LINUX,
+    LOGGER,
+    MACOS,
+    ONLINE,
+    PYTHON_VERSION,
+    ROOT,
+    TORCHVISION_VERSION,
+    USER_CONFIG_DIR,
+    WINDOWS,
+    Retry,
+    SimpleNamespace,
+    ThreadingLocked,
+    TryExcept,
+    clean_url,
+    colorstr,
+    downloads,
+    emojis,
+    is_github_action_running,
+    url2file,
+)
+
+
+def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
+    """
+    Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
+
+    Args:
+        file_path (Path): Path to the requirements.txt file.
+        package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'.
+
+    Returns:
+        (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys.
+
+    Example:
+        ```python
+        from ultralytics.utils.checks import parse_requirements
+
+        parse_requirements(package="ultralytics")
+        ```
+    """
+    if package:
+        requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
+    else:
+        requires = Path(file_path).read_text().splitlines()
+
+    requirements = []
+    for line in requires:
+        line = line.strip()
+        if line and not line.startswith("#"):
+            line = line.split("#")[0].strip()  # ignore inline comments
+            if match := re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line):
+                requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
+
+    return requirements
+
+
+def parse_version(version="0.0.0") -> tuple:
+    """
+    Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
+    function replaces deprecated 'pkg_resources.parse_version(v)'.
+
+    Args:
+        version (str): Version string, i.e. '2.0.1+cpu'
+
+    Returns:
+        (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
+    """
+    try:
+        return tuple(map(int, re.findall(r"\d+", version)[:3]))  # '2.0.1+cpu' -> (2, 0, 1)
+    except Exception as e:
+        LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}")
+        return 0, 0, 0
+
+
+def is_ascii(s) -> bool:
+    """
+    Check if a string is composed of only ASCII characters.
+
+    Args:
+        s (str): String to be checked.
+
+    Returns:
+        (bool): True if the string is composed only of ASCII characters, False otherwise.
+    """
+    # Convert list, tuple, None, etc. to string
+    s = str(s)
+
+    # Check if the string is composed of only ASCII characters
+    return all(ord(c) < 128 for c in s)
+
+
+def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
+    """
+    Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
+    stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
+
+    Args:
+        imgsz (int | cList[int]): Image size.
+        stride (int): Stride value.
+        min_dim (int): Minimum number of dimensions.
+        max_dim (int): Maximum number of dimensions.
+        floor (int): Minimum allowed value for image size.
+
+    Returns:
+        (List[int]): Updated image size.
+    """
+    # Convert stride to integer if it is a tensor
+    stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
+
+    # Convert image size to list if it is an integer
+    if isinstance(imgsz, int):
+        imgsz = [imgsz]
+    elif isinstance(imgsz, (list, tuple)):
+        imgsz = list(imgsz)
+    elif isinstance(imgsz, str):  # i.e. '640' or '[640,640]'
+        imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)
+    else:
+        raise TypeError(
+            f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
+            f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
+        )
+
+    # Apply max_dim
+    if len(imgsz) > max_dim:
+        msg = (
+            "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
+            "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
+        )
+        if max_dim != 1:
+            raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
+        LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
+        imgsz = [max(imgsz)]
+    # Make image size a multiple of the stride
+    sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
+
+    # Print warning message if image size was updated
+    if sz != imgsz:
+        LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
+
+    # Add missing dimensions if necessary
+    sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
+
+    return sz
+
+
+def check_version(
+    current: str = "0.0.0",
+    required: str = "0.0.0",
+    name: str = "version",
+    hard: bool = False,
+    verbose: bool = False,
+    msg: str = "",
+) -> bool:
+    """
+    Check current version against the required version or range.
+
+    Args:
+        current (str): Current version or package name to get version from.
+        required (str): Required version or range (in pip-style format).
+        name (str, optional): Name to be used in warning message.
+        hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
+        verbose (bool, optional): If True, print warning message if requirement is not met.
+        msg (str, optional): Extra message to display if verbose.
+
+    Returns:
+        (bool): True if requirement is met, False otherwise.
+
+    Example:
+        ```python
+        # Check if current version is exactly 22.04
+        check_version(current="22.04", required="==22.04")
+
+        # Check if current version is greater than or equal to 22.04
+        check_version(current="22.10", required="22.04")  # assumes '>=' inequality if none passed
+
+        # Check if current version is less than or equal to 22.04
+        check_version(current="22.04", required="<=22.04")
+
+        # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
+        check_version(current="21.10", required=">20.04,<22.04")
+        ```
+    """
+    if not current:  # if current is '' or None
+        LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.")
+        return True
+    elif not current[0].isdigit():  # current is package name rather than version string, i.e. current='ultralytics'
+        try:
+            name = current  # assigned package name to 'name' arg
+            current = metadata.version(current)  # get version string from package name
+        except metadata.PackageNotFoundError as e:
+            if hard:
+                raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e
+            else:
+                return False
+
+    if not required:  # if required is '' or None
+        return True
+
+    if "sys_platform" in required and (  # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"'
+        (WINDOWS and "win32" not in required)
+        or (LINUX and "linux" not in required)
+        or (MACOS and "macos" not in required and "darwin" not in required)
+    ):
+        return True
+
+    op = ""
+    version = ""
+    result = True
+    c = parse_version(current)  # '1.2.3' -> (1, 2, 3)
+    for r in required.strip(",").split(","):
+        op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups()  # split '>=22.04' -> ('>=', '22.04')
+        if not op:
+            op = ">="  # assume >= if no op passed
+        v = parse_version(version)  # '1.2.3' -> (1, 2, 3)
+        if op == "==" and c != v:
+            result = False
+        elif op == "!=" and c == v:
+            result = False
+        elif op == ">=" and not (c >= v):
+            result = False
+        elif op == "<=" and not (c <= v):
+            result = False
+        elif op == ">" and not (c > v):
+            result = False
+        elif op == "<" and not (c < v):
+            result = False
+    if not result:
+        warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}"
+        if hard:
+            raise ModuleNotFoundError(emojis(warning))  # assert version requirements met
+        if verbose:
+            LOGGER.warning(warning)
+    return result
+
+
+def check_latest_pypi_version(package_name="ultralytics"):
+    """
+    Returns the latest version of a PyPI package without downloading or installing it.
+
+    Args:
+        package_name (str): The name of the package to find the latest version for.
+
+    Returns:
+        (str): The latest version of the package.
+    """
+    try:
+        requests.packages.urllib3.disable_warnings()  # Disable the InsecureRequestWarning
+        response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
+        if response.status_code == 200:
+            return response.json()["info"]["version"]
+    except Exception:
+        return None
+
+
+def check_pip_update_available():
+    """
+    Checks if a new version of the ultralytics package is available on PyPI.
+
+    Returns:
+        (bool): True if an update is available, False otherwise.
+    """
+    if ONLINE and IS_PIP_PACKAGE:
+        try:
+            from ultralytics import __version__
+
+            latest = check_latest_pypi_version()
+            if check_version(__version__, f"<{latest}"):  # check if current version is < latest version
+                LOGGER.info(
+                    f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
+                    f"Update with 'pip install -U ultralytics'"
+                )
+                return True
+        except Exception:
+            pass
+    return False
+
+
+@ThreadingLocked()
+def check_font(font="Arial.ttf"):
+    """
+    Find font locally or download to user's configuration directory if it does not already exist.
+
+    Args:
+        font (str): Path or name of font.
+
+    Returns:
+        file (Path): Resolved font file path.
+    """
+    from matplotlib import font_manager
+
+    # Check USER_CONFIG_DIR
+    name = Path(font).name
+    file = USER_CONFIG_DIR / name
+    if file.exists():
+        return file
+
+    # Check system fonts
+    matches = [s for s in font_manager.findSystemFonts() if font in s]
+    if any(matches):
+        return matches[0]
+
+    # Download to USER_CONFIG_DIR if missing
+    url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}"
+    if downloads.is_url(url, check=True):
+        downloads.safe_download(url=url, file=file)
+        return file
+
+
+def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool:
+    """
+    Check current python version against the required minimum version.
+
+    Args:
+        minimum (str): Required minimum version of python.
+        hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
+        verbose (bool, optional): If True, print warning message if requirement is not met.
+
+    Returns:
+        (bool): Whether the installed Python version meets the minimum constraints.
+    """
+    return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose)
+
+
+@TryExcept()
+def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
+    """
+    Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed.
+
+    Args:
+        requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a
+            string, or a list of package requirements as strings.
+        exclude (Tuple[str]): Tuple of package names to exclude from checking.
+        install (bool): If True, attempt to auto-update packages that don't meet requirements.
+        cmds (str): Additional commands to pass to the pip install command when auto-updating.
+
+    Example:
+        ```python
+        from ultralytics.utils.checks import check_requirements
+
+        # Check a requirements.txt file
+        check_requirements("path/to/requirements.txt")
+
+        # Check a single package
+        check_requirements("ultralytics>=8.0.0")
+
+        # Check multiple packages
+        check_requirements(["numpy", "ultralytics>=8.0.0"])
+        ```
+    """
+    prefix = colorstr("red", "bold", "requirements:")
+    if isinstance(requirements, Path):  # requirements.txt file
+        file = requirements.resolve()
+        assert file.exists(), f"{prefix} {file} not found, check failed."
+        requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
+    elif isinstance(requirements, str):
+        requirements = [requirements]
+
+    pkgs = []
+    for r in requirements:
+        r_stripped = r.split("/")[-1].replace(".git", "")  # replace git+https://org/repo.git -> 'repo'
+        match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
+        name, required = match[1], match[2].strip() if match[2] else ""
+        try:
+            assert check_version(metadata.version(name), required)  # exception if requirements not met
+        except (AssertionError, metadata.PackageNotFoundError):
+            pkgs.append(r)
+
+    @Retry(times=2, delay=1)
+    def attempt_install(packages, commands):
+        """Attempt pip install command with retries on failure."""
+        return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode()
+
+    s = " ".join(f'"{x}"' for x in pkgs)  # console string
+    if s:
+        if install and AUTOINSTALL:  # check environment variable
+            n = len(pkgs)  # number of packages updates
+            LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
+            try:
+                t = time.time()
+                assert ONLINE, "AutoUpdate skipped (offline)"
+                LOGGER.info(attempt_install(s, cmds))
+                dt = time.time() - t
+                LOGGER.info(
+                    f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
+                    f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
+                )
+            except Exception as e:
+                LOGGER.warning(f"{prefix} ❌ {e}")
+                return False
+        else:
+            return False
+
+    return True
+
+
+def check_torchvision():
+    """
+    Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
+
+    This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
+    to the provided compatibility table based on:
+    https://github.com/pytorch/vision#installation.
+
+    The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
+    Torchvision versions.
+    """
+    # Compatibility table
+    compatibility_table = {
+        "2.5": ["0.20"],
+        "2.4": ["0.19"],
+        "2.3": ["0.18"],
+        "2.2": ["0.17"],
+        "2.1": ["0.16"],
+        "2.0": ["0.15"],
+        "1.13": ["0.14"],
+        "1.12": ["0.13"],
+    }
+
+    # Extract only the major and minor versions
+    v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
+    if v_torch in compatibility_table:
+        compatible_versions = compatibility_table[v_torch]
+        v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
+        if all(v_torchvision != v for v in compatible_versions):
+            print(
+                f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
+                f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
+                "'pip install -U torch torchvision' to update both.\n"
+                "For a full compatibility table see https://github.com/pytorch/vision#installation"
+            )
+
+
+def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""):
+    """Check file(s) for acceptable suffix."""
+    if file and suffix:
+        if isinstance(suffix, str):
+            suffix = (suffix,)
+        for f in file if isinstance(file, (list, tuple)) else [file]:
+            s = Path(f).suffix.lower().strip()  # file suffix
+            if len(s):
+                assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"
+
+
+def check_yolov5u_filename(file: str, verbose: bool = True):
+    """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
+    if "yolov3" in file or "yolov5" in file:
+        if "u.yaml" in file:
+            file = file.replace("u.yaml", ".yaml")  # i.e. yolov5nu.yaml -> yolov5n.yaml
+        elif ".pt" in file and "u" not in file:
+            original_file = file
+            file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file)  # i.e. yolov5n.pt -> yolov5nu.pt
+            file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file)  # i.e. yolov5n6.pt -> yolov5n6u.pt
+            file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file)  # i.e. yolov3-spp.pt -> yolov3-sppu.pt
+            if file != original_file and verbose:
+                LOGGER.info(
+                    f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
+                    f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
+                    f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
+                )
+    return file
+
+
+def check_model_file_from_stem(model="yolov8n"):
+    """Return a model filename from a valid model stem."""
+    if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
+        return Path(model).with_suffix(".pt")  # add suffix, i.e. yolov8n -> yolov8n.pt
+    else:
+        return model
+
+
+def check_file(file, suffix="", download=True, download_dir=".", hard=True):
+    """Search/download file (if necessary) and return path."""
+    check_suffix(file, suffix)  # optional
+    file = str(file).strip()  # convert to string and strip spaces
+    file = check_yolov5u_filename(file)  # yolov5n -> yolov5nu
+    if (
+        not file
+        or ("://" not in file and Path(file).exists())  # '://' check required in Windows Python<3.10
+        or file.lower().startswith("grpc://")
+    ):  # file exists or gRPC Triton images
+        return file
+    elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):  # download
+        url = file  # warning: Pathlib turns :// -> :/
+        file = Path(download_dir) / url2file(file)  # '%2F' to '/', split https://url.com/file.txt?auth
+        if file.exists():
+            LOGGER.info(f"Found {clean_url(url)} locally at {file}")  # file already exists
+        else:
+            downloads.safe_download(url=url, file=file, unzip=False)
+        return str(file)
+    else:  # search
+        files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file))  # find file
+        if not files and hard:
+            raise FileNotFoundError(f"'{file}' does not exist")
+        elif len(files) > 1 and hard:
+            raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
+        return files[0] if len(files) else []  # return file
+
+
+def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
+    """Search/download YAML file (if necessary) and return path, checking suffix."""
+    return check_file(file, suffix, hard=hard)
+
+
+def check_is_path_safe(basedir, path):
+    """
+    Check if the resolved path is under the intended directory to prevent path traversal.
+
+    Args:
+        basedir (Path | str): The intended directory.
+        path (Path | str): The path to check.
+
+    Returns:
+        (bool): True if the path is safe, False otherwise.
+    """
+    base_dir_resolved = Path(basedir).resolve()
+    path_resolved = Path(path).resolve()
+
+    return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
+
+
+def check_imshow(warn=False):
+    """Check if environment supports image displays."""
+    try:
+        if LINUX:
+            assert not IS_COLAB and not IS_KAGGLE
+            assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set."
+        cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8))  # show a small 8-pixel image
+        cv2.waitKey(1)
+        cv2.destroyAllWindows()
+        cv2.waitKey(1)
+        return True
+    except Exception as e:
+        if warn:
+            LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
+        return False
+
+
+def check_yolo(verbose=True, device=""):
+    """Return a human-readable YOLO software and hardware summary."""
+    import psutil
+
+    from ultralytics.utils.torch_utils import select_device
+
+    if IS_COLAB:
+        shutil.rmtree("sample_data", ignore_errors=True)  # remove colab /sample_data directory
+
+    if verbose:
+        # System info
+        gib = 1 << 30  # bytes per GiB
+        ram = psutil.virtual_memory().total
+        total, used, free = shutil.disk_usage("/")
+        s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
+        try:
+            from IPython import display
+
+            display.clear_output()  # clear display if notebook
+        except ImportError:
+            pass
+    else:
+        s = ""
+
+    select_device(device=device, newline=False)
+    LOGGER.info(f"Setup complete ✅ {s}")
+
+
+def collect_system_info():
+    """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA."""
+    import psutil
+
+    from ultralytics.utils import ENVIRONMENT  # scope to avoid circular import
+    from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info
+
+    gib = 1 << 30  # bytes per GiB
+    cuda = torch and torch.cuda.is_available()
+    check_yolo()
+    total, used, free = shutil.disk_usage("/")
+
+    info_dict = {
+        "OS": platform.platform(),
+        "Environment": ENVIRONMENT,
+        "Python": PYTHON_VERSION,
+        "Install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other",
+        "RAM": f"{psutil.virtual_memory().total / gib:.2f} GB",
+        "Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB",
+        "CPU": get_cpu_info(),
+        "CPU count": os.cpu_count(),
+        "GPU": get_gpu_info(index=0) if cuda else None,
+        "GPU count": torch.cuda.device_count() if cuda else None,
+        "CUDA": torch.version.cuda if cuda else None,
+    }
+    LOGGER.info("\n" + "\n".join(f"{k:<20}{v}" for k, v in info_dict.items()) + "\n")
+
+    package_info = {}
+    for r in parse_requirements(package="ultralytics"):
+        try:
+            current = metadata.version(r.name)
+            is_met = "✅ " if check_version(current, str(r.specifier), name=r.name, hard=True) else "❌ "
+        except metadata.PackageNotFoundError:
+            current = "(not installed)"
+            is_met = "❌ "
+        package_info[r.name] = f"{is_met}{current}{r.specifier}"
+        LOGGER.info(f"{r.name:<20}{package_info[r.name]}")
+
+    info_dict["Package Info"] = package_info
+
+    if is_github_action_running():
+        github_info = {
+            "RUNNER_OS": os.getenv("RUNNER_OS"),
+            "GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"),
+            "GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"),
+            "GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"),
+            "GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"),
+            "GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"),
+        }
+        LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items()))
+        info_dict["GitHub Info"] = github_info
+
+    return info_dict
+
+
+def check_amp(model):
+    """
+    Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model. If the checks fail, it means
+    there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled
+    during training.
+
+    Args:
+        model (nn.Module): A YOLO11 model instance.
+
+    Example:
+        ```python
+        from ultralytics import YOLO
+        from ultralytics.utils.checks import check_amp
+
+        model = YOLO("yolo11n.pt").model.cuda()
+        check_amp(model)
+        ```
+
+    Returns:
+        (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.
+    """
+    from ultralytics.utils.torch_utils import autocast
+
+    device = next(model.parameters()).device  # get model device
+    prefix = colorstr("AMP: ")
+    if device.type in {"cpu", "mps"}:
+        return False  # AMP only used on CUDA devices
+    else:
+        # GPUs that have issues with AMP
+        pattern = re.compile(
+            r"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)", re.IGNORECASE
+        )
+
+        gpu = torch.cuda.get_device_name(device)
+        if bool(pattern.search(gpu)):
+            LOGGER.warning(
+                f"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause "
+                f"NaN losses or zero-mAP results, so AMP will be disabled during training."
+            )
+            return False
+
+    def amp_allclose(m, im):
+        """All close FP32 vs AMP results."""
+        batch = [im] * 8
+        imgsz = max(256, int(model.stride.max() * 4))  # max stride P5-32 and P6-64
+        a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data  # FP32 inference
+        with autocast(enabled=True):
+            b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data  # AMP inference
+        del m
+        return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5)  # close to 0.5 absolute tolerance
+
+    im = ASSETS / "bus.jpg"  # image to check
+    LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...")
+    warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
+    try:
+        from ultralytics import YOLO
+
+        # assert amp_allclose(YOLO("yolo11n.pt"), im)
+        assert amp_allclose(YOLO("yolov13n.pt"), im)
+        LOGGER.info(f"{prefix}checks passed ✅")
+    except ConnectionError:
+        LOGGER.warning(
+            f"{prefix}checks skipped ⚠️. Offline and unable to download YOLO11n for AMP checks. {warning_msg}"
+        )
+    except (AttributeError, ModuleNotFoundError):
+        LOGGER.warning(
+            f"{prefix}checks skipped ⚠️. "
+            f"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}"
+        )
+    except AssertionError:
+        LOGGER.warning(
+            f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to "
+            f"NaN losses or zero-mAP results, so AMP will be disabled during training."
+        )
+        return False
+    return True
+
+
+def git_describe(path=ROOT):  # path must be a directory
+    """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
+    try:
+        return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
+    except Exception:
+        return ""
+
+
+def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
+    """Print function arguments (optional args dict)."""
+
+    def strip_auth(v):
+        """Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
+        return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
+
+    x = inspect.currentframe().f_back  # previous frame
+    file, _, func, _, _ = inspect.getframeinfo(x)
+    if args is None:  # get args automatically
+        args, _, _, frm = inspect.getargvalues(x)
+        args = {k: v for k, v in frm.items() if k in args}
+    try:
+        file = Path(file).resolve().relative_to(ROOT).with_suffix("")
+    except ValueError:
+        file = Path(file).stem
+    s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
+    LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items()))
+
+
+def cuda_device_count() -> int:
+    """
+    Get the number of NVIDIA GPUs available in the environment.
+
+    Returns:
+        (int): The number of NVIDIA GPUs available.
+    """
+    try:
+        # Run the nvidia-smi command and capture its output
+        output = subprocess.check_output(
+            ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
+        )
+
+        # Take the first line and strip any leading/trailing white space
+        first_line = output.strip().split("\n")[0]
+
+        return int(first_line)
+    except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
+        # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available
+        return 0
+
+
+def cuda_is_available() -> bool:
+    """
+    Check if CUDA is available in the environment.
+
+    Returns:
+        (bool): True if one or more NVIDIA GPUs are available, False otherwise.
+    """
+    return cuda_device_count() > 0
+
+
+def is_sudo_available() -> bool:
+    """
+    Check if the sudo command is available in the environment.
+
+    Returns:
+        (bool): True if the sudo command is available, False otherwise.
+    """
+    if WINDOWS:
+        return False
+    cmd = "sudo --version"
+    return subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
+
+
+# Run checks and define constants
+check_python("3.8", hard=False, verbose=True)  # check python version
+check_torchvision()  # check torch-torchvision compatibility
+IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False)
+IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")

+ 72 - 0
ultralytics/utils/dist.py

@@ -0,0 +1,72 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import os
+import shutil
+import socket
+import sys
+import tempfile
+
+from . import USER_CONFIG_DIR
+from .torch_utils import TORCH_1_9
+
+
+def find_free_network_port() -> int:
+    """
+    Finds a free port on localhost.
+
+    It is useful in single-node training when we don't want to connect to a real main node but have to set the
+    `MASTER_PORT` environment variable.
+    """
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+        s.bind(("127.0.0.1", 0))
+        return s.getsockname()[1]  # port
+
+
+def generate_ddp_file(trainer):
+    """Generates a DDP file and returns its file name."""
+    module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
+
+    content = f"""
+# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
+overrides = {vars(trainer.args)}
+
+if __name__ == "__main__":
+    from {module} import {name}
+    from ultralytics.utils import DEFAULT_CFG_DICT
+
+    cfg = DEFAULT_CFG_DICT.copy()
+    cfg.update(save_dir='')   # handle the extra key 'save_dir'
+    trainer = {name}(cfg=cfg, overrides=overrides)
+    trainer.args.model = "{getattr(trainer.hub_session, "model_url", trainer.args.model)}"
+    results = trainer.train()
+"""
+    (USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
+    with tempfile.NamedTemporaryFile(
+        prefix="_temp_",
+        suffix=f"{id(trainer)}.py",
+        mode="w+",
+        encoding="utf-8",
+        dir=USER_CONFIG_DIR / "DDP",
+        delete=False,
+    ) as file:
+        file.write(content)
+    return file.name
+
+
+def generate_ddp_command(world_size, trainer):
+    """Generates and returns command for distributed training."""
+    import __main__  # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
+
+    if not trainer.resume:
+        shutil.rmtree(trainer.save_dir)  # remove the save_dir
+    file = generate_ddp_file(trainer)
+    dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
+    port = find_free_network_port()
+    cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
+    return cmd, file
+
+
+def ddp_cleanup(trainer, file):
+    """Delete temp file if created."""
+    if f"{id(trainer)}.py" in file:  # if temp_file suffix in file
+        os.remove(file)

+ 510 - 0
ultralytics/utils/downloads.py

@@ -0,0 +1,510 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import re
+import shutil
+import subprocess
+from itertools import repeat
+from multiprocessing.pool import ThreadPool
+from pathlib import Path
+from urllib import parse, request
+
+import requests
+import torch
+
+from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
+
+# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
+GITHUB_ASSETS_REPO = "ultralytics/assets"
+GITHUB_ASSETS_NAMES = (
+    [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")]
+    + [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
+    + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
+    + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
+    + [f"yolov8{k}-world.pt" for k in "smlx"]
+    + [f"yolov8{k}-worldv2.pt" for k in "smlx"]
+    + [f"yolov9{k}.pt" for k in "tsmce"]
+    + [f"yolov10{k}.pt" for k in "nsmblx"]
+    + [f"yolo_nas_{k}.pt" for k in "sml"]
+    + [f"sam_{k}.pt" for k in "bl"]
+    + [f"FastSAM-{k}.pt" for k in "sx"]
+    + [f"rtdetr-{k}.pt" for k in "lx"]
+    + ["mobile_sam.pt"]
+    + ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"]
+)
+GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
+
+
+def is_url(url, check=False):
+    """
+    Validates if the given string is a URL and optionally checks if the URL exists online.
+
+    Args:
+        url (str): The string to be validated as a URL.
+        check (bool, optional): If True, performs an additional check to see if the URL exists online.
+            Defaults to False.
+
+    Returns:
+        (bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online.
+            Returns False otherwise.
+
+    Example:
+        ```python
+        valid = is_url("https://www.example.com")
+        ```
+    """
+    try:
+        url = str(url)
+        result = parse.urlparse(url)
+        assert all([result.scheme, result.netloc])  # check if is url
+        if check:
+            with request.urlopen(url) as response:
+                return response.getcode() == 200  # check if exists online
+        return True
+    except Exception:
+        return False
+
+
+def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
+    """
+    Deletes all ".DS_store" files under a specified directory.
+
+    Args:
+        path (str, optional): The directory path where the ".DS_store" files should be deleted.
+        files_to_delete (tuple): The files to be deleted.
+
+    Example:
+        ```python
+        from ultralytics.utils.downloads import delete_dsstore
+
+        delete_dsstore("path/to/dir")
+        ```
+
+    Note:
+        ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
+        are hidden system files and can cause issues when transferring files between different operating systems.
+    """
+    for file in files_to_delete:
+        matches = list(Path(path).rglob(file))
+        LOGGER.info(f"Deleting {file} files: {matches}")
+        for f in matches:
+            f.unlink()
+
+
+def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):
+    """
+    Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
+    named after the directory and placed alongside it.
+
+    Args:
+        directory (str | Path): The path to the directory to be zipped.
+        compress (bool): Whether to compress the files while zipping. Default is True.
+        exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
+        progress (bool, optional): Whether to display a progress bar. Defaults to True.
+
+    Returns:
+        (Path): The path to the resulting zip file.
+
+    Example:
+        ```python
+        from ultralytics.utils.downloads import zip_directory
+
+        file = zip_directory("path/to/dir")
+        ```
+    """
+    from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile
+
+    delete_dsstore(directory)
+    directory = Path(directory)
+    if not directory.is_dir():
+        raise FileNotFoundError(f"Directory '{directory}' does not exist.")
+
+    # Unzip with progress bar
+    files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)]
+    zip_file = directory.with_suffix(".zip")
+    compression = ZIP_DEFLATED if compress else ZIP_STORED
+    with ZipFile(zip_file, "w", compression) as f:
+        for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress):
+            f.write(file, file.relative_to(directory))
+
+    return zip_file  # return path to zip file
+
+
+def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
+    """
+    Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
+
+    If the zipfile does not contain a single top-level directory, the function will create a new
+    directory with the same name as the zipfile (without the extension) to extract its contents.
+    If a path is not provided, the function will use the parent directory of the zipfile as the default path.
+
+    Args:
+        file (str | Path): The path to the zipfile to be extracted.
+        path (str, optional): The path to extract the zipfile to. Defaults to None.
+        exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
+        exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False.
+        progress (bool, optional): Whether to display a progress bar. Defaults to True.
+
+    Raises:
+        BadZipFile: If the provided file does not exist or is not a valid zipfile.
+
+    Returns:
+        (Path): The path to the directory where the zipfile was extracted.
+
+    Example:
+        ```python
+        from ultralytics.utils.downloads import unzip_file
+
+        dir = unzip_file("path/to/file.zip")
+        ```
+    """
+    from zipfile import BadZipFile, ZipFile, is_zipfile
+
+    if not (Path(file).exists() and is_zipfile(file)):
+        raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.")
+    if path is None:
+        path = Path(file).parent  # default path
+
+    # Unzip the file contents
+    with ZipFile(file) as zipObj:
+        files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
+        top_level_dirs = {Path(f).parts[0] for f in files}
+
+        # Decide to unzip directly or unzip into a directory
+        unzip_as_dir = len(top_level_dirs) == 1  # (len(files) > 1 and not files[0].endswith("/"))
+        if unzip_as_dir:
+            # Zip has 1 top-level directory
+            extract_path = path  # i.e. ../datasets
+            path = Path(path) / list(top_level_dirs)[0]  # i.e. extract coco8/ dir to ../datasets/
+        else:
+            # Zip has multiple files at top level
+            path = extract_path = Path(path) / Path(file).stem  # i.e. extract multiple files to ../datasets/coco8/
+
+        # Check if destination directory already exists and contains files
+        if path.exists() and any(path.iterdir()) and not exist_ok:
+            # If it exists and is not empty, return the path without unzipping
+            LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.")
+            return path
+
+        for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress):
+            # Ensure the file is within the extract_path to avoid path traversal security vulnerability
+            if ".." in Path(f).parts:
+                LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
+                continue
+            zipObj.extract(f, extract_path)
+
+    return path  # return unzip dir
+
+
+def check_disk_space(url="https://ultralytics.com/assets/coco8.zip", path=Path.cwd(), sf=1.5, hard=True):
+    """
+    Check if there is sufficient disk space to download and store a file.
+
+    Args:
+        url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco8.zip'.
+        path (str | Path, optional): The path or drive to check the available free space on.
+        sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 1.5.
+        hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True.
+
+    Returns:
+        (bool): True if there is sufficient disk space, False otherwise.
+    """
+    try:
+        r = requests.head(url)  # response
+        assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}"  # check response
+    except Exception:
+        return True  # requests issue, default to True
+
+    # Check file size
+    gib = 1 << 30  # bytes per GiB
+    data = int(r.headers.get("Content-Length", 0)) / gib  # file size (GB)
+    total, used, free = (x / gib for x in shutil.disk_usage(path))  # bytes
+
+    if data * sf < free:
+        return True  # sufficient space
+
+    # Insufficient space
+    text = (
+        f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, "
+        f"Please free {data * sf - free:.1f} GB additional disk space and try again."
+    )
+    if hard:
+        raise MemoryError(text)
+    LOGGER.warning(text)
+    return False
+
+
+def get_google_drive_file_info(link):
+    """
+    Retrieves the direct download link and filename for a shareable Google Drive file link.
+
+    Args:
+        link (str): The shareable link of the Google Drive file.
+
+    Returns:
+        (str): Direct download URL for the Google Drive file.
+        (str): Original filename of the Google Drive file. If filename extraction fails, returns None.
+
+    Example:
+        ```python
+        from ultralytics.utils.downloads import get_google_drive_file_info
+
+        link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link"
+        url, filename = get_google_drive_file_info(link)
+        ```
+    """
+    file_id = link.split("/d/")[1].split("/view")[0]
+    drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
+    filename = None
+
+    # Start session
+    with requests.Session() as session:
+        response = session.get(drive_url, stream=True)
+        if "quota exceeded" in str(response.content.lower()):
+            raise ConnectionError(
+                emojis(
+                    f"❌  Google Drive file download quota exceeded. "
+                    f"Please try again later or download this file manually at {link}."
+                )
+            )
+        for k, v in response.cookies.items():
+            if k.startswith("download_warning"):
+                drive_url += f"&confirm={v}"  # v is token
+        if cd := response.headers.get("content-disposition"):
+            filename = re.findall('filename="(.+)"', cd)[0]
+    return drive_url, filename
+
+
+def safe_download(
+    url,
+    file=None,
+    dir=None,
+    unzip=True,
+    delete=False,
+    curl=False,
+    retry=3,
+    min_bytes=1e0,
+    exist_ok=False,
+    progress=True,
+):
+    """
+    Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
+
+    Args:
+        url (str): The URL of the file to be downloaded.
+        file (str, optional): The filename of the downloaded file.
+            If not provided, the file will be saved with the same name as the URL.
+        dir (str, optional): The directory to save the downloaded file.
+            If not provided, the file will be saved in the current working directory.
+        unzip (bool, optional): Whether to unzip the downloaded file. Default: True.
+        delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False.
+        curl (bool, optional): Whether to use curl command line tool for downloading. Default: False.
+        retry (int, optional): The number of times to retry the download in case of failure. Default: 3.
+        min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
+            a successful download. Default: 1E0.
+        exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
+        progress (bool, optional): Whether to display a progress bar during the download. Default: True.
+
+    Example:
+        ```python
+        from ultralytics.utils.downloads import safe_download
+
+        link = "https://ultralytics.com/assets/bus.jpg"
+        path = safe_download(link)
+        ```
+    """
+    gdrive = url.startswith("https://drive.google.com/")  # check if the URL is a Google Drive link
+    if gdrive:
+        url, file = get_google_drive_file_info(url)
+
+    f = Path(dir or ".") / (file or url2file(url))  # URL converted to filename
+    if "://" not in str(url) and Path(url).is_file():  # URL exists ('://' check required in Windows Python<3.10)
+        f = Path(url)  # filename
+    elif not f.is_file():  # URL and file do not exist
+        uri = (url if gdrive else clean_url(url)).replace(  # cleaned and aliased url
+            "https://github.com/ultralytics/assets/releases/download/v0.0.0/",
+            "https://ultralytics.com/assets/",  # assets alias
+        )
+        desc = f"Downloading {uri} to '{f}'"
+        LOGGER.info(f"{desc}...")
+        f.parent.mkdir(parents=True, exist_ok=True)  # make directory if missing
+        check_disk_space(url, path=f.parent)
+        for i in range(retry + 1):
+            try:
+                if curl or i > 0:  # curl download with retry, continue
+                    s = "sS" * (not progress)  # silent
+                    r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
+                    assert r == 0, f"Curl return value {r}"
+                else:  # urllib download
+                    method = "torch"
+                    if method == "torch":
+                        torch.hub.download_url_to_file(url, f, progress=progress)
+                    else:
+                        with request.urlopen(url) as response, TQDM(
+                            total=int(response.getheader("Content-Length", 0)),
+                            desc=desc,
+                            disable=not progress,
+                            unit="B",
+                            unit_scale=True,
+                            unit_divisor=1024,
+                        ) as pbar:
+                            with open(f, "wb") as f_opened:
+                                for data in response:
+                                    f_opened.write(data)
+                                    pbar.update(len(data))
+
+                if f.exists():
+                    if f.stat().st_size > min_bytes:
+                        break  # success
+                    f.unlink()  # remove partial downloads
+            except Exception as e:
+                if i == 0 and not is_online():
+                    raise ConnectionError(emojis(f"❌  Download failure for {uri}. Environment is not online.")) from e
+                elif i >= retry:
+                    raise ConnectionError(emojis(f"❌  Download failure for {uri}. Retry limit reached.")) from e
+                LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {uri}...")
+
+    if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}:
+        from zipfile import is_zipfile
+
+        unzip_dir = (dir or f.parent).resolve()  # unzip to dir if provided else unzip in place
+        if is_zipfile(f):
+            unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress)  # unzip
+        elif f.suffix in {".tar", ".gz"}:
+            LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
+            subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
+        if delete:
+            f.unlink()  # remove zip
+        return unzip_dir
+
+
+def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
+    """
+    Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the
+    function fetches the latest release assets.
+
+    Args:
+        repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
+        version (str, optional): The release version to fetch assets from. Defaults to 'latest'.
+        retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False.
+
+    Returns:
+        (tuple): A tuple containing the release tag and a list of asset names.
+
+    Example:
+        ```python
+        tag, assets = get_github_assets(repo="ultralytics/assets", version="latest")
+        ```
+    """
+    if version != "latest":
+        version = f"tags/{version}"  # i.e. tags/v6.2
+    url = f"https://api.github.com/repos/{repo}/releases/{version}"
+    r = requests.get(url)  # github api
+    if r.status_code != 200 and r.reason != "rate limit exceeded" and retry:  # failed and not 403 rate limit exceeded
+        r = requests.get(url)  # try again
+    if r.status_code != 200:
+        LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}")
+        return "", []
+    data = r.json()
+    return data["tag_name"], [x["name"] for x in data["assets"]]  # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
+
+
+def attempt_download_asset(file, repo="ultralytics/assets", release="v8.3.0", **kwargs):
+    """
+    Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file
+    locally first, then tries to download it from the specified GitHub repository release.
+
+    Args:
+        file (str | Path): The filename or file path to be downloaded.
+        repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
+        release (str, optional): The specific release version to be downloaded. Defaults to 'v8.3.0'.
+        **kwargs (any): Additional keyword arguments for the download process.
+
+    Returns:
+        (str): The path to the downloaded file.
+
+    Example:
+        ```python
+        file_path = attempt_download_asset("yolo11n.pt", repo="ultralytics/assets", release="latest")
+        ```
+    """
+    from ultralytics.utils import SETTINGS  # scoped for circular import
+
+    if 'v13' in str(file):
+        repo = "iMoonLab/yolov13"
+        release = "yolov13"
+
+    # YOLOv3/5u updates
+    file = str(file)
+    file = checks.check_yolov5u_filename(file)
+    file = Path(file.strip().replace("'", ""))
+    if file.exists():
+        return str(file)
+    elif (SETTINGS["weights_dir"] / file).exists():
+        return str(SETTINGS["weights_dir"] / file)
+    else:
+        # URL specified
+        name = Path(parse.unquote(str(file))).name  # decode '%2F' to '/' etc.
+        download_url = f"https://github.com/{repo}/releases/download"
+        if str(file).startswith(("http:/", "https:/")):  # download
+            url = str(file).replace(":/", "://")  # Pathlib turns :// -> :/
+            file = url2file(name)  # parse authentication https://url.com/file.txt?auth...
+            if Path(file).is_file():
+                LOGGER.info(f"Found {clean_url(url)} locally at {file}")  # file already exists
+            else:
+                safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
+
+        elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
+            safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
+
+        else:
+            tag, assets = get_github_assets(repo, release)
+            if not assets:
+                tag, assets = get_github_assets(repo)  # latest release
+            if name in assets:
+                safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
+
+        return str(file)
+
+
+def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False):
+    """
+    Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are
+    specified.
+
+    Args:
+        url (str | list): The URL or list of URLs of the files to be downloaded.
+        dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory.
+        unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True.
+        delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False.
+        curl (bool, optional): Flag to use curl for downloading. Defaults to False.
+        threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1.
+        retry (int, optional): Number of retries in case of download failure. Defaults to 3.
+        exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
+
+    Example:
+        ```python
+        download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True)
+        ```
+    """
+    dir = Path(dir)
+    dir.mkdir(parents=True, exist_ok=True)  # make directory
+    if threads > 1:
+        with ThreadPool(threads) as pool:
+            pool.map(
+                lambda x: safe_download(
+                    url=x[0],
+                    dir=x[1],
+                    unzip=unzip,
+                    delete=delete,
+                    curl=curl,
+                    retry=retry,
+                    exist_ok=exist_ok,
+                    progress=threads <= 1,
+                ),
+                zip(url, repeat(dir)),
+            )
+            pool.close()
+            pool.join()
+    else:
+        for u in [url] if isinstance(url, (str, Path)) else url:
+            safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)

+ 22 - 0
ultralytics/utils/errors.py

@@ -0,0 +1,22 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.utils import emojis
+
+
+class HUBModelError(Exception):
+    """
+    Custom exception class for handling errors related to model fetching in Ultralytics YOLO.
+
+    This exception is raised when a requested model is not found or cannot be retrieved.
+    The message is also processed to include emojis for better user experience.
+
+    Attributes:
+        message (str): The error message displayed when the exception is raised.
+
+    Note:
+        The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package.
+    """
+
+    def __init__(self, message="Model not found. Please check model URL and try again."):
+        """Create an exception for when a model is not found."""
+        super().__init__(emojis(message))

+ 222 - 0
ultralytics/utils/files.py

@@ -0,0 +1,222 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import contextlib
+import glob
+import os
+import shutil
+import tempfile
+from contextlib import contextmanager
+from datetime import datetime
+from pathlib import Path
+
+
+class WorkingDirectory(contextlib.ContextDecorator):
+    """
+    A context manager and decorator for temporarily changing the working directory.
+
+    This class allows for the temporary change of the working directory using a context manager or decorator.
+    It ensures that the original working directory is restored after the context or decorated function completes.
+
+    Attributes:
+        dir (Path): The new directory to switch to.
+        cwd (Path): The original current working directory before the switch.
+
+    Methods:
+        __enter__: Changes the current directory to the specified directory.
+        __exit__: Restores the original working directory on context exit.
+
+    Examples:
+        Using as a context manager:
+        >>> with WorkingDirectory('/path/to/new/dir'):
+        >>> # Perform operations in the new directory
+        >>>     pass
+
+        Using as a decorator:
+        >>> @WorkingDirectory('/path/to/new/dir')
+        >>> def some_function():
+        >>> # Perform operations in the new directory
+        >>>     pass
+    """
+
+    def __init__(self, new_dir):
+        """Sets the working directory to 'new_dir' upon instantiation for use with context managers or decorators."""
+        self.dir = new_dir  # new dir
+        self.cwd = Path.cwd().resolve()  # current dir
+
+    def __enter__(self):
+        """Changes the current working directory to the specified directory upon entering the context."""
+        os.chdir(self.dir)
+
+    def __exit__(self, exc_type, exc_val, exc_tb):  # noqa
+        """Restores the original working directory when exiting the context."""
+        os.chdir(self.cwd)
+
+
+@contextmanager
+def spaces_in_path(path):
+    """
+    Context manager to handle paths with spaces in their names. If a path contains spaces, it replaces them with
+    underscores, copies the file/directory to the new path, executes the context code block, then copies the
+    file/directory back to its original location.
+
+    Args:
+        path (str | Path): The original path that may contain spaces.
+
+    Yields:
+        (Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path.
+
+    Examples:
+        Use the context manager to handle paths with spaces:
+        >>> from ultralytics.utils.files import spaces_in_path
+        >>> with spaces_in_path('/path/with spaces') as new_path:
+        >>> # Your code here
+    """
+    # If path has spaces, replace them with underscores
+    if " " in str(path):
+        string = isinstance(path, str)  # input type
+        path = Path(path)
+
+        # Create a temporary directory and construct the new path
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")
+
+            # Copy file/directory
+            if path.is_dir():
+                # tmp_path.mkdir(parents=True, exist_ok=True)
+                shutil.copytree(path, tmp_path)
+            elif path.is_file():
+                tmp_path.parent.mkdir(parents=True, exist_ok=True)
+                shutil.copy2(path, tmp_path)
+
+            try:
+                # Yield the temporary path
+                yield str(tmp_path) if string else tmp_path
+
+            finally:
+                # Copy file/directory back
+                if tmp_path.is_dir():
+                    shutil.copytree(tmp_path, path, dirs_exist_ok=True)
+                elif tmp_path.is_file():
+                    shutil.copy2(tmp_path, path)  # Copy back the file
+
+    else:
+        # If there are no spaces, just yield the original path
+        yield path
+
+
+def increment_path(path, exist_ok=False, sep="", mkdir=False):
+    """
+    Increments a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
+
+    If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to
+    the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
+    number will be appended directly to the end of the path. If `mkdir` is set to True, the path will be created as a
+    directory if it does not already exist.
+
+    Args:
+        path (str | pathlib.Path): Path to increment.
+        exist_ok (bool): If True, the path will not be incremented and returned as-is.
+        sep (str): Separator to use between the path and the incrementation number.
+        mkdir (bool): Create a directory if it does not exist.
+
+    Returns:
+        (pathlib.Path): Incremented path.
+
+    Examples:
+        Increment a directory path:
+        >>> from pathlib import Path
+        >>> path = Path("runs/exp")
+        >>> new_path = increment_path(path)
+        >>> print(new_path)
+        runs/exp2
+
+        Increment a file path:
+        >>> path = Path("runs/exp/results.txt")
+        >>> new_path = increment_path(path)
+        >>> print(new_path)
+        runs/exp/results2.txt
+    """
+    path = Path(path)  # os-agnostic
+    if path.exists() and not exist_ok:
+        path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
+
+        # Method 1
+        for n in range(2, 9999):
+            p = f"{path}{sep}{n}{suffix}"  # increment path
+            if not os.path.exists(p):
+                break
+        path = Path(p)
+
+    if mkdir:
+        path.mkdir(parents=True, exist_ok=True)  # make directory
+
+    return path
+
+
+def file_age(path=__file__):
+    """Return days since the last modification of the specified file."""
+    dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)  # delta
+    return dt.days  # + dt.seconds / 86400  # fractional days
+
+
+def file_date(path=__file__):
+    """Returns the file modification date in 'YYYY-M-D' format."""
+    t = datetime.fromtimestamp(Path(path).stat().st_mtime)
+    return f"{t.year}-{t.month}-{t.day}"
+
+
+def file_size(path):
+    """Returns the size of a file or directory in megabytes (MB)."""
+    if isinstance(path, (str, Path)):
+        mb = 1 << 20  # bytes to MiB (1024 ** 2)
+        path = Path(path)
+        if path.is_file():
+            return path.stat().st_size / mb
+        elif path.is_dir():
+            return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
+    return 0.0
+
+
+def get_latest_run(search_dir="."):
+    """Returns the path to the most recent 'last.pt' file in the specified directory for resuming training."""
+    last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
+    return max(last_list, key=os.path.getctime) if last_list else ""
+
+
+def update_models(model_names=("yolo11n.pt",), source_dir=Path("."), update_names=False):
+    """
+    Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
+
+    Args:
+        model_names (Tuple[str, ...]): Model filenames to update.
+        source_dir (Path): Directory containing models and target subdirectory.
+        update_names (bool): Update model names from a data YAML.
+
+    Examples:
+        Update specified YOLO models and save them in 'updated_models' subdirectory:
+        >>> from ultralytics.utils.files import update_models
+        >>> model_names = ("yolo11n.pt", "yolov8s.pt")
+        >>> update_models(model_names, source_dir=Path("/models"), update_names=True)
+    """
+    from ultralytics import YOLO
+    from ultralytics.nn.autobackend import default_class_names
+
+    target_dir = source_dir / "updated_models"
+    target_dir.mkdir(parents=True, exist_ok=True)  # Ensure target directory exists
+
+    for model_name in model_names:
+        model_path = source_dir / model_name
+        print(f"Loading model from {model_path}")
+
+        # Load model
+        model = YOLO(model_path)
+        model.half()
+        if update_names:  # update model names from a dataset YAML
+            model.model.names = default_class_names("coco8.yaml")
+
+        # Define new save path
+        save_path = target_dir / model_name
+
+        # Save model using model.save()
+        print(f"Re-saving {model_name} model to {save_path}")
+        model.save(save_path)

+ 429 - 0
ultralytics/utils/instance.py

@@ -0,0 +1,429 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from collections import abc
+from itertools import repeat
+from numbers import Number
+from typing import List
+
+import numpy as np
+
+from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
+
+
+def _ntuple(n):
+    """From PyTorch internals."""
+
+    def parse(x):
+        """Parse bounding boxes format between XYWH and LTWH."""
+        return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
+
+    return parse
+
+
+to_2tuple = _ntuple(2)
+to_4tuple = _ntuple(4)
+
+# `xyxy` means left top and right bottom
+# `xywh` means center x, center y and width, height(YOLO format)
+# `ltwh` means left top and width, height(COCO format)
+_formats = ["xyxy", "xywh", "ltwh"]
+
+__all__ = ("Bboxes", "Instances")  # tuple or list
+
+
+class Bboxes:
+    """
+    A class for handling bounding boxes.
+
+    The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh'.
+    Bounding box data should be provided in numpy arrays.
+
+    Attributes:
+        bboxes (numpy.ndarray): The bounding boxes stored in a 2D numpy array.
+        format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').
+
+    Note:
+        This class does not handle normalization or denormalization of bounding boxes.
+    """
+
+    def __init__(self, bboxes, format="xyxy") -> None:
+        """Initializes the Bboxes class with bounding box data in a specified format."""
+        assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
+        bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
+        assert bboxes.ndim == 2
+        assert bboxes.shape[1] == 4
+        self.bboxes = bboxes
+        self.format = format
+        # self.normalized = normalized
+
+    def convert(self, format):
+        """Converts bounding box format from one type to another."""
+        assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
+        if self.format == format:
+            return
+        elif self.format == "xyxy":
+            func = xyxy2xywh if format == "xywh" else xyxy2ltwh
+        elif self.format == "xywh":
+            func = xywh2xyxy if format == "xyxy" else xywh2ltwh
+        else:
+            func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
+        self.bboxes = func(self.bboxes)
+        self.format = format
+
+    def areas(self):
+        """Return box areas."""
+        return (
+            (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])  # format xyxy
+            if self.format == "xyxy"
+            else self.bboxes[:, 3] * self.bboxes[:, 2]  # format xywh or ltwh
+        )
+
+    # def denormalize(self, w, h):
+    #    if not self.normalized:
+    #         return
+    #     assert (self.bboxes <= 1.0).all()
+    #     self.bboxes[:, 0::2] *= w
+    #     self.bboxes[:, 1::2] *= h
+    #     self.normalized = False
+    #
+    # def normalize(self, w, h):
+    #     if self.normalized:
+    #         return
+    #     assert (self.bboxes > 1.0).any()
+    #     self.bboxes[:, 0::2] /= w
+    #     self.bboxes[:, 1::2] /= h
+    #     self.normalized = True
+
+    def mul(self, scale):
+        """
+        Multiply bounding box coordinates by scale factor(s).
+
+        Args:
+            scale (int | tuple | list): Scale factor(s) for four coordinates.
+                If int, the same scale is applied to all coordinates.
+        """
+        if isinstance(scale, Number):
+            scale = to_4tuple(scale)
+        assert isinstance(scale, (tuple, list))
+        assert len(scale) == 4
+        self.bboxes[:, 0] *= scale[0]
+        self.bboxes[:, 1] *= scale[1]
+        self.bboxes[:, 2] *= scale[2]
+        self.bboxes[:, 3] *= scale[3]
+
+    def add(self, offset):
+        """
+        Add offset to bounding box coordinates.
+
+        Args:
+            offset (int | tuple | list): Offset(s) for four coordinates.
+                If int, the same offset is applied to all coordinates.
+        """
+        if isinstance(offset, Number):
+            offset = to_4tuple(offset)
+        assert isinstance(offset, (tuple, list))
+        assert len(offset) == 4
+        self.bboxes[:, 0] += offset[0]
+        self.bboxes[:, 1] += offset[1]
+        self.bboxes[:, 2] += offset[2]
+        self.bboxes[:, 3] += offset[3]
+
+    def __len__(self):
+        """Return the number of boxes."""
+        return len(self.bboxes)
+
+    @classmethod
+    def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
+        """
+        Concatenate a list of Bboxes objects into a single Bboxes object.
+
+        Args:
+            boxes_list (List[Bboxes]): A list of Bboxes objects to concatenate.
+            axis (int, optional): The axis along which to concatenate the bounding boxes.
+                                   Defaults to 0.
+
+        Returns:
+            Bboxes: A new Bboxes object containing the concatenated bounding boxes.
+
+        Note:
+            The input should be a list or tuple of Bboxes objects.
+        """
+        assert isinstance(boxes_list, (list, tuple))
+        if not boxes_list:
+            return cls(np.empty(0))
+        assert all(isinstance(box, Bboxes) for box in boxes_list)
+
+        if len(boxes_list) == 1:
+            return boxes_list[0]
+        return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
+
+    def __getitem__(self, index) -> "Bboxes":
+        """
+        Retrieve a specific bounding box or a set of bounding boxes using indexing.
+
+        Args:
+            index (int, slice, or np.ndarray): The index, slice, or boolean array to select
+                                               the desired bounding boxes.
+
+        Returns:
+            Bboxes: A new Bboxes object containing the selected bounding boxes.
+
+        Raises:
+            AssertionError: If the indexed bounding boxes do not form a 2-dimensional matrix.
+
+        Note:
+            When using boolean indexing, make sure to provide a boolean array with the same
+            length as the number of bounding boxes.
+        """
+        if isinstance(index, int):
+            return Bboxes(self.bboxes[index].reshape(1, -1))
+        b = self.bboxes[index]
+        assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
+        return Bboxes(b)
+
+
+class Instances:
+    """
+    Container for bounding boxes, segments, and keypoints of detected objects in an image.
+
+    Attributes:
+        _bboxes (Bboxes): Internal object for handling bounding box operations.
+        keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3]. Default is None.
+        normalized (bool): Flag indicating whether the bounding box coordinates are normalized.
+        segments (ndarray): Segments array with shape [N, 1000, 2] after resampling.
+
+    Args:
+        bboxes (ndarray): An array of bounding boxes with shape [N, 4].
+        segments (list | ndarray, optional): A list or array of object segments. Default is None.
+        keypoints (ndarray, optional): An array of keypoints with shape [N, 17, 3]. Default is None.
+        bbox_format (str, optional): The format of bounding boxes ('xywh' or 'xyxy'). Default is 'xywh'.
+        normalized (bool, optional): Whether the bounding box coordinates are normalized. Default is True.
+
+    Examples:
+        ```python
+        # Create an Instances object
+        instances = Instances(
+            bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),
+            segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],
+            keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]),
+        )
+        ```
+
+    Note:
+        The bounding box format is either 'xywh' or 'xyxy', and is determined by the `bbox_format` argument.
+        This class does not perform input validation, and it assumes the inputs are well-formed.
+    """
+
+    def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
+        """
+        Initialize the object with bounding boxes, segments, and keypoints.
+
+        Args:
+            bboxes (np.ndarray): Bounding boxes, shape [N, 4].
+            segments (list | np.ndarray, optional): Segmentation masks. Defaults to None.
+            keypoints (np.ndarray, optional): Keypoints, shape [N, 17, 3] and format (x, y, visible). Defaults to None.
+            bbox_format (str, optional): Format of bboxes. Defaults to "xywh".
+            normalized (bool, optional): Whether the coordinates are normalized. Defaults to True.
+        """
+        self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
+        self.keypoints = keypoints
+        self.normalized = normalized
+        self.segments = segments
+
+    def convert_bbox(self, format):
+        """Convert bounding box format."""
+        self._bboxes.convert(format=format)
+
+    @property
+    def bbox_areas(self):
+        """Calculate the area of bounding boxes."""
+        return self._bboxes.areas()
+
+    def scale(self, scale_w, scale_h, bbox_only=False):
+        """Similar to denormalize func but without normalized sign."""
+        self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
+        if bbox_only:
+            return
+        self.segments[..., 0] *= scale_w
+        self.segments[..., 1] *= scale_h
+        if self.keypoints is not None:
+            self.keypoints[..., 0] *= scale_w
+            self.keypoints[..., 1] *= scale_h
+
+    def denormalize(self, w, h):
+        """Denormalizes boxes, segments, and keypoints from normalized coordinates."""
+        if not self.normalized:
+            return
+        self._bboxes.mul(scale=(w, h, w, h))
+        self.segments[..., 0] *= w
+        self.segments[..., 1] *= h
+        if self.keypoints is not None:
+            self.keypoints[..., 0] *= w
+            self.keypoints[..., 1] *= h
+        self.normalized = False
+
+    def normalize(self, w, h):
+        """Normalize bounding boxes, segments, and keypoints to image dimensions."""
+        if self.normalized:
+            return
+        self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
+        self.segments[..., 0] /= w
+        self.segments[..., 1] /= h
+        if self.keypoints is not None:
+            self.keypoints[..., 0] /= w
+            self.keypoints[..., 1] /= h
+        self.normalized = True
+
+    def add_padding(self, padw, padh):
+        """Handle rect and mosaic situation."""
+        assert not self.normalized, "you should add padding with absolute coordinates."
+        self._bboxes.add(offset=(padw, padh, padw, padh))
+        self.segments[..., 0] += padw
+        self.segments[..., 1] += padh
+        if self.keypoints is not None:
+            self.keypoints[..., 0] += padw
+            self.keypoints[..., 1] += padh
+
+    def __getitem__(self, index) -> "Instances":
+        """
+        Retrieve a specific instance or a set of instances using indexing.
+
+        Args:
+            index (int, slice, or np.ndarray): The index, slice, or boolean array to select
+                                               the desired instances.
+
+        Returns:
+            Instances: A new Instances object containing the selected bounding boxes,
+                       segments, and keypoints if present.
+
+        Note:
+            When using boolean indexing, make sure to provide a boolean array with the same
+            length as the number of instances.
+        """
+        segments = self.segments[index] if len(self.segments) else self.segments
+        keypoints = self.keypoints[index] if self.keypoints is not None else None
+        bboxes = self.bboxes[index]
+        bbox_format = self._bboxes.format
+        return Instances(
+            bboxes=bboxes,
+            segments=segments,
+            keypoints=keypoints,
+            bbox_format=bbox_format,
+            normalized=self.normalized,
+        )
+
+    def flipud(self, h):
+        """Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
+        if self._bboxes.format == "xyxy":
+            y1 = self.bboxes[:, 1].copy()
+            y2 = self.bboxes[:, 3].copy()
+            self.bboxes[:, 1] = h - y2
+            self.bboxes[:, 3] = h - y1
+        else:
+            self.bboxes[:, 1] = h - self.bboxes[:, 1]
+        self.segments[..., 1] = h - self.segments[..., 1]
+        if self.keypoints is not None:
+            self.keypoints[..., 1] = h - self.keypoints[..., 1]
+
+    def fliplr(self, w):
+        """Reverses the order of the bounding boxes and segments horizontally."""
+        if self._bboxes.format == "xyxy":
+            x1 = self.bboxes[:, 0].copy()
+            x2 = self.bboxes[:, 2].copy()
+            self.bboxes[:, 0] = w - x2
+            self.bboxes[:, 2] = w - x1
+        else:
+            self.bboxes[:, 0] = w - self.bboxes[:, 0]
+        self.segments[..., 0] = w - self.segments[..., 0]
+        if self.keypoints is not None:
+            self.keypoints[..., 0] = w - self.keypoints[..., 0]
+
+    def clip(self, w, h):
+        """Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
+        ori_format = self._bboxes.format
+        self.convert_bbox(format="xyxy")
+        self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
+        self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
+        if ori_format != "xyxy":
+            self.convert_bbox(format=ori_format)
+        self.segments[..., 0] = self.segments[..., 0].clip(0, w)
+        self.segments[..., 1] = self.segments[..., 1].clip(0, h)
+        if self.keypoints is not None:
+            self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
+            self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
+
+    def remove_zero_area_boxes(self):
+        """Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height."""
+        good = self.bbox_areas > 0
+        if not all(good):
+            self._bboxes = self._bboxes[good]
+            if len(self.segments):
+                self.segments = self.segments[good]
+            if self.keypoints is not None:
+                self.keypoints = self.keypoints[good]
+        return good
+
+    def update(self, bboxes, segments=None, keypoints=None):
+        """Updates instance variables."""
+        self._bboxes = Bboxes(bboxes, format=self._bboxes.format)
+        if segments is not None:
+            self.segments = segments
+        if keypoints is not None:
+            self.keypoints = keypoints
+
+    def __len__(self):
+        """Return the length of the instance list."""
+        return len(self.bboxes)
+
+    @classmethod
+    def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
+        """
+        Concatenates a list of Instances objects into a single Instances object.
+
+        Args:
+            instances_list (List[Instances]): A list of Instances objects to concatenate.
+            axis (int, optional): The axis along which the arrays will be concatenated. Defaults to 0.
+
+        Returns:
+            Instances: A new Instances object containing the concatenated bounding boxes,
+                       segments, and keypoints if present.
+
+        Note:
+            The `Instances` objects in the list should have the same properties, such as
+            the format of the bounding boxes, whether keypoints are present, and if the
+            coordinates are normalized.
+        """
+        assert isinstance(instances_list, (list, tuple))
+        if not instances_list:
+            return cls(np.empty(0))
+        assert all(isinstance(instance, Instances) for instance in instances_list)
+
+        if len(instances_list) == 1:
+            return instances_list[0]
+
+        use_keypoint = instances_list[0].keypoints is not None
+        bbox_format = instances_list[0]._bboxes.format
+        normalized = instances_list[0].normalized
+
+        cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
+        seg_len = [b.segments.shape[1] for b in instances_list]
+        if len(set(seg_len)) > 1:  # resample segments if there's different length
+            max_len = max(seg_len)
+            cat_segments = np.concatenate(
+                [
+                    resample_segments(list(b.segments), max_len)
+                    if len(b.segments)
+                    else np.zeros((0, max_len, 2), dtype=np.float32)  # re-generating empty segments
+                    for b in instances_list
+                ],
+                axis=axis,
+            )
+        else:
+            cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
+        cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
+        return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
+
+    @property
+    def bboxes(self):
+        """Return bounding boxes."""
+        return self._bboxes.bboxes

+ 743 - 0
ultralytics/utils/loss.py

@@ -0,0 +1,743 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ultralytics.utils.metrics import OKS_SIGMA
+from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
+from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
+from ultralytics.utils.torch_utils import autocast
+
+from .metrics import bbox_iou, probiou
+from .tal import bbox2dist
+
+
+class VarifocalLoss(nn.Module):
+    """
+    Varifocal loss by Zhang et al.
+
+    https://arxiv.org/abs/2008.13367.
+    """
+
+    def __init__(self):
+        """Initialize the VarifocalLoss class."""
+        super().__init__()
+
+    @staticmethod
+    def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
+        """Computes varfocal loss."""
+        weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
+        with autocast(enabled=False):
+            loss = (
+                (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
+                .mean(1)
+                .sum()
+            )
+        return loss
+
+
+class FocalLoss(nn.Module):
+    """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
+
+    def __init__(self):
+        """Initializer for FocalLoss class with no parameters."""
+        super().__init__()
+
+    @staticmethod
+    def forward(pred, label, gamma=1.5, alpha=0.25):
+        """Calculates and updates confusion matrix for object detection/classification tasks."""
+        loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
+        # p_t = torch.exp(-loss)
+        # loss *= self.alpha * (1.000001 - p_t) ** self.gamma  # non-zero power for gradient stability
+
+        # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
+        pred_prob = pred.sigmoid()  # prob from logits
+        p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
+        modulating_factor = (1.0 - p_t) ** gamma
+        loss *= modulating_factor
+        if alpha > 0:
+            alpha_factor = label * alpha + (1 - label) * (1 - alpha)
+            loss *= alpha_factor
+        return loss.mean(1).sum()
+
+
+class DFLoss(nn.Module):
+    """Criterion class for computing DFL losses during training."""
+
+    def __init__(self, reg_max=16) -> None:
+        """Initialize the DFL module."""
+        super().__init__()
+        self.reg_max = reg_max
+
+    def __call__(self, pred_dist, target):
+        """
+        Return sum of left and right DFL losses.
+
+        Distribution Focal Loss (DFL) proposed in Generalized Focal Loss
+        https://ieeexplore.ieee.org/document/9792391
+        """
+        target = target.clamp_(0, self.reg_max - 1 - 0.01)
+        tl = target.long()  # target left
+        tr = tl + 1  # target right
+        wl = tr - target  # weight left
+        wr = 1 - wl  # weight right
+        return (
+            F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
+            + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
+        ).mean(-1, keepdim=True)
+
+
+class BboxLoss(nn.Module):
+    """Criterion class for computing training losses during training."""
+
+    def __init__(self, reg_max=16):
+        """Initialize the BboxLoss module with regularization maximum and DFL settings."""
+        super().__init__()
+        self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
+
+    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
+        """IoU loss."""
+        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
+        iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
+        loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
+
+        # DFL loss
+        if self.dfl_loss:
+            target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
+            loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
+            loss_dfl = loss_dfl.sum() / target_scores_sum
+        else:
+            loss_dfl = torch.tensor(0.0).to(pred_dist.device)
+
+        return loss_iou, loss_dfl
+
+
+class RotatedBboxLoss(BboxLoss):
+    """Criterion class for computing training losses during training."""
+
+    def __init__(self, reg_max):
+        """Initialize the BboxLoss module with regularization maximum and DFL settings."""
+        super().__init__(reg_max)
+
+    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
+        """IoU loss."""
+        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
+        iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
+        loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
+
+        # DFL loss
+        if self.dfl_loss:
+            target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
+            loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
+            loss_dfl = loss_dfl.sum() / target_scores_sum
+        else:
+            loss_dfl = torch.tensor(0.0).to(pred_dist.device)
+
+        return loss_iou, loss_dfl
+
+
+class KeypointLoss(nn.Module):
+    """Criterion class for computing training losses."""
+
+    def __init__(self, sigmas) -> None:
+        """Initialize the KeypointLoss class."""
+        super().__init__()
+        self.sigmas = sigmas
+
+    def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
+        """Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
+        d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
+        kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
+        # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9)  # from formula
+        e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2)  # from cocoeval
+        return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
+
+
+class v8DetectionLoss:
+    """Criterion class for computing training losses."""
+
+    def __init__(self, model, tal_topk=10):  # model must be de-paralleled
+        """Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
+        device = next(model.parameters()).device  # get model device
+        h = model.args  # hyperparameters
+
+        m = model.model[-1]  # Detect() module
+        self.bce = nn.BCEWithLogitsLoss(reduction="none")
+        self.hyp = h
+        self.stride = m.stride  # model strides
+        self.nc = m.nc  # number of classes
+        self.no = m.nc + m.reg_max * 4
+        self.reg_max = m.reg_max
+        self.device = device
+
+        self.use_dfl = m.reg_max > 1
+
+        self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
+        self.bbox_loss = BboxLoss(m.reg_max).to(device)
+        self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
+
+    def preprocess(self, targets, batch_size, scale_tensor):
+        """Preprocesses the target counts and matches with the input batch size to output a tensor."""
+        nl, ne = targets.shape
+        if nl == 0:
+            out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
+        else:
+            i = targets[:, 0]  # image index
+            _, counts = i.unique(return_counts=True)
+            counts = counts.to(dtype=torch.int32)
+            out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
+            for j in range(batch_size):
+                matches = i == j
+                if n := matches.sum():
+                    out[j, :n] = targets[matches, 1:]
+            out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
+        return out
+
+    def bbox_decode(self, anchor_points, pred_dist):
+        """Decode predicted object bounding box coordinates from anchor points and distribution."""
+        if self.use_dfl:
+            b, a, c = pred_dist.shape  # batch, anchors, channels
+            pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
+            # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
+            # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
+        return dist2bbox(pred_dist, anchor_points, xywh=False)
+
+    def __call__(self, preds, batch):
+        """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
+        loss = torch.zeros(3, device=self.device)  # box, cls, dfl
+        feats = preds[1] if isinstance(preds, tuple) else preds
+        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
+            (self.reg_max * 4, self.nc), 1
+        )
+
+        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
+        pred_distri = pred_distri.permute(0, 2, 1).contiguous()
+
+        dtype = pred_scores.dtype
+        batch_size = pred_scores.shape[0]
+        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
+        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
+
+        # Targets
+        targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
+        targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
+        gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
+        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
+
+        # Pboxes
+        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
+        # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
+        # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
+
+        _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
+            # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
+            pred_scores.detach().sigmoid(),
+            (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
+            anchor_points * stride_tensor,
+            gt_labels,
+            gt_bboxes,
+            mask_gt,
+        )
+
+        target_scores_sum = max(target_scores.sum(), 1)
+
+        # Cls loss
+        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
+        loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE
+
+        # Bbox loss
+        if fg_mask.sum():
+            target_bboxes /= stride_tensor
+            loss[0], loss[2] = self.bbox_loss(
+                pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
+            )
+
+        loss[0] *= self.hyp.box  # box gain
+        loss[1] *= self.hyp.cls  # cls gain
+        loss[2] *= self.hyp.dfl  # dfl gain
+
+        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)
+
+
+class v8SegmentationLoss(v8DetectionLoss):
+    """Criterion class for computing training losses."""
+
+    def __init__(self, model):  # model must be de-paralleled
+        """Initializes the v8SegmentationLoss class, taking a de-paralleled model as argument."""
+        super().__init__(model)
+        self.overlap = model.args.overlap_mask
+
+    def __call__(self, preds, batch):
+        """Calculate and return the loss for the YOLO model."""
+        loss = torch.zeros(4, device=self.device)  # box, cls, dfl
+        feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
+        batch_size, _, mask_h, mask_w = proto.shape  # batch size, number of masks, mask height, mask width
+        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
+            (self.reg_max * 4, self.nc), 1
+        )
+
+        # B, grids, ..
+        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
+        pred_distri = pred_distri.permute(0, 2, 1).contiguous()
+        pred_masks = pred_masks.permute(0, 2, 1).contiguous()
+
+        dtype = pred_scores.dtype
+        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
+        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
+
+        # Targets
+        try:
+            batch_idx = batch["batch_idx"].view(-1, 1)
+            targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
+            targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
+            gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
+            mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
+        except RuntimeError as e:
+            raise TypeError(
+                "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
+                "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
+                "i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
+                "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
+                "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
+            ) from e
+
+        # Pboxes
+        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
+
+        _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
+            pred_scores.detach().sigmoid(),
+            (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
+            anchor_points * stride_tensor,
+            gt_labels,
+            gt_bboxes,
+            mask_gt,
+        )
+
+        target_scores_sum = max(target_scores.sum(), 1)
+
+        # Cls loss
+        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
+        loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE
+
+        if fg_mask.sum():
+            # Bbox loss
+            loss[0], loss[3] = self.bbox_loss(
+                pred_distri,
+                pred_bboxes,
+                anchor_points,
+                target_bboxes / stride_tensor,
+                target_scores,
+                target_scores_sum,
+                fg_mask,
+            )
+            # Masks loss
+            masks = batch["masks"].to(self.device).float()
+            if tuple(masks.shape[-2:]) != (mask_h, mask_w):  # downsample
+                masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
+
+            loss[1] = self.calculate_segmentation_loss(
+                fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
+            )
+
+        # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
+        else:
+            loss[1] += (proto * 0).sum() + (pred_masks * 0).sum()  # inf sums may lead to nan loss
+
+        loss[0] *= self.hyp.box  # box gain
+        loss[1] *= self.hyp.box  # seg gain
+        loss[2] *= self.hyp.cls  # cls gain
+        loss[3] *= self.hyp.dfl  # dfl gain
+
+        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)
+
+    @staticmethod
+    def single_mask_loss(
+        gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
+    ) -> torch.Tensor:
+        """
+        Compute the instance segmentation loss for a single image.
+
+        Args:
+            gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
+            pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
+            proto (torch.Tensor): Prototype masks of shape (32, H, W).
+            xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
+            area (torch.Tensor): Area of each ground truth bounding box of shape (n,).
+
+        Returns:
+            (torch.Tensor): The calculated mask loss for a single image.
+
+        Notes:
+            The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
+            predicted masks from the prototype masks and predicted mask coefficients.
+        """
+        pred_mask = torch.einsum("in,nhw->ihw", pred, proto)  # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
+        loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
+        return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
+
+    def calculate_segmentation_loss(
+        self,
+        fg_mask: torch.Tensor,
+        masks: torch.Tensor,
+        target_gt_idx: torch.Tensor,
+        target_bboxes: torch.Tensor,
+        batch_idx: torch.Tensor,
+        proto: torch.Tensor,
+        pred_masks: torch.Tensor,
+        imgsz: torch.Tensor,
+        overlap: bool,
+    ) -> torch.Tensor:
+        """
+        Calculate the loss for instance segmentation.
+
+        Args:
+            fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
+            masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
+            target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
+            target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
+            batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
+            proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
+            pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
+            imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
+            overlap (bool): Whether the masks in `masks` tensor overlap.
+
+        Returns:
+            (torch.Tensor): The calculated loss for instance segmentation.
+
+        Notes:
+            The batch loss can be computed for improved speed at higher memory usage.
+            For example, pred_mask can be computed as follows:
+                pred_mask = torch.einsum('in,nhw->ihw', pred, proto)  # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
+        """
+        _, _, mask_h, mask_w = proto.shape
+        loss = 0
+
+        # Normalize to 0-1
+        target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
+
+        # Areas of target bboxes
+        marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
+
+        # Normalize to mask size
+        mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
+
+        for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
+            fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
+            if fg_mask_i.any():
+                mask_idx = target_gt_idx_i[fg_mask_i]
+                if overlap:
+                    gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
+                    gt_mask = gt_mask.float()
+                else:
+                    gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
+
+                loss += self.single_mask_loss(
+                    gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
+                )
+
+            # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
+            else:
+                loss += (proto * 0).sum() + (pred_masks * 0).sum()  # inf sums may lead to nan loss
+
+        return loss / fg_mask.sum()
+
+
+class v8PoseLoss(v8DetectionLoss):
+    """Criterion class for computing training losses."""
+
+    def __init__(self, model):  # model must be de-paralleled
+        """Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance."""
+        super().__init__(model)
+        self.kpt_shape = model.model[-1].kpt_shape
+        self.bce_pose = nn.BCEWithLogitsLoss()
+        is_pose = self.kpt_shape == [17, 3]
+        nkpt = self.kpt_shape[0]  # number of keypoints
+        sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
+        self.keypoint_loss = KeypointLoss(sigmas=sigmas)
+
+    def __call__(self, preds, batch):
+        """Calculate the total loss and detach it."""
+        loss = torch.zeros(5, device=self.device)  # box, cls, dfl, kpt_location, kpt_visibility
+        feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
+        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
+            (self.reg_max * 4, self.nc), 1
+        )
+
+        # B, grids, ..
+        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
+        pred_distri = pred_distri.permute(0, 2, 1).contiguous()
+        pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
+
+        dtype = pred_scores.dtype
+        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
+        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
+
+        # Targets
+        batch_size = pred_scores.shape[0]
+        batch_idx = batch["batch_idx"].view(-1, 1)
+        targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
+        targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
+        gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
+        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
+
+        # Pboxes
+        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
+        pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape))  # (b, h*w, 17, 3)
+
+        _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
+            pred_scores.detach().sigmoid(),
+            (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
+            anchor_points * stride_tensor,
+            gt_labels,
+            gt_bboxes,
+            mask_gt,
+        )
+
+        target_scores_sum = max(target_scores.sum(), 1)
+
+        # Cls loss
+        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
+        loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE
+
+        # Bbox loss
+        if fg_mask.sum():
+            target_bboxes /= stride_tensor
+            loss[0], loss[4] = self.bbox_loss(
+                pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
+            )
+            keypoints = batch["keypoints"].to(self.device).float().clone()
+            keypoints[..., 0] *= imgsz[1]
+            keypoints[..., 1] *= imgsz[0]
+
+            loss[1], loss[2] = self.calculate_keypoints_loss(
+                fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
+            )
+
+        loss[0] *= self.hyp.box  # box gain
+        loss[1] *= self.hyp.pose  # pose gain
+        loss[2] *= self.hyp.kobj  # kobj gain
+        loss[3] *= self.hyp.cls  # cls gain
+        loss[4] *= self.hyp.dfl  # dfl gain
+
+        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)
+
+    @staticmethod
+    def kpts_decode(anchor_points, pred_kpts):
+        """Decodes predicted keypoints to image coordinates."""
+        y = pred_kpts.clone()
+        y[..., :2] *= 2.0
+        y[..., 0] += anchor_points[:, [0]] - 0.5
+        y[..., 1] += anchor_points[:, [1]] - 0.5
+        return y
+
+    def calculate_keypoints_loss(
+        self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
+    ):
+        """
+        Calculate the keypoints loss for the model.
+
+        This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
+        based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
+        a binary classification loss that classifies whether a keypoint is present or not.
+
+        Args:
+            masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
+            target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
+            keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
+            batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
+            stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
+            target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
+            pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
+
+        Returns:
+            kpts_loss (torch.Tensor): The keypoints loss.
+            kpts_obj_loss (torch.Tensor): The keypoints object loss.
+        """
+        batch_idx = batch_idx.flatten()
+        batch_size = len(masks)
+
+        # Find the maximum number of keypoints in a single image
+        max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
+
+        # Create a tensor to hold batched keypoints
+        batched_keypoints = torch.zeros(
+            (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
+        )
+
+        # TODO: any idea how to vectorize this?
+        # Fill batched_keypoints with keypoints based on batch_idx
+        for i in range(batch_size):
+            keypoints_i = keypoints[batch_idx == i]
+            batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
+
+        # Expand dimensions of target_gt_idx to match the shape of batched_keypoints
+        target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
+
+        # Use target_gt_idx_expanded to select keypoints from batched_keypoints
+        selected_keypoints = batched_keypoints.gather(
+            1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
+        )
+
+        # Divide coordinates by stride
+        selected_keypoints /= stride_tensor.view(1, -1, 1, 1)
+
+        kpts_loss = 0
+        kpts_obj_loss = 0
+
+        if masks.any():
+            gt_kpt = selected_keypoints[masks]
+            area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
+            pred_kpt = pred_kpts[masks]
+            kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
+            kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area)  # pose loss
+
+            if pred_kpt.shape[-1] == 3:
+                kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float())  # keypoint obj loss
+
+        return kpts_loss, kpts_obj_loss
+
+
+class v8ClassificationLoss:
+    """Criterion class for computing training losses."""
+
+    def __call__(self, preds, batch):
+        """Compute the classification loss between predictions and true labels."""
+        preds = preds[1] if isinstance(preds, (list, tuple)) else preds
+        loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
+        loss_items = loss.detach()
+        return loss, loss_items
+
+
+class v8OBBLoss(v8DetectionLoss):
+    """Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
+
+    def __init__(self, model):
+        """Initializes v8OBBLoss with model, assigner, and rotated bbox loss; note model must be de-paralleled."""
+        super().__init__(model)
+        self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
+        self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
+
+    def preprocess(self, targets, batch_size, scale_tensor):
+        """Preprocesses the target counts and matches with the input batch size to output a tensor."""
+        if targets.shape[0] == 0:
+            out = torch.zeros(batch_size, 0, 6, device=self.device)
+        else:
+            i = targets[:, 0]  # image index
+            _, counts = i.unique(return_counts=True)
+            counts = counts.to(dtype=torch.int32)
+            out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
+            for j in range(batch_size):
+                matches = i == j
+                if n := matches.sum():
+                    bboxes = targets[matches, 2:]
+                    bboxes[..., :4].mul_(scale_tensor)
+                    out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
+        return out
+
+    def __call__(self, preds, batch):
+        """Calculate and return the loss for the YOLO model."""
+        loss = torch.zeros(3, device=self.device)  # box, cls, dfl
+        feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
+        batch_size = pred_angle.shape[0]  # batch size, number of masks, mask height, mask width
+        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
+            (self.reg_max * 4, self.nc), 1
+        )
+
+        # b, grids, ..
+        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
+        pred_distri = pred_distri.permute(0, 2, 1).contiguous()
+        pred_angle = pred_angle.permute(0, 2, 1).contiguous()
+
+        dtype = pred_scores.dtype
+        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
+        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
+
+        # targets
+        try:
+            batch_idx = batch["batch_idx"].view(-1, 1)
+            targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
+            rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
+            targets = targets[(rw >= 2) & (rh >= 2)]  # filter rboxes of tiny size to stabilize training
+            targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
+            gt_labels, gt_bboxes = targets.split((1, 5), 2)  # cls, xywhr
+            mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
+        except RuntimeError as e:
+            raise TypeError(
+                "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
+                "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
+                "i.e. 'yolo train model=yolov8n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
+                "correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
+                "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
+            ) from e
+
+        # Pboxes
+        pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle)  # xyxy, (b, h*w, 4)
+
+        bboxes_for_assigner = pred_bboxes.clone().detach()
+        # Only the first four elements need to be scaled
+        bboxes_for_assigner[..., :4] *= stride_tensor
+        _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
+            pred_scores.detach().sigmoid(),
+            bboxes_for_assigner.type(gt_bboxes.dtype),
+            anchor_points * stride_tensor,
+            gt_labels,
+            gt_bboxes,
+            mask_gt,
+        )
+
+        target_scores_sum = max(target_scores.sum(), 1)
+
+        # Cls loss
+        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
+        loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE
+
+        # Bbox loss
+        if fg_mask.sum():
+            target_bboxes[..., :4] /= stride_tensor
+            loss[0], loss[2] = self.bbox_loss(
+                pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
+            )
+        else:
+            loss[0] += (pred_angle * 0).sum()
+
+        loss[0] *= self.hyp.box  # box gain
+        loss[1] *= self.hyp.cls  # cls gain
+        loss[2] *= self.hyp.dfl  # dfl gain
+
+        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)
+
+    def bbox_decode(self, anchor_points, pred_dist, pred_angle):
+        """
+        Decode predicted object bounding box coordinates from anchor points and distribution.
+
+        Args:
+            anchor_points (torch.Tensor): Anchor points, (h*w, 2).
+            pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
+            pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
+
+        Returns:
+            (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
+        """
+        if self.use_dfl:
+            b, a, c = pred_dist.shape  # batch, anchors, channels
+            pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
+        return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
+
+
+class E2EDetectLoss:
+    """Criterion class for computing training losses."""
+
+    def __init__(self, model):
+        """Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
+        self.one2many = v8DetectionLoss(model, tal_topk=10)
+        self.one2one = v8DetectionLoss(model, tal_topk=1)
+
+    def __call__(self, preds, batch):
+        """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
+        preds = preds[1] if isinstance(preds, tuple) else preds
+        one2many = preds["one2many"]
+        loss_one2many = self.one2many(one2many, batch)
+        one2one = preds["one2one"]
+        loss_one2one = self.one2one(one2one, batch)
+        return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]

+ 1308 - 0
ultralytics/utils/metrics.py

@@ -0,0 +1,1308 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""Model validation metrics."""
+
+import math
+import warnings
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from ultralytics.utils import LOGGER, SimpleClass, TryExcept, plt_settings
+
+OKS_SIGMA = (
+    np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
+    / 10.0
+)
+
+
+def bbox_ioa(box1, box2, iou=False, eps=1e-7):
+    """
+    Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
+
+    Args:
+        box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes.
+        box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes.
+        iou (bool): Calculate the standard IoU if True else return inter_area/box2_area.
+        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
+
+    Returns:
+        (np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
+    """
+    # Get the coordinates of bounding boxes
+    b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
+    b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
+
+    # Intersection area
+    inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * (
+        np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)
+    ).clip(0)
+
+    # Box2 area
+    area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
+    if iou:
+        box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
+        area = area + box1_area[:, None] - inter_area
+
+    # Intersection over box2 area
+    return inter_area / (area + eps)
+
+
+def box_iou(box1, box2, eps=1e-7):
+    """
+    Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+    Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py.
+
+    Args:
+        box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
+        box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
+        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
+
+    Returns:
+        (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
+    """
+    # NOTE: Need .float() to get accurate iou values
+    # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
+    (a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)
+    inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)
+
+    # IoU = inter / (area1 + area2 - inter)
+    return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
+
+
+def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
+    """
+    Calculates the Intersection over Union (IoU) between bounding boxes.
+
+    This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
+    For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).
+    Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,
+    or (x1, y1, x2, y2) if `xywh=False`.
+
+    Args:
+        box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
+        box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
+        xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
+                               (x1, y1, x2, y2) format. Defaults to True.
+        GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
+        DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
+        CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
+        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
+
+    Returns:
+        (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
+    """
+    # Get the coordinates of bounding boxes
+    if xywh:  # transform from xywh to xyxy
+        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
+        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
+        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
+        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
+    else:  # x1, y1, x2, y2 = box1
+        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
+        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
+        w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
+        w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
+
+    # Intersection area
+    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (
+        b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
+    ).clamp_(0)
+
+    # Union Area
+    union = w1 * h1 + w2 * h2 - inter + eps
+
+    # IoU
+    iou = inter / union
+    if CIoU or DIoU or GIoU:
+        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
+        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
+        if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
+            c2 = cw.pow(2) + ch.pow(2) + eps  # convex diagonal squared
+            rho2 = (
+                (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2)
+            ) / 4  # center dist**2
+            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
+                v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)
+                with torch.no_grad():
+                    alpha = v / (v - iou + (1 + eps))
+                return iou - (rho2 / c2 + v * alpha)  # CIoU
+            return iou - rho2 / c2  # DIoU
+        c_area = cw * ch + eps  # convex area
+        return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf
+    return iou  # IoU
+
+
+def mask_iou(mask1, mask2, eps=1e-7):
+    """
+    Calculate masks IoU.
+
+    Args:
+        mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
+                        product of image width and height.
+        mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
+                        product of image width and height.
+        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
+
+    Returns:
+        (torch.Tensor): A tensor of shape (N, M) representing masks IoU.
+    """
+    intersection = torch.matmul(mask1, mask2.T).clamp_(0)
+    union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection  # (area1 + area2) - intersection
+    return intersection / (union + eps)
+
+
+def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
+    """
+    Calculate Object Keypoint Similarity (OKS).
+
+    Args:
+        kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
+        kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.
+        area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.
+        sigma (list): A list containing 17 values representing keypoint scales.
+        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
+
+    Returns:
+        (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.
+    """
+    d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2)  # (N, M, 17)
+    sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype)  # (17, )
+    kpt_mask = kpt1[..., 2] != 0  # (N, 17)
+    e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2)  # from cocoeval
+    # e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2  # from formula
+    return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
+
+
+def _get_covariance_matrix(boxes):
+    """
+    Generating covariance matrix from obbs.
+
+    Args:
+        boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
+
+    Returns:
+        (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.
+    """
+    # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.
+    gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)
+    a, b, c = gbbs.split(1, dim=-1)
+    cos = c.cos()
+    sin = c.sin()
+    cos2 = cos.pow(2)
+    sin2 = sin.pow(2)
+    return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin
+
+
+def probiou(obb1, obb2, CIoU=False, eps=1e-7):
+    """
+    Calculate probabilistic IoU between oriented bounding boxes.
+
+    Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
+
+    Args:
+        obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
+        obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.
+        CIoU (bool, optional): If True, calculate CIoU. Defaults to False.
+        eps (float, optional): Small value to avoid division by zero. Defaults to 1e-7.
+
+    Returns:
+        (torch.Tensor): OBB similarities, shape (N,).
+
+    Note:
+        OBB format: [center_x, center_y, width, height, rotation_angle].
+        If CIoU is True, returns CIoU instead of IoU.
+    """
+    x1, y1 = obb1[..., :2].split(1, dim=-1)
+    x2, y2 = obb2[..., :2].split(1, dim=-1)
+    a1, b1, c1 = _get_covariance_matrix(obb1)
+    a2, b2, c2 = _get_covariance_matrix(obb2)
+
+    t1 = (
+        ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)
+    ) * 0.25
+    t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5
+    t3 = (
+        ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))
+        / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)
+        + eps
+    ).log() * 0.5
+    bd = (t1 + t2 + t3).clamp(eps, 100.0)
+    hd = (1.0 - (-bd).exp() + eps).sqrt()
+    iou = 1 - hd
+    if CIoU:  # only include the wh aspect ratio part
+        w1, h1 = obb1[..., 2:4].split(1, dim=-1)
+        w2, h2 = obb2[..., 2:4].split(1, dim=-1)
+        v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)
+        with torch.no_grad():
+            alpha = v / (v - iou + (1 + eps))
+        return iou - v * alpha  # CIoU
+    return iou
+
+
+def batch_probiou(obb1, obb2, eps=1e-7):
+    """
+    Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
+
+    Args:
+        obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
+        obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
+        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
+
+    Returns:
+        (torch.Tensor): A tensor of shape (N, M) representing obb similarities.
+    """
+    obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
+    obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2
+
+    x1, y1 = obb1[..., :2].split(1, dim=-1)
+    x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
+    a1, b1, c1 = _get_covariance_matrix(obb1)
+    a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))
+
+    t1 = (
+        ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)
+    ) * 0.25
+    t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5
+    t3 = (
+        ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))
+        / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)
+        + eps
+    ).log() * 0.5
+    bd = (t1 + t2 + t3).clamp(eps, 100.0)
+    hd = (1.0 - (-bd).exp() + eps).sqrt()
+    return 1 - hd
+
+
+def smooth_bce(eps=0.1):
+    """
+    Computes smoothed positive and negative Binary Cross-Entropy targets.
+
+    This function calculates positive and negative label smoothing BCE targets based on a given epsilon value.
+    For implementation details, refer to https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441.
+
+    Args:
+        eps (float, optional): The epsilon value for label smoothing. Defaults to 0.1.
+
+    Returns:
+        (tuple): A tuple containing the positive and negative label smoothing BCE targets.
+    """
+    return 1.0 - 0.5 * eps, 0.5 * eps
+
+
+class ConfusionMatrix:
+    """
+    A class for calculating and updating a confusion matrix for object detection and classification tasks.
+
+    Attributes:
+        task (str): The type of task, either 'detect' or 'classify'.
+        matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
+        nc (int): The number of classes.
+        conf (float): The confidence threshold for detections.
+        iou_thres (float): The Intersection over Union threshold.
+    """
+
+    def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"):
+        """Initialize attributes for the YOLO model."""
+        self.task = task
+        self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
+        self.nc = nc  # number of classes
+        self.conf = 0.25 if conf in {None, 0.001} else conf  # apply 0.25 if default val conf is passed
+        self.iou_thres = iou_thres
+
+    def process_cls_preds(self, preds, targets):
+        """
+        Update confusion matrix for classification task.
+
+        Args:
+            preds (Array[N, min(nc,5)]): Predicted class labels.
+            targets (Array[N, 1]): Ground truth class labels.
+        """
+        preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
+        for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
+            self.matrix[p][t] += 1
+
+    def process_batch(self, detections, gt_bboxes, gt_cls):
+        """
+        Update confusion matrix for object detection task.
+
+        Args:
+            detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
+                                      Each row should contain (x1, y1, x2, y2, conf, class)
+                                      or with an additional element `angle` when it's obb.
+            gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
+            gt_cls (Array[M]): The class labels.
+        """
+        if gt_cls.shape[0] == 0:  # Check if labels is empty
+            if detections is not None:
+                detections = detections[detections[:, 4] > self.conf]
+                detection_classes = detections[:, 5].int()
+                for dc in detection_classes:
+                    self.matrix[dc, self.nc] += 1  # false positives
+            return
+        if detections is None:
+            gt_classes = gt_cls.int()
+            for gc in gt_classes:
+                self.matrix[self.nc, gc] += 1  # background FN
+            return
+
+        detections = detections[detections[:, 4] > self.conf]
+        gt_classes = gt_cls.int()
+        detection_classes = detections[:, 5].int()
+        is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5  # with additional `angle` dimension
+        iou = (
+            batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
+            if is_obb
+            else box_iou(gt_bboxes, detections[:, :4])
+        )
+
+        x = torch.where(iou > self.iou_thres)
+        if x[0].shape[0]:
+            matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
+            if x[0].shape[0] > 1:
+                matches = matches[matches[:, 2].argsort()[::-1]]
+                matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
+                matches = matches[matches[:, 2].argsort()[::-1]]
+                matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
+        else:
+            matches = np.zeros((0, 3))
+
+        n = matches.shape[0] > 0
+        m0, m1, _ = matches.transpose().astype(int)
+        for i, gc in enumerate(gt_classes):
+            j = m0 == i
+            if n and sum(j) == 1:
+                self.matrix[detection_classes[m1[j]], gc] += 1  # correct
+            else:
+                self.matrix[self.nc, gc] += 1  # true background
+
+        for i, dc in enumerate(detection_classes):
+            if not any(m1 == i):
+                self.matrix[dc, self.nc] += 1  # predicted background
+
+    def matrix(self):
+        """Returns the confusion matrix."""
+        return self.matrix
+
+    def tp_fp(self):
+        """Returns true positives and false positives."""
+        tp = self.matrix.diagonal()  # true positives
+        fp = self.matrix.sum(1) - tp  # false positives
+        # fn = self.matrix.sum(0) - tp  # false negatives (missed detections)
+        return (tp[:-1], fp[:-1]) if self.task == "detect" else (tp, fp)  # remove background class if task=detect
+
+    @TryExcept("WARNING ⚠️ ConfusionMatrix plot failure")
+    @plt_settings()
+    def plot(self, normalize=True, save_dir="", names=(), on_plot=None):
+        """
+        Plot the confusion matrix using seaborn and save it to a file.
+
+        Args:
+            normalize (bool): Whether to normalize the confusion matrix.
+            save_dir (str): Directory where the plot will be saved.
+            names (tuple): Names of classes, used as labels on the plot.
+            on_plot (func): An optional callback to pass plots path and data when they are rendered.
+        """
+        import seaborn  # scope for faster 'import ultralytics'
+
+        array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1)  # normalize columns
+        array[array < 0.005] = np.nan  # don't annotate (would appear as 0.00)
+
+        fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
+        nc, nn = self.nc, len(names)  # number of classes, names
+        seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8)  # for label size
+        labels = (0 < nn < 99) and (nn == nc)  # apply names to ticklabels
+        ticklabels = (list(names) + ["background"]) if labels else "auto"
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")  # suppress empty matrix RuntimeWarning: All-NaN slice encountered
+            seaborn.heatmap(
+                array,
+                ax=ax,
+                annot=nc < 30,
+                annot_kws={"size": 8},
+                cmap="Blues",
+                fmt=".2f" if normalize else ".0f",
+                square=True,
+                vmin=0.0,
+                xticklabels=ticklabels,
+                yticklabels=ticklabels,
+            ).set_facecolor((1, 1, 1))
+        title = "Confusion Matrix" + " Normalized" * normalize
+        ax.set_xlabel("True")
+        ax.set_ylabel("Predicted")
+        ax.set_title(title)
+        plot_fname = Path(save_dir) / f"{title.lower().replace(' ', '_')}.png"
+        fig.savefig(plot_fname, dpi=250)
+        plt.close(fig)
+        if on_plot:
+            on_plot(plot_fname)
+
+    def print(self):
+        """Print the confusion matrix to the console."""
+        for i in range(self.nc + 1):
+            LOGGER.info(" ".join(map(str, self.matrix[i])))
+
+
+def smooth(y, f=0.05):
+    """Box filter of fraction f."""
+    nf = round(len(y) * f * 2) // 2 + 1  # number of filter elements (must be odd)
+    p = np.ones(nf // 2)  # ones padding
+    yp = np.concatenate((p * y[0], y, p * y[-1]), 0)  # y padded
+    return np.convolve(yp, np.ones(nf) / nf, mode="valid")  # y-smoothed
+
+
+@plt_settings()
+def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None):
+    """Plots a precision-recall curve."""
+    fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
+    py = np.stack(py, axis=1)
+
+    if 0 < len(names) < 21:  # display per-class legend if < 21 classes
+        for i, y in enumerate(py.T):
+            ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}")  # plot(recall, precision)
+    else:
+        ax.plot(px, py, linewidth=1, color="grey")  # plot(recall, precision)
+
+    ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
+    ax.set_xlabel("Recall")
+    ax.set_ylabel("Precision")
+    ax.set_xlim(0, 1)
+    ax.set_ylim(0, 1)
+    ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
+    ax.set_title("Precision-Recall Curve")
+    fig.savefig(save_dir, dpi=250)
+    plt.close(fig)
+    if on_plot:
+        on_plot(save_dir)
+
+
+@plt_settings()
+def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confidence", ylabel="Metric", on_plot=None):
+    """Plots a metric-confidence curve."""
+    fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
+
+    if 0 < len(names) < 21:  # display per-class legend if < 21 classes
+        for i, y in enumerate(py):
+            ax.plot(px, y, linewidth=1, label=f"{names[i]}")  # plot(confidence, metric)
+    else:
+        ax.plot(px, py.T, linewidth=1, color="grey")  # plot(confidence, metric)
+
+    y = smooth(py.mean(0), 0.05)
+    ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
+    ax.set_xlabel(xlabel)
+    ax.set_ylabel(ylabel)
+    ax.set_xlim(0, 1)
+    ax.set_ylim(0, 1)
+    ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
+    ax.set_title(f"{ylabel}-Confidence Curve")
+    fig.savefig(save_dir, dpi=250)
+    plt.close(fig)
+    if on_plot:
+        on_plot(save_dir)
+
+
+def compute_ap(recall, precision):
+    """
+    Compute the average precision (AP) given the recall and precision curves.
+
+    Args:
+        recall (list): The recall curve.
+        precision (list): The precision curve.
+
+    Returns:
+        (float): Average precision.
+        (np.ndarray): Precision envelope curve.
+        (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
+    """
+    # Append sentinel values to beginning and end
+    mrec = np.concatenate(([0.0], recall, [1.0]))
+    mpre = np.concatenate(([1.0], precision, [0.0]))
+
+    # Compute the precision envelope
+    mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
+
+    # Integrate area under curve
+    method = "interp"  # methods: 'continuous', 'interp'
+    if method == "interp":
+        x = np.linspace(0, 1, 101)  # 101-point interp (COCO)
+        ap = np.trapz(np.interp(x, mrec, mpre), x)  # integrate
+    else:  # 'continuous'
+        i = np.where(mrec[1:] != mrec[:-1])[0]  # points where x-axis (recall) changes
+        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])  # area under curve
+
+    return ap, mpre, mrec
+
+
+def ap_per_class(
+    tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names={}, eps=1e-16, prefix=""
+):
+    """
+    Computes the average precision per class for object detection evaluation.
+
+    Args:
+        tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
+        conf (np.ndarray): Array of confidence scores of the detections.
+        pred_cls (np.ndarray): Array of predicted classes of the detections.
+        target_cls (np.ndarray): Array of true classes of the detections.
+        plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
+        on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.
+        save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
+        names (dict, optional): Dict of class names to plot PR curves. Defaults to an empty tuple.
+        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
+        prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
+
+    Returns:
+        tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).
+        fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).
+        p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).
+        r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).
+        f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).
+        ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).
+        unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).
+        p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).
+        r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).
+        f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).
+        x (np.ndarray): X-axis values for the curves. Shape: (1000,).
+        prec_values (np.ndarray): Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
+    """
+    # Sort by objectness
+    i = np.argsort(-conf)
+    tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
+
+    # Find unique classes
+    unique_classes, nt = np.unique(target_cls, return_counts=True)
+    nc = unique_classes.shape[0]  # number of classes, number of detections
+
+    # Create Precision-Recall curve and compute AP for each class
+    x, prec_values = np.linspace(0, 1, 1000), []
+
+    # Average precision, precision and recall curves
+    ap, p_curve, r_curve = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
+    for ci, c in enumerate(unique_classes):
+        i = pred_cls == c
+        n_l = nt[ci]  # number of labels
+        n_p = i.sum()  # number of predictions
+        if n_p == 0 or n_l == 0:
+            continue
+
+        # Accumulate FPs and TPs
+        fpc = (1 - tp[i]).cumsum(0)
+        tpc = tp[i].cumsum(0)
+
+        # Recall
+        recall = tpc / (n_l + eps)  # recall curve
+        r_curve[ci] = np.interp(-x, -conf[i], recall[:, 0], left=0)  # negative x, xp because xp decreases
+
+        # Precision
+        precision = tpc / (tpc + fpc)  # precision curve
+        p_curve[ci] = np.interp(-x, -conf[i], precision[:, 0], left=1)  # p at pr_score
+
+        # AP from recall-precision curve
+        for j in range(tp.shape[1]):
+            ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
+            if j == 0:
+                prec_values.append(np.interp(x, mrec, mpre))  # precision at mAP@0.5
+
+    prec_values = np.array(prec_values)  # (nc, 1000)
+
+    # Compute F1 (harmonic mean of precision and recall)
+    f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps)
+    names = [v for k, v in names.items() if k in unique_classes]  # list: only classes that have data
+    names = dict(enumerate(names))  # to dict
+    if plot:
+        plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)
+        plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)
+        plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot)
+        plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot)
+
+    i = smooth(f1_curve.mean(0), 0.1).argmax()  # max F1 index
+    p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i]  # max-F1 precision, recall, F1 values
+    tp = (r * nt).round()  # true positives
+    fp = (tp / (p + eps) - tp).round()  # false positives
+    return tp, fp, p, r, f1, ap, unique_classes.astype(int), p_curve, r_curve, f1_curve, x, prec_values
+
+
+class Metric(SimpleClass):
+    """
+    Class for computing evaluation metrics for YOLOv8 model.
+
+    Attributes:
+        p (list): Precision for each class. Shape: (nc,).
+        r (list): Recall for each class. Shape: (nc,).
+        f1 (list): F1 score for each class. Shape: (nc,).
+        all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
+        ap_class_index (list): Index of class for each AP score. Shape: (nc,).
+        nc (int): Number of classes.
+
+    Methods:
+        ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
+        ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
+        mp(): Mean precision of all classes. Returns: Float.
+        mr(): Mean recall of all classes. Returns: Float.
+        map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.
+        map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.
+        map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.
+        mean_results(): Mean of results, returns mp, mr, map50, map.
+        class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].
+        maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).
+        fitness(): Model fitness as a weighted combination of metrics. Returns: Float.
+        update(results): Update metric attributes with new evaluation results.
+    """
+
+    def __init__(self) -> None:
+        """Initializes a Metric instance for computing evaluation metrics for the YOLOv8 model."""
+        self.p = []  # (nc, )
+        self.r = []  # (nc, )
+        self.f1 = []  # (nc, )
+        self.all_ap = []  # (nc, 10)
+        self.ap_class_index = []  # (nc, )
+        self.nc = 0
+
+    @property
+    def ap50(self):
+        """
+        Returns the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
+
+        Returns:
+            (np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
+        """
+        return self.all_ap[:, 0] if len(self.all_ap) else []
+
+    @property
+    def ap75(self):
+        """
+        Returns the Average Precision (AP) at an IoU threshold of 0.75 for all classes.
+
+        Returns:
+            (np.ndarray, list): Array of shape (nc,) with AP75 values per class, or an empty list if not available.
+        """
+        return self.all_ap[:, 5] if len(self.all_ap) else []
+
+    @property
+    def ap(self):
+        """
+        Returns the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
+
+        Returns:
+            (np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
+        """
+        return self.all_ap.mean(1) if len(self.all_ap) else []
+
+    @property
+    def mp(self):
+        """
+        Returns the Mean Precision of all classes.
+
+        Returns:
+            (float): The mean precision of all classes.
+        """
+        return self.p.mean() if len(self.p) else 0.0
+
+    @property
+    def mr(self):
+        """
+        Returns the Mean Recall of all classes.
+
+        Returns:
+            (float): The mean recall of all classes.
+        """
+        return self.r.mean() if len(self.r) else 0.0
+
+    @property
+    def map50(self):
+        """
+        Returns the mean Average Precision (mAP) at an IoU threshold of 0.5.
+
+        Returns:
+            (float): The mAP at an IoU threshold of 0.5.
+        """
+        return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
+
+    @property
+    def map75(self):
+        """
+        Returns the mean Average Precision (mAP) at an IoU threshold of 0.75.
+
+        Returns:
+            (float): The mAP at an IoU threshold of 0.75.
+        """
+        return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
+
+    @property
+    def map(self):
+        """
+        Returns the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
+
+        Returns:
+            (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
+        """
+        return self.all_ap.mean() if len(self.all_ap) else 0.0
+
+    def mean_results(self):
+        """Mean of results, return mp, mr, map50, map."""
+        return [self.mp, self.mr, self.map50, self.map75,self.map]
+
+    def class_result(self, i):
+        """Class-aware result, return p[i], r[i], ap50[i], ap[i]."""
+        return self.p[i], self.r[i], self.ap50[i], self.ap75[i], self.ap[i]
+
+    @property
+    def maps(self):
+        """MAP of each class."""
+        maps = np.zeros(self.nc) + self.map
+        for i, c in enumerate(self.ap_class_index):
+            maps[c] = self.ap[i]
+        return maps
+
+    def fitness(self):
+        """Model fitness as a weighted combination of metrics."""
+        w = [0.0, 0.0, 0.0, 0.0, 1.0]  # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
+        return (np.array(self.mean_results()) * w).sum()
+
+    def update(self, results):
+        """
+        Updates the evaluation metrics of the model with a new set of results.
+
+        Args:
+            results (tuple): A tuple containing the following evaluation metrics:
+                - p (list): Precision for each class. Shape: (nc,).
+                - r (list): Recall for each class. Shape: (nc,).
+                - f1 (list): F1 score for each class. Shape: (nc,).
+                - all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
+                - ap_class_index (list): Index of class for each AP score. Shape: (nc,).
+
+        Side Effects:
+            Updates the class attributes `self.p`, `self.r`, `self.f1`, `self.all_ap`, and `self.ap_class_index` based
+            on the values provided in the `results` tuple.
+        """
+        (
+            self.p,
+            self.r,
+            self.f1,
+            self.all_ap,
+            self.ap_class_index,
+            self.p_curve,
+            self.r_curve,
+            self.f1_curve,
+            self.px,
+            self.prec_values,
+        ) = results
+
+    @property
+    def curves(self):
+        """Returns a list of curves for accessing specific metrics curves."""
+        return []
+
+    @property
+    def curves_results(self):
+        """Returns a list of curves for accessing specific metrics curves."""
+        return [
+            [self.px, self.prec_values, "Recall", "Precision"],
+            [self.px, self.f1_curve, "Confidence", "F1"],
+            [self.px, self.p_curve, "Confidence", "Precision"],
+            [self.px, self.r_curve, "Confidence", "Recall"],
+        ]
+
+
+class DetMetrics(SimpleClass):
+    """
+    Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP) of an
+    object detection model.
+
+    Args:
+        save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
+        plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
+        on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
+        names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple.
+
+    Attributes:
+        save_dir (Path): A path to the directory where the output plots will be saved.
+        plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
+        on_plot (func): An optional callback to pass plots path and data when they are rendered.
+        names (dict of str): A dict of strings that represents the names of the classes.
+        box (Metric): An instance of the Metric class for storing the results of the detection metrics.
+        speed (dict): A dictionary for storing the execution time of different parts of the detection process.
+
+    Methods:
+        process(tp, conf, pred_cls, target_cls): Updates the metric results with the latest batch of predictions.
+        keys: Returns a list of keys for accessing the computed detection metrics.
+        mean_results: Returns a list of mean values for the computed detection metrics.
+        class_result(i): Returns a list of values for the computed detection metrics for a specific class.
+        maps: Returns a dictionary of mean average precision (mAP) values for different IoU thresholds.
+        fitness: Computes the fitness score based on the computed detection metrics.
+        ap_class_index: Returns a list of class indices sorted by their average precision (AP) values.
+        results_dict: Returns a dictionary that maps detection metric keys to their computed values.
+        curves: TODO
+        curves_results: TODO
+    """
+
+    def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names={}) -> None:
+        """Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
+        self.save_dir = save_dir
+        self.plot = plot
+        self.on_plot = on_plot
+        self.names = names
+        self.box = Metric()
+        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+        self.task = "detect"
+
+    def process(self, tp, conf, pred_cls, target_cls):
+        """Process predicted results for object detection and update metrics."""
+        results = ap_per_class(
+            tp,
+            conf,
+            pred_cls,
+            target_cls,
+            plot=self.plot,
+            save_dir=self.save_dir,
+            names=self.names,
+            on_plot=self.on_plot,
+        )[2:]
+        self.box.nc = len(self.names)
+        self.box.update(results)
+
+    @property
+    def keys(self):
+        """Returns a list of keys for accessing specific metrics."""
+        return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP75(B)", "metrics/mAP50-95(B)"]
+
+    def mean_results(self):
+        """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
+        return self.box.mean_results()
+
+    def class_result(self, i):
+        """Return the result of evaluating the performance of an object detection model on a specific class."""
+        return self.box.class_result(i)
+
+    @property
+    def maps(self):
+        """Returns mean Average Precision (mAP) scores per class."""
+        return self.box.maps
+
+    @property
+    def fitness(self):
+        """Returns the fitness of box object."""
+        return self.box.fitness()
+
+    @property
+    def ap_class_index(self):
+        """Returns the average precision index per class."""
+        return self.box.ap_class_index
+
+    @property
+    def results_dict(self):
+        """Returns dictionary of computed performance metrics and statistics."""
+        return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
+
+    @property
+    def curves(self):
+        """Returns a list of curves for accessing specific metrics curves."""
+        return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
+
+    @property
+    def curves_results(self):
+        """Returns dictionary of computed performance metrics and statistics."""
+        return self.box.curves_results
+
+
+class SegmentMetrics(SimpleClass):
+    """
+    Calculates and aggregates detection and segmentation metrics over a given set of classes.
+
+    Args:
+        save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
+        plot (bool): Whether to save the detection and segmentation plots. Default is False.
+        on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
+        names (list): List of class names. Default is an empty list.
+
+    Attributes:
+        save_dir (Path): Path to the directory where the output plots should be saved.
+        plot (bool): Whether to save the detection and segmentation plots.
+        on_plot (func): An optional callback to pass plots path and data when they are rendered.
+        names (list): List of class names.
+        box (Metric): An instance of the Metric class to calculate box detection metrics.
+        seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
+        speed (dict): Dictionary to store the time taken in different phases of inference.
+
+    Methods:
+        process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
+        mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
+        class_result(i): Returns the detection and segmentation metrics of class `i`.
+        maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
+        fitness: Returns the fitness scores, which are a single weighted combination of metrics.
+        ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
+        results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
+    """
+
+    def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
+        """Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names."""
+        self.save_dir = save_dir
+        self.plot = plot
+        self.on_plot = on_plot
+        self.names = names
+        self.box = Metric()
+        self.seg = Metric()
+        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+        self.task = "segment"
+
+    def process(self, tp, tp_m, conf, pred_cls, target_cls):
+        """
+        Processes the detection and segmentation metrics over the given set of predictions.
+
+        Args:
+            tp (list): List of True Positive boxes.
+            tp_m (list): List of True Positive masks.
+            conf (list): List of confidence scores.
+            pred_cls (list): List of predicted classes.
+            target_cls (list): List of target classes.
+        """
+        results_mask = ap_per_class(
+            tp_m,
+            conf,
+            pred_cls,
+            target_cls,
+            plot=self.plot,
+            on_plot=self.on_plot,
+            save_dir=self.save_dir,
+            names=self.names,
+            prefix="Mask",
+        )[2:]
+        self.seg.nc = len(self.names)
+        self.seg.update(results_mask)
+        results_box = ap_per_class(
+            tp,
+            conf,
+            pred_cls,
+            target_cls,
+            plot=self.plot,
+            on_plot=self.on_plot,
+            save_dir=self.save_dir,
+            names=self.names,
+            prefix="Box",
+        )[2:]
+        self.box.nc = len(self.names)
+        self.box.update(results_box)
+
+    @property
+    def keys(self):
+        """Returns a list of keys for accessing metrics."""
+        return [
+            "metrics/precision(B)",
+            "metrics/recall(B)",
+            "metrics/mAP50(B)",
+            "metrics/mAP75(B)",
+            "metrics/mAP50-95(B)",
+            "metrics/precision(M)",
+            "metrics/recall(M)",
+            "metrics/mAP50(M)",
+            "metrics/mAP75(M)",
+            "metrics/mAP50-95(M)",
+        ]
+
+    def mean_results(self):
+        """Return the mean metrics for bounding box and segmentation results."""
+        return self.box.mean_results() + self.seg.mean_results()
+
+    def class_result(self, i):
+        """Returns classification results for a specified class index."""
+        return self.box.class_result(i) + self.seg.class_result(i)
+
+    @property
+    def maps(self):
+        """Returns mAP scores for object detection and semantic segmentation models."""
+        return self.box.maps + self.seg.maps
+
+    @property
+    def fitness(self):
+        """Get the fitness score for both segmentation and bounding box models."""
+        return self.seg.fitness() + self.box.fitness()
+
+    @property
+    def ap_class_index(self):
+        """Boxes and masks have the same ap_class_index."""
+        return self.box.ap_class_index
+
+    @property
+    def results_dict(self):
+        """Returns results of object detection model for evaluation."""
+        return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
+
+    @property
+    def curves(self):
+        """Returns a list of curves for accessing specific metrics curves."""
+        return [
+            "Precision-Recall(B)",
+            "F1-Confidence(B)",
+            "Precision-Confidence(B)",
+            "Recall-Confidence(B)",
+            "Precision-Recall(M)",
+            "F1-Confidence(M)",
+            "Precision-Confidence(M)",
+            "Recall-Confidence(M)",
+        ]
+
+    @property
+    def curves_results(self):
+        """Returns dictionary of computed performance metrics and statistics."""
+        return self.box.curves_results + self.seg.curves_results
+
+
+class PoseMetrics(SegmentMetrics):
+    """
+    Calculates and aggregates detection and pose metrics over a given set of classes.
+
+    Args:
+        save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
+        plot (bool): Whether to save the detection and segmentation plots. Default is False.
+        on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
+        names (list): List of class names. Default is an empty list.
+
+    Attributes:
+        save_dir (Path): Path to the directory where the output plots should be saved.
+        plot (bool): Whether to save the detection and segmentation plots.
+        on_plot (func): An optional callback to pass plots path and data when they are rendered.
+        names (list): List of class names.
+        box (Metric): An instance of the Metric class to calculate box detection metrics.
+        pose (Metric): An instance of the Metric class to calculate mask segmentation metrics.
+        speed (dict): Dictionary to store the time taken in different phases of inference.
+
+    Methods:
+        process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
+        mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
+        class_result(i): Returns the detection and segmentation metrics of class `i`.
+        maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
+        fitness: Returns the fitness scores, which are a single weighted combination of metrics.
+        ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
+        results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
+    """
+
+    def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
+        """Initialize the PoseMetrics class with directory path, class names, and plotting options."""
+        super().__init__(save_dir, plot, names)
+        self.save_dir = save_dir
+        self.plot = plot
+        self.on_plot = on_plot
+        self.names = names
+        self.box = Metric()
+        self.pose = Metric()
+        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+        self.task = "pose"
+
+    def process(self, tp, tp_p, conf, pred_cls, target_cls):
+        """
+        Processes the detection and pose metrics over the given set of predictions.
+
+        Args:
+            tp (list): List of True Positive boxes.
+            tp_p (list): List of True Positive keypoints.
+            conf (list): List of confidence scores.
+            pred_cls (list): List of predicted classes.
+            target_cls (list): List of target classes.
+        """
+        results_pose = ap_per_class(
+            tp_p,
+            conf,
+            pred_cls,
+            target_cls,
+            plot=self.plot,
+            on_plot=self.on_plot,
+            save_dir=self.save_dir,
+            names=self.names,
+            prefix="Pose",
+        )[2:]
+        self.pose.nc = len(self.names)
+        self.pose.update(results_pose)
+        results_box = ap_per_class(
+            tp,
+            conf,
+            pred_cls,
+            target_cls,
+            plot=self.plot,
+            on_plot=self.on_plot,
+            save_dir=self.save_dir,
+            names=self.names,
+            prefix="Box",
+        )[2:]
+        self.box.nc = len(self.names)
+        self.box.update(results_box)
+
+    @property
+    def keys(self):
+        """Returns list of evaluation metric keys."""
+        return [
+            "metrics/precision(B)",
+            "metrics/recall(B)",
+            "metrics/mAP50(B)",
+            "metrics/mAP75(B)",
+            "metrics/mAP50-95(B)",
+            "metrics/precision(P)",
+            "metrics/recall(P)",
+            "metrics/mAP50(P)",
+            "metrics/mAP75(P)",
+            "metrics/mAP50-95(P)",
+        ]
+
+    def mean_results(self):
+        """Return the mean results of box and pose."""
+        return self.box.mean_results() + self.pose.mean_results()
+
+    def class_result(self, i):
+        """Return the class-wise detection results for a specific class i."""
+        return self.box.class_result(i) + self.pose.class_result(i)
+
+    @property
+    def maps(self):
+        """Returns the mean average precision (mAP) per class for both box and pose detections."""
+        return self.box.maps + self.pose.maps
+
+    @property
+    def fitness(self):
+        """Computes classification metrics and speed using the `targets` and `pred` inputs."""
+        return self.pose.fitness() + self.box.fitness()
+
+    @property
+    def curves(self):
+        """Returns a list of curves for accessing specific metrics curves."""
+        return [
+            "Precision-Recall(B)",
+            "F1-Confidence(B)",
+            "Precision-Confidence(B)",
+            "Recall-Confidence(B)",
+            "Precision-Recall(P)",
+            "F1-Confidence(P)",
+            "Precision-Confidence(P)",
+            "Recall-Confidence(P)",
+        ]
+
+    @property
+    def curves_results(self):
+        """Returns dictionary of computed performance metrics and statistics."""
+        return self.box.curves_results + self.pose.curves_results
+
+
+class ClassifyMetrics(SimpleClass):
+    """
+    Class for computing classification metrics including top-1 and top-5 accuracy.
+
+    Attributes:
+        top1 (float): The top-1 accuracy.
+        top5 (float): The top-5 accuracy.
+        speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline.
+        fitness (float): The fitness of the model, which is equal to top-5 accuracy.
+        results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness.
+        keys (List[str]): A list of keys for the results_dict.
+
+    Methods:
+        process(targets, pred): Processes the targets and predictions to compute classification metrics.
+    """
+
+    def __init__(self) -> None:
+        """Initialize a ClassifyMetrics instance."""
+        self.top1 = 0
+        self.top5 = 0
+        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+        self.task = "classify"
+
+    def process(self, targets, pred):
+        """Target classes and predicted classes."""
+        pred, targets = torch.cat(pred), torch.cat(targets)
+        correct = (targets[:, None] == pred).float()
+        acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1)  # (top1, top5) accuracy
+        self.top1, self.top5 = acc.mean(0).tolist()
+
+    @property
+    def fitness(self):
+        """Returns mean of top-1 and top-5 accuracies as fitness score."""
+        return (self.top1 + self.top5) / 2
+
+    @property
+    def results_dict(self):
+        """Returns a dictionary with model's performance metrics and fitness score."""
+        return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
+
+    @property
+    def keys(self):
+        """Returns a list of keys for the results_dict property."""
+        return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
+
+    @property
+    def curves(self):
+        """Returns a list of curves for accessing specific metrics curves."""
+        return []
+
+    @property
+    def curves_results(self):
+        """Returns a list of curves for accessing specific metrics curves."""
+        return []
+
+
+class OBBMetrics(SimpleClass):
+    """Metrics for evaluating oriented bounding box (OBB) detection, see https://arxiv.org/pdf/2106.06072.pdf."""
+
+    def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
+        """Initialize an OBBMetrics instance with directory, plotting, callback, and class names."""
+        self.save_dir = save_dir
+        self.plot = plot
+        self.on_plot = on_plot
+        self.names = names
+        self.box = Metric()
+        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
+
+    def process(self, tp, conf, pred_cls, target_cls):
+        """Process predicted results for object detection and update metrics."""
+        results = ap_per_class(
+            tp,
+            conf,
+            pred_cls,
+            target_cls,
+            plot=self.plot,
+            save_dir=self.save_dir,
+            names=self.names,
+            on_plot=self.on_plot,
+        )[2:]
+        self.box.nc = len(self.names)
+        self.box.update(results)
+
+    @property
+    def keys(self):
+        """Returns a list of keys for accessing specific metrics."""
+        return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
+
+    def mean_results(self):
+        """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
+        return self.box.mean_results()
+
+    def class_result(self, i):
+        """Return the result of evaluating the performance of an object detection model on a specific class."""
+        return self.box.class_result(i)
+
+    @property
+    def maps(self):
+        """Returns mean Average Precision (mAP) scores per class."""
+        return self.box.maps
+
+    @property
+    def fitness(self):
+        """Returns the fitness of box object."""
+        return self.box.fitness()
+
+    @property
+    def ap_class_index(self):
+        """Returns the average precision index per class."""
+        return self.box.ap_class_index
+
+    @property
+    def results_dict(self):
+        """Returns dictionary of computed performance metrics and statistics."""
+        return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
+
+    @property
+    def curves(self):
+        """Returns a list of curves for accessing specific metrics curves."""
+        return []
+
+    @property
+    def curves_results(self):
+        """Returns a list of curves for accessing specific metrics curves."""
+        return []

+ 854 - 0
ultralytics/utils/ops.py

@@ -0,0 +1,854 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import contextlib
+import math
+import re
+import time
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from ultralytics.utils import LOGGER
+from ultralytics.utils.metrics import batch_probiou
+
+
+class Profile(contextlib.ContextDecorator):
+    """
+    YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'.
+
+    Example:
+        ```python
+        from ultralytics.utils.ops import Profile
+
+        with Profile(device=device) as dt:
+            pass  # slow operation here
+
+        print(dt)  # prints "Elapsed time is 9.5367431640625e-07 s"
+        ```
+    """
+
+    def __init__(self, t=0.0, device: torch.device = None):
+        """
+        Initialize the Profile class.
+
+        Args:
+            t (float): Initial time. Defaults to 0.0.
+            device (torch.device): Devices used for model inference. Defaults to None (cpu).
+        """
+        self.t = t
+        self.device = device
+        self.cuda = bool(device and str(device).startswith("cuda"))
+
+    def __enter__(self):
+        """Start timing."""
+        self.start = self.time()
+        return self
+
+    def __exit__(self, type, value, traceback):  # noqa
+        """Stop timing."""
+        self.dt = self.time() - self.start  # delta-time
+        self.t += self.dt  # accumulate dt
+
+    def __str__(self):
+        """Returns a human-readable string representing the accumulated elapsed time in the profiler."""
+        return f"Elapsed time is {self.t} s"
+
+    def time(self):
+        """Get current time."""
+        if self.cuda:
+            torch.cuda.synchronize(self.device)
+        return time.time()
+
+
+def segment2box(segment, width=640, height=640):
+    """
+    Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy).
+
+    Args:
+        segment (torch.Tensor): the segment label
+        width (int): the width of the image. Defaults to 640
+        height (int): The height of the image. Defaults to 640
+
+    Returns:
+        (np.ndarray): the minimum and maximum x and y values of the segment.
+    """
+    x, y = segment.T  # segment xy
+    # any 3 out of 4 sides are outside the image, clip coordinates first, https://github.com/ultralytics/ultralytics/pull/18294
+    if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3:
+        x = x.clip(0, width)
+        y = y.clip(0, height)
+    inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
+    x = x[inside]
+    y = y[inside]
+    return (
+        np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
+        if any(x)
+        else np.zeros(4, dtype=segment.dtype)
+    )  # xyxy
+
+
+def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
+    """
+    Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
+    specified in (img1_shape) to the shape of a different image (img0_shape).
+
+    Args:
+        img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
+        boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
+        img0_shape (tuple): the shape of the target image, in the format of (height, width).
+        ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
+            calculated based on the size difference between the two images.
+        padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
+            rescaling.
+        xywh (bool): The box format is xywh or not, default=False.
+
+    Returns:
+        boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
+    """
+    if ratio_pad is None:  # calculate from img0_shape
+        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
+        pad = (
+            round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
+            round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
+        )  # wh padding
+    else:
+        gain = ratio_pad[0][0]
+        pad = ratio_pad[1]
+
+    if padding:
+        boxes[..., 0] -= pad[0]  # x padding
+        boxes[..., 1] -= pad[1]  # y padding
+        if not xywh:
+            boxes[..., 2] -= pad[0]  # x padding
+            boxes[..., 3] -= pad[1]  # y padding
+    boxes[..., :4] /= gain
+    return clip_boxes(boxes, img0_shape)
+
+
+def make_divisible(x, divisor):
+    """
+    Returns the nearest number that is divisible by the given divisor.
+
+    Args:
+        x (int): The number to make divisible.
+        divisor (int | torch.Tensor): The divisor.
+
+    Returns:
+        (int): The nearest number divisible by the divisor.
+    """
+    if isinstance(divisor, torch.Tensor):
+        divisor = int(divisor.max())  # to int
+    return math.ceil(x / divisor) * divisor
+
+
+def nms_rotated(boxes, scores, threshold=0.45):
+    """
+    NMS for oriented bounding boxes using probiou and fast-nms.
+
+    Args:
+        boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
+        scores (torch.Tensor): Confidence scores, shape (N,).
+        threshold (float, optional): IoU threshold. Defaults to 0.45.
+
+    Returns:
+        (torch.Tensor): Indices of boxes to keep after NMS.
+    """
+    if len(boxes) == 0:
+        return np.empty((0,), dtype=np.int8)
+    sorted_idx = torch.argsort(scores, descending=True)
+    boxes = boxes[sorted_idx]
+    ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
+    pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
+    return sorted_idx[pick]
+
+
+def non_max_suppression(
+    prediction,
+    conf_thres=0.25,
+    iou_thres=0.45,
+    classes=None,
+    agnostic=False,
+    multi_label=False,
+    labels=(),
+    max_det=300,
+    nc=0,  # number of classes (optional)
+    max_time_img=0.05,
+    max_nms=30000,
+    max_wh=7680,
+    in_place=True,
+    rotated=False,
+):
+    """
+    Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
+
+    Args:
+        prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
+            containing the predicted boxes, classes, and masks. The tensor should be in the format
+            output by a model, such as YOLO.
+        conf_thres (float): The confidence threshold below which boxes will be filtered out.
+            Valid values are between 0.0 and 1.0.
+        iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
+            Valid values are between 0.0 and 1.0.
+        classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
+        agnostic (bool): If True, the model is agnostic to the number of classes, and all
+            classes will be considered as one.
+        multi_label (bool): If True, each box may have multiple labels.
+        labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
+            list contains the apriori labels for a given image. The list should be in the format
+            output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
+        max_det (int): The maximum number of boxes to keep after NMS.
+        nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
+        max_time_img (float): The maximum time (seconds) for processing one image.
+        max_nms (int): The maximum number of boxes into torchvision.ops.nms().
+        max_wh (int): The maximum box width and height in pixels.
+        in_place (bool): If True, the input prediction tensor will be modified in place.
+        rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
+
+    Returns:
+        (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
+            shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
+            (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
+    """
+    import torchvision  # scope for faster 'import ultralytics'
+
+    # Checks
+    assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
+    assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
+    if isinstance(prediction, (list, tuple)):  # YOLOv8 model in validation model, output = (inference_out, loss_out)
+        prediction = prediction[0]  # select only inference output
+    if classes is not None:
+        classes = torch.tensor(classes, device=prediction.device)
+
+    if prediction.shape[-1] == 6:  # end-to-end model (BNC, i.e. 1,300,6)
+        output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
+        if classes is not None:
+            output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
+        return output
+
+    bs = prediction.shape[0]  # batch size (BCN, i.e. 1,84,6300)
+    nc = nc or (prediction.shape[1] - 4)  # number of classes
+    nm = prediction.shape[1] - nc - 4  # number of masks
+    mi = 4 + nc  # mask start index
+    xc = prediction[:, 4:mi].amax(1) > conf_thres  # candidates
+
+    # Settings
+    # min_wh = 2  # (pixels) minimum box width and height
+    time_limit = 2.0 + max_time_img * bs  # seconds to quit after
+    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
+
+    prediction = prediction.transpose(-1, -2)  # shape(1,84,6300) to shape(1,6300,84)
+    if not rotated:
+        if in_place:
+            prediction[..., :4] = xywh2xyxy(prediction[..., :4])  # xywh to xyxy
+        else:
+            prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1)  # xywh to xyxy
+
+    t = time.time()
+    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
+    for xi, x in enumerate(prediction):  # image index, image inference
+        # Apply constraints
+        # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0  # width-height
+        x = x[xc[xi]]  # confidence
+
+        # Cat apriori labels if autolabelling
+        if labels and len(labels[xi]) and not rotated:
+            lb = labels[xi]
+            v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
+            v[:, :4] = xywh2xyxy(lb[:, 1:5])  # box
+            v[range(len(lb)), lb[:, 0].long() + 4] = 1.0  # cls
+            x = torch.cat((x, v), 0)
+
+        # If none remain process next image
+        if not x.shape[0]:
+            continue
+
+        # Detections matrix nx6 (xyxy, conf, cls)
+        box, cls, mask = x.split((4, nc, nm), 1)
+
+        if multi_label:
+            i, j = torch.where(cls > conf_thres)
+            x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
+        else:  # best class only
+            conf, j = cls.max(1, keepdim=True)
+            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
+
+        # Filter by class
+        if classes is not None:
+            x = x[(x[:, 5:6] == classes).any(1)]
+
+        # Check shape
+        n = x.shape[0]  # number of boxes
+        if not n:  # no boxes
+            continue
+        if n > max_nms:  # excess boxes
+            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence and remove excess boxes
+
+        # Batched NMS
+        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
+        scores = x[:, 4]  # scores
+        if rotated:
+            boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1)  # xywhr
+            i = nms_rotated(boxes, scores, iou_thres)
+        else:
+            boxes = x[:, :4] + c  # boxes (offset by class)
+            i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
+        i = i[:max_det]  # limit detections
+
+        # # Experimental
+        # merge = False  # use merge-NMS
+        # if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
+        #     # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+        #     from .metrics import box_iou
+        #     iou = box_iou(boxes[i], boxes) > iou_thres  # IoU matrix
+        #     weights = iou * scores[None]  # box weights
+        #     x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
+        #     redundant = True  # require redundant detections
+        #     if redundant:
+        #         i = i[iou.sum(1) > 1]  # require redundancy
+
+        output[xi] = x[i]
+        if (time.time() - t) > time_limit:
+            LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
+            break  # time limit exceeded
+
+    return output
+
+
+def clip_boxes(boxes, shape):
+    """
+    Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
+
+    Args:
+        boxes (torch.Tensor): The bounding boxes to clip.
+        shape (tuple): The shape of the image.
+
+    Returns:
+        (torch.Tensor | numpy.ndarray): The clipped boxes.
+    """
+    if isinstance(boxes, torch.Tensor):  # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
+        boxes[..., 0] = boxes[..., 0].clamp(0, shape[1])  # x1
+        boxes[..., 1] = boxes[..., 1].clamp(0, shape[0])  # y1
+        boxes[..., 2] = boxes[..., 2].clamp(0, shape[1])  # x2
+        boxes[..., 3] = boxes[..., 3].clamp(0, shape[0])  # y2
+    else:  # np.array (faster grouped)
+        boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1])  # x1, x2
+        boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0])  # y1, y2
+    return boxes
+
+
+def clip_coords(coords, shape):
+    """
+    Clip line coordinates to the image boundaries.
+
+    Args:
+        coords (torch.Tensor | numpy.ndarray): A list of line coordinates.
+        shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
+
+    Returns:
+        (torch.Tensor | numpy.ndarray): Clipped coordinates
+    """
+    if isinstance(coords, torch.Tensor):  # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
+        coords[..., 0] = coords[..., 0].clamp(0, shape[1])  # x
+        coords[..., 1] = coords[..., 1].clamp(0, shape[0])  # y
+    else:  # np.array (faster grouped)
+        coords[..., 0] = coords[..., 0].clip(0, shape[1])  # x
+        coords[..., 1] = coords[..., 1].clip(0, shape[0])  # y
+    return coords
+
+
+def scale_image(masks, im0_shape, ratio_pad=None):
+    """
+    Takes a mask, and resizes it to the original image size.
+
+    Args:
+        masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3].
+        im0_shape (tuple): The original image shape.
+        ratio_pad (tuple): The ratio of the padding to the original image.
+
+    Returns:
+        masks (np.ndarray): The masks that are being returned with shape [h, w, num].
+    """
+    # Rescale coordinates (xyxy) from im1_shape to im0_shape
+    im1_shape = masks.shape
+    if im1_shape[:2] == im0_shape[:2]:
+        return masks
+    if ratio_pad is None:  # calculate from im0_shape
+        gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1])  # gain  = old / new
+        pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2  # wh padding
+    else:
+        # gain = ratio_pad[0][0]
+        pad = ratio_pad[1]
+    top, left = int(pad[1]), int(pad[0])  # y, x
+    bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
+
+    if len(masks.shape) < 2:
+        raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
+    masks = masks[top:bottom, left:right]
+    masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
+    if len(masks.shape) == 2:
+        masks = masks[:, :, None]
+
+    return masks
+
+
+def xyxy2xywh(x):
+    """
+    Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
+    top-left corner and (x2, y2) is the bottom-right corner.
+
+    Args:
+        x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
+
+    Returns:
+        y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
+    """
+    assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
+    y = empty_like(x)  # faster than clone/copy
+    y[..., 0] = (x[..., 0] + x[..., 2]) / 2  # x center
+    y[..., 1] = (x[..., 1] + x[..., 3]) / 2  # y center
+    y[..., 2] = x[..., 2] - x[..., 0]  # width
+    y[..., 3] = x[..., 3] - x[..., 1]  # height
+    return y
+
+
+def xywh2xyxy(x):
+    """
+    Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
+    top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
+
+    Args:
+        x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
+
+    Returns:
+        y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
+    """
+    assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
+    y = empty_like(x)  # faster than clone/copy
+    xy = x[..., :2]  # centers
+    wh = x[..., 2:] / 2  # half width-height
+    y[..., :2] = xy - wh  # top left xy
+    y[..., 2:] = xy + wh  # bottom right xy
+    return y
+
+
+def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
+    """
+    Convert normalized bounding box coordinates to pixel coordinates.
+
+    Args:
+        x (np.ndarray | torch.Tensor): The bounding box coordinates.
+        w (int): Width of the image. Defaults to 640
+        h (int): Height of the image. Defaults to 640
+        padw (int): Padding width. Defaults to 0
+        padh (int): Padding height. Defaults to 0
+    Returns:
+        y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
+            x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
+    """
+    assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
+    y = empty_like(x)  # faster than clone/copy
+    y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw  # top left x
+    y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh  # top left y
+    y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw  # bottom right x
+    y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh  # bottom right y
+    return y
+
+
+def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
+    """
+    Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
+    width and height are normalized to image dimensions.
+
+    Args:
+        x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
+        w (int): The width of the image. Defaults to 640
+        h (int): The height of the image. Defaults to 640
+        clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
+        eps (float): The minimum value of the box's width and height. Defaults to 0.0
+
+    Returns:
+        y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
+    """
+    if clip:
+        x = clip_boxes(x, (h - eps, w - eps))
+    assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
+    y = empty_like(x)  # faster than clone/copy
+    y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w  # x center
+    y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h  # y center
+    y[..., 2] = (x[..., 2] - x[..., 0]) / w  # width
+    y[..., 3] = (x[..., 3] - x[..., 1]) / h  # height
+    return y
+
+
+def xywh2ltwh(x):
+    """
+    Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
+
+    Args:
+        x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
+
+    Returns:
+        y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
+    """
+    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+    y[..., 0] = x[..., 0] - x[..., 2] / 2  # top left x
+    y[..., 1] = x[..., 1] - x[..., 3] / 2  # top left y
+    return y
+
+
+def xyxy2ltwh(x):
+    """
+    Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right.
+
+    Args:
+        x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
+
+    Returns:
+        y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
+    """
+    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+    y[..., 2] = x[..., 2] - x[..., 0]  # width
+    y[..., 3] = x[..., 3] - x[..., 1]  # height
+    return y
+
+
+def ltwh2xywh(x):
+    """
+    Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
+
+    Args:
+        x (torch.Tensor): the input tensor
+
+    Returns:
+        y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format.
+    """
+    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+    y[..., 0] = x[..., 0] + x[..., 2] / 2  # center x
+    y[..., 1] = x[..., 1] + x[..., 3] / 2  # center y
+    return y
+
+
+def xyxyxyxy2xywhr(x):
+    """
+    Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
+    returned in radians from 0 to pi/2.
+
+    Args:
+        x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4] of shape (n, 8).
+
+    Returns:
+        (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
+    """
+    is_torch = isinstance(x, torch.Tensor)
+    points = x.cpu().numpy() if is_torch else x
+    points = points.reshape(len(x), -1, 2)
+    rboxes = []
+    for pts in points:
+        # NOTE: Use cv2.minAreaRect to get accurate xywhr,
+        # especially some objects are cut off by augmentations in dataloader.
+        (cx, cy), (w, h), angle = cv2.minAreaRect(pts)
+        rboxes.append([cx, cy, w, h, angle / 180 * np.pi])
+    return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes)
+
+
+def xywhr2xyxyxyxy(x):
+    """
+    Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
+    be in radians from 0 to pi/2.
+
+    Args:
+        x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
+
+    Returns:
+        (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
+    """
+    cos, sin, cat, stack = (
+        (torch.cos, torch.sin, torch.cat, torch.stack)
+        if isinstance(x, torch.Tensor)
+        else (np.cos, np.sin, np.concatenate, np.stack)
+    )
+
+    ctr = x[..., :2]
+    w, h, angle = (x[..., i : i + 1] for i in range(2, 5))
+    cos_value, sin_value = cos(angle), sin(angle)
+    vec1 = [w / 2 * cos_value, w / 2 * sin_value]
+    vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
+    vec1 = cat(vec1, -1)
+    vec2 = cat(vec2, -1)
+    pt1 = ctr + vec1 + vec2
+    pt2 = ctr + vec1 - vec2
+    pt3 = ctr - vec1 - vec2
+    pt4 = ctr - vec1 + vec2
+    return stack([pt1, pt2, pt3, pt4], -2)
+
+
+def ltwh2xyxy(x):
+    """
+    It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
+
+    Args:
+        x (np.ndarray | torch.Tensor): the input image
+
+    Returns:
+        y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
+    """
+    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+    y[..., 2] = x[..., 2] + x[..., 0]  # width
+    y[..., 3] = x[..., 3] + x[..., 1]  # height
+    return y
+
+
+def segments2boxes(segments):
+    """
+    It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
+
+    Args:
+        segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
+
+    Returns:
+        (np.ndarray): the xywh coordinates of the bounding boxes.
+    """
+    boxes = []
+    for s in segments:
+        x, y = s.T  # segment xy
+        boxes.append([x.min(), y.min(), x.max(), y.max()])  # cls, xyxy
+    return xyxy2xywh(np.array(boxes))  # cls, xywh
+
+
+def resample_segments(segments, n=1000):
+    """
+    Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
+
+    Args:
+        segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
+        n (int): number of points to resample the segment to. Defaults to 1000
+
+    Returns:
+        segments (list): the resampled segments.
+    """
+    for i, s in enumerate(segments):
+        if len(s) == n:
+            continue
+        s = np.concatenate((s, s[0:1, :]), axis=0)
+        x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n)
+        xp = np.arange(len(s))
+        x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x
+        segments[i] = (
+            np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
+        )  # segment xy
+    return segments
+
+
+def crop_mask(masks, boxes):
+    """
+    It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box.
+
+    Args:
+        masks (torch.Tensor): [n, h, w] tensor of masks
+        boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
+
+    Returns:
+        (torch.Tensor): The masks are being cropped to the bounding box.
+    """
+    _, h, w = masks.shape
+    x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)  # x1 shape(n,1,1)
+    r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :]  # rows shape(1,1,w)
+    c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None]  # cols shape(1,h,1)
+
+    return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
+
+
+def process_mask(protos, masks_in, bboxes, shape, upsample=False):
+    """
+    Apply masks to bounding boxes using the output of the mask head.
+
+    Args:
+        protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
+        masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
+        bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
+        shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
+        upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False.
+
+    Returns:
+        (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
+            are the height and width of the input image. The mask is applied to the bounding boxes.
+    """
+    c, mh, mw = protos.shape  # CHW
+    ih, iw = shape
+    masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)  # CHW
+    width_ratio = mw / iw
+    height_ratio = mh / ih
+
+    downsampled_bboxes = bboxes.clone()
+    downsampled_bboxes[:, 0] *= width_ratio
+    downsampled_bboxes[:, 2] *= width_ratio
+    downsampled_bboxes[:, 3] *= height_ratio
+    downsampled_bboxes[:, 1] *= height_ratio
+
+    masks = crop_mask(masks, downsampled_bboxes)  # CHW
+    if upsample:
+        masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0]  # CHW
+    return masks.gt_(0.0)
+
+
+def process_mask_native(protos, masks_in, bboxes, shape):
+    """
+    It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
+
+    Args:
+        protos (torch.Tensor): [mask_dim, mask_h, mask_w]
+        masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms.
+        bboxes (torch.Tensor): [n, 4], n is number of masks after nms.
+        shape (tuple): The size of the input image (h,w).
+
+    Returns:
+        masks (torch.Tensor): The returned masks with dimensions [h, w, n].
+    """
+    c, mh, mw = protos.shape  # CHW
+    masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
+    masks = scale_masks(masks[None], shape)[0]  # CHW
+    masks = crop_mask(masks, bboxes)  # CHW
+    return masks.gt_(0.0)
+
+
+def scale_masks(masks, shape, padding=True):
+    """
+    Rescale segment masks to shape.
+
+    Args:
+        masks (torch.Tensor): (N, C, H, W).
+        shape (tuple): Height and width.
+        padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
+            rescaling.
+    """
+    mh, mw = masks.shape[2:]
+    gain = min(mh / shape[0], mw / shape[1])  # gain  = old / new
+    pad = [mw - shape[1] * gain, mh - shape[0] * gain]  # wh padding
+    if padding:
+        pad[0] /= 2
+        pad[1] /= 2
+    top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0)  # y, x
+    bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
+    masks = masks[..., top:bottom, left:right]
+
+    masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False)  # NCHW
+    return masks
+
+
+def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
+    """
+    Rescale segment coordinates (xy) from img1_shape to img0_shape.
+
+    Args:
+        img1_shape (tuple): The shape of the image that the coords are from.
+        coords (torch.Tensor): the coords to be scaled of shape n,2.
+        img0_shape (tuple): the shape of the image that the segmentation is being applied to.
+        ratio_pad (tuple): the ratio of the image size to the padded image size.
+        normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False.
+        padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
+            rescaling.
+
+    Returns:
+        coords (torch.Tensor): The scaled coordinates.
+    """
+    if ratio_pad is None:  # calculate from img0_shape
+        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
+        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
+    else:
+        gain = ratio_pad[0][0]
+        pad = ratio_pad[1]
+
+    if padding:
+        coords[..., 0] -= pad[0]  # x padding
+        coords[..., 1] -= pad[1]  # y padding
+    coords[..., 0] /= gain
+    coords[..., 1] /= gain
+    coords = clip_coords(coords, img0_shape)
+    if normalize:
+        coords[..., 0] /= img0_shape[1]  # width
+        coords[..., 1] /= img0_shape[0]  # height
+    return coords
+
+
+def regularize_rboxes(rboxes):
+    """
+    Regularize rotated boxes in range [0, pi/2].
+
+    Args:
+        rboxes (torch.Tensor): Input boxes of shape(N, 5) in xywhr format.
+
+    Returns:
+        (torch.Tensor): The regularized boxes.
+    """
+    x, y, w, h, t = rboxes.unbind(dim=-1)
+    # Swap edge and angle if h >= w
+    w_ = torch.where(w > h, w, h)
+    h_ = torch.where(w > h, h, w)
+    t = torch.where(w > h, t, t + math.pi / 2) % math.pi
+    return torch.stack([x, y, w_, h_, t], dim=-1)  # regularized boxes
+
+
+def masks2segments(masks, strategy="all"):
+    """
+    It takes a list of masks(n,h,w) and returns a list of segments(n,xy).
+
+    Args:
+        masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
+        strategy (str): 'all' or 'largest'. Defaults to all
+
+    Returns:
+        segments (List): list of segment masks
+    """
+    from ultralytics.data.converter import merge_multi_segment
+
+    segments = []
+    for x in masks.int().cpu().numpy().astype("uint8"):
+        c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
+        if c:
+            if strategy == "all":  # merge and concatenate all segments
+                c = (
+                    np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c]))
+                    if len(c) > 1
+                    else c[0].reshape(-1, 2)
+                )
+            elif strategy == "largest":  # select largest segment
+                c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
+        else:
+            c = np.zeros((0, 2))  # no segments found
+        segments.append(c.astype("float32"))
+    return segments
+
+
+def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
+    """
+    Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout.
+
+    Args:
+        batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32.
+
+    Returns:
+        (np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8.
+    """
+    return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
+
+
+def clean_str(s):
+    """
+    Cleans a string by replacing special characters with '_' character.
+
+    Args:
+        s (str): a string needing special characters replaced
+
+    Returns:
+        (str): a string with special characters replaced by an underscore _
+    """
+    return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
+
+
+def empty_like(x):
+    """Creates empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
+    return (
+        torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
+    )

+ 104 - 0
ultralytics/utils/patches.py

@@ -0,0 +1,104 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+"""Monkey patches to update/extend functionality of existing functions."""
+
+import time
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+
+# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
+_imshow = cv2.imshow  # copy to avoid recursion errors
+
+
+def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
+    """
+    Read an image from a file.
+
+    Args:
+        filename (str): Path to the file to read.
+        flags (int, optional): Flag that can take values of cv2.IMREAD_*. Defaults to cv2.IMREAD_COLOR.
+
+    Returns:
+        (np.ndarray): The read image.
+    """
+    return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
+
+
+def imwrite(filename: str, img: np.ndarray, params=None):
+    """
+    Write an image to a file.
+
+    Args:
+        filename (str): Path to the file to write.
+        img (np.ndarray): Image to write.
+        params (list of ints, optional): Additional parameters. See OpenCV documentation.
+
+    Returns:
+        (bool): True if the file was written, False otherwise.
+    """
+    try:
+        cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)
+        return True
+    except Exception:
+        return False
+
+
+def imshow(winname: str, mat: np.ndarray):
+    """
+    Displays an image in the specified window.
+
+    Args:
+        winname (str): Name of the window.
+        mat (np.ndarray): Image to be shown.
+    """
+    _imshow(winname.encode("unicode_escape").decode(), mat)
+
+
+# PyTorch functions ----------------------------------------------------------------------------------------------------
+_torch_load = torch.load  # copy to avoid recursion errors
+_torch_save = torch.save
+
+
+def torch_load(*args, **kwargs):
+    """
+    Load a PyTorch model with updated arguments to avoid warnings.
+
+    This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
+
+    Args:
+        *args (Any): Variable length argument list to pass to torch.load.
+        **kwargs (Any): Arbitrary keyword arguments to pass to torch.load.
+
+    Returns:
+        (Any): The loaded PyTorch object.
+
+    Note:
+        For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
+        if the argument is not provided, to avoid deprecation warnings.
+    """
+    from ultralytics.utils.torch_utils import TORCH_1_13
+
+    if TORCH_1_13 and "weights_only" not in kwargs:
+        kwargs["weights_only"] = False
+
+    return _torch_load(*args, **kwargs)
+
+
+def torch_save(*args, **kwargs):
+    """
+    Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries and
+    exponential standoff in case of save failure.
+
+    Args:
+        *args (tuple): Positional arguments to pass to torch.save.
+        **kwargs (Any): Keyword arguments to pass to torch.save.
+    """
+    for i in range(4):  # 3 retries
+        try:
+            return _torch_save(*args, **kwargs)
+        except RuntimeError as e:  # unable to save, possibly waiting for device to flush or antivirus scan
+            if i == 3:
+                raise e
+            time.sleep((2**i) / 2)  # exponential standoff: 0.5s, 1.0s, 2.0s

+ 1378 - 0
ultralytics/utils/plotting.py

@@ -0,0 +1,1378 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import math
+import warnings
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Union
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from PIL import __version__ as pil_version
+
+from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded
+from ultralytics.utils.checks import check_font, check_version, is_ascii
+from ultralytics.utils.files import increment_path
+
+
+class Colors:
+    """
+    Ultralytics color palette https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors.
+
+    This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
+    RGB values.
+
+    Attributes:
+        palette (list of tuple): List of RGB color values.
+        n (int): The number of colors in the palette.
+        pose_palette (np.ndarray): A specific color palette array with dtype np.uint8.
+
+    ## Ultralytics Color Palette
+
+    | Index | Color                                                             | HEX       | RGB               |
+    |-------|-------------------------------------------------------------------|-----------|-------------------|
+    | 0     | <i class="fa-solid fa-square fa-2xl" style="color: #042aff;"></i> | `#042aff` | (4, 42, 255)      |
+    | 1     | <i class="fa-solid fa-square fa-2xl" style="color: #0bdbeb;"></i> | `#0bdbeb` | (11, 219, 235)    |
+    | 2     | <i class="fa-solid fa-square fa-2xl" style="color: #f3f3f3;"></i> | `#f3f3f3` | (243, 243, 243)   |
+    | 3     | <i class="fa-solid fa-square fa-2xl" style="color: #00dfb7;"></i> | `#00dfb7` | (0, 223, 183)     |
+    | 4     | <i class="fa-solid fa-square fa-2xl" style="color: #111f68;"></i> | `#111f68` | (17, 31, 104)     |
+    | 5     | <i class="fa-solid fa-square fa-2xl" style="color: #ff6fdd;"></i> | `#ff6fdd` | (255, 111, 221)   |
+    | 6     | <i class="fa-solid fa-square fa-2xl" style="color: #ff444f;"></i> | `#ff444f` | (255, 68, 79)     |
+    | 7     | <i class="fa-solid fa-square fa-2xl" style="color: #cced00;"></i> | `#cced00` | (204, 237, 0)     |
+    | 8     | <i class="fa-solid fa-square fa-2xl" style="color: #00f344;"></i> | `#00f344` | (0, 243, 68)      |
+    | 9     | <i class="fa-solid fa-square fa-2xl" style="color: #bd00ff;"></i> | `#bd00ff` | (189, 0, 255)     |
+    | 10    | <i class="fa-solid fa-square fa-2xl" style="color: #00b4ff;"></i> | `#00b4ff` | (0, 180, 255)     |
+    | 11    | <i class="fa-solid fa-square fa-2xl" style="color: #dd00ba;"></i> | `#dd00ba` | (221, 0, 186)     |
+    | 12    | <i class="fa-solid fa-square fa-2xl" style="color: #00ffff;"></i> | `#00ffff` | (0, 255, 255)     |
+    | 13    | <i class="fa-solid fa-square fa-2xl" style="color: #26c000;"></i> | `#26c000` | (38, 192, 0)      |
+    | 14    | <i class="fa-solid fa-square fa-2xl" style="color: #01ffb3;"></i> | `#01ffb3` | (1, 255, 179)     |
+    | 15    | <i class="fa-solid fa-square fa-2xl" style="color: #7d24ff;"></i> | `#7d24ff` | (125, 36, 255)    |
+    | 16    | <i class="fa-solid fa-square fa-2xl" style="color: #7b0068;"></i> | `#7b0068` | (123, 0, 104)     |
+    | 17    | <i class="fa-solid fa-square fa-2xl" style="color: #ff1b6c;"></i> | `#ff1b6c` | (255, 27, 108)    |
+    | 18    | <i class="fa-solid fa-square fa-2xl" style="color: #fc6d2f;"></i> | `#fc6d2f` | (252, 109, 47)    |
+    | 19    | <i class="fa-solid fa-square fa-2xl" style="color: #a2ff0b;"></i> | `#a2ff0b` | (162, 255, 11)    |
+
+    ## Pose Color Palette
+
+    | Index | Color                                                             | HEX       | RGB               |
+    |-------|-------------------------------------------------------------------|-----------|-------------------|
+    | 0     | <i class="fa-solid fa-square fa-2xl" style="color: #ff8000;"></i> | `#ff8000` | (255, 128, 0)     |
+    | 1     | <i class="fa-solid fa-square fa-2xl" style="color: #ff9933;"></i> | `#ff9933` | (255, 153, 51)    |
+    | 2     | <i class="fa-solid fa-square fa-2xl" style="color: #ffb266;"></i> | `#ffb266` | (255, 178, 102)   |
+    | 3     | <i class="fa-solid fa-square fa-2xl" style="color: #e6e600;"></i> | `#e6e600` | (230, 230, 0)     |
+    | 4     | <i class="fa-solid fa-square fa-2xl" style="color: #ff99ff;"></i> | `#ff99ff` | (255, 153, 255)   |
+    | 5     | <i class="fa-solid fa-square fa-2xl" style="color: #99ccff;"></i> | `#99ccff` | (153, 204, 255)   |
+    | 6     | <i class="fa-solid fa-square fa-2xl" style="color: #ff66ff;"></i> | `#ff66ff` | (255, 102, 255)   |
+    | 7     | <i class="fa-solid fa-square fa-2xl" style="color: #ff33ff;"></i> | `#ff33ff` | (255, 51, 255)    |
+    | 8     | <i class="fa-solid fa-square fa-2xl" style="color: #66b2ff;"></i> | `#66b2ff` | (102, 178, 255)   |
+    | 9     | <i class="fa-solid fa-square fa-2xl" style="color: #3399ff;"></i> | `#3399ff` | (51, 153, 255)    |
+    | 10    | <i class="fa-solid fa-square fa-2xl" style="color: #ff9999;"></i> | `#ff9999` | (255, 153, 153)   |
+    | 11    | <i class="fa-solid fa-square fa-2xl" style="color: #ff6666;"></i> | `#ff6666` | (255, 102, 102)   |
+    | 12    | <i class="fa-solid fa-square fa-2xl" style="color: #ff3333;"></i> | `#ff3333` | (255, 51, 51)     |
+    | 13    | <i class="fa-solid fa-square fa-2xl" style="color: #99ff99;"></i> | `#99ff99` | (153, 255, 153)   |
+    | 14    | <i class="fa-solid fa-square fa-2xl" style="color: #66ff66;"></i> | `#66ff66` | (102, 255, 102)   |
+    | 15    | <i class="fa-solid fa-square fa-2xl" style="color: #33ff33;"></i> | `#33ff33` | (51, 255, 51)     |
+    | 16    | <i class="fa-solid fa-square fa-2xl" style="color: #00ff00;"></i> | `#00ff00` | (0, 255, 0)       |
+    | 17    | <i class="fa-solid fa-square fa-2xl" style="color: #0000ff;"></i> | `#0000ff` | (0, 0, 255)       |
+    | 18    | <i class="fa-solid fa-square fa-2xl" style="color: #ff0000;"></i> | `#ff0000` | (255, 0, 0)       |
+    | 19    | <i class="fa-solid fa-square fa-2xl" style="color: #ffffff;"></i> | `#ffffff` | (255, 255, 255)   |
+
+    !!! note "Ultralytics Brand Colors"
+
+        For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand). Please use the official Ultralytics colors for all marketing materials.
+    """
+
+    def __init__(self):
+        """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
+        hexs = (
+            "042AFF",
+            "0BDBEB",
+            "F3F3F3",
+            "00DFB7",
+            "111F68",
+            "FF6FDD",
+            "FF444F",
+            "CCED00",
+            "00F344",
+            "BD00FF",
+            "00B4FF",
+            "DD00BA",
+            "00FFFF",
+            "26C000",
+            "01FFB3",
+            "7D24FF",
+            "7B0068",
+            "FF1B6C",
+            "FC6D2F",
+            "A2FF0B",
+        )
+        self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
+        self.n = len(self.palette)
+        self.pose_palette = np.array(
+            [
+                [255, 128, 0],
+                [255, 153, 51],
+                [255, 178, 102],
+                [230, 230, 0],
+                [255, 153, 255],
+                [153, 204, 255],
+                [255, 102, 255],
+                [255, 51, 255],
+                [102, 178, 255],
+                [51, 153, 255],
+                [255, 153, 153],
+                [255, 102, 102],
+                [255, 51, 51],
+                [153, 255, 153],
+                [102, 255, 102],
+                [51, 255, 51],
+                [0, 255, 0],
+                [0, 0, 255],
+                [255, 0, 0],
+                [255, 255, 255],
+            ],
+            dtype=np.uint8,
+        )
+
+    def __call__(self, i, bgr=False):
+        """Converts hex color codes to RGB values."""
+        c = self.palette[int(i) % self.n]
+        return (c[2], c[1], c[0]) if bgr else c
+
+    @staticmethod
+    def hex2rgb(h):
+        """Converts hex color codes to RGB values (i.e. default PIL order)."""
+        return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
+
+
+colors = Colors()  # create instance for 'from utils.plots import colors'
+
+
+class Annotator:
+    """
+    Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
+
+    Attributes:
+        im (Image.Image or numpy array): The image to annotate.
+        pil (bool): Whether to use PIL or cv2 for drawing annotations.
+        font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.
+        lw (float): Line width for drawing.
+        skeleton (List[List[int]]): Skeleton structure for keypoints.
+        limb_color (List[int]): Color palette for limbs.
+        kpt_color (List[int]): Color palette for keypoints.
+    """
+
+    def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
+        """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
+        non_ascii = not is_ascii(example)  # non-latin labels, i.e. asian, arabic, cyrillic
+        input_is_pil = isinstance(im, Image.Image)
+        self.pil = pil or non_ascii or input_is_pil
+        self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
+        if self.pil:  # use PIL
+            self.im = im if input_is_pil else Image.fromarray(im)
+            self.draw = ImageDraw.Draw(self.im)
+            try:
+                font = check_font("Arial.Unicode.ttf" if non_ascii else font)
+                size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
+                self.font = ImageFont.truetype(str(font), size)
+            except Exception:
+                self.font = ImageFont.load_default()
+            # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
+            if check_version(pil_version, "9.2.0"):
+                self.font.getsize = lambda x: self.font.getbbox(x)[2:4]  # text width, height
+        else:  # use cv2
+            assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."
+            self.im = im if im.flags.writeable else im.copy()
+            self.tf = max(self.lw - 1, 1)  # font thickness
+            self.sf = self.lw / 3  # font scale
+        # Pose
+        self.skeleton = [
+            [16, 14],
+            [14, 12],
+            [17, 15],
+            [15, 13],
+            [12, 13],
+            [6, 12],
+            [7, 13],
+            [6, 7],
+            [6, 8],
+            [7, 9],
+            [8, 10],
+            [9, 11],
+            [2, 3],
+            [1, 2],
+            [1, 3],
+            [2, 4],
+            [3, 5],
+            [4, 6],
+            [5, 7],
+        ]
+
+        self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
+        self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
+        self.dark_colors = {
+            (235, 219, 11),
+            (243, 243, 243),
+            (183, 223, 0),
+            (221, 111, 255),
+            (0, 237, 204),
+            (68, 243, 0),
+            (255, 255, 0),
+            (179, 255, 1),
+            (11, 255, 162),
+        }
+        self.light_colors = {
+            (255, 42, 4),
+            (79, 68, 255),
+            (255, 0, 189),
+            (255, 180, 0),
+            (186, 0, 221),
+            (0, 192, 38),
+            (255, 36, 125),
+            (104, 0, 123),
+            (108, 27, 255),
+            (47, 109, 252),
+            (104, 31, 17),
+        }
+
+    def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
+        """
+        Assign text color based on background color.
+
+        Args:
+            color (tuple, optional): The background color of the rectangle for text (B, G, R).
+            txt_color (tuple, optional): The color of the text (R, G, B).
+
+        Returns:
+            txt_color (tuple): Text color for label
+        """
+        if color in self.dark_colors:
+            return 104, 31, 17
+        elif color in self.light_colors:
+            return 255, 255, 255
+        else:
+            return txt_color
+
+    def circle_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=2):
+        """
+        Draws a label with a background circle centered within a given bounding box.
+
+        Args:
+            box (tuple): The bounding box coordinates (x1, y1, x2, y2).
+            label (str): The text label to be displayed.
+            color (tuple, optional): The background color of the rectangle (B, G, R).
+            txt_color (tuple, optional): The color of the text (R, G, B).
+            margin (int, optional): The margin between the text and the rectangle border.
+        """
+        # If label have more than 3 characters, skip other characters, due to circle size
+        if len(label) > 3:
+            print(
+                f"Length of label is {len(label)}, initial 3 label characters will be considered for circle annotation!"
+            )
+            label = label[:3]
+
+        # Calculate the center of the box
+        x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
+        # Get the text size
+        text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]
+        # Calculate the required radius to fit the text with the margin
+        required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin
+        # Draw the circle with the required radius
+        cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)
+        # Calculate the position for the text
+        text_x = x_center - text_size[0] // 2
+        text_y = y_center + text_size[1] // 2
+        # Draw the text
+        cv2.putText(
+            self.im,
+            str(label),
+            (text_x, text_y),
+            cv2.FONT_HERSHEY_SIMPLEX,
+            self.sf - 0.15,
+            self.get_txt_color(color, txt_color),
+            self.tf,
+            lineType=cv2.LINE_AA,
+        )
+
+    def text_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=5):
+        """
+        Draws a label with a background rectangle centered within a given bounding box.
+
+        Args:
+            box (tuple): The bounding box coordinates (x1, y1, x2, y2).
+            label (str): The text label to be displayed.
+            color (tuple, optional): The background color of the rectangle (B, G, R).
+            txt_color (tuple, optional): The color of the text (R, G, B).
+            margin (int, optional): The margin between the text and the rectangle border.
+        """
+        # Calculate the center of the bounding box
+        x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
+        # Get the size of the text
+        text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]
+        # Calculate the top-left corner of the text (to center it)
+        text_x = x_center - text_size[0] // 2
+        text_y = y_center + text_size[1] // 2
+        # Calculate the coordinates of the background rectangle
+        rect_x1 = text_x - margin
+        rect_y1 = text_y - text_size[1] - margin
+        rect_x2 = text_x + text_size[0] + margin
+        rect_y2 = text_y + margin
+        # Draw the background rectangle
+        cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)
+        # Draw the text on top of the rectangle
+        cv2.putText(
+            self.im,
+            label,
+            (text_x, text_y),
+            cv2.FONT_HERSHEY_SIMPLEX,
+            self.sf - 0.1,
+            self.get_txt_color(color, txt_color),
+            self.tf,
+            lineType=cv2.LINE_AA,
+        )
+
+    def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
+        """
+        Draws a bounding box to image with label.
+
+        Args:
+            box (tuple): The bounding box coordinates (x1, y1, x2, y2).
+            label (str): The text label to be displayed.
+            color (tuple, optional): The background color of the rectangle (B, G, R).
+            txt_color (tuple, optional): The color of the text (R, G, B).
+            rotated (bool, optional): Variable used to check if task is OBB
+        """
+        txt_color = self.get_txt_color(color, txt_color)
+        if isinstance(box, torch.Tensor):
+            box = box.tolist()
+        if self.pil or not is_ascii(label):
+            if rotated:
+                p1 = box[0]
+                self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color)  # PIL requires tuple box
+            else:
+                p1 = (box[0], box[1])
+                self.draw.rectangle(box, width=self.lw, outline=color)  # box
+            if label:
+                w, h = self.font.getsize(label)  # text width, height
+                outside = p1[1] >= h  # label fits outside box
+                if p1[0] > self.im.size[0] - w:  # size is (w, h), check if label extend beyond right side of image
+                    p1 = self.im.size[0] - w, p1[1]
+                self.draw.rectangle(
+                    (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),
+                    fill=color,
+                )
+                # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls')  # for PIL>8.0
+                self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
+        else:  # cv2
+            if rotated:
+                p1 = [int(b) for b in box[0]]
+                cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw)  # cv2 requires nparray box
+            else:
+                p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
+                cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
+            if label:
+                w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]  # text width, height
+                h += 3  # add pixels to pad text
+                outside = p1[1] >= h  # label fits outside box
+                if p1[0] > self.im.shape[1] - w:  # shape is (h, w), check if label extend beyond right side of image
+                    p1 = self.im.shape[1] - w, p1[1]
+                p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h
+                cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA)  # filled
+                cv2.putText(
+                    self.im,
+                    label,
+                    (p1[0], p1[1] - 2 if outside else p1[1] + h - 1),
+                    0,
+                    self.sf,
+                    txt_color,
+                    thickness=self.tf,
+                    lineType=cv2.LINE_AA,
+                )
+
+    def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
+        """
+        Plot masks on image.
+
+        Args:
+            masks (tensor): Predicted masks on cuda, shape: [n, h, w]
+            colors (List[List[Int]]): Colors for predicted masks, [[r, g, b] * n]
+            im_gpu (tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]
+            alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque
+            retina_masks (bool): Whether to use high resolution masks or not. Defaults to False.
+        """
+        if self.pil:
+            # Convert to numpy first
+            self.im = np.asarray(self.im).copy()
+        if len(masks) == 0:
+            self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
+        if im_gpu.device != masks.device:
+            im_gpu = im_gpu.to(masks.device)
+        colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0  # shape(n,3)
+        colors = colors[:, None, None]  # shape(n,1,1,3)
+        masks = masks.unsqueeze(3)  # shape(n,h,w,1)
+        masks_color = masks * (colors * alpha)  # shape(n,h,w,3)
+
+        inv_alpha_masks = (1 - masks * alpha).cumprod(0)  # shape(n,h,w,1)
+        mcs = masks_color.max(dim=0).values  # shape(n,h,w,3)
+
+        im_gpu = im_gpu.flip(dims=[0])  # flip channel
+        im_gpu = im_gpu.permute(1, 2, 0).contiguous()  # shape(h,w,3)
+        im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
+        im_mask = im_gpu * 255
+        im_mask_np = im_mask.byte().cpu().numpy()
+        self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
+        if self.pil:
+            # Convert im back to PIL and update draw
+            self.fromarray(self.im)
+
+    def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None):
+        """
+        Plot keypoints on the image.
+
+        Args:
+            kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
+            shape (tuple, optional): Image shape (h, w). Defaults to (640, 640).
+            radius (int, optional): Keypoint radius. Defaults to 5.
+            kpt_line (bool, optional): Draw lines between keypoints. Defaults to True.
+            conf_thres (float, optional): Confidence threshold. Defaults to 0.25.
+            kpt_color (tuple, optional): Keypoint color (B, G, R). Defaults to None.
+
+        Note:
+            - `kpt_line=True` currently only supports human pose plotting.
+            - Modifies self.im in-place.
+            - If self.pil is True, converts image to numpy array and back to PIL.
+        """
+        radius = radius if radius is not None else self.lw
+        if self.pil:
+            # Convert to numpy first
+            self.im = np.asarray(self.im).copy()
+        nkpt, ndim = kpts.shape
+        is_pose = nkpt == 17 and ndim in {2, 3}
+        kpt_line &= is_pose  # `kpt_line=True` for now only supports human pose plotting
+        for i, k in enumerate(kpts):
+            color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i))
+            x_coord, y_coord = k[0], k[1]
+            if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
+                if len(k) == 3:
+                    conf = k[2]
+                    if conf < conf_thres:
+                        continue
+                cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)
+
+        if kpt_line:
+            ndim = kpts.shape[-1]
+            for i, sk in enumerate(self.skeleton):
+                pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))
+                pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))
+                if ndim == 3:
+                    conf1 = kpts[(sk[0] - 1), 2]
+                    conf2 = kpts[(sk[1] - 1), 2]
+                    if conf1 < conf_thres or conf2 < conf_thres:
+                        continue
+                if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
+                    continue
+                if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
+                    continue
+                cv2.line(
+                    self.im,
+                    pos1,
+                    pos2,
+                    kpt_color or self.limb_color[i].tolist(),
+                    thickness=int(np.ceil(self.lw / 2)),
+                    lineType=cv2.LINE_AA,
+                )
+        if self.pil:
+            # Convert im back to PIL and update draw
+            self.fromarray(self.im)
+
+    def rectangle(self, xy, fill=None, outline=None, width=1):
+        """Add rectangle to image (PIL-only)."""
+        self.draw.rectangle(xy, fill, outline, width)
+
+    def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):
+        """Adds text to an image using PIL or cv2."""
+        if anchor == "bottom":  # start y from font bottom
+            w, h = self.font.getsize(text)  # text width, height
+            xy[1] += 1 - h
+        if self.pil:
+            if box_style:
+                w, h = self.font.getsize(text)
+                self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
+                # Using `txt_color` for background and draw fg with white color
+                txt_color = (255, 255, 255)
+            if "\n" in text:
+                lines = text.split("\n")
+                _, h = self.font.getsize(text)
+                for line in lines:
+                    self.draw.text(xy, line, fill=txt_color, font=self.font)
+                    xy[1] += h
+            else:
+                self.draw.text(xy, text, fill=txt_color, font=self.font)
+        else:
+            if box_style:
+                w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]  # text width, height
+                h += 3  # add pixels to pad text
+                outside = xy[1] >= h  # label fits outside box
+                p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h
+                cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA)  # filled
+                # Using `txt_color` for background and draw fg with white color
+                txt_color = (255, 255, 255)
+            cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
+
+    def fromarray(self, im):
+        """Update self.im from a numpy array."""
+        self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
+        self.draw = ImageDraw.Draw(self.im)
+
+    def result(self):
+        """Return annotated image as array."""
+        return np.asarray(self.im)
+
+    def show(self, title=None):
+        """Show the annotated image."""
+        im = Image.fromarray(np.asarray(self.im)[..., ::-1])  # Convert numpy array to PIL Image with RGB to BGR
+        if IS_COLAB or IS_KAGGLE:  # can not use IS_JUPYTER as will run for all ipython environments
+            try:
+                display(im)  # noqa - display() function only available in ipython environments
+            except ImportError as e:
+                LOGGER.warning(f"Unable to display image in Jupyter notebooks: {e}")
+        else:
+            im.show(title=title)
+
+    def save(self, filename="image.jpg"):
+        """Save the annotated image to 'filename'."""
+        cv2.imwrite(filename, np.asarray(self.im))
+
+    @staticmethod
+    def get_bbox_dimension(bbox=None):
+        """
+        Calculate the area of a bounding box.
+
+        Args:
+            bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
+
+        Returns:
+            width (float): Width of the bounding box.
+            height (float): Height of the bounding box.
+            area (float): Area enclosed by the bounding box.
+        """
+        x_min, y_min, x_max, y_max = bbox
+        width = x_max - x_min
+        height = y_max - y_min
+        return width, height, width * height
+
+    def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
+        """
+        Draw region line.
+
+        Args:
+            reg_pts (list): Region Points (for line 2 points, for region 4 points)
+            color (tuple): Region Color value
+            thickness (int): Region area thickness value
+        """
+        cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
+
+        # Draw small circles at the corner points
+        for point in reg_pts:
+            cv2.circle(self.im, (point[0], point[1]), thickness * 2, color, -1)  # -1 fills the circle
+
+    def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
+        """
+        Draw centroid point and track trails.
+
+        Args:
+            track (list): object tracking points for trails display
+            color (tuple): tracks line color
+            track_thickness (int): track line thickness value
+        """
+        points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
+        cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
+        cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
+
+    def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), txt_color=(0, 0, 0)):
+        """
+        Displays queue counts on an image centered at the points with customizable font size and colors.
+
+        Args:
+            label (str): Queue counts label.
+            points (tuple): Region points for center point calculation to display text.
+            region_color (tuple): RGB queue region color.
+            txt_color (tuple): RGB text display color.
+        """
+        x_values = [point[0] for point in points]
+        y_values = [point[1] for point in points]
+        center_x = sum(x_values) // len(points)
+        center_y = sum(y_values) // len(points)
+
+        text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]
+        text_width = text_size[0]
+        text_height = text_size[1]
+
+        rect_width = text_width + 20
+        rect_height = text_height + 20
+        rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2)
+        rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2)
+        cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1)
+
+        text_x = center_x - text_width // 2
+        text_y = center_y + text_height // 2
+
+        # Draw text
+        cv2.putText(
+            self.im,
+            label,
+            (text_x, text_y),
+            0,
+            fontScale=self.sf,
+            color=txt_color,
+            thickness=self.tf,
+            lineType=cv2.LINE_AA,
+        )
+
+    def display_objects_labels(self, im0, text, txt_color, bg_color, x_center, y_center, margin):
+        """
+        Display the bounding boxes labels in parking management app.
+
+        Args:
+            im0 (ndarray): Inference image.
+            text (str): Object/class name.
+            txt_color (tuple): Display color for text foreground.
+            bg_color (tuple): Display color for text background.
+            x_center (float): The x position center point for bounding box.
+            y_center (float): The y position center point for bounding box.
+            margin (int): The gap between text and rectangle for better display.
+        """
+        text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
+        text_x = x_center - text_size[0] // 2
+        text_y = y_center + text_size[1] // 2
+
+        rect_x1 = text_x - margin
+        rect_y1 = text_y - text_size[1] - margin
+        rect_x2 = text_x + text_size[0] + margin
+        rect_y2 = text_y + margin
+        cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
+        cv2.putText(im0, text, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
+
+    def display_analytics(self, im0, text, txt_color, bg_color, margin):
+        """
+        Display the overall statistics for parking lots.
+
+        Args:
+            im0 (ndarray): Inference image.
+            text (dict): Labels dictionary.
+            txt_color (tuple): Display color for text foreground.
+            bg_color (tuple): Display color for text background.
+            margin (int): Gap between text and rectangle for better display.
+        """
+        horizontal_gap = int(im0.shape[1] * 0.02)
+        vertical_gap = int(im0.shape[0] * 0.01)
+        text_y_offset = 0
+        for label, value in text.items():
+            txt = f"{label}: {value}"
+            text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0]
+            if text_size[0] < 5 or text_size[1] < 5:
+                text_size = (5, 5)
+            text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap
+            text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap
+            rect_x1 = text_x - margin * 2
+            rect_y1 = text_y - text_size[1] - margin * 2
+            rect_x2 = text_x + text_size[0] + margin * 2
+            rect_y2 = text_y + margin * 2
+            cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
+            cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
+            text_y_offset = rect_y2
+
+    @staticmethod
+    def estimate_pose_angle(a, b, c):
+        """
+        Calculate the pose angle for object.
+
+        Args:
+            a (float) : The value of pose point a
+            b (float): The value of pose point b
+            c (float): The value o pose point c
+
+        Returns:
+            angle (degree): Degree value of angle between three points
+        """
+        a, b, c = np.array(a), np.array(b), np.array(c)
+        radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
+        angle = np.abs(radians * 180.0 / np.pi)
+        if angle > 180.0:
+            angle = 360 - angle
+        return angle
+
+    def draw_specific_points(self, keypoints, indices=None, radius=2, conf_thres=0.25):
+        """
+        Draw specific keypoints for gym steps counting.
+
+        Args:
+            keypoints (list): Keypoints data to be plotted.
+            indices (list, optional): Keypoint indices to be plotted. Defaults to [2, 5, 7].
+            radius (int, optional): Keypoint radius. Defaults to 2.
+            conf_thres (float, optional): Confidence threshold for keypoints. Defaults to 0.25.
+
+        Returns:
+            (numpy.ndarray): Image with drawn keypoints.
+
+        Note:
+            Keypoint format: [x, y] or [x, y, confidence].
+            Modifies self.im in-place.
+        """
+        indices = indices or [2, 5, 7]
+        points = [(int(k[0]), int(k[1])) for i, k in enumerate(keypoints) if i in indices and k[2] >= conf_thres]
+
+        # Draw lines between consecutive points
+        for start, end in zip(points[:-1], points[1:]):
+            cv2.line(self.im, start, end, (0, 255, 0), 2, lineType=cv2.LINE_AA)
+
+        # Draw circles for keypoints
+        for pt in points:
+            cv2.circle(self.im, pt, radius, (0, 0, 255), -1, lineType=cv2.LINE_AA)
+
+        return self.im
+
+    def plot_workout_information(self, display_text, position, color=(104, 31, 17), txt_color=(255, 255, 255)):
+        """
+        Draw text with a background on the image.
+
+        Args:
+            display_text (str): The text to be displayed.
+            position (tuple): Coordinates (x, y) on the image where the text will be placed.
+            color (tuple, optional): Text background color
+            txt_color (tuple, optional): Text foreground color
+        """
+        (text_width, text_height), _ = cv2.getTextSize(display_text, 0, self.sf, self.tf)
+
+        # Draw background rectangle
+        cv2.rectangle(
+            self.im,
+            (position[0], position[1] - text_height - 5),
+            (position[0] + text_width + 10, position[1] - text_height - 5 + text_height + 10 + self.tf),
+            color,
+            -1,
+        )
+        # Draw text
+        cv2.putText(self.im, display_text, position, 0, self.sf, txt_color, self.tf)
+
+        return text_height
+
+    def plot_angle_and_count_and_stage(
+        self, angle_text, count_text, stage_text, center_kpt, color=(104, 31, 17), txt_color=(255, 255, 255)
+    ):
+        """
+        Plot the pose angle, count value, and step stage.
+
+        Args:
+            angle_text (str): Angle value for workout monitoring
+            count_text (str): Counts value for workout monitoring
+            stage_text (str): Stage decision for workout monitoring
+            center_kpt (list): Centroid pose index for workout monitoring
+            color (tuple, optional): Text background color
+            txt_color (tuple, optional): Text foreground color
+        """
+        # Format text
+        angle_text, count_text, stage_text = f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}"
+
+        # Draw angle, count and stage text
+        angle_height = self.plot_workout_information(
+            angle_text, (int(center_kpt[0]), int(center_kpt[1])), color, txt_color
+        )
+        count_height = self.plot_workout_information(
+            count_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + 20), color, txt_color
+        )
+        self.plot_workout_information(
+            stage_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + count_height + 40), color, txt_color
+        )
+
+    def seg_bbox(self, mask, mask_color=(255, 0, 255), label=None, txt_color=(255, 255, 255)):
+        """
+        Function for drawing segmented object in bounding box shape.
+
+        Args:
+            mask (np.ndarray): A 2D array of shape (N, 2) containing the contour points of the segmented object.
+            mask_color (tuple): RGB color for the contour and label background.
+            label (str, optional): Text label for the object. If None, no label is drawn.
+            txt_color (tuple): RGB color for the label text.
+        """
+        if mask.size == 0:  # no masks to plot
+            return
+
+        cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
+        text_size, _ = cv2.getTextSize(label, 0, self.sf, self.tf)
+
+        if label:
+            cv2.rectangle(
+                self.im,
+                (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
+                (int(mask[0][0]) + text_size[0] // 2 + 10, int(mask[0][1] + 10)),
+                mask_color,
+                -1,
+            )
+            cv2.putText(
+                self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, txt_color, self.tf
+            )
+
+    def sweep_annotator(self, line_x=0, line_y=0, label=None, color=(221, 0, 186), txt_color=(255, 255, 255)):
+        """
+        Function for drawing a sweep annotation line and an optional label.
+
+        Args:
+            line_x (int): The x-coordinate of the sweep line.
+            line_y (int): The y-coordinate limit of the sweep line.
+            label (str, optional): Text label to be drawn in center of sweep line. If None, no label is drawn.
+            color (tuple): RGB color for the line and label background.
+            txt_color (tuple): RGB color for the label text.
+        """
+        # Draw the sweep line
+        cv2.line(self.im, (line_x, 0), (line_x, line_y), color, self.tf * 2)
+
+        # Draw label, if provided
+        if label:
+            (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf, self.tf)
+            cv2.rectangle(
+                self.im,
+                (line_x - text_width // 2 - 10, line_y // 2 - text_height // 2 - 10),
+                (line_x + text_width // 2 + 10, line_y // 2 + text_height // 2 + 10),
+                color,
+                -1,
+            )
+            cv2.putText(
+                self.im,
+                label,
+                (line_x - text_width // 2, line_y // 2 + text_height // 2),
+                cv2.FONT_HERSHEY_SIMPLEX,
+                self.sf,
+                txt_color,
+                self.tf,
+            )
+
+    def plot_distance_and_line(
+        self, pixels_distance, centroids, line_color=(104, 31, 17), centroid_color=(255, 0, 255)
+    ):
+        """
+        Plot the distance and line on frame.
+
+        Args:
+            pixels_distance (float): Pixels distance between two bbox centroids.
+            centroids (list): Bounding box centroids data.
+            line_color (tuple, optional): Distance line color.
+            centroid_color (tuple, optional): Bounding box centroid color.
+        """
+        # Get the text size
+        text = f"Pixels Distance: {pixels_distance:.2f}"
+        (text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf)
+
+        # Define corners with 10-pixel margin and draw rectangle
+        cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1)
+
+        # Calculate the position for the text with a 10-pixel margin and draw text
+        text_position = (25, 25 + text_height_m + 10)
+        cv2.putText(
+            self.im,
+            text,
+            text_position,
+            0,
+            self.sf,
+            (255, 255, 255),
+            self.tf,
+            cv2.LINE_AA,
+        )
+
+        cv2.line(self.im, centroids[0], centroids[1], line_color, 3)
+        cv2.circle(self.im, centroids[0], 6, centroid_color, -1)
+        cv2.circle(self.im, centroids[1], 6, centroid_color, -1)
+
+    def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255)):
+        """
+        Function for pinpoint human-vision eye mapping and plotting.
+
+        Args:
+            box (list): Bounding box coordinates
+            center_point (tuple): center point for vision eye view
+            color (tuple): object centroid and line color value
+            pin_color (tuple): visioneye point color value
+        """
+        center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
+        cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)
+        cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
+        cv2.line(self.im, center_point, center_bbox, color, self.tf)
+
+
+@TryExcept()  # known issue https://github.com/ultralytics/yolov5/issues/5395
+@plt_settings()
+def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
+    """Plot training labels including class histograms and box statistics."""
+    import pandas  # scope for faster 'import ultralytics'
+    import seaborn  # scope for faster 'import ultralytics'
+
+    # Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
+    warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
+    warnings.filterwarnings("ignore", category=FutureWarning)
+
+    # Plot dataset labels
+    LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
+    nc = int(cls.max() + 1)  # number of classes
+    boxes = boxes[:1000000]  # limit to 1M boxes
+    x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
+
+    # Seaborn correlogram
+    seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
+    plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
+    plt.close()
+
+    # Matplotlib labels
+    ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
+    y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
+    for i in range(nc):
+        y[2].patches[i].set_color([x / 255 for x in colors(i)])
+    ax[0].set_ylabel("instances")
+    if 0 < len(names) < 30:
+        ax[0].set_xticks(range(len(names)))
+        ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
+    else:
+        ax[0].set_xlabel("classes")
+    seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
+    seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
+
+    # Rectangles
+    boxes[:, 0:2] = 0.5  # center
+    boxes = ops.xywh2xyxy(boxes) * 1000
+    img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
+    for cls, box in zip(cls[:500], boxes[:500]):
+        ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls))  # plot
+    ax[1].imshow(img)
+    ax[1].axis("off")
+
+    for a in [0, 1, 2, 3]:
+        for s in ["top", "right", "left", "bottom"]:
+            ax[a].spines[s].set_visible(False)
+
+    fname = save_dir / "labels.jpg"
+    plt.savefig(fname, dpi=200)
+    plt.close()
+    if on_plot:
+        on_plot(fname)
+
+
+def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
+    """
+    Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
+
+    This function takes a bounding box and an image, and then saves a cropped portion of the image according
+    to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
+    adjustments to the bounding box.
+
+    Args:
+        xyxy (torch.Tensor or list): A tensor or list representing the bounding box in xyxy format.
+        im (numpy.ndarray): The input image.
+        file (Path, optional): The path where the cropped image will be saved. Defaults to 'im.jpg'.
+        gain (float, optional): A multiplicative factor to increase the size of the bounding box. Defaults to 1.02.
+        pad (int, optional): The number of pixels to add to the width and height of the bounding box. Defaults to 10.
+        square (bool, optional): If True, the bounding box will be transformed into a square. Defaults to False.
+        BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB. Defaults to False.
+        save (bool, optional): If True, the cropped image will be saved to disk. Defaults to True.
+
+    Returns:
+        (numpy.ndarray): The cropped image.
+
+    Example:
+        ```python
+        from ultralytics.utils.plotting import save_one_box
+
+        xyxy = [50, 50, 150, 150]
+        im = cv2.imread("image.jpg")
+        cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True)
+        ```
+    """
+    if not isinstance(xyxy, torch.Tensor):  # may be list
+        xyxy = torch.stack(xyxy)
+    b = ops.xyxy2xywh(xyxy.view(-1, 4))  # boxes
+    if square:
+        b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1)  # attempt rectangle to square
+    b[:, 2:] = b[:, 2:] * gain + pad  # box wh * gain + pad
+    xyxy = ops.xywh2xyxy(b).long()
+    xyxy = ops.clip_boxes(xyxy, im.shape)
+    crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
+    if save:
+        file.parent.mkdir(parents=True, exist_ok=True)  # make directory
+        f = str(increment_path(file).with_suffix(".jpg"))
+        # cv2.imwrite(f, crop)  # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
+        Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0)  # save RGB
+    return crop
+
+
+@threaded
+def plot_images(
+    images: Union[torch.Tensor, np.ndarray],
+    batch_idx: Union[torch.Tensor, np.ndarray],
+    cls: Union[torch.Tensor, np.ndarray],
+    bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
+    confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
+    masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
+    kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
+    paths: Optional[List[str]] = None,
+    fname: str = "images.jpg",
+    names: Optional[Dict[int, str]] = None,
+    on_plot: Optional[Callable] = None,
+    max_size: int = 1920,
+    max_subplots: int = 16,
+    save: bool = True,
+    conf_thres: float = 0.25,
+) -> Optional[np.ndarray]:
+    """
+    Plot image grid with labels, bounding boxes, masks, and keypoints.
+
+    Args:
+        images: Batch of images to plot. Shape: (batch_size, channels, height, width).
+        batch_idx: Batch indices for each detection. Shape: (num_detections,).
+        cls: Class labels for each detection. Shape: (num_detections,).
+        bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.
+        confs: Confidence scores for each detection. Shape: (num_detections,).
+        masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).
+        kpts: Keypoints for each detection. Shape: (num_detections, 51).
+        paths: List of file paths for each image in the batch.
+        fname: Output filename for the plotted image grid.
+        names: Dictionary mapping class indices to class names.
+        on_plot: Optional callback function to be called after saving the plot.
+        max_size: Maximum size of the output image grid.
+        max_subplots: Maximum number of subplots in the image grid.
+        save: Whether to save the plotted image grid to a file.
+        conf_thres: Confidence threshold for displaying detections.
+
+    Returns:
+        np.ndarray: Plotted image grid as a numpy array if save is False, None otherwise.
+
+    Note:
+        This function supports both tensor and numpy array inputs. It will automatically
+        convert tensor inputs to numpy arrays for processing.
+    """
+    if isinstance(images, torch.Tensor):
+        images = images.cpu().float().numpy()
+    if isinstance(cls, torch.Tensor):
+        cls = cls.cpu().numpy()
+    if isinstance(bboxes, torch.Tensor):
+        bboxes = bboxes.cpu().numpy()
+    if isinstance(masks, torch.Tensor):
+        masks = masks.cpu().numpy().astype(int)
+    if isinstance(kpts, torch.Tensor):
+        kpts = kpts.cpu().numpy()
+    if isinstance(batch_idx, torch.Tensor):
+        batch_idx = batch_idx.cpu().numpy()
+
+    bs, _, h, w = images.shape  # batch size, _, height, width
+    bs = min(bs, max_subplots)  # limit plot images
+    ns = np.ceil(bs**0.5)  # number of subplots (square)
+    if np.max(images[0]) <= 1:
+        images *= 255  # de-normalise (optional)
+
+    # Build Image
+    mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # init
+    for i in range(bs):
+        x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
+        mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
+
+    # Resize (optional)
+    scale = max_size / ns / max(h, w)
+    if scale < 1:
+        h = math.ceil(scale * h)
+        w = math.ceil(scale * w)
+        mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
+
+    # Annotate
+    fs = int((h + w) * ns * 0.01)  # font size
+    annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
+    for i in range(bs):
+        x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
+        annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2)  # borders
+        if paths:
+            annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220))  # filenames
+        if len(cls) > 0:
+            idx = batch_idx == i
+            classes = cls[idx].astype("int")
+            labels = confs is None
+
+            if len(bboxes):
+                boxes = bboxes[idx]
+                conf = confs[idx] if confs is not None else None  # check for confidence presence (label vs pred)
+                if len(boxes):
+                    if boxes[:, :4].max() <= 1.1:  # if normalized with tolerance 0.1
+                        boxes[..., [0, 2]] *= w  # scale to pixels
+                        boxes[..., [1, 3]] *= h
+                    elif scale < 1:  # absolute coords need scale if image scales
+                        boxes[..., :4] *= scale
+                boxes[..., 0] += x
+                boxes[..., 1] += y
+                is_obb = boxes.shape[-1] == 5  # xywhr
+                boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
+                for j, box in enumerate(boxes.astype(np.int64).tolist()):
+                    c = classes[j]
+                    color = colors(c)
+                    c = names.get(c, c) if names else c
+                    if labels or conf[j] > conf_thres:
+                        label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
+                        annotator.box_label(box, label, color=color, rotated=is_obb)
+
+            elif len(classes):
+                for c in classes:
+                    color = colors(c)
+                    c = names.get(c, c) if names else c
+                    annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
+
+            # Plot keypoints
+            if len(kpts):
+                kpts_ = kpts[idx].copy()
+                if len(kpts_):
+                    if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01:  # if normalized with tolerance .01
+                        kpts_[..., 0] *= w  # scale to pixels
+                        kpts_[..., 1] *= h
+                    elif scale < 1:  # absolute coords need scale if image scales
+                        kpts_ *= scale
+                kpts_[..., 0] += x
+                kpts_[..., 1] += y
+                for j in range(len(kpts_)):
+                    if labels or conf[j] > conf_thres:
+                        annotator.kpts(kpts_[j], conf_thres=conf_thres)
+
+            # Plot masks
+            if len(masks):
+                if idx.shape[0] == masks.shape[0]:  # overlap_masks=False
+                    image_masks = masks[idx]
+                else:  # overlap_masks=True
+                    image_masks = masks[[i]]  # (1, 640, 640)
+                    nl = idx.sum()
+                    index = np.arange(nl).reshape((nl, 1, 1)) + 1
+                    image_masks = np.repeat(image_masks, nl, axis=0)
+                    image_masks = np.where(image_masks == index, 1.0, 0.0)
+
+                im = np.asarray(annotator.im).copy()
+                for j in range(len(image_masks)):
+                    if labels or conf[j] > conf_thres:
+                        color = colors(classes[j])
+                        mh, mw = image_masks[j].shape
+                        if mh != h or mw != w:
+                            mask = image_masks[j].astype(np.uint8)
+                            mask = cv2.resize(mask, (w, h))
+                            mask = mask.astype(bool)
+                        else:
+                            mask = image_masks[j].astype(bool)
+                        try:
+                            im[y : y + h, x : x + w, :][mask] = (
+                                im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
+                            )
+                        except Exception:
+                            pass
+                annotator.fromarray(im)
+    if not save:
+        return np.asarray(annotator.im)
+    annotator.im.save(fname)  # save
+    if on_plot:
+        on_plot(fname)
+
+
+@plt_settings()
+def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
+    """
+    Plot training results from a results CSV file. The function supports various types of data including segmentation,
+    pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
+
+    Args:
+        file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.
+        dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.
+        segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.
+        pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.
+        classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.
+        on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
+            Defaults to None.
+
+    Example:
+        ```python
+        from ultralytics.utils.plotting import plot_results
+
+        plot_results("path/to/results.csv", segment=True)
+        ```
+    """
+    import pandas as pd  # scope for faster 'import ultralytics'
+    from scipy.ndimage import gaussian_filter1d
+
+    save_dir = Path(file).parent if file else Path(dir)
+    if classify:
+        fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
+        index = [2, 5, 3, 4]
+    elif segment:
+        fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
+        index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
+    elif pose:
+        fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
+        index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
+    else:
+        fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
+        index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
+    ax = ax.ravel()
+    files = list(save_dir.glob("results*.csv"))
+    assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
+    for f in files:
+        try:
+            data = pd.read_csv(f)
+            s = [x.strip() for x in data.columns]
+            x = data.values[:, 0]
+            for i, j in enumerate(index):
+                y = data.values[:, j].astype("float")
+                # y[y == 0] = np.nan  # don't show zero values
+                ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8)  # actual results
+                ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2)  # smoothing line
+                ax[i].set_title(s[j], fontsize=12)
+                # if j in {8, 9, 10}:  # share train and val loss y axes
+                #     ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
+        except Exception as e:
+            LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
+    ax[1].legend()
+    fname = save_dir / "results.png"
+    fig.savefig(fname, dpi=200)
+    plt.close()
+    if on_plot:
+        on_plot(fname)
+
+
+def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
+    """
+    Plots a scatter plot with points colored based on a 2D histogram.
+
+    Args:
+        v (array-like): Values for the x-axis.
+        f (array-like): Values for the y-axis.
+        bins (int, optional): Number of bins for the histogram. Defaults to 20.
+        cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'.
+        alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8.
+        edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'.
+
+    Examples:
+        >>> v = np.random.rand(100)
+        >>> f = np.random.rand(100)
+        >>> plt_color_scatter(v, f)
+    """
+    # Calculate 2D histogram and corresponding colors
+    hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
+    colors = [
+        hist[
+            min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
+            min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
+        ]
+        for i in range(len(v))
+    ]
+
+    # Scatter plot
+    plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
+
+
+def plot_tune_results(csv_file="tune_results.csv"):
+    """
+    Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
+    in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
+
+    Args:
+        csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.
+
+    Examples:
+        >>> plot_tune_results("path/to/tune_results.csv")
+    """
+    import pandas as pd  # scope for faster 'import ultralytics'
+    from scipy.ndimage import gaussian_filter1d
+
+    def _save_one_file(file):
+        """Save one matplotlib plot to 'file'."""
+        plt.savefig(file, dpi=200)
+        plt.close()
+        LOGGER.info(f"Saved {file}")
+
+    # Scatter plots for each hyperparameter
+    csv_file = Path(csv_file)
+    data = pd.read_csv(csv_file)
+    num_metrics_columns = 1
+    keys = [x.strip() for x in data.columns][num_metrics_columns:]
+    x = data.values
+    fitness = x[:, 0]  # fitness
+    j = np.argmax(fitness)  # max fitness index
+    n = math.ceil(len(keys) ** 0.5)  # columns and rows in plot
+    plt.figure(figsize=(10, 10), tight_layout=True)
+    for i, k in enumerate(keys):
+        v = x[:, i + num_metrics_columns]
+        mu = v[j]  # best single result
+        plt.subplot(n, n, i + 1)
+        plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
+        plt.plot(mu, fitness.max(), "k+", markersize=15)
+        plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9})  # limit to 40 characters
+        plt.tick_params(axis="both", labelsize=8)  # Set axis label size to 8
+        if i % n != 0:
+            plt.yticks([])
+    _save_one_file(csv_file.with_name("tune_scatter_plots.png"))
+
+    # Fitness vs iteration
+    x = range(1, len(fitness) + 1)
+    plt.figure(figsize=(10, 6), tight_layout=True)
+    plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
+    plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2)  # smoothing line
+    plt.title("Fitness vs Iteration")
+    plt.xlabel("Iteration")
+    plt.ylabel("Fitness")
+    plt.grid(True)
+    plt.legend()
+    _save_one_file(csv_file.with_name("tune_fitness.png"))
+
+
+def output_to_target(output, max_det=300):
+    """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
+    targets = []
+    for i, o in enumerate(output):
+        box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
+        j = torch.full((conf.shape[0], 1), i)
+        targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
+    targets = torch.cat(targets, 0).numpy()
+    return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
+
+
+def output_to_rotated_target(output, max_det=300):
+    """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
+    targets = []
+    for i, o in enumerate(output):
+        box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
+        j = torch.full((conf.shape[0], 1), i)
+        targets.append(torch.cat((j, cls, box, angle, conf), 1))
+    targets = torch.cat(targets, 0).numpy()
+    return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
+
+
+def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
+    """
+    Visualize feature maps of a given model module during inference.
+
+    Args:
+        x (torch.Tensor): Features to be visualized.
+        module_type (str): Module type.
+        stage (int): Module stage within the model.
+        n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
+        save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
+    """
+    for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}:  # all model heads
+        if m in module_type:
+            return
+    if isinstance(x, torch.Tensor):
+        _, channels, height, width = x.shape  # batch, channels, height, width
+        if height > 1 and width > 1:
+            f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png"  # filename
+
+            blocks = torch.chunk(x[0].cpu(), channels, dim=0)  # select batch index 0, block by channels
+            n = min(n, channels)  # number of plots
+            _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)  # 8 rows x n/8 cols
+            ax = ax.ravel()
+            plt.subplots_adjust(wspace=0.05, hspace=0.05)
+            for i in range(n):
+                ax[i].imshow(blocks[i].squeeze())  # cmap='gray'
+                ax[i].axis("off")
+
+            LOGGER.info(f"Saving {f}... ({n}/{channels})")
+            plt.savefig(f, dpi=300, bbox_inches="tight")
+            plt.close()
+            np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy())  # npy save

+ 385 - 0
ultralytics/utils/tal.py

@@ -0,0 +1,385 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import torch
+import torch.nn as nn
+
+from . import LOGGER
+from .checks import check_version
+from .metrics import bbox_iou, probiou
+from .ops import xywhr2xyxyxyxy
+
+TORCH_1_10 = check_version(torch.__version__, "1.10.0")
+
+
+class TaskAlignedAssigner(nn.Module):
+    """
+    A task-aligned assigner for object detection.
+
+    This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
+    classification and localization information.
+
+    Attributes:
+        topk (int): The number of top candidates to consider.
+        num_classes (int): The number of object classes.
+        alpha (float): The alpha parameter for the classification component of the task-aligned metric.
+        beta (float): The beta parameter for the localization component of the task-aligned metric.
+        eps (float): A small value to prevent division by zero.
+    """
+
+    def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
+        """Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
+        super().__init__()
+        self.topk = topk
+        self.num_classes = num_classes
+        self.bg_idx = num_classes
+        self.alpha = alpha
+        self.beta = beta
+        self.eps = eps
+
+    @torch.no_grad()
+    def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
+        """
+        Compute the task-aligned assignment. Reference code is available at
+        https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
+
+        Args:
+            pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+            pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+            anc_points (Tensor): shape(num_total_anchors, 2)
+            gt_labels (Tensor): shape(bs, n_max_boxes, 1)
+            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+            mask_gt (Tensor): shape(bs, n_max_boxes, 1)
+
+        Returns:
+            target_labels (Tensor): shape(bs, num_total_anchors)
+            target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+            target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+            fg_mask (Tensor): shape(bs, num_total_anchors)
+            target_gt_idx (Tensor): shape(bs, num_total_anchors)
+        """
+        self.bs = pd_scores.shape[0]
+        self.n_max_boxes = gt_bboxes.shape[1]
+        device = gt_bboxes.device
+
+        if self.n_max_boxes == 0:
+            return (
+                torch.full_like(pd_scores[..., 0], self.bg_idx),
+                torch.zeros_like(pd_bboxes),
+                torch.zeros_like(pd_scores),
+                torch.zeros_like(pd_scores[..., 0]),
+                torch.zeros_like(pd_scores[..., 0]),
+            )
+
+        try:
+            return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
+        except torch.OutOfMemoryError:
+            # Move tensors to CPU, compute, then move back to original device
+            LOGGER.warning("WARNING: CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
+            cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
+            result = self._forward(*cpu_tensors)
+            return tuple(t.to(device) for t in result)
+
+    def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
+        """
+        Compute the task-aligned assignment. Reference code is available at
+        https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
+
+        Args:
+            pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+            pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+            anc_points (Tensor): shape(num_total_anchors, 2)
+            gt_labels (Tensor): shape(bs, n_max_boxes, 1)
+            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+            mask_gt (Tensor): shape(bs, n_max_boxes, 1)
+
+        Returns:
+            target_labels (Tensor): shape(bs, num_total_anchors)
+            target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+            target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+            fg_mask (Tensor): shape(bs, num_total_anchors)
+            target_gt_idx (Tensor): shape(bs, num_total_anchors)
+        """
+        mask_pos, align_metric, overlaps = self.get_pos_mask(
+            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
+        )
+
+        target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
+
+        # Assigned target
+        target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
+
+        # Normalize
+        align_metric *= mask_pos
+        pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)  # b, max_num_obj
+        pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)  # b, max_num_obj
+        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
+        target_scores = target_scores * norm_align_metric
+
+        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
+
+    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
+        """Get in_gts mask, (b, max_num_obj, h*w)."""
+        mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
+        # Get anchor_align metric, (b, max_num_obj, h*w)
+        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
+        # Get topk_metric mask, (b, max_num_obj, h*w)
+        mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
+        # Merge all mask to a final mask, (b, max_num_obj, h*w)
+        mask_pos = mask_topk * mask_in_gts * mask_gt
+
+        return mask_pos, align_metric, overlaps
+
+    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
+        """Compute alignment metric given predicted and ground truth bounding boxes."""
+        na = pd_bboxes.shape[-2]
+        mask_gt = mask_gt.bool()  # b, max_num_obj, h*w
+        overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
+        bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
+
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
+        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
+        # Get the scores of each grid for each gt cls
+        bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]  # b, max_num_obj, h*w
+
+        # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
+        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
+        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
+        overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
+
+        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+        return align_metric, overlaps
+
+    def iou_calculation(self, gt_bboxes, pd_bboxes):
+        """IoU calculation for horizontal bounding boxes."""
+        return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
+
+    def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
+        """
+        Select the top-k candidates based on the given metrics.
+
+        Args:
+            metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
+                              max_num_obj is the maximum number of objects, and h*w represents the
+                              total number of anchor points.
+            largest (bool): If True, select the largest values; otherwise, select the smallest values.
+            topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
+                                topk is the number of top candidates to consider. If not provided,
+                                the top-k values are automatically computed based on the given metrics.
+
+        Returns:
+            (Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
+        """
+        # (b, max_num_obj, topk)
+        topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
+        if topk_mask is None:
+            topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
+        # (b, max_num_obj, topk)
+        topk_idxs.masked_fill_(~topk_mask, 0)
+
+        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
+        count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
+        ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
+        for k in range(self.topk):
+            # Expand topk_idxs for each value of k and add 1 at the specified positions
+            count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
+        # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
+        # Filter invalid bboxes
+        count_tensor.masked_fill_(count_tensor > 1, 0)
+
+        return count_tensor.to(metrics.dtype)
+
+    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+        """
+        Compute target labels, target bounding boxes, and target scores for the positive anchor points.
+
+        Args:
+            gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
+                                batch size and max_num_obj is the maximum number of objects.
+            gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
+            target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
+                                    anchor points, with shape (b, h*w), where h*w is the total
+                                    number of anchor points.
+            fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
+                              (foreground) anchor points.
+
+        Returns:
+            (Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
+                - target_labels (Tensor): Shape (b, h*w), containing the target labels for
+                                          positive anchor points.
+                - target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
+                                          for positive anchor points.
+                - target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
+                                          for positive anchor points, where num_classes is the number
+                                          of object classes.
+        """
+        # Assigned target labels, (b, 1)
+        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
+        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
+        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
+
+        # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
+        target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
+
+        # Assigned target scores
+        target_labels.clamp_(0)
+
+        # 10x faster than F.one_hot()
+        target_scores = torch.zeros(
+            (target_labels.shape[0], target_labels.shape[1], self.num_classes),
+            dtype=torch.int64,
+            device=target_labels.device,
+        )  # (b, h*w, 80)
+        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
+
+        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
+        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
+
+        return target_labels, target_bboxes, target_scores
+
+    @staticmethod
+    def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+        """
+        Select positive anchor centers within ground truth bounding boxes.
+
+        Args:
+            xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
+            gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
+            eps (float, optional): Small value for numerical stability. Defaults to 1e-9.
+
+        Returns:
+            (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
+
+        Note:
+            b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
+            Bounding box format: [x_min, y_min, x_max, y_max].
+        """
+        n_anchors = xy_centers.shape[0]
+        bs, n_boxes, _ = gt_bboxes.shape
+        lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)  # left-top, right-bottom
+        bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
+        # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
+        return bbox_deltas.amin(3).gt_(eps)
+
+    @staticmethod
+    def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+        """
+        Select anchor boxes with highest IoU when assigned to multiple ground truths.
+
+        Args:
+            mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
+            overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
+            n_max_boxes (int): Maximum number of ground truth boxes.
+
+        Returns:
+            target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
+            fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
+            mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
+
+        Note:
+            b: batch size, h: height, w: width.
+        """
+        # Convert (b, n_max_boxes, h*w) -> (b, h*w)
+        fg_mask = mask_pos.sum(-2)
+        if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
+            mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
+            max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)
+
+            is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
+            is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
+
+            mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
+            fg_mask = mask_pos.sum(-2)
+        # Find each grid serve which gt(index)
+        target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
+        return target_gt_idx, fg_mask, mask_pos
+
+
+class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
+    """Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
+
+    def iou_calculation(self, gt_bboxes, pd_bboxes):
+        """IoU calculation for rotated bounding boxes."""
+        return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
+
+    @staticmethod
+    def select_candidates_in_gts(xy_centers, gt_bboxes):
+        """
+        Select the positive anchor center in gt for rotated bounding boxes.
+
+        Args:
+            xy_centers (Tensor): shape(h*w, 2)
+            gt_bboxes (Tensor): shape(b, n_boxes, 5)
+
+        Returns:
+            (Tensor): shape(b, n_boxes, h*w)
+        """
+        # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
+        corners = xywhr2xyxyxyxy(gt_bboxes)
+        # (b, n_boxes, 1, 2)
+        a, b, _, d = corners.split(1, dim=-2)
+        ab = b - a
+        ad = d - a
+
+        # (b, n_boxes, h*w, 2)
+        ap = xy_centers - a
+        norm_ab = (ab * ab).sum(dim=-1)
+        norm_ad = (ad * ad).sum(dim=-1)
+        ap_dot_ab = (ap * ab).sum(dim=-1)
+        ap_dot_ad = (ap * ad).sum(dim=-1)
+        return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad)  # is_in_box
+
+
+def make_anchors(feats, strides, grid_cell_offset=0.5):
+    """Generate anchors from features."""
+    anchor_points, stride_tensor = [], []
+    assert feats is not None
+    dtype, device = feats[0].dtype, feats[0].device
+    for i, stride in enumerate(strides):
+        h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
+        sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift x
+        sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift y
+        sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
+        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
+        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
+    return torch.cat(anchor_points), torch.cat(stride_tensor)
+
+
+def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
+    """Transform distance(ltrb) to box(xywh or xyxy)."""
+    lt, rb = distance.chunk(2, dim)
+    x1y1 = anchor_points - lt
+    x2y2 = anchor_points + rb
+    if xywh:
+        c_xy = (x1y1 + x2y2) / 2
+        wh = x2y2 - x1y1
+        return torch.cat((c_xy, wh), dim)  # xywh bbox
+    return torch.cat((x1y1, x2y2), dim)  # xyxy bbox
+
+
+def bbox2dist(anchor_points, bbox, reg_max):
+    """Transform bbox(xyxy) to dist(ltrb)."""
+    x1y1, x2y2 = bbox.chunk(2, -1)
+    return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01)  # dist (lt, rb)
+
+
+def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
+    """
+    Decode predicted rotated bounding box coordinates from anchor points and distribution.
+
+    Args:
+        pred_dist (torch.Tensor): Predicted rotated distance, shape (bs, h*w, 4).
+        pred_angle (torch.Tensor): Predicted angle, shape (bs, h*w, 1).
+        anchor_points (torch.Tensor): Anchor points, shape (h*w, 2).
+        dim (int, optional): Dimension along which to split. Defaults to -1.
+
+    Returns:
+        (torch.Tensor): Predicted rotated bounding boxes, shape (bs, h*w, 4).
+    """
+    lt, rb = pred_dist.split(2, dim=dim)
+    cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
+    # (bs, h*w, 1)
+    xf, yf = ((rb - lt) / 2).split(1, dim=dim)
+    x, y = xf * cos - yf * sin, xf * sin + yf * cos
+    xy = torch.cat([x, y], dim=dim) + anchor_points
+    return torch.cat([xy, lt + rb], dim=dim)

+ 801 - 0
ultralytics/utils/torch_utils.py

@@ -0,0 +1,801 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+import gc
+import math
+import os
+import random
+import time
+from contextlib import contextmanager
+from copy import deepcopy
+from datetime import datetime
+from pathlib import Path
+from typing import Union
+
+import numpy as np
+import thop
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ultralytics.utils import (
+    DEFAULT_CFG_DICT,
+    DEFAULT_CFG_KEYS,
+    LOGGER,
+    NUM_THREADS,
+    PYTHON_VERSION,
+    TORCHVISION_VERSION,
+    WINDOWS,
+    __version__,
+    colorstr,
+)
+from ultralytics.utils.checks import check_version
+
+# Version checks (all default to version>=min_version)
+TORCH_1_9 = check_version(torch.__version__, "1.9.0")
+TORCH_1_13 = check_version(torch.__version__, "1.13.0")
+TORCH_2_0 = check_version(torch.__version__, "2.0.0")
+TORCH_2_4 = check_version(torch.__version__, "2.4.0")
+TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
+TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
+TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
+TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
+if WINDOWS and check_version(torch.__version__, "==2.4.0"):  # reject version 2.4.0 on Windows
+    LOGGER.warning(
+        "WARNING ⚠️ Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve "
+        "https://github.com/ultralytics/ultralytics/issues/15049"
+    )
+
+
+@contextmanager
+def torch_distributed_zero_first(local_rank: int):
+    """Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
+    initialized = dist.is_available() and dist.is_initialized()
+
+    if initialized and local_rank not in {-1, 0}:
+        dist.barrier(device_ids=[local_rank])
+    yield
+    if initialized and local_rank == 0:
+        dist.barrier(device_ids=[local_rank])
+
+
+def smart_inference_mode():
+    """Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
+
+    def decorate(fn):
+        """Applies appropriate torch decorator for inference mode based on torch version."""
+        if TORCH_1_9 and torch.is_inference_mode_enabled():
+            return fn  # already in inference_mode, act as a pass-through
+        else:
+            return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
+
+    return decorate
+
+
+def autocast(enabled: bool, device: str = "cuda"):
+    """
+    Get the appropriate autocast context manager based on PyTorch version and AMP setting.
+
+    This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
+    older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
+
+    Args:
+        enabled (bool): Whether to enable automatic mixed precision.
+        device (str, optional): The device to use for autocast. Defaults to 'cuda'.
+
+    Returns:
+        (torch.amp.autocast): The appropriate autocast context manager.
+
+    Note:
+        - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
+        - For older versions, it uses `torch.cuda.autocast`.
+
+    Example:
+        ```python
+        with autocast(amp=True):
+            # Your mixed precision operations here
+            pass
+        ```
+    """
+    if TORCH_1_13:
+        return torch.amp.autocast(device, enabled=enabled)
+    else:
+        return torch.cuda.amp.autocast(enabled)
+
+
+def get_cpu_info():
+    """Return a string with system CPU information, i.e. 'Apple M2'."""
+    from ultralytics.utils import PERSISTENT_CACHE  # avoid circular import error
+
+    if "cpu_info" not in PERSISTENT_CACHE:
+        try:
+            import cpuinfo  # pip install py-cpuinfo
+
+            k = "brand_raw", "hardware_raw", "arch_string_raw"  # keys sorted by preference
+            info = cpuinfo.get_cpu_info()  # info dict
+            string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
+            PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
+        except Exception:
+            pass
+    return PERSISTENT_CACHE.get("cpu_info", "unknown")
+
+
+def get_gpu_info(index):
+    """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
+    properties = torch.cuda.get_device_properties(index)
+    return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
+
+
+def select_device(device="", batch=0, newline=False, verbose=True):
+    """
+    Selects the appropriate PyTorch device based on the provided arguments.
+
+    The function takes a string specifying the device or a torch.device object and returns a torch.device object
+    representing the selected device. The function also validates the number of available devices and raises an
+    exception if the requested device(s) are not available.
+
+    Args:
+        device (str | torch.device, optional): Device string or torch.device object.
+            Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
+            the first available GPU, or CPU if no GPU is available.
+        batch (int, optional): Batch size being used in your model. Defaults to 0.
+        newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False.
+        verbose (bool, optional): If True, logs the device information. Defaults to True.
+
+    Returns:
+        (torch.device): Selected device.
+
+    Raises:
+        ValueError: If the specified device is not available or if the batch size is not a multiple of the number of
+            devices when using multiple GPUs.
+
+    Examples:
+        >>> select_device("cuda:0")
+        device(type='cuda', index=0)
+
+        >>> select_device("cpu")
+        device(type='cpu')
+
+    Note:
+        Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
+    """
+    if isinstance(device, torch.device) or str(device).startswith("tpu"):
+        return device
+
+    s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
+    device = str(device).lower()
+    for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
+        device = device.replace(remove, "")  # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
+    cpu = device == "cpu"
+    mps = device in {"mps", "mps:0"}  # Apple Metal Performance Shaders (MPS)
+    if cpu or mps:
+        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"  # force torch.cuda.is_available() = False
+    elif device:  # non-cpu device requested
+        if device == "cuda":
+            device = "0"
+        if "," in device:
+            device = ",".join([x for x in device.split(",") if x])  # remove sequential commas, i.e. "0,,1" -> "0,1"
+        visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+        os.environ["CUDA_VISIBLE_DEVICES"] = device  # set environment variable - must be before assert is_available()
+        if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
+            LOGGER.info(s)
+            install = (
+                "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
+                "CUDA devices are seen by torch.\n"
+                if torch.cuda.device_count() == 0
+                else ""
+            )
+            raise ValueError(
+                f"Invalid CUDA 'device={device}' requested."
+                f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
+                f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
+                f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
+                f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
+                f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
+                f"{install}"
+            )
+
+    if not cpu and not mps and torch.cuda.is_available():  # prefer GPU if available
+        devices = device.split(",") if device else "0"  # i.e. "0,1" -> ["0", "1"]
+        n = len(devices)  # device count
+        if n > 1:  # multi-GPU
+            if batch < 1:
+                raise ValueError(
+                    "AutoBatch with batch<1 not supported for Multi-GPU training, "
+                    "please specify a valid batch size, i.e. batch=16."
+                )
+            if batch >= 0 and batch % n != 0:  # check batch_size is divisible by device_count
+                raise ValueError(
+                    f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
+                    f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
+                )
+        space = " " * (len(s) + 1)
+        for i, d in enumerate(devices):
+            s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n"  # bytes to MB
+        arg = "cuda:0"
+    elif mps and TORCH_2_0 and torch.backends.mps.is_available():
+        # Prefer MPS if available
+        s += f"MPS ({get_cpu_info()})\n"
+        arg = "mps"
+    else:  # revert to CPU
+        s += f"CPU ({get_cpu_info()})\n"
+        arg = "cpu"
+
+    if arg in {"cpu", "mps"}:
+        torch.set_num_threads(NUM_THREADS)  # reset OMP_NUM_THREADS for cpu training
+    if verbose:
+        LOGGER.info(s if newline else s.rstrip())
+    return torch.device(arg)
+
+
+def time_sync():
+    """PyTorch-accurate time."""
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+    return time.time()
+
+
+def fuse_conv_and_bn(conv, bn):
+    """Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
+    fusedconv = (
+        nn.Conv2d(
+            conv.in_channels,
+            conv.out_channels,
+            kernel_size=conv.kernel_size,
+            stride=conv.stride,
+            padding=conv.padding,
+            dilation=conv.dilation,
+            groups=conv.groups,
+            bias=True,
+        )
+        .requires_grad_(False)
+        .to(conv.weight.device)
+    )
+
+    # Prepare filters
+    w_conv = conv.weight.view(conv.out_channels, -1)
+    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
+    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
+
+    # Prepare spatial bias
+    b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias
+    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
+    fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
+
+    return fusedconv
+
+
+def fuse_deconv_and_bn(deconv, bn):
+    """Fuse ConvTranspose2d() and BatchNorm2d() layers."""
+    fuseddconv = (
+        nn.ConvTranspose2d(
+            deconv.in_channels,
+            deconv.out_channels,
+            kernel_size=deconv.kernel_size,
+            stride=deconv.stride,
+            padding=deconv.padding,
+            output_padding=deconv.output_padding,
+            dilation=deconv.dilation,
+            groups=deconv.groups,
+            bias=True,
+        )
+        .requires_grad_(False)
+        .to(deconv.weight.device)
+    )
+
+    # Prepare filters
+    w_deconv = deconv.weight.view(deconv.out_channels, -1)
+    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
+    fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
+
+    # Prepare spatial bias
+    b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias
+    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
+    fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
+
+    return fuseddconv
+
+
+def model_info(model, detailed=False, verbose=True, imgsz=640):
+    """Print and return detailed model information layer by layer."""
+    if not verbose:
+        return
+    n_p = get_num_params(model)  # number of parameters
+    n_g = get_num_gradients(model)  # number of gradients
+    n_l = len(list(model.modules()))  # number of layers
+    if detailed:
+        LOGGER.info(f"{'layer':>5}{'name':>40}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}")
+        for i, (name, p) in enumerate(model.named_parameters()):
+            name = name.replace("module_list.", "")
+            LOGGER.info(
+                f"{i:>5g}{name:>40s}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20s}"
+                f"{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype):>15s}"
+            )
+
+    flops = get_flops(model, imgsz)  # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
+    fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
+    fs = f", {flops:.1f} GFLOPs" if flops else ""
+    yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
+    model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
+    LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")
+    return n_l, n_p, n_g, flops
+
+
+def get_num_params(model):
+    """Return the total number of parameters in a YOLO model."""
+    return sum(x.numel() for x in model.parameters())
+
+
+def get_num_gradients(model):
+    """Return the total number of parameters with gradients in a YOLO model."""
+    return sum(x.numel() for x in model.parameters() if x.requires_grad)
+
+
+def model_info_for_loggers(trainer):
+    """
+    Return model info dict with useful model information.
+
+    Example:
+        YOLOv8n info for loggers
+        ```python
+        results = {
+            "model/parameters": 3151904,
+            "model/GFLOPs": 8.746,
+            "model/speed_ONNX(ms)": 41.244,
+            "model/speed_TensorRT(ms)": 3.211,
+            "model/speed_PyTorch(ms)": 18.755,
+        }
+        ```
+    """
+    if trainer.args.profile:  # profile ONNX and TensorRT times
+        from ultralytics.utils.benchmarks import ProfileModels
+
+        results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
+        results.pop("model/name")
+    else:  # only return PyTorch times from most recent validation
+        results = {
+            "model/parameters": get_num_params(trainer.model),
+            "model/GFLOPs": round(get_flops(trainer.model), 3),
+        }
+    results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
+    return results
+
+
+def get_flops(model, imgsz=640):
+    """Return a YOLO model's FLOPs."""
+    try:
+        model = de_parallel(model)
+        p = next(model.parameters())
+        if not isinstance(imgsz, list):
+            imgsz = [imgsz, imgsz]  # expand if int/float
+        try:
+            # Use stride size for input tensor
+            stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32  # max stride
+            im = torch.empty((1, p.shape[1], stride, stride), device=p.device)  # input image in BCHW format
+            flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2  # stride GFLOPs
+            return flops * imgsz[0] / stride * imgsz[1] / stride  # imgsz GFLOPs
+        except Exception:
+            # Use actual image size for input tensor (i.e. required for RTDETR models)
+            im = torch.empty((1, p.shape[1], *imgsz), device=p.device)  # input image in BCHW format
+            return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2  # imgsz GFLOPs
+    except Exception:
+        return 0.0
+
+
+def get_flops_with_torch_profiler(model, imgsz=640):
+    """Compute model FLOPs (thop package alternative, but 2-10x slower unfortunately)."""
+    if not TORCH_2_0:  # torch profiler implemented in torch>=2.0
+        return 0.0
+    model = de_parallel(model)
+    p = next(model.parameters())
+    if not isinstance(imgsz, list):
+        imgsz = [imgsz, imgsz]  # expand if int/float
+    try:
+        # Use stride size for input tensor
+        stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2  # max stride
+        im = torch.empty((1, p.shape[1], stride, stride), device=p.device)  # input image in BCHW format
+        with torch.profiler.profile(with_flops=True) as prof:
+            model(im)
+        flops = sum(x.flops for x in prof.key_averages()) / 1e9
+        flops = flops * imgsz[0] / stride * imgsz[1] / stride  # 640x640 GFLOPs
+    except Exception:
+        # Use actual image size for input tensor (i.e. required for RTDETR models)
+        im = torch.empty((1, p.shape[1], *imgsz), device=p.device)  # input image in BCHW format
+        with torch.profiler.profile(with_flops=True) as prof:
+            model(im)
+        flops = sum(x.flops for x in prof.key_averages()) / 1e9
+    return flops
+
+
+def initialize_weights(model):
+    """Initialize model weights to random values."""
+    for m in model.modules():
+        t = type(m)
+        if t is nn.Conv2d:
+            pass  # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+        elif t is nn.BatchNorm2d:
+            m.eps = 1e-3
+            m.momentum = 0.03
+        elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
+            m.inplace = True
+
+
+def scale_img(img, ratio=1.0, same_shape=False, gs=32):
+    """Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple."""
+    if ratio == 1.0:
+        return img
+    h, w = img.shape[2:]
+    s = (int(h * ratio), int(w * ratio))  # new size
+    img = F.interpolate(img, size=s, mode="bilinear", align_corners=False)  # resize
+    if not same_shape:  # pad/crop img
+        h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
+    return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447)  # value = imagenet mean
+
+
+def copy_attr(a, b, include=(), exclude=()):
+    """Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
+    for k, v in b.__dict__.items():
+        if (len(include) and k not in include) or k.startswith("_") or k in exclude:
+            continue
+        else:
+            setattr(a, k, v)
+
+
+def get_latest_opset():
+    """Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity."""
+    if TORCH_1_13:
+        # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
+        return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
+    # Otherwise for PyTorch<=1.12 return the corresponding predefined opset
+    version = torch.onnx.producer_version.rsplit(".", 1)[0]  # i.e. '2.3'
+    return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)
+
+
+def intersect_dicts(da, db, exclude=()):
+    """Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""
+    return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
+
+
+def is_parallel(model):
+    """Returns True if model is of type DP or DDP."""
+    return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
+
+
+def de_parallel(model):
+    """De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""
+    return model.module if is_parallel(model) else model
+
+
+def one_cycle(y1=0.0, y2=1.0, steps=100):
+    """Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
+    return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
+
+
+def init_seeds(seed=0, deterministic=False):
+    """Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html."""
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)  # for Multi-GPU, exception safe
+    # torch.backends.cudnn.benchmark = True  # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
+    if deterministic:
+        if TORCH_2_0:
+            torch.use_deterministic_algorithms(True, warn_only=True)  # warn if deterministic is not possible
+            torch.backends.cudnn.deterministic = True
+            os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+            os.environ["PYTHONHASHSEED"] = str(seed)
+        else:
+            LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.")
+    else:
+        torch.use_deterministic_algorithms(False)
+        torch.backends.cudnn.deterministic = False
+
+
+class ModelEMA:
+    """
+    Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
+    average of everything in the model state_dict (parameters and buffers).
+
+    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
+
+    To disable EMA set the `enabled` attribute to `False`.
+    """
+
+    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
+        """Initialize EMA for 'model' with given arguments."""
+        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
+        self.updates = updates  # number of EMA updates
+        self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
+        for p in self.ema.parameters():
+            p.requires_grad_(False)
+        self.enabled = True
+
+    def update(self, model):
+        """Update EMA parameters."""
+        if self.enabled:
+            self.updates += 1
+            d = self.decay(self.updates)
+
+            msd = de_parallel(model).state_dict()  # model state_dict
+            for k, v in self.ema.state_dict().items():
+                if v.dtype.is_floating_point:  # true for FP16 and FP32
+                    v *= d
+                    v += (1 - d) * msd[k].detach()
+                    # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype},  model {msd[k].dtype}'
+
+    def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
+        """Updates attributes and saves stripped model with optimizer removed."""
+        if self.enabled:
+            copy_attr(self.ema, model, include, exclude)
+
+
+def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
+    """
+    Strip optimizer from 'f' to finalize training, optionally save as 's'.
+
+    Args:
+        f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
+        s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
+        updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
+
+    Returns:
+        (dict): The combined checkpoint dictionary.
+
+    Example:
+        ```python
+        from pathlib import Path
+        from ultralytics.utils.torch_utils import strip_optimizer
+
+        for f in Path("path/to/model/checkpoints").rglob("*.pt"):
+            strip_optimizer(f)
+        ```
+
+    Note:
+        Use `ultralytics.nn.torch_safe_load` for missing modules with `x = torch_safe_load(f)[0]`
+    """
+    try:
+        x = torch.load(f, map_location=torch.device("cpu"))
+        assert isinstance(x, dict), "checkpoint is not a Python dictionary"
+        assert "model" in x, "'model' missing from checkpoint"
+    except Exception as e:
+        LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
+        return {}
+
+    metadata = {
+        "date": datetime.now().isoformat(),
+        "version": __version__,
+        "license": "AGPL-3.0 License (https://ultralytics.com/license)",
+        "docs": "https://docs.ultralytics.com",
+    }
+
+    # Update model
+    if x.get("ema"):
+        x["model"] = x["ema"]  # replace model with EMA
+    if hasattr(x["model"], "args"):
+        x["model"].args = dict(x["model"].args)  # convert from IterableSimpleNamespace to dict
+    if hasattr(x["model"], "criterion"):
+        x["model"].criterion = None  # strip loss criterion
+    x["model"].half()  # to FP16
+    for p in x["model"].parameters():
+        p.requires_grad = False
+
+    # Update other keys
+    args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})}  # combine args
+    for k in "optimizer", "best_fitness", "ema", "updates":  # keys
+        x[k] = None
+    x["epoch"] = -1
+    x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # strip non-default keys
+    # x['model'].args = x['train_args']
+
+    # Save
+    combined = {**metadata, **x, **(updates or {})}
+    torch.save(combined, s or f)  # combine dicts (prefer to the right)
+    mb = os.path.getsize(s or f) / 1e6  # file size
+    LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
+    return combined
+
+
+def convert_optimizer_state_dict_to_fp16(state_dict):
+    """
+    Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
+
+    This method aims to reduce storage size without altering 'param_groups' as they contain non-tensor data.
+    """
+    for state in state_dict["state"].values():
+        for k, v in state.items():
+            if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
+                state[k] = v.half()
+
+    return state_dict
+
+
+@contextmanager
+def cuda_memory_usage(device=None):
+    """
+    Monitor and manage CUDA memory usage.
+
+    This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.
+    It then yields a dictionary containing memory usage information, which can be updated by the caller.
+    Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.
+
+    Args:
+        device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None.
+
+    Yields:
+        (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
+    """
+    cuda_info = dict(memory=0)
+    if torch.cuda.is_available():
+        torch.cuda.empty_cache()
+        try:
+            yield cuda_info
+        finally:
+            cuda_info["memory"] = torch.cuda.memory_reserved(device)
+    else:
+        yield cuda_info
+
+
+def profile(input, ops, n=10, device=None, max_num_obj=0):
+    """
+    Ultralytics speed, memory and FLOPs profiler.
+
+    Example:
+        ```python
+        from ultralytics.utils.torch_utils import profile
+
+        input = torch.randn(16, 3, 640, 640)
+        m1 = lambda x: x * torch.sigmoid(x)
+        m2 = nn.SiLU()
+        profile(input, [m1, m2], n=100)  # profile over 100 iterations
+        ```
+    """
+    results = []
+    if not isinstance(device, torch.device):
+        device = select_device(device)
+    LOGGER.info(
+        f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
+        f"{'input':>24s}{'output':>24s}"
+    )
+    gc.collect()  # attempt to free unused memory
+    torch.cuda.empty_cache()
+    for x in input if isinstance(input, list) else [input]:
+        x = x.to(device)
+        x.requires_grad = True
+        for m in ops if isinstance(ops, list) else [ops]:
+            m = m.to(device) if hasattr(m, "to") else m  # device
+            m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
+            tf, tb, t = 0, 0, [0, 0, 0]  # dt forward, backward
+            try:
+                flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2  # GFLOPs
+            except Exception:
+                flops = 0
+
+            try:
+                mem = 0
+                for _ in range(n):
+                    with cuda_memory_usage(device) as cuda_info:
+                        t[0] = time_sync()
+                        y = m(x)
+                        t[1] = time_sync()
+                        try:
+                            (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
+                            t[2] = time_sync()
+                        except Exception:  # no backward method
+                            # print(e)  # for debug
+                            t[2] = float("nan")
+                    mem += cuda_info["memory"] / 1e9  # (GB)
+                    tf += (t[1] - t[0]) * 1000 / n  # ms per op forward
+                    tb += (t[2] - t[1]) * 1000 / n  # ms per op backward
+                    if max_num_obj:  # simulate training with predictions per image grid (for AutoBatch)
+                        with cuda_memory_usage(device) as cuda_info:
+                            torch.randn(
+                                x.shape[0],
+                                max_num_obj,
+                                int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),
+                                device=device,
+                                dtype=torch.float32,
+                            )
+                        mem += cuda_info["memory"] / 1e9  # (GB)
+                s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y))  # shapes
+                p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0  # parameters
+                LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
+                results.append([p, flops, mem, tf, tb, s_in, s_out])
+            except Exception as e:
+                LOGGER.info(e)
+                results.append(None)
+            finally:
+                gc.collect()  # attempt to free unused memory
+                torch.cuda.empty_cache()
+    return results
+
+
+class EarlyStopping:
+    """Early stopping class that stops training when a specified number of epochs have passed without improvement."""
+
+    def __init__(self, patience=50):
+        """
+        Initialize early stopping object.
+
+        Args:
+            patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
+        """
+        self.best_fitness = 0.0  # i.e. mAP
+        self.best_epoch = 0
+        self.patience = patience or float("inf")  # epochs to wait after fitness stops improving to stop
+        self.possible_stop = False  # possible stop may occur next epoch
+
+    def __call__(self, epoch, fitness):
+        """
+        Check whether to stop training.
+
+        Args:
+            epoch (int): Current epoch of training
+            fitness (float): Fitness value of current epoch
+
+        Returns:
+            (bool): True if training should stop, False otherwise
+        """
+        if fitness is None:  # check if fitness=None (happens when val=False)
+            return False
+
+        if fitness >= self.best_fitness:  # >= 0 to allow for early zero-fitness stage of training
+            self.best_epoch = epoch
+            self.best_fitness = fitness
+        delta = epoch - self.best_epoch  # epochs without improvement
+        self.possible_stop = delta >= (self.patience - 1)  # possible stop may occur next epoch
+        stop = delta >= self.patience  # stop training if patience exceeded
+        if stop:
+            prefix = colorstr("EarlyStopping: ")
+            LOGGER.info(
+                f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
+                f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
+                f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
+                f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
+            )
+        return stop
+
+
+class FXModel(nn.Module):
+    """
+    A custom model class for torch.fx compatibility.
+
+    This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph manipulation.
+    It copies attributes from an existing model and explicitly sets the model attribute to ensure proper copying.
+
+    Args:
+        model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
+    """
+
+    def __init__(self, model):
+        """
+        Initialize the FXModel.
+
+        Args:
+            model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
+        """
+        super().__init__()
+        copy_attr(self, model)
+        # Explicitly set `model` since `copy_attr` somehow does not copy it.
+        self.model = model.model
+
+    def forward(self, x):
+        """
+        Forward pass through the model.
+
+        This method performs the forward pass through the model, handling the dependencies between layers and saving intermediate outputs.
+
+        Args:
+            x (torch.Tensor): The input tensor to the model.
+
+        Returns:
+            (torch.Tensor): The output tensor from the model.
+        """
+        y = []  # outputs
+        for m in self.model:
+            if m.f != -1:  # if not from previous layer
+                # from earlier layers
+                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
+            x = m(x)  # run
+            y.append(x)  # save output
+        return x

+ 93 - 0
ultralytics/utils/triton.py

@@ -0,0 +1,93 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from typing import List
+from urllib.parse import urlsplit
+
+import numpy as np
+
+
+class TritonRemoteModel:
+    """
+    Client for interacting with a remote Triton Inference Server model.
+
+    Attributes:
+        endpoint (str): The name of the model on the Triton server.
+        url (str): The URL of the Triton server.
+        triton_client: The Triton client (either HTTP or gRPC).
+        InferInput: The input class for the Triton client.
+        InferRequestedOutput: The output request class for the Triton client.
+        input_formats (List[str]): The data types of the model inputs.
+        np_input_formats (List[type]): The numpy data types of the model inputs.
+        input_names (List[str]): The names of the model inputs.
+        output_names (List[str]): The names of the model outputs.
+    """
+
+    def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
+        """
+        Initialize the TritonRemoteModel.
+
+        Arguments may be provided individually or parsed from a collective 'url' argument of the form
+            <scheme>://<netloc>/<endpoint>/<task_name>
+
+        Args:
+            url (str): The URL of the Triton server.
+            endpoint (str): The name of the model on the Triton server.
+            scheme (str): The communication scheme ('http' or 'grpc').
+        """
+        if not endpoint and not scheme:  # Parse all args from URL string
+            splits = urlsplit(url)
+            endpoint = splits.path.strip("/").split("/")[0]
+            scheme = splits.scheme
+            url = splits.netloc
+
+        self.endpoint = endpoint
+        self.url = url
+
+        # Choose the Triton client based on the communication scheme
+        if scheme == "http":
+            import tritonclient.http as client  # noqa
+
+            self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
+            config = self.triton_client.get_model_config(endpoint)
+        else:
+            import tritonclient.grpc as client  # noqa
+
+            self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
+            config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
+
+        # Sort output names alphabetically, i.e. 'output0', 'output1', etc.
+        config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
+
+        # Define model attributes
+        type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
+        self.InferRequestedOutput = client.InferRequestedOutput
+        self.InferInput = client.InferInput
+        self.input_formats = [x["data_type"] for x in config["input"]]
+        self.np_input_formats = [type_map[x] for x in self.input_formats]
+        self.input_names = [x["name"] for x in config["input"]]
+        self.output_names = [x["name"] for x in config["output"]]
+        self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None"))
+
+    def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
+        """
+        Call the model with the given inputs.
+
+        Args:
+            *inputs (List[np.ndarray]): Input data to the model.
+
+        Returns:
+            (List[np.ndarray]): Model outputs.
+        """
+        infer_inputs = []
+        input_format = inputs[0].dtype
+        for i, x in enumerate(inputs):
+            if x.dtype != self.np_input_formats[i]:
+                x = x.astype(self.np_input_formats[i])
+            infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
+            infer_input.set_data_from_numpy(x)
+            infer_inputs.append(infer_input)
+
+        infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
+        outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
+
+        return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]

+ 157 - 0
ultralytics/utils/tuner.py

@@ -0,0 +1,157 @@
+# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
+
+from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir
+from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks
+
+
+def run_ray_tune(
+    model,
+    space: dict = None,
+    grace_period: int = 10,
+    gpu_per_trial: int = None,
+    max_samples: int = 10,
+    **train_args,
+):
+    """
+    Runs hyperparameter tuning using Ray Tune.
+
+    Args:
+        model (YOLO): Model to run the tuner on.
+        space (dict, optional): The hyperparameter search space. Defaults to None.
+        grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.
+        gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.
+        max_samples (int, optional): The maximum number of trials to run. Defaults to 10.
+        train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.
+
+    Returns:
+        (dict): A dictionary containing the results of the hyperparameter search.
+
+    Example:
+        ```python
+        from ultralytics import YOLO
+
+        # Load a YOLOv8n model
+        model = YOLO("yolo11n.pt")
+
+        # Start tuning hyperparameters for YOLOv8n training on the COCO8 dataset
+        result_grid = model.tune(data="coco8.yaml", use_ray=True)
+        ```
+    """
+    LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
+    if train_args is None:
+        train_args = {}
+
+    try:
+        checks.check_requirements("ray[tune]")
+
+        import ray
+        from ray import tune
+        from ray.air import RunConfig
+        from ray.air.integrations.wandb import WandbLoggerCallback
+        from ray.tune.schedulers import ASHAScheduler
+    except ImportError:
+        raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"')
+
+    try:
+        import wandb
+
+        assert hasattr(wandb, "__version__")
+    except (ImportError, AssertionError):
+        wandb = False
+
+    checks.check_version(ray.__version__, ">=2.0.0", "ray")
+    default_space = {
+        # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
+        "lr0": tune.uniform(1e-5, 1e-1),
+        "lrf": tune.uniform(0.01, 1.0),  # final OneCycleLR learning rate (lr0 * lrf)
+        "momentum": tune.uniform(0.6, 0.98),  # SGD momentum/Adam beta1
+        "weight_decay": tune.uniform(0.0, 0.001),  # optimizer weight decay 5e-4
+        "warmup_epochs": tune.uniform(0.0, 5.0),  # warmup epochs (fractions ok)
+        "warmup_momentum": tune.uniform(0.0, 0.95),  # warmup initial momentum
+        "box": tune.uniform(0.02, 0.2),  # box loss gain
+        "cls": tune.uniform(0.2, 4.0),  # cls loss gain (scale with pixels)
+        "hsv_h": tune.uniform(0.0, 0.1),  # image HSV-Hue augmentation (fraction)
+        "hsv_s": tune.uniform(0.0, 0.9),  # image HSV-Saturation augmentation (fraction)
+        "hsv_v": tune.uniform(0.0, 0.9),  # image HSV-Value augmentation (fraction)
+        "degrees": tune.uniform(0.0, 45.0),  # image rotation (+/- deg)
+        "translate": tune.uniform(0.0, 0.9),  # image translation (+/- fraction)
+        "scale": tune.uniform(0.0, 0.9),  # image scale (+/- gain)
+        "shear": tune.uniform(0.0, 10.0),  # image shear (+/- deg)
+        "perspective": tune.uniform(0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001
+        "flipud": tune.uniform(0.0, 1.0),  # image flip up-down (probability)
+        "fliplr": tune.uniform(0.0, 1.0),  # image flip left-right (probability)
+        "bgr": tune.uniform(0.0, 1.0),  # image channel BGR (probability)
+        "mosaic": tune.uniform(0.0, 1.0),  # image mixup (probability)
+        "mixup": tune.uniform(0.0, 1.0),  # image mixup (probability)
+        "copy_paste": tune.uniform(0.0, 1.0),  # segment copy-paste (probability)
+    }
+
+    # Put the model in ray store
+    task = model.task
+    model_in_store = ray.put(model)
+
+    def _tune(config):
+        """
+        Trains the YOLO model with the specified hyperparameters and additional arguments.
+
+        Args:
+            config (dict): A dictionary of hyperparameters to use for training.
+
+        Returns:
+            None
+        """
+        model_to_train = ray.get(model_in_store)  # get the model from ray store for tuning
+        model_to_train.reset_callbacks()
+        config.update(train_args)
+        results = model_to_train.train(**config)
+        return results.results_dict
+
+    # Get search space
+    if not space:
+        space = default_space
+        LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.")
+
+    # Get dataset
+    data = train_args.get("data", TASK2DATA[task])
+    space["data"] = data
+    if "data" not in train_args:
+        LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
+
+    # Define the trainable function with allocated resources
+    trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
+
+    # Define the ASHA scheduler for hyperparameter search
+    asha_scheduler = ASHAScheduler(
+        time_attr="epoch",
+        metric=TASK2METRIC[task],
+        mode="max",
+        max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
+        grace_period=grace_period,
+        reduction_factor=3,
+    )
+
+    # Define the callbacks for the hyperparameter search
+    tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
+
+    # Create the Ray Tune hyperparameter search tuner
+    tune_dir = get_save_dir(
+        get_cfg(DEFAULT_CFG, train_args), name=train_args.pop("name", "tune")
+    ).resolve()  # must be absolute dir
+    tune_dir.mkdir(parents=True, exist_ok=True)
+    tuner = tune.Tuner(
+        trainable_with_resources,
+        param_space=space,
+        tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
+        run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir),
+    )
+
+    # Run the hyperparameter search
+    tuner.fit()
+
+    # Get the results of the hyperparameter search
+    results = tuner.get_results()
+
+    # Shut down Ray to clean up workers
+    ray.shutdown()
+
+    return results