main.cpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. #include <iostream>
  2. #include <iomanip>
  3. #include "inference.h"
  4. #include <filesystem>
  5. #include <fstream>
  6. #include <random>
  7. void Detector(YOLO_V8*& p) {
  8. std::filesystem::path current_path = std::filesystem::current_path();
  9. std::filesystem::path imgs_path = current_path / "images";
  10. for (auto& i : std::filesystem::directory_iterator(imgs_path))
  11. {
  12. if (i.path().extension() == ".jpg" || i.path().extension() == ".png" || i.path().extension() == ".jpeg")
  13. {
  14. std::string img_path = i.path().string();
  15. cv::Mat img = cv::imread(img_path);
  16. std::vector<DL_RESULT> res;
  17. p->RunSession(img, res);
  18. for (auto& re : res)
  19. {
  20. cv::RNG rng(cv::getTickCount());
  21. cv::Scalar color(rng.uniform(0, 256), rng.uniform(0, 256), rng.uniform(0, 256));
  22. cv::rectangle(img, re.box, color, 3);
  23. float confidence = floor(100 * re.confidence) / 100;
  24. std::cout << std::fixed << std::setprecision(2);
  25. std::string label = p->classes[re.classId] + " " +
  26. std::to_string(confidence).substr(0, std::to_string(confidence).size() - 4);
  27. cv::rectangle(
  28. img,
  29. cv::Point(re.box.x, re.box.y - 25),
  30. cv::Point(re.box.x + label.length() * 15, re.box.y),
  31. color,
  32. cv::FILLED
  33. );
  34. cv::putText(
  35. img,
  36. label,
  37. cv::Point(re.box.x, re.box.y - 5),
  38. cv::FONT_HERSHEY_SIMPLEX,
  39. 0.75,
  40. cv::Scalar(0, 0, 0),
  41. 2
  42. );
  43. }
  44. std::cout << "Press any key to exit" << std::endl;
  45. cv::imshow("Result of Detection", img);
  46. cv::waitKey(0);
  47. cv::destroyAllWindows();
  48. }
  49. }
  50. }
  51. void Classifier(YOLO_V8*& p)
  52. {
  53. std::filesystem::path current_path = std::filesystem::current_path();
  54. std::filesystem::path imgs_path = current_path;// / "images"
  55. std::random_device rd;
  56. std::mt19937 gen(rd());
  57. std::uniform_int_distribution<int> dis(0, 255);
  58. for (auto& i : std::filesystem::directory_iterator(imgs_path))
  59. {
  60. if (i.path().extension() == ".jpg" || i.path().extension() == ".png")
  61. {
  62. std::string img_path = i.path().string();
  63. //std::cout << img_path << std::endl;
  64. cv::Mat img = cv::imread(img_path);
  65. std::vector<DL_RESULT> res;
  66. char* ret = p->RunSession(img, res);
  67. float positionY = 50;
  68. for (int i = 0; i < res.size(); i++)
  69. {
  70. int r = dis(gen);
  71. int g = dis(gen);
  72. int b = dis(gen);
  73. cv::putText(img, std::to_string(i) + ":", cv::Point(10, positionY), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(b, g, r), 2);
  74. cv::putText(img, std::to_string(res.at(i).confidence), cv::Point(70, positionY), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(b, g, r), 2);
  75. positionY += 50;
  76. }
  77. cv::imshow("TEST_CLS", img);
  78. cv::waitKey(0);
  79. cv::destroyAllWindows();
  80. //cv::imwrite("E:\\output\\" + std::to_string(k) + ".png", img);
  81. }
  82. }
  83. }
  84. int ReadCocoYaml(YOLO_V8*& p) {
  85. // Open the YAML file
  86. std::ifstream file("coco.yaml");
  87. if (!file.is_open())
  88. {
  89. std::cerr << "Failed to open file" << std::endl;
  90. return 1;
  91. }
  92. // Read the file line by line
  93. std::string line;
  94. std::vector<std::string> lines;
  95. while (std::getline(file, line))
  96. {
  97. lines.push_back(line);
  98. }
  99. // Find the start and end of the names section
  100. std::size_t start = 0;
  101. std::size_t end = 0;
  102. for (std::size_t i = 0; i < lines.size(); i++)
  103. {
  104. if (lines[i].find("names:") != std::string::npos)
  105. {
  106. start = i + 1;
  107. }
  108. else if (start > 0 && lines[i].find(':') == std::string::npos)
  109. {
  110. end = i;
  111. break;
  112. }
  113. }
  114. // Extract the names
  115. std::vector<std::string> names;
  116. for (std::size_t i = start; i < end; i++)
  117. {
  118. std::stringstream ss(lines[i]);
  119. std::string name;
  120. std::getline(ss, name, ':'); // Extract the number before the delimiter
  121. std::getline(ss, name); // Extract the string after the delimiter
  122. names.push_back(name);
  123. }
  124. p->classes = names;
  125. return 0;
  126. }
  127. void DetectTest()
  128. {
  129. YOLO_V8* yoloDetector = new YOLO_V8;
  130. ReadCocoYaml(yoloDetector);
  131. DL_INIT_PARAM params;
  132. params.rectConfidenceThreshold = 0.1;
  133. params.iouThreshold = 0.5;
  134. params.modelPath = "yolov8n.onnx";
  135. params.imgSize = { 640, 640 };
  136. #ifdef USE_CUDA
  137. params.cudaEnable = true;
  138. // GPU FP32 inference
  139. params.modelType = YOLO_DETECT_V8;
  140. // GPU FP16 inference
  141. //Note: change fp16 onnx model
  142. //params.modelType = YOLO_DETECT_V8_HALF;
  143. #else
  144. // CPU inference
  145. params.modelType = YOLO_DETECT_V8;
  146. params.cudaEnable = false;
  147. #endif
  148. yoloDetector->CreateSession(params);
  149. Detector(yoloDetector);
  150. }
  151. void ClsTest()
  152. {
  153. YOLO_V8* yoloDetector = new YOLO_V8;
  154. std::string model_path = "cls.onnx";
  155. ReadCocoYaml(yoloDetector);
  156. DL_INIT_PARAM params{ model_path, YOLO_CLS, {224, 224} };
  157. yoloDetector->CreateSession(params);
  158. Classifier(yoloDetector);
  159. }
  160. int main()
  161. {
  162. //DetectTest();
  163. ClsTest();
  164. }