import torch from models.line_detect.line_net import linenet_resnet50_fpn, LineNet device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if __name__ == '__main__': model = LineNet('line_net.yaml') model.train(cfg='./train.yaml')