train_demo.py 728 B

1234567891011121314151617181920
  1. import torch
  2. from models.line_detect.line_detect import linedetect_newresnet18fpn
  3. from models.line_net.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, linenet_newresnet50fpn, \
  4. get_line_net_convnext_fpn, linenet_newresnet18fpn
  5. from models.line_net.trainer import Trainer
  6. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  7. if __name__ == '__main__':
  8. # model = LineNet('line_net.yaml')
  9. # model=linenet_resnet50_fpn()
  10. # model = linenet_resnet18_fpn()
  11. # model=get_line_net_convnext_fpn(num_classes=2).to(device)
  12. # model=linenet_newresnet50fpn()
  13. # model = lineDetect_resnet18_fpn()
  14. model=linedetect_newresnet18fpn()
  15. model.start_train(cfg='train.yaml')