test.py 740 B

12345678910111213141516171819202122232425
  1. import torch
  2. from models.line_detect.line_detect import linedetect_newresnet18fpn, linedetect_resnet50_fpn, linedetect_resnet18_fpn
  3. from models.line_net.trainer import Trainer
  4. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  5. if __name__ == '__main__':
  6. # model = LineNet('line_net.yaml')
  7. # model=linenet_resnet50_fpn()
  8. # model = linedetect_resnet50_fpn()
  9. # model=get_line_net_convnext_fpn(num_classes=2).to(device)
  10. # model=linenet_newresnet50fpn()
  11. # model = lineDetect_resnet18_fpn()
  12. # model=linedetect_resnet18_fpn()
  13. model=linedetect_newresnet18fpn(num_points=3)
  14. model.eval()
  15. input=torch.zeros((3,3,512,512))
  16. out=model(input)
  17. # model.start_train(cfg='train.yaml')