test_engine.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import sys
  3. from unittest import mock
  4. from tests import MODEL
  5. from ultralytics import YOLO
  6. from ultralytics.cfg import get_cfg
  7. from ultralytics.engine.exporter import Exporter
  8. from ultralytics.models.yolo import classify, detect, segment
  9. from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR
  10. def test_func(*args): # noqa
  11. """Test function callback for evaluating YOLO model performance metrics."""
  12. print("callback test passed")
  13. def test_export():
  14. """Tests the model exporting function by adding a callback and asserting its execution."""
  15. exporter = Exporter()
  16. exporter.add_callback("on_export_start", test_func)
  17. assert test_func in exporter.callbacks["on_export_start"], "callback test failed"
  18. f = exporter(model=YOLO("yolo11n.yaml").model)
  19. YOLO(f)(ASSETS) # exported model inference
  20. def test_detect():
  21. """Test YOLO object detection training, validation, and prediction functionality."""
  22. overrides = {"data": "coco8.yaml", "model": "yolo11n.yaml", "imgsz": 32, "epochs": 1, "save": False}
  23. cfg = get_cfg(DEFAULT_CFG)
  24. cfg.data = "coco8.yaml"
  25. cfg.imgsz = 32
  26. # Trainer
  27. trainer = detect.DetectionTrainer(overrides=overrides)
  28. trainer.add_callback("on_train_start", test_func)
  29. assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
  30. trainer.train()
  31. # Validator
  32. val = detect.DetectionValidator(args=cfg)
  33. val.add_callback("on_val_start", test_func)
  34. assert test_func in val.callbacks["on_val_start"], "callback test failed"
  35. val(model=trainer.best) # validate best.pt
  36. # Predictor
  37. pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})
  38. pred.add_callback("on_predict_start", test_func)
  39. assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
  40. # Confirm there is no issue with sys.argv being empty.
  41. with mock.patch.object(sys, "argv", []):
  42. result = pred(source=ASSETS, model=MODEL)
  43. assert len(result), "predictor test failed"
  44. overrides["resume"] = trainer.last
  45. trainer = detect.DetectionTrainer(overrides=overrides)
  46. try:
  47. trainer.train()
  48. except Exception as e:
  49. print(f"Expected exception caught: {e}")
  50. return
  51. Exception("Resume test failed!")
  52. def test_segment():
  53. """Tests image segmentation training, validation, and prediction pipelines using YOLO models."""
  54. overrides = {"data": "coco8-seg.yaml", "model": "yolo11n-seg.yaml", "imgsz": 32, "epochs": 1, "save": False}
  55. cfg = get_cfg(DEFAULT_CFG)
  56. cfg.data = "coco8-seg.yaml"
  57. cfg.imgsz = 32
  58. # YOLO(CFG_SEG).train(**overrides) # works
  59. # Trainer
  60. trainer = segment.SegmentationTrainer(overrides=overrides)
  61. trainer.add_callback("on_train_start", test_func)
  62. assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
  63. trainer.train()
  64. # Validator
  65. val = segment.SegmentationValidator(args=cfg)
  66. val.add_callback("on_val_start", test_func)
  67. assert test_func in val.callbacks["on_val_start"], "callback test failed"
  68. val(model=trainer.best) # validate best.pt
  69. # Predictor
  70. pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]})
  71. pred.add_callback("on_predict_start", test_func)
  72. assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
  73. result = pred(source=ASSETS, model=WEIGHTS_DIR / "yolo11n-seg.pt")
  74. assert len(result), "predictor test failed"
  75. # Test resume
  76. overrides["resume"] = trainer.last
  77. trainer = segment.SegmentationTrainer(overrides=overrides)
  78. try:
  79. trainer.train()
  80. except Exception as e:
  81. print(f"Expected exception caught: {e}")
  82. return
  83. Exception("Resume test failed!")
  84. def test_classify():
  85. """Test image classification including training, validation, and prediction phases."""
  86. overrides = {"data": "imagenet10", "model": "yolo11n-cls.yaml", "imgsz": 32, "epochs": 1, "save": False}
  87. cfg = get_cfg(DEFAULT_CFG)
  88. cfg.data = "imagenet10"
  89. cfg.imgsz = 32
  90. # YOLO(CFG_SEG).train(**overrides) # works
  91. # Trainer
  92. trainer = classify.ClassificationTrainer(overrides=overrides)
  93. trainer.add_callback("on_train_start", test_func)
  94. assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
  95. trainer.train()
  96. # Validator
  97. val = classify.ClassificationValidator(args=cfg)
  98. val.add_callback("on_val_start", test_func)
  99. assert test_func in val.callbacks["on_val_start"], "callback test failed"
  100. val(model=trainer.best)
  101. # Predictor
  102. pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]})
  103. pred.add_callback("on_predict_start", test_func)
  104. assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
  105. result = pred(source=ASSETS, model=trainer.best)
  106. assert len(result), "predictor test failed"