main.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import argparse
  3. import cv2
  4. import numpy as np
  5. import onnxruntime as ort
  6. import torch
  7. from ultralytics.utils import ASSETS, yaml_load
  8. from ultralytics.utils.checks import check_requirements, check_yaml
  9. class RTDETR:
  10. """RTDETR object detection model class for handling inference and visualization."""
  11. def __init__(self, model_path, img_path, conf_thres=0.5, iou_thres=0.5):
  12. """
  13. Initializes the RTDETR object with the specified parameters.
  14. Args:
  15. model_path: Path to the ONNX model file.
  16. img_path: Path to the input image.
  17. conf_thres: Confidence threshold for object detection.
  18. iou_thres: IoU threshold for non-maximum suppression
  19. """
  20. self.model_path = model_path
  21. self.img_path = img_path
  22. self.conf_thres = conf_thres
  23. self.iou_thres = iou_thres
  24. # Set up the ONNX runtime session with CUDA and CPU execution providers
  25. self.session = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
  26. self.model_input = self.session.get_inputs()
  27. self.input_width = self.model_input[0].shape[2]
  28. self.input_height = self.model_input[0].shape[3]
  29. # Load class names from the COCO dataset YAML file
  30. self.classes = yaml_load(check_yaml("coco8.yaml"))["names"]
  31. # Generate a color palette for drawing bounding boxes
  32. self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
  33. def draw_detections(self, box, score, class_id):
  34. """
  35. Draws bounding boxes and labels on the input image based on the detected objects.
  36. Args:
  37. box: Detected bounding box.
  38. score: Corresponding detection score.
  39. class_id: Class ID for the detected object.
  40. Returns:
  41. None
  42. """
  43. # Extract the coordinates of the bounding box
  44. x1, y1, x2, y2 = box
  45. # Retrieve the color for the class ID
  46. color = self.color_palette[class_id]
  47. # Draw the bounding box on the image
  48. cv2.rectangle(self.img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  49. # Create the label text with class name and score
  50. label = f"{self.classes[class_id]}: {score:.2f}"
  51. # Calculate the dimensions of the label text
  52. (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
  53. # Calculate the position of the label text
  54. label_x = x1
  55. label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
  56. # Draw a filled rectangle as the background for the label text
  57. cv2.rectangle(
  58. self.img,
  59. (int(label_x), int(label_y - label_height)),
  60. (int(label_x + label_width), int(label_y + label_height)),
  61. color,
  62. cv2.FILLED,
  63. )
  64. # Draw the label text on the image
  65. cv2.putText(
  66. self.img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA
  67. )
  68. def preprocess(self):
  69. """
  70. Preprocesses the input image before performing inference.
  71. Returns:
  72. image_data: Preprocessed image data ready for inference.
  73. """
  74. # Read the input image using OpenCV
  75. self.img = cv2.imread(self.img_path)
  76. # Get the height and width of the input image
  77. self.img_height, self.img_width = self.img.shape[:2]
  78. # Convert the image color space from BGR to RGB
  79. img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)
  80. # Resize the image to match the input shape
  81. img = cv2.resize(img, (self.input_width, self.input_height))
  82. # Normalize the image data by dividing it by 255.0
  83. image_data = np.array(img) / 255.0
  84. # Transpose the image to have the channel dimension as the first dimension
  85. image_data = np.transpose(image_data, (2, 0, 1)) # Channel first
  86. # Expand the dimensions of the image data to match the expected input shape
  87. image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
  88. # Return the preprocessed image data
  89. return image_data
  90. def bbox_cxcywh_to_xyxy(self, boxes):
  91. """
  92. Converts bounding boxes from (center x, center y, width, height) format to (x_min, y_min, x_max, y_max) format.
  93. Args:
  94. boxes (numpy.ndarray): An array of shape (N, 4) where each row represents
  95. a bounding box in (cx, cy, w, h) format.
  96. Returns:
  97. numpy.ndarray: An array of shape (N, 4) where each row represents
  98. a bounding box in (x_min, y_min, x_max, y_max) format.
  99. """
  100. # Calculate half width and half height of the bounding boxes
  101. half_width = boxes[:, 2] / 2
  102. half_height = boxes[:, 3] / 2
  103. # Calculate the coordinates of the bounding boxes
  104. x_min = boxes[:, 0] - half_width
  105. y_min = boxes[:, 1] - half_height
  106. x_max = boxes[:, 0] + half_width
  107. y_max = boxes[:, 1] + half_height
  108. # Return the bounding boxes in (x_min, y_min, x_max, y_max) format
  109. return np.column_stack((x_min, y_min, x_max, y_max))
  110. def postprocess(self, model_output):
  111. """
  112. Postprocesses the model output to extract detections and draw them on the input image.
  113. Args:
  114. model_output: Output of the model inference.
  115. Returns:
  116. np.array: Annotated image with detections.
  117. """
  118. # Squeeze the model output to remove unnecessary dimensions
  119. outputs = np.squeeze(model_output[0])
  120. # Extract bounding boxes and scores from the model output
  121. boxes = outputs[:, :4]
  122. scores = outputs[:, 4:]
  123. # Get the class labels and scores for each detection
  124. labels = np.argmax(scores, axis=1)
  125. scores = np.max(scores, axis=1)
  126. # Apply confidence threshold to filter out low-confidence detections
  127. mask = scores > self.conf_thres
  128. boxes, scores, labels = boxes[mask], scores[mask], labels[mask]
  129. # Convert bounding boxes to (x_min, y_min, x_max, y_max) format
  130. boxes = self.bbox_cxcywh_to_xyxy(boxes)
  131. # Scale bounding boxes to match the original image dimensions
  132. boxes[:, 0::2] *= self.img_width
  133. boxes[:, 1::2] *= self.img_height
  134. # Draw detections on the image
  135. for box, score, label in zip(boxes, scores, labels):
  136. self.draw_detections(box, score, label)
  137. # Return the annotated image
  138. return self.img
  139. def main(self):
  140. """
  141. Executes the detection on the input image using the ONNX model.
  142. Returns:
  143. np.array: Output image with annotations.
  144. """
  145. # Preprocess the image for model input
  146. image_data = self.preprocess()
  147. # Run the model inference
  148. model_output = self.session.run(None, {self.model_input[0].name: image_data})
  149. # Process and return the model output
  150. return self.postprocess(model_output)
  151. if __name__ == "__main__":
  152. # Set up argument parser for command-line arguments
  153. parser = argparse.ArgumentParser()
  154. parser.add_argument("--model", type=str, default="rtdetr-l.onnx", help="Path to the ONNX model file.")
  155. parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to the input image.")
  156. parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold for object detection.")
  157. parser.add_argument("--iou-thres", type=float, default=0.5, help="IoU threshold for non-maximum suppression.")
  158. args = parser.parse_args()
  159. # Check for dependencies and set up ONNX runtime
  160. check_requirements("onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime")
  161. # Create the detector instance with specified parameters
  162. detection = RTDETR(args.model, args.img, args.conf_thres, args.iou_thres)
  163. # Perform detection and get the output image
  164. output_image = detection.main()
  165. # Display the annotated output image
  166. cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
  167. cv2.imshow("Output", output_image)
  168. cv2.waitKey(0)