solutions.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from collections import defaultdict
  3. import cv2
  4. from ultralytics import YOLO
  5. from ultralytics.utils import ASSETS_URL, DEFAULT_CFG_DICT, DEFAULT_SOL_DICT, LOGGER
  6. from ultralytics.utils.checks import check_imshow, check_requirements
  7. class BaseSolution:
  8. """
  9. A base class for managing Ultralytics Solutions.
  10. This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking,
  11. and region initialization.
  12. Attributes:
  13. LineString (shapely.geometry.LineString): Class for creating line string geometries.
  14. Polygon (shapely.geometry.Polygon): Class for creating polygon geometries.
  15. Point (shapely.geometry.Point): Class for creating point geometries.
  16. CFG (Dict): Configuration dictionary loaded from a YAML file and updated with kwargs.
  17. region (List[Tuple[int, int]]): List of coordinate tuples defining a region of interest.
  18. line_width (int): Width of lines used in visualizations.
  19. model (ultralytics.YOLO): Loaded YOLO model instance.
  20. names (Dict[int, str]): Dictionary mapping class indices to class names.
  21. env_check (bool): Flag indicating whether the environment supports image display.
  22. track_history (collections.defaultdict): Dictionary to store tracking history for each object.
  23. Methods:
  24. extract_tracks: Apply object tracking and extract tracks from an input image.
  25. store_tracking_history: Store object tracking history for a given track ID and bounding box.
  26. initialize_region: Initialize the counting region and line segment based on configuration.
  27. display_output: Display the results of processing, including showing frames or saving results.
  28. Examples:
  29. >>> solution = BaseSolution(model="yolov8n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)])
  30. >>> solution.initialize_region()
  31. >>> image = cv2.imread("image.jpg")
  32. >>> solution.extract_tracks(image)
  33. >>> solution.display_output(image)
  34. """
  35. def __init__(self, IS_CLI=False, **kwargs):
  36. """
  37. Initializes the `BaseSolution` class with configuration settings and the YOLO model for Ultralytics solutions.
  38. IS_CLI (optional): Enables CLI mode if set.
  39. """
  40. check_requirements("shapely>=2.0.0")
  41. from shapely.geometry import LineString, Point, Polygon
  42. from shapely.prepared import prep
  43. self.LineString = LineString
  44. self.Polygon = Polygon
  45. self.Point = Point
  46. self.prep = prep
  47. self.annotator = None # Initialize annotator
  48. self.tracks = None
  49. self.track_data = None
  50. self.boxes = []
  51. self.clss = []
  52. self.track_ids = []
  53. self.track_line = None
  54. self.r_s = None
  55. # Load config and update with args
  56. DEFAULT_SOL_DICT.update(kwargs)
  57. DEFAULT_CFG_DICT.update(kwargs)
  58. self.CFG = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT}
  59. LOGGER.info(f"Ultralytics Solutions: ✅ {DEFAULT_SOL_DICT}")
  60. self.region = self.CFG["region"] # Store region data for other classes usage
  61. self.line_width = (
  62. self.CFG["line_width"] if self.CFG["line_width"] is not None else 2
  63. ) # Store line_width for usage
  64. # Load Model and store classes names
  65. if self.CFG["model"] is None:
  66. self.CFG["model"] = "yolo11n.pt"
  67. self.model = YOLO(self.CFG["model"])
  68. self.names = self.model.names
  69. self.track_add_args = { # Tracker additional arguments for advance configuration
  70. k: self.CFG[k] for k in ["verbose", "iou", "conf", "device", "max_det", "half", "tracker"]
  71. }
  72. if IS_CLI and self.CFG["source"] is None:
  73. d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4"
  74. LOGGER.warning(f"⚠️ WARNING: source not provided. using default source {ASSETS_URL}/{d_s}")
  75. from ultralytics.utils.downloads import safe_download
  76. safe_download(f"{ASSETS_URL}/{d_s}") # download source from ultralytics assets
  77. self.CFG["source"] = d_s # set default source
  78. # Initialize environment and region setup
  79. self.env_check = check_imshow(warn=True)
  80. self.track_history = defaultdict(list)
  81. def extract_tracks(self, im0):
  82. """
  83. Applies object tracking and extracts tracks from an input image or frame.
  84. Args:
  85. im0 (ndarray): The input image or frame.
  86. Examples:
  87. >>> solution = BaseSolution()
  88. >>> frame = cv2.imread("path/to/image.jpg")
  89. >>> solution.extract_tracks(frame)
  90. """
  91. self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args)
  92. # Extract tracks for OBB or object detection
  93. self.track_data = self.tracks[0].obb or self.tracks[0].boxes
  94. if self.track_data and self.track_data.id is not None:
  95. self.boxes = self.track_data.xyxy.cpu()
  96. self.clss = self.track_data.cls.cpu().tolist()
  97. self.track_ids = self.track_data.id.int().cpu().tolist()
  98. else:
  99. LOGGER.warning("WARNING ⚠️ no tracks found!")
  100. self.boxes, self.clss, self.track_ids = [], [], []
  101. def store_tracking_history(self, track_id, box):
  102. """
  103. Stores the tracking history of an object.
  104. This method updates the tracking history for a given object by appending the center point of its
  105. bounding box to the track line. It maintains a maximum of 30 points in the tracking history.
  106. Args:
  107. track_id (int): The unique identifier for the tracked object.
  108. box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2].
  109. Examples:
  110. >>> solution = BaseSolution()
  111. >>> solution.store_tracking_history(1, [100, 200, 300, 400])
  112. """
  113. # Store tracking history
  114. self.track_line = self.track_history[track_id]
  115. self.track_line.append(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2))
  116. if len(self.track_line) > 30:
  117. self.track_line.pop(0)
  118. def initialize_region(self):
  119. """Initialize the counting region and line segment based on configuration settings."""
  120. if self.region is None:
  121. self.region = [(20, 400), (1080, 400), (1080, 360), (20, 360)]
  122. self.r_s = (
  123. self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region)
  124. ) # region or line
  125. def display_output(self, im0):
  126. """
  127. Display the results of the processing, which could involve showing frames, printing counts, or saving results.
  128. This method is responsible for visualizing the output of the object detection and tracking process. It displays
  129. the processed frame with annotations, and allows for user interaction to close the display.
  130. Args:
  131. im0 (numpy.ndarray): The input image or frame that has been processed and annotated.
  132. Examples:
  133. >>> solution = BaseSolution()
  134. >>> frame = cv2.imread("path/to/image.jpg")
  135. >>> solution.display_output(frame)
  136. Notes:
  137. - This method will only display output if the 'show' configuration is set to True and the environment
  138. supports image display.
  139. - The display can be closed by pressing the 'q' key.
  140. """
  141. if self.CFG.get("show") and self.env_check:
  142. cv2.imshow("Ultralytics Solutions", im0)
  143. if cv2.waitKey(1) & 0xFF == ord("q"):
  144. return