test_train.py 255 B

123456789
  1. import torch
  2. from models.line_detect.line_net import linenet_resnet50_fpn, LineNet
  3. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  4. if __name__ == '__main__':
  5. model = LineNet('line_net.yaml')
  6. model.train(cfg='./train.yaml')