exporter.py 68 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. """
  3. Export a YOLO PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit.
  4. Format | `format=argument` | Model
  5. --- | --- | ---
  6. PyTorch | - | yolo11n.pt
  7. TorchScript | `torchscript` | yolo11n.torchscript
  8. ONNX | `onnx` | yolo11n.onnx
  9. OpenVINO | `openvino` | yolo11n_openvino_model/
  10. TensorRT | `engine` | yolo11n.engine
  11. CoreML | `coreml` | yolo11n.mlpackage
  12. TensorFlow SavedModel | `saved_model` | yolo11n_saved_model/
  13. TensorFlow GraphDef | `pb` | yolo11n.pb
  14. TensorFlow Lite | `tflite` | yolo11n.tflite
  15. TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite
  16. TensorFlow.js | `tfjs` | yolo11n_web_model/
  17. PaddlePaddle | `paddle` | yolo11n_paddle_model/
  18. MNN | `mnn` | yolo11n.mnn
  19. NCNN | `ncnn` | yolo11n_ncnn_model/
  20. IMX | `imx` | yolo11n_imx_model/
  21. Requirements:
  22. $ pip install "ultralytics[export]"
  23. Python:
  24. from ultralytics import YOLO
  25. model = YOLO('yolo11n.pt')
  26. results = model.export(format='onnx')
  27. CLI:
  28. $ yolo mode=export model=yolo11n.pt format=onnx
  29. Inference:
  30. $ yolo predict model=yolo11n.pt # PyTorch
  31. yolo11n.torchscript # TorchScript
  32. yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
  33. yolo11n_openvino_model # OpenVINO
  34. yolo11n.engine # TensorRT
  35. yolo11n.mlpackage # CoreML (macOS-only)
  36. yolo11n_saved_model # TensorFlow SavedModel
  37. yolo11n.pb # TensorFlow GraphDef
  38. yolo11n.tflite # TensorFlow Lite
  39. yolo11n_edgetpu.tflite # TensorFlow Edge TPU
  40. yolo11n_paddle_model # PaddlePaddle
  41. yolo11n.mnn # MNN
  42. yolo11n_ncnn_model # NCNN
  43. yolo11n_imx_model # IMX
  44. TensorFlow.js:
  45. $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
  46. $ npm install
  47. $ ln -s ../../yolo11n_web_model public/yolo11n_web_model
  48. $ npm start
  49. """
  50. import gc
  51. import json
  52. import os
  53. import shutil
  54. import subprocess
  55. import time
  56. import warnings
  57. from copy import deepcopy
  58. from datetime import datetime
  59. from pathlib import Path
  60. import numpy as np
  61. import torch
  62. from ultralytics.cfg import TASK2DATA, get_cfg
  63. from ultralytics.data import build_dataloader
  64. from ultralytics.data.dataset import YOLODataset
  65. from ultralytics.data.utils import check_cls_dataset, check_det_dataset
  66. from ultralytics.nn.autobackend import check_class_names, default_class_names
  67. from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder
  68. from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
  69. from ultralytics.utils import (
  70. ARM64,
  71. DEFAULT_CFG,
  72. IS_JETSON,
  73. LINUX,
  74. LOGGER,
  75. MACOS,
  76. PYTHON_VERSION,
  77. ROOT,
  78. WINDOWS,
  79. __version__,
  80. callbacks,
  81. colorstr,
  82. get_default_args,
  83. yaml_save,
  84. )
  85. from ultralytics.utils.checks import (
  86. check_imgsz,
  87. check_is_path_safe,
  88. check_requirements,
  89. check_version,
  90. is_sudo_available,
  91. )
  92. from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
  93. from ultralytics.utils.files import file_size, spaces_in_path
  94. from ultralytics.utils.ops import Profile
  95. from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device
  96. def export_formats():
  97. """Ultralytics YOLO export formats."""
  98. x = [
  99. ["PyTorch", "-", ".pt", True, True, []],
  100. ["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize"]],
  101. ["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify"]],
  102. ["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8"]],
  103. ["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify"]],
  104. ["CoreML", "coreml", ".mlpackage", True, False, ["batch", "half", "int8", "nms"]],
  105. ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras"]],
  106. ["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]],
  107. ["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8"]],
  108. ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []],
  109. ["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8"]],
  110. ["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]],
  111. ["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]],
  112. ["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]],
  113. ["IMX", "imx", "_imx_model", True, True, ["int8"]],
  114. ]
  115. return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU", "Arguments"], zip(*x)))
  116. def validate_args(format, passed_args, valid_args):
  117. """
  118. Validates arguments based on format.
  119. Args:
  120. format (str): The export format.
  121. passed_args (Namespace): The arguments used during export.
  122. valid_args (dict): List of valid arguments for the format.
  123. Raises:
  124. AssertionError: If an argument that's not supported by the export format is used, or if format doesn't have the supported arguments listed.
  125. """
  126. # Only check valid usage of these args
  127. export_args = ["half", "int8", "dynamic", "keras", "nms", "batch"]
  128. assert valid_args is not None, f"ERROR ❌️ valid arguments for '{format}' not listed."
  129. custom = {"batch": 1, "data": None, "device": None} # exporter defaults
  130. default_args = get_cfg(DEFAULT_CFG, custom)
  131. for arg in export_args:
  132. not_default = getattr(passed_args, arg, None) != getattr(default_args, arg, None)
  133. if not_default:
  134. assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'"
  135. def gd_outputs(gd):
  136. """TensorFlow GraphDef model output node names."""
  137. name_list, input_list = [], []
  138. for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
  139. name_list.append(node.name)
  140. input_list.extend(node.input)
  141. return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
  142. def try_export(inner_func):
  143. """YOLO export decorator, i.e. @try_export."""
  144. inner_args = get_default_args(inner_func)
  145. def outer_func(*args, **kwargs):
  146. """Export a model."""
  147. prefix = inner_args["prefix"]
  148. try:
  149. with Profile() as dt:
  150. f, model = inner_func(*args, **kwargs)
  151. LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
  152. return f, model
  153. except Exception as e:
  154. LOGGER.error(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}")
  155. raise e
  156. return outer_func
  157. class Exporter:
  158. """
  159. A class for exporting a model.
  160. Attributes:
  161. args (SimpleNamespace): Configuration for the exporter.
  162. callbacks (list, optional): List of callback functions. Defaults to None.
  163. """
  164. def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
  165. """
  166. Initializes the Exporter class.
  167. Args:
  168. cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
  169. overrides (dict, optional): Configuration overrides. Defaults to None.
  170. _callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
  171. """
  172. self.args = get_cfg(cfg, overrides)
  173. if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors
  174. os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
  175. self.callbacks = _callbacks or callbacks.get_default_callbacks()
  176. callbacks.add_integration_callbacks(self)
  177. def __call__(self, model=None) -> str:
  178. """Returns list of exported files/dirs after running callbacks."""
  179. self.run_callbacks("on_export_start")
  180. t = time.time()
  181. fmt = self.args.format.lower() # to lowercase
  182. if fmt in {"tensorrt", "trt"}: # 'engine' aliases
  183. fmt = "engine"
  184. if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases
  185. fmt = "coreml"
  186. fmts_dict = export_formats()
  187. fmts = tuple(fmts_dict["Argument"][1:]) # available export formats
  188. if fmt not in fmts:
  189. import difflib
  190. # Get the closest match if format is invalid
  191. matches = difflib.get_close_matches(fmt, fmts, n=1, cutoff=0.6) # 60% similarity required to match
  192. if not matches:
  193. raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
  194. LOGGER.warning(f"WARNING ⚠️ Invalid export format='{fmt}', updating to format='{matches[0]}'")
  195. fmt = matches[0]
  196. flags = [x == fmt for x in fmts]
  197. if sum(flags) != 1:
  198. raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
  199. (
  200. jit,
  201. onnx,
  202. xml,
  203. engine,
  204. coreml,
  205. saved_model,
  206. pb,
  207. tflite,
  208. edgetpu,
  209. tfjs,
  210. paddle,
  211. mnn,
  212. ncnn,
  213. imx,
  214. ) = flags # export booleans
  215. is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
  216. # Device
  217. dla = None
  218. if fmt == "engine" and self.args.device is None:
  219. LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0")
  220. self.args.device = "0"
  221. if fmt == "engine" and "dla" in str(self.args.device): # convert int/list to str first
  222. dla = self.args.device.split(":")[-1]
  223. self.args.device = "0" # update device to "0"
  224. assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}."
  225. self.device = select_device("cpu" if self.args.device is None else self.args.device)
  226. # Argument compatibility checks
  227. fmt_keys = fmts_dict["Arguments"][flags.index(True) + 1]
  228. validate_args(fmt, self.args, fmt_keys)
  229. if imx and not self.args.int8:
  230. LOGGER.warning("WARNING ⚠️ IMX only supports int8 export, setting int8=True.")
  231. self.args.int8 = True
  232. if not hasattr(model, "names"):
  233. model.names = default_class_names()
  234. model.names = check_class_names(model.names)
  235. if self.args.half and self.args.int8:
  236. LOGGER.warning("WARNING ⚠️ half=True and int8=True are mutually exclusive, setting half=False.")
  237. self.args.half = False
  238. if self.args.half and onnx and self.device.type == "cpu":
  239. LOGGER.warning("WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0")
  240. self.args.half = False
  241. assert not self.args.dynamic, "half=True not compatible with dynamic=True, i.e. use only one."
  242. self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
  243. if self.args.int8 and engine:
  244. self.args.dynamic = True # enforce dynamic to export TensorRT INT8
  245. if self.args.optimize:
  246. assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
  247. assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
  248. if self.args.int8 and tflite:
  249. assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
  250. if edgetpu:
  251. if not LINUX:
  252. raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler")
  253. elif self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420
  254. LOGGER.warning("WARNING ⚠️ Edge TPU export requires batch size 1, setting batch=1.")
  255. self.args.batch = 1
  256. if isinstance(model, WorldModel):
  257. LOGGER.warning(
  258. "WARNING ⚠️ YOLOWorld (original version) export is not supported to any format.\n"
  259. "WARNING ⚠️ YOLOWorldv2 models (i.e. 'yolov8s-worldv2.pt') only support export to "
  260. "(torchscript, onnx, openvino, engine, coreml) formats. "
  261. "See https://docs.ultralytics.com/models/yolo-world for details."
  262. )
  263. model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445
  264. if self.args.int8 and not self.args.data:
  265. self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data
  266. LOGGER.warning(
  267. "WARNING ⚠️ INT8 export requires a missing 'data' arg for calibration. "
  268. f"Using default 'data={self.args.data}'."
  269. )
  270. # Input
  271. im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
  272. file = Path(
  273. getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "")
  274. )
  275. if file.suffix in {".yaml", ".yml"}:
  276. file = Path(file.name)
  277. # Update model
  278. model = deepcopy(model).to(self.device)
  279. for p in model.parameters():
  280. p.requires_grad = False
  281. model.eval()
  282. model.float()
  283. model = model.fuse()
  284. if imx:
  285. from ultralytics.utils.torch_utils import FXModel
  286. model = FXModel(model)
  287. for m in model.modules():
  288. if isinstance(m, Classify):
  289. m.export = True
  290. if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB
  291. m.dynamic = self.args.dynamic
  292. m.export = True
  293. m.format = self.args.format
  294. m.max_det = self.args.max_det
  295. elif isinstance(m, C2f) and not is_tf_format:
  296. # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
  297. m.forward = m.forward_split
  298. if isinstance(m, Detect) and imx:
  299. from ultralytics.utils.tal import make_anchors
  300. m.anchors, m.strides = (
  301. x.transpose(0, 1)
  302. for x in make_anchors(
  303. torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5
  304. )
  305. )
  306. y = None
  307. for _ in range(2):
  308. y = model(im) # dry runs
  309. if self.args.half and onnx and self.device.type != "cpu":
  310. im, model = im.half(), model.half() # to FP16
  311. # Filter warnings
  312. warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning
  313. warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning
  314. warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
  315. # Assign
  316. self.im = im
  317. self.model = model
  318. self.file = file
  319. self.output_shape = (
  320. tuple(y.shape)
  321. if isinstance(y, torch.Tensor)
  322. else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
  323. )
  324. self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO")
  325. data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else ""
  326. description = f"Ultralytics {self.pretty_name} model {f'trained on {data}' if data else ''}"
  327. self.metadata = {
  328. "description": description,
  329. "author": "Ultralytics",
  330. "date": datetime.now().isoformat(),
  331. "version": __version__,
  332. "license": "AGPL-3.0 License (https://ultralytics.com/license)",
  333. "docs": "https://docs.ultralytics.com",
  334. "stride": int(max(model.stride)),
  335. "task": model.task,
  336. "batch": self.args.batch,
  337. "imgsz": self.imgsz,
  338. "names": model.names,
  339. "args": {k: v for k, v in self.args if k in fmt_keys},
  340. } # model metadata
  341. if model.task == "pose":
  342. self.metadata["kpt_shape"] = model.model[-1].kpt_shape
  343. LOGGER.info(
  344. f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and "
  345. f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)"
  346. )
  347. # Exports
  348. f = [""] * len(fmts) # exported filenames
  349. if jit or ncnn: # TorchScript
  350. f[0], _ = self.export_torchscript()
  351. if engine: # TensorRT required before ONNX
  352. f[1], _ = self.export_engine(dla=dla)
  353. if onnx: # ONNX
  354. f[2], _ = self.export_onnx()
  355. if xml: # OpenVINO
  356. f[3], _ = self.export_openvino()
  357. if coreml: # CoreML
  358. f[4], _ = self.export_coreml()
  359. if is_tf_format: # TensorFlow formats
  360. self.args.int8 |= edgetpu
  361. f[5], keras_model = self.export_saved_model()
  362. if pb or tfjs: # pb prerequisite to tfjs
  363. f[6], _ = self.export_pb(keras_model=keras_model)
  364. if tflite:
  365. f[7], _ = self.export_tflite(keras_model=keras_model, nms=False, agnostic_nms=self.args.agnostic_nms)
  366. if edgetpu:
  367. f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite")
  368. if tfjs:
  369. f[9], _ = self.export_tfjs()
  370. if paddle: # PaddlePaddle
  371. f[10], _ = self.export_paddle()
  372. if mnn: # MNN
  373. f[11], _ = self.export_mnn()
  374. if ncnn: # NCNN
  375. f[12], _ = self.export_ncnn()
  376. if imx:
  377. f[13], _ = self.export_imx()
  378. # Finish
  379. f = [str(x) for x in f if x] # filter out '' and None
  380. if any(f):
  381. f = str(Path(f[-1]))
  382. square = self.imgsz[0] == self.imgsz[1]
  383. s = (
  384. ""
  385. if square
  386. else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not "
  387. f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
  388. )
  389. imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "")
  390. predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else ""
  391. q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization
  392. LOGGER.info(
  393. f"\nExport complete ({time.time() - t:.1f}s)"
  394. f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
  395. f"\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}"
  396. f"\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}"
  397. f"\nVisualize: https://netron.app"
  398. )
  399. self.run_callbacks("on_export_end")
  400. return f # return list of exported files/dirs
  401. def get_int8_calibration_dataloader(self, prefix=""):
  402. """Build and return a dataloader suitable for calibration of INT8 models."""
  403. LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
  404. data = (check_cls_dataset if self.model.task == "classify" else check_det_dataset)(self.args.data)
  405. # TensorRT INT8 calibration should use 2x batch size
  406. batch = self.args.batch * (2 if self.args.format == "engine" else 1)
  407. dataset = YOLODataset(
  408. data[self.args.split or "val"],
  409. data=data,
  410. task=self.model.task,
  411. imgsz=self.imgsz[0],
  412. augment=False,
  413. batch_size=batch,
  414. )
  415. n = len(dataset)
  416. if n < self.args.batch:
  417. raise ValueError(
  418. f"The calibration dataset ({n} images) must have at least as many images as the batch size ('batch={self.args.batch}')."
  419. )
  420. elif n < 300:
  421. LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.")
  422. return build_dataloader(dataset, batch=batch, workers=0) # required for batch loading
  423. @try_export
  424. def export_torchscript(self, prefix=colorstr("TorchScript:")):
  425. """YOLO TorchScript model export."""
  426. LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
  427. f = self.file.with_suffix(".torchscript")
  428. ts = torch.jit.trace(self.model, self.im, strict=False)
  429. extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
  430. if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
  431. LOGGER.info(f"{prefix} optimizing for mobile...")
  432. from torch.utils.mobile_optimizer import optimize_for_mobile
  433. optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
  434. else:
  435. ts.save(str(f), _extra_files=extra_files)
  436. return f, None
  437. @try_export
  438. def export_onnx(self, prefix=colorstr("ONNX:")):
  439. """YOLO ONNX export."""
  440. requirements = ["onnx>=1.12.0"]
  441. if self.args.simplify:
  442. requirements += ["onnxslim", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
  443. check_requirements(requirements)
  444. import onnx # noqa
  445. opset_version = self.args.opset or get_latest_opset()
  446. LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
  447. f = str(self.file.with_suffix(".onnx"))
  448. output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
  449. dynamic = self.args.dynamic
  450. if dynamic:
  451. dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
  452. if isinstance(self.model, SegmentationModel):
  453. dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400)
  454. dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
  455. elif isinstance(self.model, DetectionModel):
  456. dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
  457. torch.onnx.export(
  458. self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
  459. self.im.cpu() if dynamic else self.im,
  460. f,
  461. verbose=False,
  462. opset_version=opset_version,
  463. do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
  464. input_names=["images"],
  465. output_names=output_names,
  466. dynamic_axes=dynamic or None,
  467. )
  468. # Checks
  469. model_onnx = onnx.load(f) # load onnx model
  470. # Simplify
  471. if self.args.simplify:
  472. try:
  473. import onnxslim
  474. LOGGER.info(f"{prefix} slimming with onnxslim {onnxslim.__version__}...")
  475. model_onnx = onnxslim.slim(model_onnx)
  476. except Exception as e:
  477. LOGGER.warning(f"{prefix} simplifier failure: {e}")
  478. # Metadata
  479. for k, v in self.metadata.items():
  480. meta = model_onnx.metadata_props.add()
  481. meta.key, meta.value = k, str(v)
  482. onnx.save(model_onnx, f)
  483. return f, model_onnx
  484. @try_export
  485. def export_openvino(self, prefix=colorstr("OpenVINO:")):
  486. """YOLO OpenVINO export."""
  487. check_requirements("openvino>=2024.5.0")
  488. import openvino as ov
  489. LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
  490. assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
  491. ov_model = ov.convert_model(
  492. self.model,
  493. input=None if self.args.dynamic else [self.im.shape],
  494. example_input=self.im,
  495. )
  496. def serialize(ov_model, file):
  497. """Set RT info, serialize and save metadata YAML."""
  498. ov_model.set_rt_info("YOLO", ["model_info", "model_type"])
  499. ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"])
  500. ov_model.set_rt_info(114, ["model_info", "pad_value"])
  501. ov_model.set_rt_info([255.0], ["model_info", "scale_values"])
  502. ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"])
  503. ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"])
  504. if self.model.task != "classify":
  505. ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])
  506. ov.runtime.save_model(ov_model, file, compress_to_fp16=self.args.half)
  507. yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml
  508. if self.args.int8:
  509. fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
  510. fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
  511. check_requirements("nncf>=2.14.0")
  512. import nncf
  513. def transform_fn(data_item) -> np.ndarray:
  514. """Quantization transform function."""
  515. data_item: torch.Tensor = data_item["img"] if isinstance(data_item, dict) else data_item
  516. assert data_item.dtype == torch.uint8, "Input image must be uint8 for the quantization preprocessing"
  517. im = data_item.numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
  518. return np.expand_dims(im, 0) if im.ndim == 3 else im
  519. # Generate calibration data for integer quantization
  520. ignored_scope = None
  521. if isinstance(self.model.model[-1], Detect):
  522. # Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
  523. head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2])
  524. ignored_scope = nncf.IgnoredScope( # ignore operations
  525. patterns=[
  526. f".*{head_module_name}/.*/Add",
  527. f".*{head_module_name}/.*/Sub*",
  528. f".*{head_module_name}/.*/Mul*",
  529. f".*{head_module_name}/.*/Div*",
  530. f".*{head_module_name}\\.dfl.*",
  531. ],
  532. types=["Sigmoid"],
  533. )
  534. quantized_ov_model = nncf.quantize(
  535. model=ov_model,
  536. calibration_dataset=nncf.Dataset(self.get_int8_calibration_dataloader(prefix), transform_fn),
  537. preset=nncf.QuantizationPreset.MIXED,
  538. ignored_scope=ignored_scope,
  539. )
  540. serialize(quantized_ov_model, fq_ov)
  541. return fq, None
  542. f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
  543. f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
  544. serialize(ov_model, f_ov)
  545. return f, None
  546. @try_export
  547. def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
  548. """YOLO Paddle export."""
  549. check_requirements(("paddlepaddle-gpu" if torch.cuda.is_available() else "paddlepaddle", "x2paddle"))
  550. import x2paddle # noqa
  551. from x2paddle.convert import pytorch2paddle # noqa
  552. LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...")
  553. f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}")
  554. pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export
  555. yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
  556. return f, None
  557. @try_export
  558. def export_mnn(self, prefix=colorstr("MNN:")):
  559. """YOLOv8 MNN export using MNN https://github.com/alibaba/MNN."""
  560. f_onnx, _ = self.export_onnx() # get onnx model first
  561. check_requirements("MNN>=2.9.6")
  562. import MNN # noqa
  563. from MNN.tools import mnnconvert
  564. # Setup and checks
  565. LOGGER.info(f"\n{prefix} starting export with MNN {MNN.version()}...")
  566. assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
  567. f = str(self.file.with_suffix(".mnn")) # MNN model file
  568. args = ["", "-f", "ONNX", "--modelFile", f_onnx, "--MNNModel", f, "--bizCode", json.dumps(self.metadata)]
  569. if self.args.int8:
  570. args.extend(("--weightQuantBits", "8"))
  571. if self.args.half:
  572. args.append("--fp16")
  573. mnnconvert.convert(args)
  574. # remove scratch file for model convert optimize
  575. convert_scratch = Path(self.file.parent / ".__convert_external_data.bin")
  576. if convert_scratch.exists():
  577. convert_scratch.unlink()
  578. return f, None
  579. @try_export
  580. def export_ncnn(self, prefix=colorstr("NCNN:")):
  581. """YOLO NCNN export using PNNX https://github.com/pnnx/pnnx."""
  582. check_requirements("ncnn")
  583. import ncnn # noqa
  584. LOGGER.info(f"\n{prefix} starting export with NCNN {ncnn.__version__}...")
  585. f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}"))
  586. f_ts = self.file.with_suffix(".torchscript")
  587. name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename
  588. pnnx = name if name.is_file() else (ROOT / name)
  589. if not pnnx.is_file():
  590. LOGGER.warning(
  591. f"{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from "
  592. "https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory "
  593. f"or in {ROOT}. See PNNX repo for full installation instructions."
  594. )
  595. system = "macos" if MACOS else "windows" if WINDOWS else "linux-aarch64" if ARM64 else "linux"
  596. try:
  597. release, assets = get_github_assets(repo="pnnx/pnnx")
  598. asset = [x for x in assets if f"{system}.zip" in x][0]
  599. assert isinstance(asset, str), "Unable to retrieve PNNX repo assets" # i.e. pnnx-20240410-macos.zip
  600. LOGGER.info(f"{prefix} successfully found latest PNNX asset file {asset}")
  601. except Exception as e:
  602. release = "20240410"
  603. asset = f"pnnx-{release}-{system}.zip"
  604. LOGGER.warning(f"{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {asset}")
  605. unzip_dir = safe_download(f"https://github.com/pnnx/pnnx/releases/download/{release}/{asset}", delete=True)
  606. if check_is_path_safe(Path.cwd(), unzip_dir): # avoid path traversal security vulnerability
  607. shutil.move(src=unzip_dir / name, dst=pnnx) # move binary to ROOT
  608. pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
  609. shutil.rmtree(unzip_dir) # delete unzip dir
  610. ncnn_args = [
  611. f"ncnnparam={f / 'model.ncnn.param'}",
  612. f"ncnnbin={f / 'model.ncnn.bin'}",
  613. f"ncnnpy={f / 'model_ncnn.py'}",
  614. ]
  615. pnnx_args = [
  616. f"pnnxparam={f / 'model.pnnx.param'}",
  617. f"pnnxbin={f / 'model.pnnx.bin'}",
  618. f"pnnxpy={f / 'model_pnnx.py'}",
  619. f"pnnxonnx={f / 'model.pnnx.onnx'}",
  620. ]
  621. cmd = [
  622. str(pnnx),
  623. str(f_ts),
  624. *ncnn_args,
  625. *pnnx_args,
  626. f"fp16={int(self.args.half)}",
  627. f"device={self.device.type}",
  628. f'inputshape="{[self.args.batch, 3, *self.imgsz]}"',
  629. ]
  630. f.mkdir(exist_ok=True) # make ncnn_model directory
  631. LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
  632. subprocess.run(cmd, check=True)
  633. # Remove debug files
  634. pnnx_files = [x.split("=")[-1] for x in pnnx_args]
  635. for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files):
  636. Path(f_debug).unlink(missing_ok=True)
  637. yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
  638. return str(f), None
  639. @try_export
  640. def export_coreml(self, prefix=colorstr("CoreML:")):
  641. """YOLO CoreML export."""
  642. mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested
  643. check_requirements("coremltools>=6.0,<=6.2" if mlmodel else "coremltools>=7.0")
  644. import coremltools as ct # noqa
  645. LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...")
  646. assert not WINDOWS, "CoreML export is not supported on Windows, please run on macOS or Linux."
  647. assert self.args.batch == 1, "CoreML batch sizes > 1 are not supported. Please retry at 'batch=1'."
  648. f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
  649. if f.is_dir():
  650. shutil.rmtree(f)
  651. if self.args.nms and getattr(self.model, "end2end", False):
  652. LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
  653. self.args.nms = False
  654. bias = [0.0, 0.0, 0.0]
  655. scale = 1 / 255
  656. classifier_config = None
  657. if self.model.task == "classify":
  658. classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
  659. model = self.model
  660. elif self.model.task == "detect":
  661. model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model
  662. else:
  663. if self.args.nms:
  664. LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolov8n.pt'.")
  665. # TODO CoreML Segment and Pose model pipelining
  666. model = self.model
  667. ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
  668. ct_model = ct.convert(
  669. ts,
  670. inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)],
  671. classifier_config=classifier_config,
  672. convert_to="neuralnetwork" if mlmodel else "mlprogram",
  673. )
  674. bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
  675. if bits < 32:
  676. if "kmeans" in mode:
  677. check_requirements("scikit-learn") # scikit-learn package required for k-means quantization
  678. if mlmodel:
  679. ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
  680. elif bits == 8: # mlprogram already quantized to FP16
  681. import coremltools.optimize.coreml as cto
  682. op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512)
  683. config = cto.OptimizationConfig(global_config=op_config)
  684. ct_model = cto.palettize_weights(ct_model, config=config)
  685. if self.args.nms and self.model.task == "detect":
  686. if mlmodel:
  687. # coremltools<=6.2 NMS export requires Python<3.11
  688. check_version(PYTHON_VERSION, "<3.11", name="Python ", hard=True)
  689. weights_dir = None
  690. else:
  691. ct_model.save(str(f)) # save otherwise weights_dir does not exist
  692. weights_dir = str(f / "Data/com.apple.CoreML/weights")
  693. ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir)
  694. m = self.metadata # metadata dict
  695. ct_model.short_description = m.pop("description")
  696. ct_model.author = m.pop("author")
  697. ct_model.license = m.pop("license")
  698. ct_model.version = m.pop("version")
  699. ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
  700. try:
  701. ct_model.save(str(f)) # save *.mlpackage
  702. except Exception as e:
  703. LOGGER.warning(
  704. f"{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. "
  705. f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928."
  706. )
  707. f = f.with_suffix(".mlmodel")
  708. ct_model.save(str(f))
  709. return f, ct_model
  710. @try_export
  711. def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
  712. """YOLO TensorRT export https://developer.nvidia.com/tensorrt."""
  713. assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
  714. f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016
  715. try:
  716. import tensorrt as trt # noqa
  717. except ImportError:
  718. if LINUX:
  719. check_requirements("tensorrt>7.0.0,!=10.1.0")
  720. import tensorrt as trt # noqa
  721. check_version(trt.__version__, ">=7.0.0", hard=True)
  722. check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")
  723. # Setup and checks
  724. LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
  725. is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10
  726. assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
  727. f = self.file.with_suffix(".engine") # TensorRT engine file
  728. logger = trt.Logger(trt.Logger.INFO)
  729. if self.args.verbose:
  730. logger.min_severity = trt.Logger.Severity.VERBOSE
  731. # Engine builder
  732. builder = trt.Builder(logger)
  733. config = builder.create_builder_config()
  734. workspace = int(self.args.workspace * (1 << 30)) if self.args.workspace is not None else 0
  735. if is_trt10 and workspace > 0:
  736. config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
  737. elif workspace > 0: # TensorRT versions 7, 8
  738. config.max_workspace_size = workspace
  739. flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  740. network = builder.create_network(flag)
  741. half = builder.platform_has_fast_fp16 and self.args.half
  742. int8 = builder.platform_has_fast_int8 and self.args.int8
  743. # Optionally switch to DLA if enabled
  744. if dla is not None:
  745. if not IS_JETSON:
  746. raise ValueError("DLA is only available on NVIDIA Jetson devices")
  747. LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
  748. if not self.args.half and not self.args.int8:
  749. raise ValueError(
  750. "DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
  751. )
  752. config.default_device_type = trt.DeviceType.DLA
  753. config.DLA_core = int(dla)
  754. config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
  755. # Read ONNX file
  756. parser = trt.OnnxParser(network, logger)
  757. if not parser.parse_from_file(f_onnx):
  758. raise RuntimeError(f"failed to load ONNX file: {f_onnx}")
  759. # Network inputs
  760. inputs = [network.get_input(i) for i in range(network.num_inputs)]
  761. outputs = [network.get_output(i) for i in range(network.num_outputs)]
  762. for inp in inputs:
  763. LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
  764. for out in outputs:
  765. LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
  766. if self.args.dynamic:
  767. shape = self.im.shape
  768. if shape[0] <= 1:
  769. LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
  770. profile = builder.create_optimization_profile()
  771. min_shape = (1, shape[1], 32, 32) # minimum input shape
  772. max_shape = (*shape[:2], *(int(max(1, workspace) * d) for d in shape[2:])) # max input shape
  773. for inp in inputs:
  774. profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
  775. config.add_optimization_profile(profile)
  776. LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {f}")
  777. if int8:
  778. config.set_flag(trt.BuilderFlag.INT8)
  779. config.set_calibration_profile(profile)
  780. config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
  781. class EngineCalibrator(trt.IInt8Calibrator):
  782. def __init__(
  783. self,
  784. dataset, # ultralytics.data.build.InfiniteDataLoader
  785. batch: int,
  786. cache: str = "",
  787. ) -> None:
  788. trt.IInt8Calibrator.__init__(self)
  789. self.dataset = dataset
  790. self.data_iter = iter(dataset)
  791. self.algo = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2
  792. self.batch = batch
  793. self.cache = Path(cache)
  794. def get_algorithm(self) -> trt.CalibrationAlgoType:
  795. """Get the calibration algorithm to use."""
  796. return self.algo
  797. def get_batch_size(self) -> int:
  798. """Get the batch size to use for calibration."""
  799. return self.batch or 1
  800. def get_batch(self, names) -> list:
  801. """Get the next batch to use for calibration, as a list of device memory pointers."""
  802. try:
  803. im0s = next(self.data_iter)["img"] / 255.0
  804. im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
  805. return [int(im0s.data_ptr())]
  806. except StopIteration:
  807. # Return [] or None, signal to TensorRT there is no calibration data remaining
  808. return None
  809. def read_calibration_cache(self) -> bytes:
  810. """Use existing cache instead of calibrating again, otherwise, implicitly return None."""
  811. if self.cache.exists() and self.cache.suffix == ".cache":
  812. return self.cache.read_bytes()
  813. def write_calibration_cache(self, cache) -> None:
  814. """Write calibration cache to disk."""
  815. _ = self.cache.write_bytes(cache)
  816. # Load dataset w/ builder (for batching) and calibrate
  817. config.int8_calibrator = EngineCalibrator(
  818. dataset=self.get_int8_calibration_dataloader(prefix),
  819. batch=2 * self.args.batch, # TensorRT INT8 calibration should use 2x batch size
  820. cache=str(self.file.with_suffix(".cache")),
  821. )
  822. elif half:
  823. config.set_flag(trt.BuilderFlag.FP16)
  824. # Free CUDA memory
  825. del self.model
  826. gc.collect()
  827. torch.cuda.empty_cache()
  828. # Write file
  829. build = builder.build_serialized_network if is_trt10 else builder.build_engine
  830. with build(network, config) as engine, open(f, "wb") as t:
  831. # Metadata
  832. meta = json.dumps(self.metadata)
  833. t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
  834. t.write(meta.encode())
  835. # Model
  836. t.write(engine if is_trt10 else engine.serialize())
  837. return f, None
  838. @try_export
  839. def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")):
  840. """YOLO TensorFlow SavedModel export."""
  841. cuda = torch.cuda.is_available()
  842. try:
  843. import tensorflow as tf # noqa
  844. except ImportError:
  845. suffix = "-macos" if MACOS else "-aarch64" if ARM64 else "" if cuda else "-cpu"
  846. version = ">=2.0.0"
  847. check_requirements(f"tensorflow{suffix}{version}")
  848. import tensorflow as tf # noqa
  849. check_requirements(
  850. (
  851. "keras", # required by 'onnx2tf' package
  852. "tf_keras", # required by 'onnx2tf' package
  853. "sng4onnx>=1.0.1", # required by 'onnx2tf' package
  854. "onnx_graphsurgeon>=0.3.26", # required by 'onnx2tf' package
  855. "onnx>=1.12.0",
  856. "onnx2tf>1.17.5,<=1.26.3",
  857. "onnxslim>=0.1.31",
  858. "tflite_support<=0.4.3" if IS_JETSON else "tflite_support", # fix ImportError 'GLIBCXX_3.4.29'
  859. "flatbuffers>=23.5.26,<100", # update old 'flatbuffers' included inside tensorflow package
  860. "onnxruntime-gpu" if cuda else "onnxruntime",
  861. ),
  862. cmds="--extra-index-url https://pypi.ngc.nvidia.com", # onnx_graphsurgeon only on NVIDIA
  863. )
  864. LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
  865. check_version(
  866. tf.__version__,
  867. ">=2.0.0",
  868. name="tensorflow",
  869. verbose=True,
  870. msg="https://github.com/ultralytics/ultralytics/issues/5161",
  871. )
  872. import onnx2tf
  873. f = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
  874. if f.is_dir():
  875. shutil.rmtree(f) # delete output folder
  876. # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
  877. onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
  878. if not onnx2tf_file.exists():
  879. attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
  880. # Export to ONNX
  881. self.args.simplify = True
  882. f_onnx, _ = self.export_onnx()
  883. # Export to TF
  884. np_data = None
  885. if self.args.int8:
  886. tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
  887. if self.args.data:
  888. f.mkdir()
  889. images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
  890. images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute(
  891. 0, 2, 3, 1
  892. )
  893. np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC
  894. np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
  895. LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...")
  896. keras_model = onnx2tf.convert(
  897. input_onnx_file_path=f_onnx,
  898. output_folder_path=str(f),
  899. not_use_onnxsim=True,
  900. verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873
  901. output_integer_quantized_tflite=self.args.int8,
  902. quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate)
  903. custom_input_op_name_np_data_path=np_data,
  904. disable_group_convolution=True, # for end-to-end model compatibility
  905. enable_batchmatmul_unfold=True, # for end-to-end model compatibility
  906. )
  907. yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml
  908. # Remove/rename TFLite models
  909. if self.args.int8:
  910. tmp_file.unlink(missing_ok=True)
  911. for file in f.rglob("*_dynamic_range_quant.tflite"):
  912. file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
  913. for file in f.rglob("*_integer_quant_with_int16_act.tflite"):
  914. file.unlink() # delete extra fp16 activation TFLite files
  915. # Add TFLite metadata
  916. for file in f.rglob("*.tflite"):
  917. f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file)
  918. return str(f), keras_model # or keras_model = tf.saved_model.load(f, tags=None, options=None)
  919. @try_export
  920. def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
  921. """YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
  922. import tensorflow as tf # noqa
  923. from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
  924. LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
  925. f = self.file.with_suffix(".pb")
  926. m = tf.function(lambda x: keras_model(x)) # full model
  927. m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
  928. frozen_func = convert_variables_to_constants_v2(m)
  929. frozen_func.graph.as_graph_def()
  930. tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
  931. return f, None
  932. @try_export
  933. def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")):
  934. """YOLO TensorFlow Lite export."""
  935. # BUG https://github.com/ultralytics/ultralytics/issues/13436
  936. import tensorflow as tf # noqa
  937. LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
  938. saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
  939. if self.args.int8:
  940. f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out
  941. elif self.args.half:
  942. f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out
  943. else:
  944. f = saved_model / f"{self.file.stem}_float32.tflite"
  945. return str(f), None
  946. @try_export
  947. def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
  948. """YOLO Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
  949. LOGGER.warning(f"{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185")
  950. cmd = "edgetpu_compiler --version"
  951. help_url = "https://coral.ai/docs/edgetpu/compiler/"
  952. assert LINUX, f"export only supported on Linux. See {help_url}"
  953. if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
  954. LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}")
  955. for c in (
  956. "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -",
  957. 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | '
  958. "sudo tee /etc/apt/sources.list.d/coral-edgetpu.list",
  959. "sudo apt-get update",
  960. "sudo apt-get install edgetpu-compiler",
  961. ):
  962. subprocess.run(c if is_sudo_available() else c.replace("sudo ", ""), shell=True, check=True)
  963. ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
  964. LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
  965. f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
  966. cmd = (
  967. "edgetpu_compiler "
  968. f'--out_dir "{Path(f).parent}" '
  969. "--show_operations "
  970. "--search_delegate "
  971. "--delegate_search_step 30 "
  972. "--timeout_sec 180 "
  973. f'"{tflite_model}"'
  974. )
  975. LOGGER.info(f"{prefix} running '{cmd}'")
  976. subprocess.run(cmd, shell=True)
  977. self._add_tflite_metadata(f)
  978. return f, None
  979. @try_export
  980. def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
  981. """YOLO TensorFlow.js export."""
  982. check_requirements("tensorflowjs")
  983. if ARM64:
  984. # Fix error: `np.object` was a deprecated alias for the builtin `object` when exporting to TF.js on ARM64
  985. check_requirements("numpy==1.23.5")
  986. import tensorflow as tf
  987. import tensorflowjs as tfjs # noqa
  988. LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
  989. f = str(self.file).replace(self.file.suffix, "_web_model") # js dir
  990. f_pb = str(self.file.with_suffix(".pb")) # *.pb path
  991. gd = tf.Graph().as_graph_def() # TF GraphDef
  992. with open(f_pb, "rb") as file:
  993. gd.ParseFromString(file.read())
  994. outputs = ",".join(gd_outputs(gd))
  995. LOGGER.info(f"\n{prefix} output node names: {outputs}")
  996. quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else ""
  997. with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
  998. cmd = (
  999. "tensorflowjs_converter "
  1000. f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
  1001. )
  1002. LOGGER.info(f"{prefix} running '{cmd}'")
  1003. subprocess.run(cmd, shell=True)
  1004. if " " in f:
  1005. LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.")
  1006. # Add metadata
  1007. yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
  1008. return f, None
  1009. @try_export
  1010. def export_imx(self, prefix=colorstr("IMX:")):
  1011. """YOLO IMX export."""
  1012. gptq = False
  1013. assert LINUX, (
  1014. "export only supported on Linux. See https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter"
  1015. )
  1016. if getattr(self.model, "end2end", False):
  1017. raise ValueError("IMX export is not supported for end2end models.")
  1018. if "C2f" not in self.model.__str__():
  1019. raise ValueError("IMX export is only supported for YOLOv8n detection models")
  1020. check_requirements(("model-compression-toolkit==2.1.1", "sony-custom-layers==0.2.0", "tensorflow==2.12.0"))
  1021. check_requirements("imx500-converter[pt]==3.14.3") # Separate requirements for imx500-converter
  1022. import model_compression_toolkit as mct
  1023. import onnx
  1024. from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms
  1025. try:
  1026. out = subprocess.run(
  1027. ["java", "--version"], check=True, capture_output=True
  1028. ) # Java 17 is required for imx500-converter
  1029. if "openjdk 17" not in str(out.stdout):
  1030. raise FileNotFoundError
  1031. except FileNotFoundError:
  1032. c = ["apt", "install", "-y", "openjdk-17-jdk", "openjdk-17-jre"]
  1033. if is_sudo_available():
  1034. c.insert(0, "sudo")
  1035. subprocess.run(c, check=True)
  1036. def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):
  1037. for batch in dataloader:
  1038. img = batch["img"]
  1039. img = img / 255.0
  1040. yield [img]
  1041. tpc = mct.get_target_platform_capabilities(
  1042. fw_name="pytorch", target_platform_name="imx500", target_platform_version="v1"
  1043. )
  1044. config = mct.core.CoreConfig(
  1045. mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10),
  1046. quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True),
  1047. )
  1048. resource_utilization = mct.core.ResourceUtilization(weights_memory=3146176 * 0.76)
  1049. quant_model = (
  1050. mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization
  1051. model=self.model,
  1052. representative_data_gen=representative_dataset_gen,
  1053. target_resource_utilization=resource_utilization,
  1054. gptq_config=mct.gptq.get_pytorch_gptq_config(n_epochs=1000, use_hessian_based_weights=False),
  1055. core_config=config,
  1056. target_platform_capabilities=tpc,
  1057. )[0]
  1058. if gptq
  1059. else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization
  1060. in_module=self.model,
  1061. representative_data_gen=representative_dataset_gen,
  1062. target_resource_utilization=resource_utilization,
  1063. core_config=config,
  1064. target_platform_capabilities=tpc,
  1065. )[0]
  1066. )
  1067. class NMSWrapper(torch.nn.Module):
  1068. def __init__(
  1069. self,
  1070. model: torch.nn.Module,
  1071. score_threshold: float = 0.001,
  1072. iou_threshold: float = 0.7,
  1073. max_detections: int = 300,
  1074. ):
  1075. """
  1076. Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers.
  1077. Args:
  1078. model (nn.Module): Model instance.
  1079. score_threshold (float): Score threshold for non-maximum suppression.
  1080. iou_threshold (float): Intersection over union threshold for non-maximum suppression.
  1081. max_detections (float): The number of detections to return.
  1082. """
  1083. super().__init__()
  1084. self.model = model
  1085. self.score_threshold = score_threshold
  1086. self.iou_threshold = iou_threshold
  1087. self.max_detections = max_detections
  1088. def forward(self, images):
  1089. # model inference
  1090. outputs = self.model(images)
  1091. boxes = outputs[0]
  1092. scores = outputs[1]
  1093. nms = multiclass_nms(
  1094. boxes=boxes,
  1095. scores=scores,
  1096. score_threshold=self.score_threshold,
  1097. iou_threshold=self.iou_threshold,
  1098. max_detections=self.max_detections,
  1099. )
  1100. return nms
  1101. quant_model = NMSWrapper(
  1102. model=quant_model,
  1103. score_threshold=self.args.conf or 0.001,
  1104. iou_threshold=self.args.iou,
  1105. max_detections=self.args.max_det,
  1106. ).to(self.device)
  1107. f = Path(str(self.file).replace(self.file.suffix, "_imx_model"))
  1108. f.mkdir(exist_ok=True)
  1109. onnx_model = f / Path(str(self.file).replace(self.file.suffix, "_imx.onnx")) # js dir
  1110. mct.exporter.pytorch_export_model(
  1111. model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
  1112. )
  1113. model_onnx = onnx.load(onnx_model) # load onnx model
  1114. for k, v in self.metadata.items():
  1115. meta = model_onnx.metadata_props.add()
  1116. meta.key, meta.value = k, str(v)
  1117. onnx.save(model_onnx, onnx_model)
  1118. subprocess.run(
  1119. ["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
  1120. check=True,
  1121. )
  1122. # Needed for imx models.
  1123. with open(f / "labels.txt", "w") as file:
  1124. file.writelines([f"{name}\n" for _, name in self.model.names.items()])
  1125. return f, None
  1126. def _add_tflite_metadata(self, file):
  1127. """Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
  1128. import flatbuffers
  1129. try:
  1130. # TFLite Support bug https://github.com/tensorflow/tflite-support/issues/954#issuecomment-2108570845
  1131. from tensorflow_lite_support.metadata import metadata_schema_py_generated as schema # noqa
  1132. from tensorflow_lite_support.metadata.python import metadata # noqa
  1133. except ImportError: # ARM64 systems may not have the 'tensorflow_lite_support' package available
  1134. from tflite_support import metadata # noqa
  1135. from tflite_support import metadata_schema_py_generated as schema # noqa
  1136. # Create model info
  1137. model_meta = schema.ModelMetadataT()
  1138. model_meta.name = self.metadata["description"]
  1139. model_meta.version = self.metadata["version"]
  1140. model_meta.author = self.metadata["author"]
  1141. model_meta.license = self.metadata["license"]
  1142. # Label file
  1143. tmp_file = Path(file).parent / "temp_meta.txt"
  1144. with open(tmp_file, "w") as f:
  1145. f.write(str(self.metadata))
  1146. label_file = schema.AssociatedFileT()
  1147. label_file.name = tmp_file.name
  1148. label_file.type = schema.AssociatedFileType.TENSOR_AXIS_LABELS
  1149. # Create input info
  1150. input_meta = schema.TensorMetadataT()
  1151. input_meta.name = "image"
  1152. input_meta.description = "Input image to be detected."
  1153. input_meta.content = schema.ContentT()
  1154. input_meta.content.contentProperties = schema.ImagePropertiesT()
  1155. input_meta.content.contentProperties.colorSpace = schema.ColorSpaceType.RGB
  1156. input_meta.content.contentPropertiesType = schema.ContentProperties.ImageProperties
  1157. # Create output info
  1158. output1 = schema.TensorMetadataT()
  1159. output1.name = "output"
  1160. output1.description = "Coordinates of detected objects, class labels, and confidence score"
  1161. output1.associatedFiles = [label_file]
  1162. if self.model.task == "segment":
  1163. output2 = schema.TensorMetadataT()
  1164. output2.name = "output"
  1165. output2.description = "Mask protos"
  1166. output2.associatedFiles = [label_file]
  1167. # Create subgraph info
  1168. subgraph = schema.SubGraphMetadataT()
  1169. subgraph.inputTensorMetadata = [input_meta]
  1170. subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1]
  1171. model_meta.subgraphMetadata = [subgraph]
  1172. b = flatbuffers.Builder(0)
  1173. b.Finish(model_meta.Pack(b), metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
  1174. metadata_buf = b.Output()
  1175. populator = metadata.MetadataPopulator.with_model_file(str(file))
  1176. populator.load_metadata_buffer(metadata_buf)
  1177. populator.load_associated_files([str(tmp_file)])
  1178. populator.populate()
  1179. tmp_file.unlink()
  1180. def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
  1181. """YOLO CoreML pipeline."""
  1182. import coremltools as ct # noqa
  1183. LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
  1184. _, _, h, w = list(self.im.shape) # BCHW
  1185. # Output shapes
  1186. spec = model.get_spec()
  1187. out0, out1 = iter(spec.description.output)
  1188. if MACOS:
  1189. from PIL import Image
  1190. img = Image.new("RGB", (w, h)) # w=192, h=320
  1191. out = model.predict({"image": img})
  1192. out0_shape = out[out0.name].shape # (3780, 80)
  1193. out1_shape = out[out1.name].shape # (3780, 4)
  1194. else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y
  1195. out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
  1196. out1_shape = self.output_shape[2], 4 # (3780, 4)
  1197. # Checks
  1198. names = self.metadata["names"]
  1199. nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
  1200. _, nc = out0_shape # number of anchors, number of classes
  1201. assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check
  1202. # Define output shapes (missing)
  1203. out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
  1204. out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
  1205. # Model from spec
  1206. model = ct.models.MLModel(spec, weights_dir=weights_dir)
  1207. # 3. Create NMS protobuf
  1208. nms_spec = ct.proto.Model_pb2.Model()
  1209. nms_spec.specificationVersion = 5
  1210. for i in range(2):
  1211. decoder_output = model._spec.description.output[i].SerializeToString()
  1212. nms_spec.description.input.add()
  1213. nms_spec.description.input[i].ParseFromString(decoder_output)
  1214. nms_spec.description.output.add()
  1215. nms_spec.description.output[i].ParseFromString(decoder_output)
  1216. nms_spec.description.output[0].name = "confidence"
  1217. nms_spec.description.output[1].name = "coordinates"
  1218. output_sizes = [nc, 4]
  1219. for i in range(2):
  1220. ma_type = nms_spec.description.output[i].type.multiArrayType
  1221. ma_type.shapeRange.sizeRanges.add()
  1222. ma_type.shapeRange.sizeRanges[0].lowerBound = 0
  1223. ma_type.shapeRange.sizeRanges[0].upperBound = -1
  1224. ma_type.shapeRange.sizeRanges.add()
  1225. ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
  1226. ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
  1227. del ma_type.shape[:]
  1228. nms = nms_spec.nonMaximumSuppression
  1229. nms.confidenceInputFeatureName = out0.name # 1x507x80
  1230. nms.coordinatesInputFeatureName = out1.name # 1x507x4
  1231. nms.confidenceOutputFeatureName = "confidence"
  1232. nms.coordinatesOutputFeatureName = "coordinates"
  1233. nms.iouThresholdInputFeatureName = "iouThreshold"
  1234. nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
  1235. nms.iouThreshold = 0.45
  1236. nms.confidenceThreshold = 0.25
  1237. nms.pickTop.perClass = True
  1238. nms.stringClassLabels.vector.extend(names.values())
  1239. nms_model = ct.models.MLModel(nms_spec)
  1240. # 4. Pipeline models together
  1241. pipeline = ct.models.pipeline.Pipeline(
  1242. input_features=[
  1243. ("image", ct.models.datatypes.Array(3, ny, nx)),
  1244. ("iouThreshold", ct.models.datatypes.Double()),
  1245. ("confidenceThreshold", ct.models.datatypes.Double()),
  1246. ],
  1247. output_features=["confidence", "coordinates"],
  1248. )
  1249. pipeline.add_model(model)
  1250. pipeline.add_model(nms_model)
  1251. # Correct datatypes
  1252. pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
  1253. pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
  1254. pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
  1255. # Update metadata
  1256. pipeline.spec.specificationVersion = 5
  1257. pipeline.spec.description.metadata.userDefined.update(
  1258. {"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)}
  1259. )
  1260. # Save the model
  1261. model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
  1262. model.input_description["image"] = "Input image"
  1263. model.input_description["iouThreshold"] = f"(optional) IoU threshold override (default: {nms.iouThreshold})"
  1264. model.input_description["confidenceThreshold"] = (
  1265. f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})"
  1266. )
  1267. model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")'
  1268. model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)"
  1269. LOGGER.info(f"{prefix} pipeline success")
  1270. return model
  1271. def add_callback(self, event: str, callback):
  1272. """Appends the given callback."""
  1273. self.callbacks[event].append(callback)
  1274. def run_callbacks(self, event: str):
  1275. """Execute all callbacks for a given event."""
  1276. for callback in self.callbacks.get(event, []):
  1277. callback(self)
  1278. class IOSDetectModel(torch.nn.Module):
  1279. """Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
  1280. def __init__(self, model, im):
  1281. """Initialize the IOSDetectModel class with a YOLO model and example image."""
  1282. super().__init__()
  1283. _, _, h, w = im.shape # batch, channel, height, width
  1284. self.model = model
  1285. self.nc = len(model.names) # number of classes
  1286. if w == h:
  1287. self.normalize = 1.0 / w # scalar
  1288. else:
  1289. self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
  1290. def forward(self, x):
  1291. """Normalize predictions of object detection model with input size-dependent factors."""
  1292. xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
  1293. return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)