raytune.py 727 B

12345678910111213141516171819202122232425262728
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from ultralytics.utils import SETTINGS
  3. try:
  4. assert SETTINGS["raytune"] is True # verify integration is enabled
  5. import ray
  6. from ray import tune
  7. from ray.air import session
  8. except (ImportError, AssertionError):
  9. tune = None
  10. def on_fit_epoch_end(trainer):
  11. """Sends training metrics to Ray Tune at end of each epoch."""
  12. if ray.train._internal.session.get_session(): # replacement for deprecated ray.tune.is_session_enabled()
  13. metrics = trainer.metrics
  14. session.report({**metrics, **{"epoch": trainer.epoch + 1}})
  15. callbacks = (
  16. {
  17. "on_fit_epoch_end": on_fit_epoch_end,
  18. }
  19. if tune
  20. else {}
  21. )