predictor.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. """
  3. Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
  4. Usage - sources:
  5. $ yolo mode=predict model=yolov8n.pt source=0 # webcam
  6. img.jpg # image
  7. vid.mp4 # video
  8. screen # screenshot
  9. path/ # directory
  10. list.txt # list of images
  11. list.streams # list of streams
  12. 'path/*.jpg' # glob
  13. 'https://youtu.be/LNwODJXcvt4' # YouTube
  14. 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream
  15. Usage - formats:
  16. $ yolo mode=predict model=yolov8n.pt # PyTorch
  17. yolov8n.torchscript # TorchScript
  18. yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
  19. yolov8n_openvino_model # OpenVINO
  20. yolov8n.engine # TensorRT
  21. yolov8n.mlpackage # CoreML (macOS-only)
  22. yolov8n_saved_model # TensorFlow SavedModel
  23. yolov8n.pb # TensorFlow GraphDef
  24. yolov8n.tflite # TensorFlow Lite
  25. yolov8n_edgetpu.tflite # TensorFlow Edge TPU
  26. yolov8n_paddle_model # PaddlePaddle
  27. yolov8n.mnn # MNN
  28. yolov8n_ncnn_model # NCNN
  29. """
  30. import platform
  31. import re
  32. import threading
  33. from pathlib import Path
  34. import cv2
  35. import numpy as np
  36. import torch
  37. from ultralytics.cfg import get_cfg, get_save_dir
  38. from ultralytics.data import load_inference_source
  39. from ultralytics.data.augment import LetterBox, classify_transforms
  40. from ultralytics.nn.autobackend import AutoBackend
  41. from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
  42. from ultralytics.utils.checks import check_imgsz, check_imshow
  43. from ultralytics.utils.files import increment_path
  44. from ultralytics.utils.torch_utils import select_device, smart_inference_mode
  45. STREAM_WARNING = """
  46. WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
  47. errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
  48. Example:
  49. results = model(source=..., stream=True) # generator of Results objects
  50. for r in results:
  51. boxes = r.boxes # Boxes object for bbox outputs
  52. masks = r.masks # Masks object for segment masks outputs
  53. probs = r.probs # Class probabilities for classification outputs
  54. """
  55. class BasePredictor:
  56. """
  57. BasePredictor.
  58. A base class for creating predictors.
  59. Attributes:
  60. args (SimpleNamespace): Configuration for the predictor.
  61. save_dir (Path): Directory to save results.
  62. done_warmup (bool): Whether the predictor has finished setup.
  63. model (nn.Module): Model used for prediction.
  64. data (dict): Data configuration.
  65. device (torch.device): Device used for prediction.
  66. dataset (Dataset): Dataset used for prediction.
  67. vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output.
  68. """
  69. def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
  70. """
  71. Initializes the BasePredictor class.
  72. Args:
  73. cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
  74. overrides (dict, optional): Configuration overrides. Defaults to None.
  75. """
  76. self.args = get_cfg(cfg, overrides)
  77. self.save_dir = get_save_dir(self.args)
  78. if self.args.conf is None:
  79. self.args.conf = 0.25 # default conf=0.25
  80. self.done_warmup = False
  81. if self.args.show:
  82. self.args.show = check_imshow(warn=True)
  83. # Usable if setup is done
  84. self.model = None
  85. self.data = self.args.data # data_dict
  86. self.imgsz = None
  87. self.device = None
  88. self.dataset = None
  89. self.vid_writer = {} # dict of {save_path: video_writer, ...}
  90. self.plotted_img = None
  91. self.source_type = None
  92. self.seen = 0
  93. self.windows = []
  94. self.batch = None
  95. self.results = None
  96. self.transforms = None
  97. self.callbacks = _callbacks or callbacks.get_default_callbacks()
  98. self.txt_path = None
  99. self._lock = threading.Lock() # for automatic thread-safe inference
  100. callbacks.add_integration_callbacks(self)
  101. def preprocess(self, im):
  102. """
  103. Prepares input image before inference.
  104. Args:
  105. im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
  106. """
  107. not_tensor = not isinstance(im, torch.Tensor)
  108. if not_tensor:
  109. im = np.stack(self.pre_transform(im))
  110. im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
  111. im = np.ascontiguousarray(im) # contiguous
  112. im = torch.from_numpy(im)
  113. im = im.to(self.device)
  114. im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
  115. if not_tensor:
  116. im /= 255 # 0 - 255 to 0.0 - 1.0
  117. return im
  118. def inference(self, im, *args, **kwargs):
  119. """Runs inference on a given image using the specified model and arguments."""
  120. visualize = (
  121. increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
  122. if self.args.visualize and (not self.source_type.tensor)
  123. else False
  124. )
  125. return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
  126. def pre_transform(self, im):
  127. """
  128. Pre-transform input image before inference.
  129. Args:
  130. im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
  131. Returns:
  132. (list): A list of transformed images.
  133. """
  134. same_shapes = len({x.shape for x in im}) == 1
  135. letterbox = LetterBox(
  136. self.imgsz,
  137. auto=same_shapes and (self.model.pt or (getattr(self.model, "dynamic", False) and not self.model.imx)),
  138. stride=self.model.stride,
  139. )
  140. return [letterbox(image=x) for x in im]
  141. def postprocess(self, preds, img, orig_imgs):
  142. """Post-processes predictions for an image and returns them."""
  143. return preds
  144. def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
  145. """Performs inference on an image or stream."""
  146. self.stream = stream
  147. if stream:
  148. return self.stream_inference(source, model, *args, **kwargs)
  149. else:
  150. return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
  151. def predict_cli(self, source=None, model=None):
  152. """
  153. Method used for Command Line Interface (CLI) prediction.
  154. This function is designed to run predictions using the CLI. It sets up the source and model, then processes
  155. the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
  156. generator without storing results.
  157. Note:
  158. Do not modify this function or remove the generator. The generator ensures that no outputs are
  159. accumulated in memory, which is critical for preventing memory issues during long-running predictions.
  160. """
  161. gen = self.stream_inference(source, model)
  162. for _ in gen: # sourcery skip: remove-empty-nested-block, noqa
  163. pass
  164. def setup_source(self, source):
  165. """Sets up source and inference mode."""
  166. self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
  167. self.transforms = (
  168. getattr(
  169. self.model.model,
  170. "transforms",
  171. classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
  172. )
  173. if self.args.task == "classify"
  174. else None
  175. )
  176. self.dataset = load_inference_source(
  177. source=source,
  178. batch=self.args.batch,
  179. vid_stride=self.args.vid_stride,
  180. buffer=self.args.stream_buffer,
  181. )
  182. self.source_type = self.dataset.source_type
  183. if not getattr(self, "stream", True) and (
  184. self.source_type.stream
  185. or self.source_type.screenshot
  186. or len(self.dataset) > 1000 # many images
  187. or any(getattr(self.dataset, "video_flag", [False]))
  188. ): # videos
  189. LOGGER.warning(STREAM_WARNING)
  190. self.vid_writer = {}
  191. @smart_inference_mode()
  192. def stream_inference(self, source=None, model=None, *args, **kwargs):
  193. """Streams real-time inference on camera feed and saves results to file."""
  194. if self.args.verbose:
  195. LOGGER.info("")
  196. # Setup model
  197. if not self.model:
  198. self.setup_model(model)
  199. with self._lock: # for thread-safe inference
  200. # Setup source every time predict is called
  201. self.setup_source(source if source is not None else self.args.source)
  202. # Check if save_dir/ label file exists
  203. if self.args.save or self.args.save_txt:
  204. (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
  205. # Warmup model
  206. if not self.done_warmup:
  207. self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
  208. self.done_warmup = True
  209. self.seen, self.windows, self.batch = 0, [], None
  210. profilers = (
  211. ops.Profile(device=self.device),
  212. ops.Profile(device=self.device),
  213. ops.Profile(device=self.device),
  214. )
  215. self.run_callbacks("on_predict_start")
  216. for self.batch in self.dataset:
  217. self.run_callbacks("on_predict_batch_start")
  218. paths, im0s, s = self.batch
  219. # Preprocess
  220. with profilers[0]:
  221. im = self.preprocess(im0s)
  222. # Inference
  223. with profilers[1]:
  224. preds = self.inference(im, *args, **kwargs)
  225. if self.args.embed:
  226. yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
  227. continue
  228. # Postprocess
  229. with profilers[2]:
  230. self.results = self.postprocess(preds, im, im0s)
  231. self.run_callbacks("on_predict_postprocess_end")
  232. # Visualize, save, write results
  233. n = len(im0s)
  234. for i in range(n):
  235. self.seen += 1
  236. self.results[i].speed = {
  237. "preprocess": profilers[0].dt * 1e3 / n,
  238. "inference": profilers[1].dt * 1e3 / n,
  239. "postprocess": profilers[2].dt * 1e3 / n,
  240. }
  241. if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
  242. s[i] += self.write_results(i, Path(paths[i]), im, s)
  243. # Print batch results
  244. if self.args.verbose:
  245. LOGGER.info("\n".join(s))
  246. self.run_callbacks("on_predict_batch_end")
  247. yield from self.results
  248. # Release assets
  249. for v in self.vid_writer.values():
  250. if isinstance(v, cv2.VideoWriter):
  251. v.release()
  252. # Print final results
  253. if self.args.verbose and self.seen:
  254. t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
  255. LOGGER.info(
  256. f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
  257. f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t
  258. )
  259. if self.args.save or self.args.save_txt or self.args.save_crop:
  260. nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
  261. s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
  262. LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
  263. self.run_callbacks("on_predict_end")
  264. def setup_model(self, model, verbose=True):
  265. """Initialize YOLO model with given parameters and set it to evaluation mode."""
  266. self.model = AutoBackend(
  267. weights=model or self.args.model,
  268. device=select_device(self.args.device, verbose=verbose),
  269. dnn=self.args.dnn,
  270. data=self.args.data,
  271. fp16=self.args.half,
  272. batch=self.args.batch,
  273. fuse=True,
  274. verbose=verbose,
  275. )
  276. self.device = self.model.device # update device
  277. self.args.half = self.model.fp16 # update half
  278. self.model.eval()
  279. def write_results(self, i, p, im, s):
  280. """Write inference results to a file or directory."""
  281. string = "" # print string
  282. if len(im.shape) == 3:
  283. im = im[None] # expand for batch dim
  284. if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
  285. string += f"{i}: "
  286. frame = self.dataset.count
  287. else:
  288. match = re.search(r"frame (\d+)/", s[i])
  289. frame = int(match[1]) if match else None # 0 if frame undetermined
  290. self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
  291. string += "{:g}x{:g} ".format(*im.shape[2:])
  292. result = self.results[i]
  293. result.save_dir = self.save_dir.__str__() # used in other locations
  294. string += f"{result.verbose()}{result.speed['inference']:.1f}ms"
  295. # Add predictions to image
  296. if self.args.save or self.args.show:
  297. self.plotted_img = result.plot(
  298. line_width=self.args.line_width,
  299. boxes=self.args.show_boxes,
  300. conf=self.args.show_conf,
  301. labels=self.args.show_labels,
  302. im_gpu=None if self.args.retina_masks else im[i],
  303. )
  304. # Save results
  305. if self.args.save_txt:
  306. result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
  307. if self.args.save_crop:
  308. result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
  309. if self.args.show:
  310. self.show(str(p))
  311. if self.args.save:
  312. self.save_predicted_images(str(self.save_dir / p.name), frame)
  313. return string
  314. def save_predicted_images(self, save_path="", frame=0):
  315. """Save video predictions as mp4 at specified path."""
  316. im = self.plotted_img
  317. # Save videos and streams
  318. if self.dataset.mode in {"stream", "video"}:
  319. fps = self.dataset.fps if self.dataset.mode == "video" else 30
  320. frames_path = f"{save_path.split('.', 1)[0]}_frames/"
  321. if save_path not in self.vid_writer: # new video
  322. if self.args.save_frames:
  323. Path(frames_path).mkdir(parents=True, exist_ok=True)
  324. suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
  325. self.vid_writer[save_path] = cv2.VideoWriter(
  326. filename=str(Path(save_path).with_suffix(suffix)),
  327. fourcc=cv2.VideoWriter_fourcc(*fourcc),
  328. fps=fps, # integer required, floats produce error in MP4 codec
  329. frameSize=(im.shape[1], im.shape[0]), # (width, height)
  330. )
  331. # Save video
  332. self.vid_writer[save_path].write(im)
  333. if self.args.save_frames:
  334. cv2.imwrite(f"{frames_path}{frame}.jpg", im)
  335. # Save images
  336. else:
  337. cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support
  338. def show(self, p=""):
  339. """Display an image in a window using the OpenCV imshow function."""
  340. im = self.plotted_img
  341. if platform.system() == "Linux" and p not in self.windows:
  342. self.windows.append(p)
  343. cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
  344. cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height)
  345. cv2.imshow(p, im)
  346. cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond
  347. def run_callbacks(self, event: str):
  348. """Runs all registered callbacks for a specific event."""
  349. for callback in self.callbacks.get(event, []):
  350. callback(self)
  351. def add_callback(self, event: str, func):
  352. """Add callback."""
  353. self.callbacks[event].append(func)