123456789101112131415161718192021 |
- from abc import ABC, abstractmethod
- class BaseTrainer(ABC):
- def __init__(self,
- model=None,
- dataset=None,
- device='cuda',
- **kwargs):
- self.model = model
- self.dataset = dataset
- self.device=device
- @abstractmethod
- def train_cfg(self,model,cfg):
- return
- @abstractmethod
- def train(self,model, **kwargs):
- return
|