distance_calculation.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import math
  3. import cv2
  4. from ultralytics.solutions.solutions import BaseSolution
  5. from ultralytics.utils.plotting import Annotator, colors
  6. class DistanceCalculation(BaseSolution):
  7. """
  8. A class to calculate distance between two objects in a real-time video stream based on their tracks.
  9. This class extends BaseSolution to provide functionality for selecting objects and calculating the distance
  10. between them in a video stream using YOLO object detection and tracking.
  11. Attributes:
  12. left_mouse_count (int): Counter for left mouse button clicks.
  13. selected_boxes (Dict[int, List[float]]): Dictionary to store selected bounding boxes and their track IDs.
  14. annotator (Annotator): An instance of the Annotator class for drawing on the image.
  15. boxes (List[List[float]]): List of bounding boxes for detected objects.
  16. track_ids (List[int]): List of track IDs for detected objects.
  17. clss (List[int]): List of class indices for detected objects.
  18. names (List[str]): List of class names that the model can detect.
  19. centroids (List[List[int]]): List to store centroids of selected bounding boxes.
  20. Methods:
  21. mouse_event_for_distance: Handles mouse events for selecting objects in the video stream.
  22. calculate: Processes video frames and calculates the distance between selected objects.
  23. Examples:
  24. >>> distance_calc = DistanceCalculation()
  25. >>> frame = cv2.imread("frame.jpg")
  26. >>> processed_frame = distance_calc.calculate(frame)
  27. >>> cv2.imshow("Distance Calculation", processed_frame)
  28. >>> cv2.waitKey(0)
  29. """
  30. def __init__(self, **kwargs):
  31. """Initializes the DistanceCalculation class for measuring object distances in video streams."""
  32. super().__init__(**kwargs)
  33. # Mouse event information
  34. self.left_mouse_count = 0
  35. self.selected_boxes = {}
  36. self.centroids = [] # Initialize empty list to store centroids
  37. def mouse_event_for_distance(self, event, x, y, flags, param):
  38. """
  39. Handles mouse events to select regions in a real-time video stream for distance calculation.
  40. Args:
  41. event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN).
  42. x (int): X-coordinate of the mouse pointer.
  43. y (int): Y-coordinate of the mouse pointer.
  44. flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY).
  45. param (Dict): Additional parameters passed to the function.
  46. Examples:
  47. >>> # Assuming 'dc' is an instance of DistanceCalculation
  48. >>> cv2.setMouseCallback("window_name", dc.mouse_event_for_distance)
  49. """
  50. if event == cv2.EVENT_LBUTTONDOWN:
  51. self.left_mouse_count += 1
  52. if self.left_mouse_count <= 2:
  53. for box, track_id in zip(self.boxes, self.track_ids):
  54. if box[0] < x < box[2] and box[1] < y < box[3] and track_id not in self.selected_boxes:
  55. self.selected_boxes[track_id] = box
  56. elif event == cv2.EVENT_RBUTTONDOWN:
  57. self.selected_boxes = {}
  58. self.left_mouse_count = 0
  59. def calculate(self, im0):
  60. """
  61. Processes a video frame and calculates the distance between two selected bounding boxes.
  62. This method extracts tracks from the input frame, annotates bounding boxes, and calculates the distance
  63. between two user-selected objects if they have been chosen.
  64. Args:
  65. im0 (numpy.ndarray): The input image frame to process.
  66. Returns:
  67. (numpy.ndarray): The processed image frame with annotations and distance calculations.
  68. Examples:
  69. >>> import numpy as np
  70. >>> from ultralytics.solutions import DistanceCalculation
  71. >>> dc = DistanceCalculation()
  72. >>> frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
  73. >>> processed_frame = dc.calculate(frame)
  74. """
  75. self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
  76. self.extract_tracks(im0) # Extract tracks
  77. # Iterate over bounding boxes, track ids and classes index
  78. for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
  79. self.annotator.box_label(box, color=colors(int(cls), True), label=self.names[int(cls)])
  80. if len(self.selected_boxes) == 2:
  81. for trk_id in self.selected_boxes.keys():
  82. if trk_id == track_id:
  83. self.selected_boxes[track_id] = box
  84. if len(self.selected_boxes) == 2:
  85. # Store user selected boxes in centroids list
  86. self.centroids.extend(
  87. [[int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)] for box in self.selected_boxes.values()]
  88. )
  89. # Calculate pixels distance
  90. pixels_distance = math.sqrt(
  91. (self.centroids[0][0] - self.centroids[1][0]) ** 2 + (self.centroids[0][1] - self.centroids[1][1]) ** 2
  92. )
  93. self.annotator.plot_distance_and_line(pixels_distance, self.centroids)
  94. self.centroids = []
  95. self.display_output(im0) # display output with base class function
  96. cv2.setMouseCallback("Ultralytics Solutions", self.mouse_event_for_distance)
  97. return im0 # return output image for more usage