object_counter.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from ultralytics.solutions.solutions import BaseSolution
  3. from ultralytics.utils.plotting import Annotator, colors
  4. class ObjectCounter(BaseSolution):
  5. """
  6. A class to manage the counting of objects in a real-time video stream based on their tracks.
  7. This class extends the BaseSolution class and provides functionality for counting objects moving in and out of a
  8. specified region in a video stream. It supports both polygonal and linear regions for counting.
  9. Attributes:
  10. in_count (int): Counter for objects moving inward.
  11. out_count (int): Counter for objects moving outward.
  12. counted_ids (List[int]): List of IDs of objects that have been counted.
  13. classwise_counts (Dict[str, Dict[str, int]]): Dictionary for counts, categorized by object class.
  14. region_initialized (bool): Flag indicating whether the counting region has been initialized.
  15. show_in (bool): Flag to control display of inward count.
  16. show_out (bool): Flag to control display of outward count.
  17. Methods:
  18. count_objects: Counts objects within a polygonal or linear region.
  19. store_classwise_counts: Initializes class-wise counts if not already present.
  20. display_counts: Displays object counts on the frame.
  21. count: Processes input data (frames or object tracks) and updates counts.
  22. Examples:
  23. >>> counter = ObjectCounter()
  24. >>> frame = cv2.imread("frame.jpg")
  25. >>> processed_frame = counter.count(frame)
  26. >>> print(f"Inward count: {counter.in_count}, Outward count: {counter.out_count}")
  27. """
  28. def __init__(self, **kwargs):
  29. """Initializes the ObjectCounter class for real-time object counting in video streams."""
  30. super().__init__(**kwargs)
  31. self.in_count = 0 # Counter for objects moving inward
  32. self.out_count = 0 # Counter for objects moving outward
  33. self.counted_ids = [] # List of IDs of objects that have been counted
  34. self.classwise_counts = {} # Dictionary for counts, categorized by object class
  35. self.region_initialized = False # Bool variable for region initialization
  36. self.show_in = self.CFG["show_in"]
  37. self.show_out = self.CFG["show_out"]
  38. def count_objects(self, current_centroid, track_id, prev_position, cls):
  39. """
  40. Counts objects within a polygonal or linear region based on their tracks.
  41. Args:
  42. current_centroid (Tuple[float, float]): Current centroid values in the current frame.
  43. track_id (int): Unique identifier for the tracked object.
  44. prev_position (Tuple[float, float]): Last frame position coordinates (x, y) of the track.
  45. cls (int): Class index for classwise count updates.
  46. Examples:
  47. >>> counter = ObjectCounter()
  48. >>> track_line = {1: [100, 200], 2: [110, 210], 3: [120, 220]}
  49. >>> box = [130, 230, 150, 250]
  50. >>> track_id = 1
  51. >>> prev_position = (120, 220)
  52. >>> cls = 0
  53. >>> counter.count_objects(current_centroid, track_id, prev_position, cls)
  54. """
  55. if prev_position is None or track_id in self.counted_ids:
  56. return
  57. if len(self.region) == 2: # Linear region (defined as a line segment)
  58. line = self.LineString(self.region) # Check if the line intersects the trajectory of the object
  59. if line.intersects(self.LineString([prev_position, current_centroid])):
  60. # Determine orientation of the region (vertical or horizontal)
  61. if abs(self.region[0][0] - self.region[1][0]) < abs(self.region[0][1] - self.region[1][1]):
  62. # Vertical region: Compare x-coordinates to determine direction
  63. if current_centroid[0] > prev_position[0]: # Moving right
  64. self.in_count += 1
  65. self.classwise_counts[self.names[cls]]["IN"] += 1
  66. else: # Moving left
  67. self.out_count += 1
  68. self.classwise_counts[self.names[cls]]["OUT"] += 1
  69. # Horizontal region: Compare y-coordinates to determine direction
  70. elif current_centroid[1] > prev_position[1]: # Moving downward
  71. self.in_count += 1
  72. self.classwise_counts[self.names[cls]]["IN"] += 1
  73. else: # Moving upward
  74. self.out_count += 1
  75. self.classwise_counts[self.names[cls]]["OUT"] += 1
  76. self.counted_ids.append(track_id)
  77. elif len(self.region) > 2: # Polygonal region
  78. polygon = self.Polygon(self.region)
  79. if polygon.contains(self.Point(current_centroid)):
  80. # Determine motion direction for vertical or horizontal polygons
  81. region_width = max(p[0] for p in self.region) - min(p[0] for p in self.region)
  82. region_height = max(p[1] for p in self.region) - min(p[1] for p in self.region)
  83. if (
  84. region_width < region_height
  85. and current_centroid[0] > prev_position[0]
  86. or region_width >= region_height
  87. and current_centroid[1] > prev_position[1]
  88. ): # Moving right
  89. self.in_count += 1
  90. self.classwise_counts[self.names[cls]]["IN"] += 1
  91. else: # Moving left
  92. self.out_count += 1
  93. self.classwise_counts[self.names[cls]]["OUT"] += 1
  94. self.counted_ids.append(track_id)
  95. def store_classwise_counts(self, cls):
  96. """
  97. Initialize class-wise counts for a specific object class if not already present.
  98. Args:
  99. cls (int): Class index for classwise count updates.
  100. This method ensures that the 'classwise_counts' dictionary contains an entry for the specified class,
  101. initializing 'IN' and 'OUT' counts to zero if the class is not already present.
  102. Examples:
  103. >>> counter = ObjectCounter()
  104. >>> counter.store_classwise_counts(0) # Initialize counts for class index 0
  105. >>> print(counter.classwise_counts)
  106. {'person': {'IN': 0, 'OUT': 0}}
  107. """
  108. if self.names[cls] not in self.classwise_counts:
  109. self.classwise_counts[self.names[cls]] = {"IN": 0, "OUT": 0}
  110. def display_counts(self, im0):
  111. """
  112. Displays object counts on the input image or frame.
  113. Args:
  114. im0 (numpy.ndarray): The input image or frame to display counts on.
  115. Examples:
  116. >>> counter = ObjectCounter()
  117. >>> frame = cv2.imread("image.jpg")
  118. >>> counter.display_counts(frame)
  119. """
  120. labels_dict = {
  121. str.capitalize(key): f"{'IN ' + str(value['IN']) if self.show_in else ''} "
  122. f"{'OUT ' + str(value['OUT']) if self.show_out else ''}".strip()
  123. for key, value in self.classwise_counts.items()
  124. if value["IN"] != 0 or value["OUT"] != 0
  125. }
  126. if labels_dict:
  127. self.annotator.display_analytics(im0, labels_dict, (104, 31, 17), (255, 255, 255), 10)
  128. def count(self, im0):
  129. """
  130. Processes input data (frames or object tracks) and updates object counts.
  131. This method initializes the counting region, extracts tracks, draws bounding boxes and regions, updates
  132. object counts, and displays the results on the input image.
  133. Args:
  134. im0 (numpy.ndarray): The input image or frame to be processed.
  135. Returns:
  136. (numpy.ndarray): The processed image with annotations and count information.
  137. Examples:
  138. >>> counter = ObjectCounter()
  139. >>> frame = cv2.imread("path/to/image.jpg")
  140. >>> processed_frame = counter.count(frame)
  141. """
  142. if not self.region_initialized:
  143. self.initialize_region()
  144. self.region_initialized = True
  145. self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
  146. self.extract_tracks(im0) # Extract tracks
  147. self.annotator.draw_region(
  148. reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2
  149. ) # Draw region
  150. # Iterate over bounding boxes, track ids and classes index
  151. for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
  152. # Draw bounding box and counting region
  153. self.annotator.box_label(box, label=self.names[cls], color=colors(cls, True))
  154. self.store_tracking_history(track_id, box) # Store track history
  155. self.store_classwise_counts(cls) # store classwise counts in dict
  156. # Draw tracks of objects
  157. self.annotator.draw_centroid_and_tracks(
  158. self.track_line, color=colors(int(cls), True), track_thickness=self.line_width
  159. )
  160. current_centroid = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
  161. # store previous position of track for object counting
  162. prev_position = None
  163. if len(self.track_history[track_id]) > 1:
  164. prev_position = self.track_history[track_id][-2]
  165. self.count_objects(current_centroid, track_id, prev_position, cls) # Perform object counting
  166. self.display_counts(im0) # Display the counts on the frame
  167. self.display_output(im0) # display output with base class function
  168. return im0 # return output image for more usage