train.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from copy import copy
  3. import torch
  4. from ultralytics.data import ClassificationDataset, build_dataloader
  5. from ultralytics.engine.trainer import BaseTrainer
  6. from ultralytics.models import yolo
  7. from ultralytics.nn.tasks import ClassificationModel
  8. from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
  9. from ultralytics.utils.plotting import plot_images, plot_results
  10. from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
  11. class ClassificationTrainer(BaseTrainer):
  12. """
  13. A class extending the BaseTrainer class for training based on a classification model.
  14. Notes:
  15. - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
  16. Example:
  17. ```python
  18. from ultralytics.models.yolo.classify import ClassificationTrainer
  19. args = dict(model="yolov8n-cls.pt", data="imagenet10", epochs=3)
  20. trainer = ClassificationTrainer(overrides=args)
  21. trainer.train()
  22. ```
  23. """
  24. def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
  25. """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
  26. if overrides is None:
  27. overrides = {}
  28. overrides["task"] = "classify"
  29. if overrides.get("imgsz") is None:
  30. overrides["imgsz"] = 224
  31. super().__init__(cfg, overrides, _callbacks)
  32. def set_model_attributes(self):
  33. """Set the YOLO model's class names from the loaded dataset."""
  34. self.model.names = self.data["names"]
  35. def get_model(self, cfg=None, weights=None, verbose=True):
  36. """Returns a modified PyTorch model configured for training YOLO."""
  37. model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
  38. if weights:
  39. model.load(weights)
  40. for m in model.modules():
  41. if not self.args.pretrained and hasattr(m, "reset_parameters"):
  42. m.reset_parameters()
  43. if isinstance(m, torch.nn.Dropout) and self.args.dropout:
  44. m.p = self.args.dropout # set dropout
  45. for p in model.parameters():
  46. p.requires_grad = True # for training
  47. return model
  48. def setup_model(self):
  49. """Load, create or download model for any task."""
  50. import torchvision # scope for faster 'import ultralytics'
  51. if str(self.model) in torchvision.models.__dict__:
  52. self.model = torchvision.models.__dict__[self.model](
  53. weights="IMAGENET1K_V1" if self.args.pretrained else None
  54. )
  55. ckpt = None
  56. else:
  57. ckpt = super().setup_model()
  58. ClassificationModel.reshape_outputs(self.model, self.data["nc"])
  59. return ckpt
  60. def build_dataset(self, img_path, mode="train", batch=None):
  61. """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
  62. return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
  63. def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
  64. """Returns PyTorch DataLoader with transforms to preprocess images for inference."""
  65. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  66. dataset = self.build_dataset(dataset_path, mode)
  67. loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
  68. # Attach inference transforms
  69. if mode != "train":
  70. if is_parallel(self.model):
  71. self.model.module.transforms = loader.dataset.torch_transforms
  72. else:
  73. self.model.transforms = loader.dataset.torch_transforms
  74. return loader
  75. def preprocess_batch(self, batch):
  76. """Preprocesses a batch of images and classes."""
  77. batch["img"] = batch["img"].to(self.device)
  78. batch["cls"] = batch["cls"].to(self.device)
  79. return batch
  80. def progress_string(self):
  81. """Returns a formatted string showing training progress."""
  82. return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
  83. "Epoch",
  84. "GPU_mem",
  85. *self.loss_names,
  86. "Instances",
  87. "Size",
  88. )
  89. def get_validator(self):
  90. """Returns an instance of ClassificationValidator for validation."""
  91. self.loss_names = ["loss"]
  92. return yolo.classify.ClassificationValidator(
  93. self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
  94. )
  95. def label_loss_items(self, loss_items=None, prefix="train"):
  96. """
  97. Returns a loss dict with labelled training loss items tensor.
  98. Not needed for classification but necessary for segmentation & detection
  99. """
  100. keys = [f"{prefix}/{x}" for x in self.loss_names]
  101. if loss_items is None:
  102. return keys
  103. loss_items = [round(float(loss_items), 5)]
  104. return dict(zip(keys, loss_items))
  105. def plot_metrics(self):
  106. """Plots metrics from a CSV file."""
  107. plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
  108. def final_eval(self):
  109. """Evaluate trained model and save validation results."""
  110. for f in self.last, self.best:
  111. if f.exists():
  112. strip_optimizer(f) # strip optimizers
  113. if f is self.best:
  114. LOGGER.info(f"\nValidating {f}...")
  115. self.validator.args.data = self.args.data
  116. self.validator.args.plots = self.args.plots
  117. self.metrics = self.validator(model=f)
  118. self.metrics.pop("fitness", None)
  119. self.run_callbacks("on_fit_epoch_end")
  120. def plot_training_samples(self, batch, ni):
  121. """Plots training samples with their annotations."""
  122. plot_images(
  123. images=batch["img"],
  124. batch_idx=torch.arange(len(batch["img"])),
  125. cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models
  126. fname=self.save_dir / f"train_batch{ni}.jpg",
  127. on_plot=self.on_plot,
  128. )