model.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from pathlib import Path
  3. from ultralytics.engine.model import Model
  4. from ultralytics.models import yolo
  5. from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
  6. from ultralytics.utils import ROOT, yaml_load
  7. class YOLO(Model):
  8. """YOLO (You Only Look Once) object detection model."""
  9. def __init__(self, model="yolo11n.pt", task=None, verbose=False):
  10. """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
  11. path = Path(model)
  12. if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
  13. new_instance = YOLOWorld(path, verbose=verbose)
  14. self.__class__ = type(new_instance)
  15. self.__dict__ = new_instance.__dict__
  16. else:
  17. # Continue with default YOLO initialization
  18. super().__init__(model=model, task=task, verbose=verbose)
  19. @property
  20. def task_map(self):
  21. """Map head to model, trainer, validator, and predictor classes."""
  22. return {
  23. "classify": {
  24. "model": ClassificationModel,
  25. "trainer": yolo.classify.ClassificationTrainer,
  26. "validator": yolo.classify.ClassificationValidator,
  27. "predictor": yolo.classify.ClassificationPredictor,
  28. },
  29. "detect": {
  30. "model": DetectionModel,
  31. "trainer": yolo.detect.DetectionTrainer,
  32. "validator": yolo.detect.DetectionValidator,
  33. "predictor": yolo.detect.DetectionPredictor,
  34. },
  35. "segment": {
  36. "model": SegmentationModel,
  37. "trainer": yolo.segment.SegmentationTrainer,
  38. "validator": yolo.segment.SegmentationValidator,
  39. "predictor": yolo.segment.SegmentationPredictor,
  40. },
  41. "pose": {
  42. "model": PoseModel,
  43. "trainer": yolo.pose.PoseTrainer,
  44. "validator": yolo.pose.PoseValidator,
  45. "predictor": yolo.pose.PosePredictor,
  46. },
  47. "obb": {
  48. "model": OBBModel,
  49. "trainer": yolo.obb.OBBTrainer,
  50. "validator": yolo.obb.OBBValidator,
  51. "predictor": yolo.obb.OBBPredictor,
  52. },
  53. }
  54. class YOLOWorld(Model):
  55. """YOLO-World object detection model."""
  56. def __init__(self, model="yolov8s-world.pt", verbose=False) -> None:
  57. """
  58. Initialize YOLOv8-World model with a pre-trained model file.
  59. Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
  60. COCO class names.
  61. Args:
  62. model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
  63. verbose (bool): If True, prints additional information during initialization.
  64. """
  65. super().__init__(model=model, task="detect", verbose=verbose)
  66. # Assign default COCO class names when there are no custom names
  67. if not hasattr(self.model, "names"):
  68. self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")
  69. @property
  70. def task_map(self):
  71. """Map head to model, validator, and predictor classes."""
  72. return {
  73. "detect": {
  74. "model": WorldModel,
  75. "validator": yolo.detect.DetectionValidator,
  76. "predictor": yolo.detect.DetectionPredictor,
  77. "trainer": yolo.world.WorldTrainer,
  78. }
  79. }
  80. def set_classes(self, classes):
  81. """
  82. Set classes.
  83. Args:
  84. classes (List(str)): A list of categories i.e. ["person"].
  85. """
  86. self.model.set_classes(classes)
  87. # Remove background if it's given
  88. background = " "
  89. if background in classes:
  90. classes.remove(background)
  91. self.model.names = classes
  92. # Reset method class names
  93. # self.predictor = None # reset predictor otherwise old names remain
  94. if self.predictor:
  95. self.predictor.model.names = classes