base_trainer.py 455 B

123456789101112131415161718192021
  1. from abc import ABC, abstractmethod
  2. class BaseTrainer(ABC):
  3. def __init__(self,
  4. model=None,
  5. dataset=None,
  6. device='cuda',
  7. **kwargs):
  8. #
  9. self.model = model
  10. self.dataset = dataset
  11. self.device=device
  12. # @abstractmethod
  13. # def train_cfg(self,model,cfg):
  14. # return
  15. # @abstractmethod
  16. # def train(self,model, **kwargs):
  17. # return