123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
- from collections import defaultdict
- import cv2
- from ultralytics import YOLO
- from ultralytics.utils import ASSETS_URL, DEFAULT_CFG_DICT, DEFAULT_SOL_DICT, LOGGER
- from ultralytics.utils.checks import check_imshow, check_requirements
- class BaseSolution:
- """
- A base class for managing Ultralytics Solutions.
- This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking,
- and region initialization.
- Attributes:
- LineString (shapely.geometry.LineString): Class for creating line string geometries.
- Polygon (shapely.geometry.Polygon): Class for creating polygon geometries.
- Point (shapely.geometry.Point): Class for creating point geometries.
- CFG (Dict): Configuration dictionary loaded from a YAML file and updated with kwargs.
- region (List[Tuple[int, int]]): List of coordinate tuples defining a region of interest.
- line_width (int): Width of lines used in visualizations.
- model (ultralytics.YOLO): Loaded YOLO model instance.
- names (Dict[int, str]): Dictionary mapping class indices to class names.
- env_check (bool): Flag indicating whether the environment supports image display.
- track_history (collections.defaultdict): Dictionary to store tracking history for each object.
- Methods:
- extract_tracks: Apply object tracking and extract tracks from an input image.
- store_tracking_history: Store object tracking history for a given track ID and bounding box.
- initialize_region: Initialize the counting region and line segment based on configuration.
- display_output: Display the results of processing, including showing frames or saving results.
- Examples:
- >>> solution = BaseSolution(model="yolov8n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)])
- >>> solution.initialize_region()
- >>> image = cv2.imread("image.jpg")
- >>> solution.extract_tracks(image)
- >>> solution.display_output(image)
- """
- def __init__(self, IS_CLI=False, **kwargs):
- """
- Initializes the `BaseSolution` class with configuration settings and the YOLO model for Ultralytics solutions.
- IS_CLI (optional): Enables CLI mode if set.
- """
- check_requirements("shapely>=2.0.0")
- from shapely.geometry import LineString, Point, Polygon
- from shapely.prepared import prep
- self.LineString = LineString
- self.Polygon = Polygon
- self.Point = Point
- self.prep = prep
- self.annotator = None # Initialize annotator
- self.tracks = None
- self.track_data = None
- self.boxes = []
- self.clss = []
- self.track_ids = []
- self.track_line = None
- self.r_s = None
- # Load config and update with args
- DEFAULT_SOL_DICT.update(kwargs)
- DEFAULT_CFG_DICT.update(kwargs)
- self.CFG = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT}
- LOGGER.info(f"Ultralytics Solutions: ✅ {DEFAULT_SOL_DICT}")
- self.region = self.CFG["region"] # Store region data for other classes usage
- self.line_width = (
- self.CFG["line_width"] if self.CFG["line_width"] is not None else 2
- ) # Store line_width for usage
- # Load Model and store classes names
- if self.CFG["model"] is None:
- self.CFG["model"] = "yolo11n.pt"
- self.model = YOLO(self.CFG["model"])
- self.names = self.model.names
- self.track_add_args = { # Tracker additional arguments for advance configuration
- k: self.CFG[k] for k in ["verbose", "iou", "conf", "device", "max_det", "half", "tracker"]
- }
- if IS_CLI and self.CFG["source"] is None:
- d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4"
- LOGGER.warning(f"⚠️ WARNING: source not provided. using default source {ASSETS_URL}/{d_s}")
- from ultralytics.utils.downloads import safe_download
- safe_download(f"{ASSETS_URL}/{d_s}") # download source from ultralytics assets
- self.CFG["source"] = d_s # set default source
- # Initialize environment and region setup
- self.env_check = check_imshow(warn=True)
- self.track_history = defaultdict(list)
- def extract_tracks(self, im0):
- """
- Applies object tracking and extracts tracks from an input image or frame.
- Args:
- im0 (ndarray): The input image or frame.
- Examples:
- >>> solution = BaseSolution()
- >>> frame = cv2.imread("path/to/image.jpg")
- >>> solution.extract_tracks(frame)
- """
- self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args)
- # Extract tracks for OBB or object detection
- self.track_data = self.tracks[0].obb or self.tracks[0].boxes
- if self.track_data and self.track_data.id is not None:
- self.boxes = self.track_data.xyxy.cpu()
- self.clss = self.track_data.cls.cpu().tolist()
- self.track_ids = self.track_data.id.int().cpu().tolist()
- else:
- LOGGER.warning("WARNING ⚠️ no tracks found!")
- self.boxes, self.clss, self.track_ids = [], [], []
- def store_tracking_history(self, track_id, box):
- """
- Stores the tracking history of an object.
- This method updates the tracking history for a given object by appending the center point of its
- bounding box to the track line. It maintains a maximum of 30 points in the tracking history.
- Args:
- track_id (int): The unique identifier for the tracked object.
- box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2].
- Examples:
- >>> solution = BaseSolution()
- >>> solution.store_tracking_history(1, [100, 200, 300, 400])
- """
- # Store tracking history
- self.track_line = self.track_history[track_id]
- self.track_line.append(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2))
- if len(self.track_line) > 30:
- self.track_line.pop(0)
- def initialize_region(self):
- """Initialize the counting region and line segment based on configuration settings."""
- if self.region is None:
- self.region = [(20, 400), (1080, 400), (1080, 360), (20, 360)]
- self.r_s = (
- self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region)
- ) # region or line
- def display_output(self, im0):
- """
- Display the results of the processing, which could involve showing frames, printing counts, or saving results.
- This method is responsible for visualizing the output of the object detection and tracking process. It displays
- the processed frame with annotations, and allows for user interaction to close the display.
- Args:
- im0 (numpy.ndarray): The input image or frame that has been processed and annotated.
- Examples:
- >>> solution = BaseSolution()
- >>> frame = cv2.imread("path/to/image.jpg")
- >>> solution.display_output(frame)
- Notes:
- - This method will only display output if the 'show' configuration is set to True and the environment
- supports image display.
- - The display can be closed by pressing the 'q' key.
- """
- if self.CFG.get("show") and self.env_check:
- cv2.imshow("Ultralytics Solutions", im0)
- if cv2.waitKey(1) & 0xFF == ord("q"):
- return
|