inference.cc 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. #include "inference.h"
  2. #include <memory>
  3. #include <opencv2/dnn.hpp>
  4. #include <random>
  5. namespace yolo {
  6. // Constructor to initialize the model with default input shape
  7. Inference::Inference(const std::string &model_path, const float &model_confidence_threshold, const float &model_NMS_threshold) {
  8. model_input_shape_ = cv::Size(640, 640); // Set the default size for models with dynamic shapes to prevent errors.
  9. model_confidence_threshold_ = model_confidence_threshold;
  10. model_NMS_threshold_ = model_NMS_threshold;
  11. InitializeModel(model_path);
  12. }
  13. // Constructor to initialize the model with specified input shape
  14. Inference::Inference(const std::string &model_path, const cv::Size model_input_shape, const float &model_confidence_threshold, const float &model_NMS_threshold) {
  15. model_input_shape_ = model_input_shape;
  16. model_confidence_threshold_ = model_confidence_threshold;
  17. model_NMS_threshold_ = model_NMS_threshold;
  18. InitializeModel(model_path);
  19. }
  20. void Inference::InitializeModel(const std::string &model_path) {
  21. ov::Core core; // OpenVINO core object
  22. std::shared_ptr<ov::Model> model = core.read_model(model_path); // Read the model from file
  23. // If the model has dynamic shapes, reshape it to the specified input shape
  24. if (model->is_dynamic()) {
  25. model->reshape({1, 3, static_cast<long int>(model_input_shape_.height), static_cast<long int>(model_input_shape_.width)});
  26. }
  27. // Preprocessing setup for the model
  28. ov::preprocess::PrePostProcessor ppp = ov::preprocess::PrePostProcessor(model);
  29. ppp.input().tensor().set_element_type(ov::element::u8).set_layout("NHWC").set_color_format(ov::preprocess::ColorFormat::BGR);
  30. ppp.input().preprocess().convert_element_type(ov::element::f32).convert_color(ov::preprocess::ColorFormat::RGB).scale({255, 255, 255});
  31. ppp.input().model().set_layout("NCHW");
  32. ppp.output().tensor().set_element_type(ov::element::f32);
  33. model = ppp.build(); // Build the preprocessed model
  34. // Compile the model for inference
  35. compiled_model_ = core.compile_model(model, "AUTO");
  36. inference_request_ = compiled_model_.create_infer_request(); // Create inference request
  37. short width, height;
  38. // Get input shape from the model
  39. const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
  40. const ov::Shape input_shape = inputs[0].get_shape();
  41. height = input_shape[1];
  42. width = input_shape[2];
  43. model_input_shape_ = cv::Size2f(width, height);
  44. // Get output shape from the model
  45. const std::vector<ov::Output<ov::Node>> outputs = model->outputs();
  46. const ov::Shape output_shape = outputs[0].get_shape();
  47. height = output_shape[1];
  48. width = output_shape[2];
  49. model_output_shape_ = cv::Size(width, height);
  50. }
  51. // Method to run inference on an input frame
  52. void Inference::RunInference(cv::Mat &frame) {
  53. Preprocessing(frame); // Preprocess the input frame
  54. inference_request_.infer(); // Run inference
  55. PostProcessing(frame); // Postprocess the inference results
  56. }
  57. // Method to preprocess the input frame
  58. void Inference::Preprocessing(const cv::Mat &frame) {
  59. cv::Mat resized_frame;
  60. cv::resize(frame, resized_frame, model_input_shape_, 0, 0, cv::INTER_AREA); // Resize the frame to match the model input shape
  61. // Calculate scaling factor
  62. scale_factor_.x = static_cast<float>(frame.cols / model_input_shape_.width);
  63. scale_factor_.y = static_cast<float>(frame.rows / model_input_shape_.height);
  64. float *input_data = (float *)resized_frame.data; // Get pointer to resized frame data
  65. const ov::Tensor input_tensor = ov::Tensor(compiled_model_.input().get_element_type(), compiled_model_.input().get_shape(), input_data); // Create input tensor
  66. inference_request_.set_input_tensor(input_tensor); // Set input tensor for inference
  67. }
  68. // Method to postprocess the inference results
  69. void Inference::PostProcessing(cv::Mat &frame) {
  70. std::vector<int> class_list;
  71. std::vector<float> confidence_list;
  72. std::vector<cv::Rect> box_list;
  73. // Get the output tensor from the inference request
  74. const float *detections = inference_request_.get_output_tensor().data<const float>();
  75. const cv::Mat detection_outputs(model_output_shape_, CV_32F, (float *)detections); // Create OpenCV matrix from output tensor
  76. // Iterate over detections and collect class IDs, confidence scores, and bounding boxes
  77. for (int i = 0; i < detection_outputs.cols; ++i) {
  78. const cv::Mat classes_scores = detection_outputs.col(i).rowRange(4, detection_outputs.rows);
  79. cv::Point class_id;
  80. double score;
  81. cv::minMaxLoc(classes_scores, nullptr, &score, nullptr, &class_id); // Find the class with the highest score
  82. // Check if the detection meets the confidence threshold
  83. if (score > model_confidence_threshold_) {
  84. class_list.push_back(class_id.y);
  85. confidence_list.push_back(score);
  86. const float x = detection_outputs.at<float>(0, i);
  87. const float y = detection_outputs.at<float>(1, i);
  88. const float w = detection_outputs.at<float>(2, i);
  89. const float h = detection_outputs.at<float>(3, i);
  90. cv::Rect box;
  91. box.x = static_cast<int>(x);
  92. box.y = static_cast<int>(y);
  93. box.width = static_cast<int>(w);
  94. box.height = static_cast<int>(h);
  95. box_list.push_back(box);
  96. }
  97. }
  98. // Apply Non-Maximum Suppression (NMS) to filter overlapping bounding boxes
  99. std::vector<int> NMS_result;
  100. cv::dnn::NMSBoxes(box_list, confidence_list, model_confidence_threshold_, model_NMS_threshold_, NMS_result);
  101. // Collect final detections after NMS
  102. for (int i = 0; i < NMS_result.size(); ++i) {
  103. Detection result;
  104. const unsigned short id = NMS_result[i];
  105. result.class_id = class_list[id];
  106. result.confidence = confidence_list[id];
  107. result.box = GetBoundingBox(box_list[id]);
  108. DrawDetectedObject(frame, result);
  109. }
  110. }
  111. // Method to get the bounding box in the correct scale
  112. cv::Rect Inference::GetBoundingBox(const cv::Rect &src) const {
  113. cv::Rect box = src;
  114. box.x = (box.x - box.width / 2) * scale_factor_.x;
  115. box.y = (box.y - box.height / 2) * scale_factor_.y;
  116. box.width *= scale_factor_.x;
  117. box.height *= scale_factor_.y;
  118. return box;
  119. }
  120. void Inference::DrawDetectedObject(cv::Mat &frame, const Detection &detection) const {
  121. const cv::Rect &box = detection.box;
  122. const float &confidence = detection.confidence;
  123. const int &class_id = detection.class_id;
  124. // Generate a random color for the bounding box
  125. std::random_device rd;
  126. std::mt19937 gen(rd());
  127. std::uniform_int_distribution<int> dis(120, 255);
  128. const cv::Scalar &color = cv::Scalar(dis(gen), dis(gen), dis(gen));
  129. // Draw the bounding box around the detected object
  130. cv::rectangle(frame, cv::Point(box.x, box.y), cv::Point(box.x + box.width, box.y + box.height), color, 3);
  131. // Prepare the class label and confidence text
  132. std::string classString = classes_[class_id] + std::to_string(confidence).substr(0, 4);
  133. // Get the size of the text box
  134. cv::Size textSize = cv::getTextSize(classString, cv::FONT_HERSHEY_DUPLEX, 0.75, 2, 0);
  135. cv::Rect textBox(box.x, box.y - 40, textSize.width + 10, textSize.height + 20);
  136. // Draw the text box
  137. cv::rectangle(frame, textBox, color, cv::FILLED);
  138. // Put the class label and confidence text above the bounding box
  139. cv::putText(frame, classString, cv::Point(box.x + 5, box.y - 10), cv::FONT_HERSHEY_DUPLEX, 0.75, cv::Scalar(0, 0, 0), 2, 0);
  140. }
  141. } // namespace yolo