bot_sort.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from collections import deque
  3. import numpy as np
  4. from .basetrack import TrackState
  5. from .byte_tracker import BYTETracker, STrack
  6. from .utils import matching
  7. from .utils.gmc import GMC
  8. from .utils.kalman_filter import KalmanFilterXYWH
  9. class BOTrack(STrack):
  10. """
  11. An extended version of the STrack class for YOLOv8, adding object tracking features.
  12. This class extends the STrack class to include additional functionalities for object tracking, such as feature
  13. smoothing, Kalman filter prediction, and reactivation of tracks.
  14. Attributes:
  15. shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
  16. smooth_feat (np.ndarray): Smoothed feature vector.
  17. curr_feat (np.ndarray): Current feature vector.
  18. features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`.
  19. alpha (float): Smoothing factor for the exponential moving average of features.
  20. mean (np.ndarray): The mean state of the Kalman filter.
  21. covariance (np.ndarray): The covariance matrix of the Kalman filter.
  22. Methods:
  23. update_features(feat): Update features vector and smooth it using exponential moving average.
  24. predict(): Predicts the mean and covariance using Kalman filter.
  25. re_activate(new_track, frame_id, new_id): Reactivates a track with updated features and optionally new ID.
  26. update(new_track, frame_id): Update the YOLOv8 instance with new track and frame ID.
  27. tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`.
  28. multi_predict(stracks): Predicts the mean and covariance of multiple object tracks using shared Kalman filter.
  29. convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format.
  30. tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`.
  31. Examples:
  32. Create a BOTrack instance and update its features
  33. >>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128))
  34. >>> bo_track.predict()
  35. >>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128))
  36. >>> bo_track.update(new_track, frame_id=2)
  37. """
  38. shared_kalman = KalmanFilterXYWH()
  39. def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
  40. """
  41. Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
  42. Args:
  43. tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).
  44. score (float): Confidence score of the detection.
  45. cls (int): Class ID of the detected object.
  46. feat (np.ndarray | None): Feature vector associated with the detection.
  47. feat_history (int): Maximum length of the feature history deque.
  48. Examples:
  49. Initialize a BOTrack object with bounding box, score, class ID, and feature vector
  50. >>> tlwh = np.array([100, 50, 80, 120])
  51. >>> score = 0.9
  52. >>> cls = 1
  53. >>> feat = np.random.rand(128)
  54. >>> bo_track = BOTrack(tlwh, score, cls, feat)
  55. """
  56. super().__init__(tlwh, score, cls)
  57. self.smooth_feat = None
  58. self.curr_feat = None
  59. if feat is not None:
  60. self.update_features(feat)
  61. self.features = deque([], maxlen=feat_history)
  62. self.alpha = 0.9
  63. def update_features(self, feat):
  64. """Update the feature vector and apply exponential moving average smoothing."""
  65. feat /= np.linalg.norm(feat)
  66. self.curr_feat = feat
  67. if self.smooth_feat is None:
  68. self.smooth_feat = feat
  69. else:
  70. self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
  71. self.features.append(feat)
  72. self.smooth_feat /= np.linalg.norm(self.smooth_feat)
  73. def predict(self):
  74. """Predicts the object's future state using the Kalman filter to update its mean and covariance."""
  75. mean_state = self.mean.copy()
  76. if self.state != TrackState.Tracked:
  77. mean_state[6] = 0
  78. mean_state[7] = 0
  79. self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
  80. def re_activate(self, new_track, frame_id, new_id=False):
  81. """Reactivates a track with updated features and optionally assigns a new ID."""
  82. if new_track.curr_feat is not None:
  83. self.update_features(new_track.curr_feat)
  84. super().re_activate(new_track, frame_id, new_id)
  85. def update(self, new_track, frame_id):
  86. """Updates the YOLOv8 instance with new track information and the current frame ID."""
  87. if new_track.curr_feat is not None:
  88. self.update_features(new_track.curr_feat)
  89. super().update(new_track, frame_id)
  90. @property
  91. def tlwh(self):
  92. """Returns the current bounding box position in `(top left x, top left y, width, height)` format."""
  93. if self.mean is None:
  94. return self._tlwh.copy()
  95. ret = self.mean[:4].copy()
  96. ret[:2] -= ret[2:] / 2
  97. return ret
  98. @staticmethod
  99. def multi_predict(stracks):
  100. """Predicts the mean and covariance for multiple object tracks using a shared Kalman filter."""
  101. if len(stracks) <= 0:
  102. return
  103. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  104. multi_covariance = np.asarray([st.covariance for st in stracks])
  105. for i, st in enumerate(stracks):
  106. if st.state != TrackState.Tracked:
  107. multi_mean[i][6] = 0
  108. multi_mean[i][7] = 0
  109. multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
  110. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  111. stracks[i].mean = mean
  112. stracks[i].covariance = cov
  113. def convert_coords(self, tlwh):
  114. """Converts tlwh bounding box coordinates to xywh format."""
  115. return self.tlwh_to_xywh(tlwh)
  116. @staticmethod
  117. def tlwh_to_xywh(tlwh):
  118. """Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
  119. ret = np.asarray(tlwh).copy()
  120. ret[:2] += ret[2:] / 2
  121. return ret
  122. class BOTSORT(BYTETracker):
  123. """
  124. An extended version of the BYTETracker class for YOLOv8, designed for object tracking with ReID and GMC algorithm.
  125. Attributes:
  126. proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
  127. appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
  128. encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled.
  129. gmc (GMC): An instance of the GMC algorithm for data association.
  130. args (Any): Parsed command-line arguments containing tracking parameters.
  131. Methods:
  132. get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking.
  133. init_track(dets, scores, cls, img): Initialize track with detections, scores, and classes.
  134. get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID.
  135. multi_predict(tracks): Predict and track multiple objects with YOLOv8 model.
  136. Examples:
  137. Initialize BOTSORT and process detections
  138. >>> bot_sort = BOTSORT(args, frame_rate=30)
  139. >>> bot_sort.init_track(dets, scores, cls, img)
  140. >>> bot_sort.multi_predict(tracks)
  141. Note:
  142. The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args.
  143. """
  144. def __init__(self, args, frame_rate=30):
  145. """
  146. Initialize YOLOv8 object with ReID module and GMC algorithm.
  147. Args:
  148. args (object): Parsed command-line arguments containing tracking parameters.
  149. frame_rate (int): Frame rate of the video being processed.
  150. Examples:
  151. Initialize BOTSORT with command-line arguments and a specified frame rate:
  152. >>> args = parse_args()
  153. >>> bot_sort = BOTSORT(args, frame_rate=30)
  154. """
  155. super().__init__(args, frame_rate)
  156. # ReID module
  157. self.proximity_thresh = args.proximity_thresh
  158. self.appearance_thresh = args.appearance_thresh
  159. if args.with_reid:
  160. # Haven't supported BoT-SORT(reid) yet
  161. self.encoder = None
  162. self.gmc = GMC(method=args.gmc_method)
  163. def get_kalmanfilter(self):
  164. """Returns an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
  165. return KalmanFilterXYWH()
  166. def init_track(self, dets, scores, cls, img=None):
  167. """Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
  168. if len(dets) == 0:
  169. return []
  170. if self.args.with_reid and self.encoder is not None:
  171. features_keep = self.encoder.inference(img, dets)
  172. return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections
  173. else:
  174. return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
  175. def get_dists(self, tracks, detections):
  176. """Calculates distances between tracks and detections using IoU and optionally ReID embeddings."""
  177. dists = matching.iou_distance(tracks, detections)
  178. dists_mask = dists > self.proximity_thresh
  179. if self.args.fuse_score:
  180. dists = matching.fuse_score(dists, detections)
  181. if self.args.with_reid and self.encoder is not None:
  182. emb_dists = matching.embedding_distance(tracks, detections) / 2.0
  183. emb_dists[emb_dists > self.appearance_thresh] = 1.0
  184. emb_dists[dists_mask] = 1.0
  185. dists = np.minimum(dists, emb_dists)
  186. return dists
  187. def multi_predict(self, tracks):
  188. """Predicts the mean and covariance of multiple object tracks using a shared Kalman filter."""
  189. BOTrack.multi_predict(tracks)
  190. def reset(self):
  191. """Resets the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
  192. super().reset()
  193. self.gmc.reset_params()