| 12345678910111213141516171819 |
- import torch
- from models.line_net.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, linenet_newresnet50fpn, \
- get_line_net_convnext_fpn, linenet_newresnet18fpn
- from models.line_net.trainer import Trainer
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- if __name__ == '__main__':
- # model = LineNet('line_net.yaml')
- # model=linenet_resnet50_fpn()
- # model = linenet_resnet18_fpn()
- # model=get_line_net_convnext_fpn(num_classes=2).to(device)
- # model=linenet_newresnet50fpn()
- model = linenet_newresnet18fpn()
- # model.load_best_model('train_results/20250622_140412/weights/best_val.pth')
- # trainer = Trainer()
- # trainer.train_cfg(model,cfg='./train.yaml')
- model.start_train(cfg='train.yaml')
|