test_train.py 475 B

1234567891011121314
  1. import torch
  2. from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn
  3. from models.line_detect.trainer import Trainer
  4. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  5. if __name__ == '__main__':
  6. # model = LineNet('line_net.yaml')
  7. # model=linenet_resnet50_fpn()
  8. model=linenet_resnet18_fpn()
  9. # trainer = Trainer()
  10. # trainer.train_cfg(model,cfg='./train.yaml')
  11. model.train_by_cfg(cfg='train.yaml')