model.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. """
  3. YOLO-NAS model interface.
  4. Example:
  5. ```python
  6. from ultralytics import NAS
  7. model = NAS("yolo_nas_s")
  8. results = model.predict("ultralytics/assets/bus.jpg")
  9. ```
  10. """
  11. from pathlib import Path
  12. import torch
  13. from ultralytics.engine.model import Model
  14. from ultralytics.utils.downloads import attempt_download_asset
  15. from ultralytics.utils.torch_utils import model_info
  16. from .predict import NASPredictor
  17. from .val import NASValidator
  18. class NAS(Model):
  19. """
  20. YOLO NAS model for object detection.
  21. This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
  22. It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
  23. Example:
  24. ```python
  25. from ultralytics import NAS
  26. model = NAS("yolo_nas_s")
  27. results = model.predict("ultralytics/assets/bus.jpg")
  28. ```
  29. Attributes:
  30. model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'.
  31. Note:
  32. YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
  33. """
  34. def __init__(self, model="yolo_nas_s.pt") -> None:
  35. """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
  36. assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
  37. super().__init__(model, task="detect")
  38. def _load(self, weights: str, task=None) -> None:
  39. """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
  40. import super_gradients
  41. suffix = Path(weights).suffix
  42. if suffix == ".pt":
  43. self.model = torch.load(attempt_download_asset(weights))
  44. elif suffix == "":
  45. self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
  46. # Override the forward method to ignore additional arguments
  47. def new_forward(x, *args, **kwargs):
  48. """Ignore additional __call__ arguments."""
  49. return self.model._original_forward(x)
  50. self.model._original_forward = self.model.forward
  51. self.model.forward = new_forward
  52. # Standardize model
  53. self.model.fuse = lambda verbose=True: self.model
  54. self.model.stride = torch.tensor([32])
  55. self.model.names = dict(enumerate(self.model._class_names))
  56. self.model.is_fused = lambda: False # for info()
  57. self.model.yaml = {} # for info()
  58. self.model.pt_path = weights # for export()
  59. self.model.task = "detect" # for export()
  60. def info(self, detailed=False, verbose=True):
  61. """
  62. Logs model info.
  63. Args:
  64. detailed (bool): Show detailed information about model.
  65. verbose (bool): Controls verbosity.
  66. """
  67. return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
  68. @property
  69. def task_map(self):
  70. """Returns a dictionary mapping tasks to respective predictor and validator classes."""
  71. return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}