train_demo.py 715 B

123456789101112131415161718
  1. import torch
  2. from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, linenet_newresnet50fpn, \
  3. get_line_net_convnext_fpn
  4. from models.line_detect.trainer import Trainer
  5. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  6. if __name__ == '__main__':
  7. # model = LineNet('line_net.yaml')
  8. # model=linenet_resnet50_fpn()
  9. # model = linenet_resnet18_fpn()
  10. # model=get_line_net_convnext_fpn(num_classes=2).to(device)
  11. model=linenet_newresnet50fpn()
  12. model.load_best_model('train_results/20250622_140412/weights/best_val.pth')
  13. # trainer = Trainer()
  14. # trainer.train_cfg(model,cfg='./train.yaml')
  15. model.start_train(cfg='train.yaml')