train_demo.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435
  1. import torch
  2. import os
  3. from models.line_detect.line_detect import linedetect_newresnet18fpn, linedetect_resnet50_fpn, linedetect_resnet18_fpn, \
  4. linedetect_newresnet50fpn, linedetect_maxvitfpn, linedetect_high_maxvitfpn, linedetect_swin_transformer_fpn, \
  5. linedetect_newresnet101fpn, linedetect_newresnet152fpn, linedetect_efficientnet
  6. from models.line_net.trainer import Trainer
  7. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  8. if __name__ == '__main__':
  9. os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
  10. # model = LineNet('line_net.yaml')
  11. # model=linedetect_resnet50_fpn()
  12. # model = linedetect_resnet50_fpn()
  13. # model=get_line_net_convnext_fpn(num_classes=2).to(device)
  14. # model=linenet_newresnet50fpn()
  15. # model = lineDetect_resnet18_fpn()
  16. # model=linedetect_resnet18_fpn()
  17. # model=linedetect_newresnet18fpn(num_points=4)
  18. # model=linedetect_newresnet50fpn(num_points=4)
  19. # model = linedetect_newresnet101fpn(num_points=4)
  20. # model = linedetect_newresnet152fpn(num_points=4)
  21. # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
  22. # model=linedetect_maxvitfpn()
  23. model=linedetect_efficientnet(name='efficientnet_v2_l')
  24. # model=linedetect_high_maxvitfpn()
  25. # model=linedetect_swin_transformer_fpn(type='t')
  26. model.start_train(cfg='train.yaml')