comet.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
  3. from ultralytics.utils.metrics import ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics
  4. try:
  5. assert not TESTS_RUNNING # do not log pytest
  6. assert SETTINGS["comet"] is True # verify integration is enabled
  7. import comet_ml
  8. assert hasattr(comet_ml, "__version__") # verify package is not directory
  9. import os
  10. from pathlib import Path
  11. # Ensures certain logging functions only run for supported tasks
  12. COMET_SUPPORTED_TASKS = ["detect"]
  13. # Names of plots created by Ultralytics that are logged to Comet
  14. CONFUSION_MATRIX_PLOT_NAMES = "confusion_matrix", "confusion_matrix_normalized"
  15. EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve"
  16. LABEL_PLOT_NAMES = "labels", "labels_correlogram"
  17. SEGMENT_METRICS_PLOT_PREFIX = "Box", "Mask"
  18. POSE_METRICS_PLOT_PREFIX = "Box", "Pose"
  19. _comet_image_prediction_count = 0
  20. except (ImportError, AssertionError):
  21. comet_ml = None
  22. def _get_comet_mode():
  23. """Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
  24. return os.getenv("COMET_MODE", "online")
  25. def _get_comet_model_name():
  26. """Returns the model name for Comet from the environment variable COMET_MODEL_NAME or defaults to 'Ultralytics'."""
  27. return os.getenv("COMET_MODEL_NAME", "Ultralytics")
  28. def _get_eval_batch_logging_interval():
  29. """Get the evaluation batch logging interval from environment variable or use default value 1."""
  30. return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
  31. def _get_max_image_predictions_to_log():
  32. """Get the maximum number of image predictions to log from the environment variables."""
  33. return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
  34. def _scale_confidence_score(score):
  35. """Scales the given confidence score by a factor specified in an environment variable."""
  36. scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
  37. return score * scale
  38. def _should_log_confusion_matrix():
  39. """Determines if the confusion matrix should be logged based on the environment variable settings."""
  40. return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
  41. def _should_log_image_predictions():
  42. """Determines whether to log image predictions based on a specified environment variable."""
  43. return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
  44. def _get_experiment_type(mode, project_name):
  45. """Return an experiment based on mode and project name."""
  46. if mode == "offline":
  47. return comet_ml.OfflineExperiment(project_name=project_name)
  48. return comet_ml.Experiment(project_name=project_name)
  49. def _create_experiment(args):
  50. """Ensures that the experiment object is only created in a single process during distributed training."""
  51. if RANK not in {-1, 0}:
  52. return
  53. try:
  54. comet_mode = _get_comet_mode()
  55. _project_name = os.getenv("COMET_PROJECT_NAME", args.project)
  56. experiment = _get_experiment_type(comet_mode, _project_name)
  57. experiment.log_parameters(vars(args))
  58. experiment.log_others(
  59. {
  60. "eval_batch_logging_interval": _get_eval_batch_logging_interval(),
  61. "log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
  62. "log_image_predictions": _should_log_image_predictions(),
  63. "max_image_predictions": _get_max_image_predictions_to_log(),
  64. }
  65. )
  66. experiment.log_other("Created from", "ultralytics")
  67. except Exception as e:
  68. LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}")
  69. def _fetch_trainer_metadata(trainer):
  70. """Returns metadata for YOLO training including epoch and asset saving status."""
  71. curr_epoch = trainer.epoch + 1
  72. train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
  73. curr_step = curr_epoch * train_num_steps_per_epoch
  74. final_epoch = curr_epoch == trainer.epochs
  75. save = trainer.args.save
  76. save_period = trainer.args.save_period
  77. save_interval = curr_epoch % save_period == 0
  78. save_assets = save and save_period > 0 and save_interval and not final_epoch
  79. return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
  80. def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
  81. """
  82. YOLO resizes images during training and the label values are normalized based on this resized shape.
  83. This function rescales the bounding box labels to the original image shape.
  84. """
  85. resized_image_height, resized_image_width = resized_image_shape
  86. # Convert normalized xywh format predictions to xyxy in resized scale format
  87. box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
  88. # Scale box predictions from resized image scale back to original image scale
  89. box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
  90. # Convert bounding box format from xyxy to xywh for Comet logging
  91. box = ops.xyxy2xywh(box)
  92. # Adjust xy center to correspond top-left corner
  93. box[:2] -= box[2:] / 2
  94. box = box.tolist()
  95. return box
  96. def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
  97. """Format ground truth annotations for detection."""
  98. indices = batch["batch_idx"] == img_idx
  99. bboxes = batch["bboxes"][indices]
  100. if len(bboxes) == 0:
  101. LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels")
  102. return None
  103. cls_labels = batch["cls"][indices].squeeze(1).tolist()
  104. if class_name_map:
  105. cls_labels = [str(class_name_map[label]) for label in cls_labels]
  106. original_image_shape = batch["ori_shape"][img_idx]
  107. resized_image_shape = batch["resized_shape"][img_idx]
  108. ratio_pad = batch["ratio_pad"][img_idx]
  109. data = []
  110. for box, label in zip(bboxes, cls_labels):
  111. box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
  112. data.append(
  113. {
  114. "boxes": [box],
  115. "label": f"gt_{label}",
  116. "score": _scale_confidence_score(1.0),
  117. }
  118. )
  119. return {"name": "ground_truth", "data": data}
  120. def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
  121. """Format YOLO predictions for object detection visualization."""
  122. stem = image_path.stem
  123. image_id = int(stem) if stem.isnumeric() else stem
  124. predictions = metadata.get(image_id)
  125. if not predictions:
  126. LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions")
  127. return None
  128. data = []
  129. for prediction in predictions:
  130. boxes = prediction["bbox"]
  131. score = _scale_confidence_score(prediction["score"])
  132. cls_label = prediction["category_id"]
  133. if class_label_map:
  134. cls_label = str(class_label_map[cls_label])
  135. data.append({"boxes": [boxes], "label": cls_label, "score": score})
  136. return {"name": "prediction", "data": data}
  137. def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
  138. """Join the ground truth and prediction annotations if they exist."""
  139. ground_truth_annotations = _format_ground_truth_annotations_for_detection(
  140. img_idx, image_path, batch, class_label_map
  141. )
  142. prediction_annotations = _format_prediction_annotations_for_detection(
  143. image_path, prediction_metadata_map, class_label_map
  144. )
  145. annotations = [
  146. annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
  147. ]
  148. return [annotations] if annotations else None
  149. def _create_prediction_metadata_map(model_predictions):
  150. """Create metadata map for model predictions by groupings them based on image ID."""
  151. pred_metadata_map = {}
  152. for prediction in model_predictions:
  153. pred_metadata_map.setdefault(prediction["image_id"], [])
  154. pred_metadata_map[prediction["image_id"]].append(prediction)
  155. return pred_metadata_map
  156. def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
  157. """Log the confusion matrix to Comet experiment."""
  158. conf_mat = trainer.validator.confusion_matrix.matrix
  159. names = list(trainer.data["names"].values()) + ["background"]
  160. experiment.log_confusion_matrix(
  161. matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
  162. )
  163. def _log_images(experiment, image_paths, curr_step, annotations=None):
  164. """Logs images to the experiment with optional annotations."""
  165. if annotations:
  166. for image_path, annotation in zip(image_paths, annotations):
  167. experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)
  168. else:
  169. for image_path in image_paths:
  170. experiment.log_image(image_path, name=image_path.stem, step=curr_step)
  171. def _log_image_predictions(experiment, validator, curr_step):
  172. """Logs predicted boxes for a single image during training."""
  173. global _comet_image_prediction_count
  174. task = validator.args.task
  175. if task not in COMET_SUPPORTED_TASKS:
  176. return
  177. jdict = validator.jdict
  178. if not jdict:
  179. return
  180. predictions_metadata_map = _create_prediction_metadata_map(jdict)
  181. dataloader = validator.dataloader
  182. class_label_map = validator.names
  183. batch_logging_interval = _get_eval_batch_logging_interval()
  184. max_image_predictions = _get_max_image_predictions_to_log()
  185. for batch_idx, batch in enumerate(dataloader):
  186. if (batch_idx + 1) % batch_logging_interval != 0:
  187. continue
  188. image_paths = batch["im_file"]
  189. for img_idx, image_path in enumerate(image_paths):
  190. if _comet_image_prediction_count >= max_image_predictions:
  191. return
  192. image_path = Path(image_path)
  193. annotations = _fetch_annotations(
  194. img_idx,
  195. image_path,
  196. batch,
  197. predictions_metadata_map,
  198. class_label_map,
  199. )
  200. _log_images(
  201. experiment,
  202. [image_path],
  203. curr_step,
  204. annotations=annotations,
  205. )
  206. _comet_image_prediction_count += 1
  207. def _log_plots(experiment, trainer):
  208. """Logs evaluation plots and label plots for the experiment."""
  209. plot_filenames = None
  210. if isinstance(trainer.validator.metrics, SegmentMetrics) and trainer.validator.metrics.task == "segment":
  211. plot_filenames = [
  212. trainer.save_dir / f"{prefix}{plots}.png"
  213. for plots in EVALUATION_PLOT_NAMES
  214. for prefix in SEGMENT_METRICS_PLOT_PREFIX
  215. ]
  216. elif isinstance(trainer.validator.metrics, PoseMetrics):
  217. plot_filenames = [
  218. trainer.save_dir / f"{prefix}{plots}.png"
  219. for plots in EVALUATION_PLOT_NAMES
  220. for prefix in POSE_METRICS_PLOT_PREFIX
  221. ]
  222. elif isinstance(trainer.validator.metrics, (DetMetrics, OBBMetrics)):
  223. plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
  224. if plot_filenames is not None:
  225. _log_images(experiment, plot_filenames, None)
  226. confusion_matrix_filenames = [trainer.save_dir / f"{plots}.png" for plots in CONFUSION_MATRIX_PLOT_NAMES]
  227. _log_images(experiment, confusion_matrix_filenames, None)
  228. if not isinstance(trainer.validator.metrics, ClassifyMetrics):
  229. label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
  230. _log_images(experiment, label_plot_filenames, None)
  231. def _log_model(experiment, trainer):
  232. """Log the best-trained model to Comet.ml."""
  233. model_name = _get_comet_model_name()
  234. experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
  235. def on_pretrain_routine_start(trainer):
  236. """Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
  237. experiment = comet_ml.get_global_experiment()
  238. is_alive = getattr(experiment, "alive", False)
  239. if not experiment or not is_alive:
  240. _create_experiment(trainer.args)
  241. def on_train_epoch_end(trainer):
  242. """Log metrics and save batch images at the end of training epochs."""
  243. experiment = comet_ml.get_global_experiment()
  244. if not experiment:
  245. return
  246. metadata = _fetch_trainer_metadata(trainer)
  247. curr_epoch = metadata["curr_epoch"]
  248. curr_step = metadata["curr_step"]
  249. experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
  250. def on_fit_epoch_end(trainer):
  251. """Logs model assets at the end of each epoch."""
  252. experiment = comet_ml.get_global_experiment()
  253. if not experiment:
  254. return
  255. metadata = _fetch_trainer_metadata(trainer)
  256. curr_epoch = metadata["curr_epoch"]
  257. curr_step = metadata["curr_step"]
  258. save_assets = metadata["save_assets"]
  259. experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
  260. experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
  261. if curr_epoch == 1:
  262. from ultralytics.utils.torch_utils import model_info_for_loggers
  263. experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
  264. if not save_assets:
  265. return
  266. _log_model(experiment, trainer)
  267. if _should_log_confusion_matrix():
  268. _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
  269. if _should_log_image_predictions():
  270. _log_image_predictions(experiment, trainer.validator, curr_step)
  271. def on_train_end(trainer):
  272. """Perform operations at the end of training."""
  273. experiment = comet_ml.get_global_experiment()
  274. if not experiment:
  275. return
  276. metadata = _fetch_trainer_metadata(trainer)
  277. curr_epoch = metadata["curr_epoch"]
  278. curr_step = metadata["curr_step"]
  279. plots = trainer.args.plots
  280. _log_model(experiment, trainer)
  281. if plots:
  282. _log_plots(experiment, trainer)
  283. _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
  284. _log_image_predictions(experiment, trainer.validator, curr_step)
  285. _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
  286. _log_images(experiment, trainer.save_dir.glob("val_batch*.jpg"), curr_step)
  287. experiment.end()
  288. global _comet_image_prediction_count
  289. _comet_image_prediction_count = 0
  290. callbacks = (
  291. {
  292. "on_pretrain_routine_start": on_pretrain_routine_start,
  293. "on_train_epoch_end": on_train_epoch_end,
  294. "on_fit_epoch_end": on_fit_epoch_end,
  295. "on_train_end": on_train_end,
  296. }
  297. if comet_ml
  298. else {}
  299. )