inference.h 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #pragma once
  2. #define RET_OK nullptr
  3. #ifdef _WIN32
  4. #include <Windows.h>
  5. #include <direct.h>
  6. #include <io.h>
  7. #endif
  8. #include <string>
  9. #include <vector>
  10. #include <cstdio>
  11. #include <opencv2/opencv.hpp>
  12. #include "onnxruntime_cxx_api.h"
  13. #ifdef USE_CUDA
  14. #include <cuda_fp16.h>
  15. #endif
  16. enum MODEL_TYPE
  17. {
  18. //FLOAT32 MODEL
  19. YOLO_DETECT_V8 = 1,
  20. YOLO_POSE = 2,
  21. YOLO_CLS = 3,
  22. //FLOAT16 MODEL
  23. YOLO_DETECT_V8_HALF = 4,
  24. YOLO_POSE_V8_HALF = 5,
  25. YOLO_CLS_HALF = 6
  26. };
  27. typedef struct _DL_INIT_PARAM
  28. {
  29. std::string modelPath;
  30. MODEL_TYPE modelType = YOLO_DETECT_V8;
  31. std::vector<int> imgSize = { 640, 640 };
  32. float rectConfidenceThreshold = 0.6;
  33. float iouThreshold = 0.5;
  34. int keyPointsNum = 2;//Note:kpt number for pose
  35. bool cudaEnable = false;
  36. int logSeverityLevel = 3;
  37. int intraOpNumThreads = 1;
  38. } DL_INIT_PARAM;
  39. typedef struct _DL_RESULT
  40. {
  41. int classId;
  42. float confidence;
  43. cv::Rect box;
  44. std::vector<cv::Point2f> keyPoints;
  45. } DL_RESULT;
  46. class YOLO_V8
  47. {
  48. public:
  49. YOLO_V8();
  50. ~YOLO_V8();
  51. public:
  52. char* CreateSession(DL_INIT_PARAM& iParams);
  53. char* RunSession(cv::Mat& iImg, std::vector<DL_RESULT>& oResult);
  54. char* WarmUpSession();
  55. template<typename N>
  56. char* TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std::vector<int64_t>& inputNodeDims,
  57. std::vector<DL_RESULT>& oResult);
  58. char* PreProcess(cv::Mat& iImg, std::vector<int> iImgSize, cv::Mat& oImg);
  59. std::vector<std::string> classes{};
  60. private:
  61. Ort::Env env;
  62. Ort::Session* session;
  63. bool cudaEnable;
  64. Ort::RunOptions options;
  65. std::vector<const char*> inputNodeNames;
  66. std::vector<const char*> outputNodeNames;
  67. MODEL_TYPE modelType;
  68. std::vector<int> imgSize;
  69. float rectConfidenceThreshold;
  70. float iouThreshold;
  71. float resizeScales;//letterbox scale
  72. };