main.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import argparse
  3. import cv2.dnn
  4. import numpy as np
  5. from ultralytics.utils import ASSETS, yaml_load
  6. from ultralytics.utils.checks import check_yaml
  7. CLASSES = yaml_load(check_yaml("coco8.yaml"))["names"]
  8. colors = np.random.uniform(0, 255, size=(len(CLASSES), 3))
  9. def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h):
  10. """
  11. Draws bounding boxes on the input image based on the provided arguments.
  12. Args:
  13. img (numpy.ndarray): The input image to draw the bounding box on.
  14. class_id (int): Class ID of the detected object.
  15. confidence (float): Confidence score of the detected object.
  16. x (int): X-coordinate of the top-left corner of the bounding box.
  17. y (int): Y-coordinate of the top-left corner of the bounding box.
  18. x_plus_w (int): X-coordinate of the bottom-right corner of the bounding box.
  19. y_plus_h (int): Y-coordinate of the bottom-right corner of the bounding box.
  20. """
  21. label = f"{CLASSES[class_id]} ({confidence:.2f})"
  22. color = colors[class_id]
  23. cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
  24. cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
  25. def main(onnx_model, input_image):
  26. """
  27. Main function to load ONNX model, perform inference, draw bounding boxes, and display the output image.
  28. Args:
  29. onnx_model (str): Path to the ONNX model.
  30. input_image (str): Path to the input image.
  31. Returns:
  32. list: List of dictionaries containing detection information such as class_id, class_name, confidence, etc.
  33. """
  34. # Load the ONNX model
  35. model: cv2.dnn.Net = cv2.dnn.readNetFromONNX(onnx_model)
  36. # Read the input image
  37. original_image: np.ndarray = cv2.imread(input_image)
  38. [height, width, _] = original_image.shape
  39. # Prepare a square image for inference
  40. length = max((height, width))
  41. image = np.zeros((length, length, 3), np.uint8)
  42. image[0:height, 0:width] = original_image
  43. # Calculate scale factor
  44. scale = length / 640
  45. # Preprocess the image and prepare blob for model
  46. blob = cv2.dnn.blobFromImage(image, scalefactor=1 / 255, size=(640, 640), swapRB=True)
  47. model.setInput(blob)
  48. # Perform inference
  49. outputs = model.forward()
  50. # Prepare output array
  51. outputs = np.array([cv2.transpose(outputs[0])])
  52. rows = outputs.shape[1]
  53. boxes = []
  54. scores = []
  55. class_ids = []
  56. # Iterate through output to collect bounding boxes, confidence scores, and class IDs
  57. for i in range(rows):
  58. classes_scores = outputs[0][i][4:]
  59. (minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores)
  60. if maxScore >= 0.25:
  61. box = [
  62. outputs[0][i][0] - (0.5 * outputs[0][i][2]),
  63. outputs[0][i][1] - (0.5 * outputs[0][i][3]),
  64. outputs[0][i][2],
  65. outputs[0][i][3],
  66. ]
  67. boxes.append(box)
  68. scores.append(maxScore)
  69. class_ids.append(maxClassIndex)
  70. # Apply NMS (Non-maximum suppression)
  71. result_boxes = cv2.dnn.NMSBoxes(boxes, scores, 0.25, 0.45, 0.5)
  72. detections = []
  73. # Iterate through NMS results to draw bounding boxes and labels
  74. for i in range(len(result_boxes)):
  75. index = result_boxes[i]
  76. box = boxes[index]
  77. detection = {
  78. "class_id": class_ids[index],
  79. "class_name": CLASSES[class_ids[index]],
  80. "confidence": scores[index],
  81. "box": box,
  82. "scale": scale,
  83. }
  84. detections.append(detection)
  85. draw_bounding_box(
  86. original_image,
  87. class_ids[index],
  88. scores[index],
  89. round(box[0] * scale),
  90. round(box[1] * scale),
  91. round((box[0] + box[2]) * scale),
  92. round((box[1] + box[3]) * scale),
  93. )
  94. # Display the image with bounding boxes
  95. cv2.imshow("image", original_image)
  96. cv2.waitKey(0)
  97. cv2.destroyAllWindows()
  98. return detections
  99. if __name__ == "__main__":
  100. parser = argparse.ArgumentParser()
  101. parser.add_argument("--model", default="yolov8n.onnx", help="Input your ONNX model.")
  102. parser.add_argument("--img", default=str(ASSETS / "bus.jpg"), help="Path to input image.")
  103. args = parser.parse_args()
  104. main(args.model, args.img)