base_trainer.py 434 B

123456789101112131415161718192021
  1. from abc import ABC, abstractmethod
  2. class BaseTrainer(ABC):
  3. def __init__(self,
  4. model=None,
  5. dataset=None,
  6. device='cuda',
  7. **kwargs):
  8. self.model = model
  9. self.dataset = dataset
  10. self.device=device
  11. @abstractmethod
  12. def train_cfg(self,model,cfg):
  13. return
  14. @abstractmethod
  15. def train(self,model, **kwargs):
  16. return