queue_management.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from ultralytics.solutions.solutions import BaseSolution
  3. from ultralytics.utils.plotting import Annotator, colors
  4. class QueueManager(BaseSolution):
  5. """
  6. Manages queue counting in real-time video streams based on object tracks.
  7. This class extends BaseSolution to provide functionality for tracking and counting objects within a specified
  8. region in video frames.
  9. Attributes:
  10. counts (int): The current count of objects in the queue.
  11. rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle.
  12. region_length (int): The number of points defining the queue region.
  13. annotator (Annotator): An instance of the Annotator class for drawing on frames.
  14. track_line (List[Tuple[int, int]]): List of track line coordinates.
  15. track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object.
  16. Methods:
  17. initialize_region: Initializes the queue region.
  18. process_queue: Processes a single frame for queue management.
  19. extract_tracks: Extracts object tracks from the current frame.
  20. store_tracking_history: Stores the tracking history for an object.
  21. display_output: Displays the processed output.
  22. Examples:
  23. >>> cap = cv2.VideoCapture("Path/to/video/file.mp4")
  24. >>> queue_manager = QueueManager(region=[100, 100, 200, 200, 300, 300])
  25. >>> while cap.isOpened():
  26. >>> success, im0 = cap.read()
  27. >>> if not success:
  28. >>> break
  29. >>> out = queue.process_queue(im0)
  30. """
  31. def __init__(self, **kwargs):
  32. """Initializes the QueueManager with parameters for tracking and counting objects in a video stream."""
  33. super().__init__(**kwargs)
  34. self.initialize_region()
  35. self.counts = 0 # Queue counts Information
  36. self.rect_color = (255, 255, 255) # Rectangle color
  37. self.region_length = len(self.region) # Store region length for further usage
  38. def process_queue(self, im0):
  39. """
  40. Processes the queue management for a single frame of video.
  41. Args:
  42. im0 (numpy.ndarray): Input image for processing, typically a frame from a video stream.
  43. Returns:
  44. (numpy.ndarray): Processed image with annotations, bounding boxes, and queue counts.
  45. This method performs the following steps:
  46. 1. Resets the queue count for the current frame.
  47. 2. Initializes an Annotator object for drawing on the image.
  48. 3. Extracts tracks from the image.
  49. 4. Draws the counting region on the image.
  50. 5. For each detected object:
  51. - Draws bounding boxes and labels.
  52. - Stores tracking history.
  53. - Draws centroids and tracks.
  54. - Checks if the object is inside the counting region and updates the count.
  55. 6. Displays the queue count on the image.
  56. 7. Displays the processed output.
  57. Examples:
  58. >>> queue_manager = QueueManager()
  59. >>> frame = cv2.imread("frame.jpg")
  60. >>> processed_frame = queue_manager.process_queue(frame)
  61. """
  62. self.counts = 0 # Reset counts every frame
  63. self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
  64. self.extract_tracks(im0) # Extract tracks
  65. self.annotator.draw_region(
  66. reg_pts=self.region, color=self.rect_color, thickness=self.line_width * 2
  67. ) # Draw region
  68. for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
  69. # Draw bounding box and counting region
  70. self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True))
  71. self.store_tracking_history(track_id, box) # Store track history
  72. # Draw tracks of objects
  73. self.annotator.draw_centroid_and_tracks(
  74. self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width
  75. )
  76. # Cache frequently accessed attributes
  77. track_history = self.track_history.get(track_id, [])
  78. # store previous position of track and check if the object is inside the counting region
  79. prev_position = None
  80. if len(track_history) > 1:
  81. prev_position = track_history[-2]
  82. if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])):
  83. self.counts += 1
  84. # Display queue counts
  85. self.annotator.queue_counts_display(
  86. f"Queue Counts : {str(self.counts)}",
  87. points=self.region,
  88. region_color=self.rect_color,
  89. txt_color=(104, 31, 17),
  90. )
  91. self.display_output(im0) # display output with base class function
  92. return im0 # return output image for more usage