import torch from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, linenet_resnet101_fpn_v2 from models.line_detect.trainer import Trainer device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if __name__ == '__main__': # model = LineNet('line_net.yaml') # model=linenet_resnet50_fpn() # model=get_line_net_convnext_fpn(num_classes=2).to(device) model=linenet_resnet18_fpn() # model=linenet_resnet101_fpn_v2() # trainer = Trainer() # trainer.train_cfg(model,cfg='./train.yaml') model.start_train(cfg='train.yaml')