basetrack.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. """Module defines the base classes and structures for object tracking in YOLO."""
  3. from collections import OrderedDict
  4. import numpy as np
  5. class TrackState:
  6. """
  7. Enumeration class representing the possible states of an object being tracked.
  8. Attributes:
  9. New (int): State when the object is newly detected.
  10. Tracked (int): State when the object is successfully tracked in subsequent frames.
  11. Lost (int): State when the object is no longer tracked.
  12. Removed (int): State when the object is removed from tracking.
  13. Examples:
  14. >>> state = TrackState.New
  15. >>> if state == TrackState.New:
  16. >>> print("Object is newly detected.")
  17. """
  18. New = 0
  19. Tracked = 1
  20. Lost = 2
  21. Removed = 3
  22. class BaseTrack:
  23. """
  24. Base class for object tracking, providing foundational attributes and methods.
  25. Attributes:
  26. _count (int): Class-level counter for unique track IDs.
  27. track_id (int): Unique identifier for the track.
  28. is_activated (bool): Flag indicating whether the track is currently active.
  29. state (TrackState): Current state of the track.
  30. history (OrderedDict): Ordered history of the track's states.
  31. features (List): List of features extracted from the object for tracking.
  32. curr_feature (Any): The current feature of the object being tracked.
  33. score (float): The confidence score of the tracking.
  34. start_frame (int): The frame number where tracking started.
  35. frame_id (int): The most recent frame ID processed by the track.
  36. time_since_update (int): Frames passed since the last update.
  37. location (tuple): The location of the object in the context of multi-camera tracking.
  38. Methods:
  39. end_frame: Returns the ID of the last frame where the object was tracked.
  40. next_id: Increments and returns the next global track ID.
  41. activate: Abstract method to activate the track.
  42. predict: Abstract method to predict the next state of the track.
  43. update: Abstract method to update the track with new data.
  44. mark_lost: Marks the track as lost.
  45. mark_removed: Marks the track as removed.
  46. reset_id: Resets the global track ID counter.
  47. Examples:
  48. Initialize a new track and mark it as lost:
  49. >>> track = BaseTrack()
  50. >>> track.mark_lost()
  51. >>> print(track.state) # Output: 2 (TrackState.Lost)
  52. """
  53. _count = 0
  54. def __init__(self):
  55. """
  56. Initializes a new track with a unique ID and foundational tracking attributes.
  57. Examples:
  58. Initialize a new track
  59. >>> track = BaseTrack()
  60. >>> print(track.track_id)
  61. 0
  62. """
  63. self.track_id = 0
  64. self.is_activated = False
  65. self.state = TrackState.New
  66. self.history = OrderedDict()
  67. self.features = []
  68. self.curr_feature = None
  69. self.score = 0
  70. self.start_frame = 0
  71. self.frame_id = 0
  72. self.time_since_update = 0
  73. self.location = (np.inf, np.inf)
  74. @property
  75. def end_frame(self):
  76. """Returns the ID of the most recent frame where the object was tracked."""
  77. return self.frame_id
  78. @staticmethod
  79. def next_id():
  80. """Increment and return the next unique global track ID for object tracking."""
  81. BaseTrack._count += 1
  82. return BaseTrack._count
  83. def activate(self, *args):
  84. """Activates the track with provided arguments, initializing necessary attributes for tracking."""
  85. raise NotImplementedError
  86. def predict(self):
  87. """Predicts the next state of the track based on the current state and tracking model."""
  88. raise NotImplementedError
  89. def update(self, *args, **kwargs):
  90. """Updates the track with new observations and data, modifying its state and attributes accordingly."""
  91. raise NotImplementedError
  92. def mark_lost(self):
  93. """Marks the track as lost by updating its state to TrackState.Lost."""
  94. self.state = TrackState.Lost
  95. def mark_removed(self):
  96. """Marks the track as removed by setting its state to TrackState.Removed."""
  97. self.state = TrackState.Removed
  98. @staticmethod
  99. def reset_id():
  100. """Reset the global track ID counter to its initial value."""
  101. BaseTrack._count = 0