12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
- """
- YOLO-NAS model interface.
- Example:
- ```python
- from ultralytics import NAS
- model = NAS("yolo_nas_s")
- results = model.predict("ultralytics/assets/bus.jpg")
- ```
- """
- from pathlib import Path
- import torch
- from ultralytics.engine.model import Model
- from ultralytics.utils.downloads import attempt_download_asset
- from ultralytics.utils.torch_utils import model_info
- from .predict import NASPredictor
- from .val import NASValidator
- class NAS(Model):
- """
- YOLO NAS model for object detection.
- This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
- It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
- Example:
- ```python
- from ultralytics import NAS
- model = NAS("yolo_nas_s")
- results = model.predict("ultralytics/assets/bus.jpg")
- ```
- Attributes:
- model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'.
- Note:
- YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
- """
- def __init__(self, model="yolo_nas_s.pt") -> None:
- """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
- assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
- super().__init__(model, task="detect")
- def _load(self, weights: str, task=None) -> None:
- """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
- import super_gradients
- suffix = Path(weights).suffix
- if suffix == ".pt":
- self.model = torch.load(attempt_download_asset(weights))
- elif suffix == "":
- self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
- # Override the forward method to ignore additional arguments
- def new_forward(x, *args, **kwargs):
- """Ignore additional __call__ arguments."""
- return self.model._original_forward(x)
- self.model._original_forward = self.model.forward
- self.model.forward = new_forward
- # Standardize model
- self.model.fuse = lambda verbose=True: self.model
- self.model.stride = torch.tensor([32])
- self.model.names = dict(enumerate(self.model._class_names))
- self.model.is_fused = lambda: False # for info()
- self.model.yaml = {} # for info()
- self.model.pt_path = weights # for export()
- self.model.task = "detect" # for export()
- def info(self, detailed=False, verbose=True):
- """
- Logs model info.
- Args:
- detailed (bool): Show detailed information about model.
- verbose (bool): Controls verbosity.
- """
- return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
- @property
- def task_map(self):
- """Returns a dictionary mapping tasks to respective predictor and validator classes."""
- return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
|