torch_utils.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import gc
  3. import math
  4. import os
  5. import random
  6. import time
  7. from contextlib import contextmanager
  8. from copy import deepcopy
  9. from datetime import datetime
  10. from pathlib import Path
  11. from typing import Union
  12. import numpy as np
  13. import thop
  14. import torch
  15. import torch.distributed as dist
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from ultralytics.utils import (
  19. DEFAULT_CFG_DICT,
  20. DEFAULT_CFG_KEYS,
  21. LOGGER,
  22. NUM_THREADS,
  23. PYTHON_VERSION,
  24. TORCHVISION_VERSION,
  25. WINDOWS,
  26. __version__,
  27. colorstr,
  28. )
  29. from ultralytics.utils.checks import check_version
  30. # Version checks (all default to version>=min_version)
  31. TORCH_1_9 = check_version(torch.__version__, "1.9.0")
  32. TORCH_1_13 = check_version(torch.__version__, "1.13.0")
  33. TORCH_2_0 = check_version(torch.__version__, "2.0.0")
  34. TORCH_2_4 = check_version(torch.__version__, "2.4.0")
  35. TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
  36. TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
  37. TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
  38. TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
  39. if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows
  40. LOGGER.warning(
  41. "WARNING ⚠️ Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve "
  42. "https://github.com/ultralytics/ultralytics/issues/15049"
  43. )
  44. @contextmanager
  45. def torch_distributed_zero_first(local_rank: int):
  46. """Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
  47. initialized = dist.is_available() and dist.is_initialized()
  48. if initialized and local_rank not in {-1, 0}:
  49. dist.barrier(device_ids=[local_rank])
  50. yield
  51. if initialized and local_rank == 0:
  52. dist.barrier(device_ids=[local_rank])
  53. def smart_inference_mode():
  54. """Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
  55. def decorate(fn):
  56. """Applies appropriate torch decorator for inference mode based on torch version."""
  57. if TORCH_1_9 and torch.is_inference_mode_enabled():
  58. return fn # already in inference_mode, act as a pass-through
  59. else:
  60. return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
  61. return decorate
  62. def autocast(enabled: bool, device: str = "cuda"):
  63. """
  64. Get the appropriate autocast context manager based on PyTorch version and AMP setting.
  65. This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
  66. older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
  67. Args:
  68. enabled (bool): Whether to enable automatic mixed precision.
  69. device (str, optional): The device to use for autocast. Defaults to 'cuda'.
  70. Returns:
  71. (torch.amp.autocast): The appropriate autocast context manager.
  72. Note:
  73. - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
  74. - For older versions, it uses `torch.cuda.autocast`.
  75. Example:
  76. ```python
  77. with autocast(amp=True):
  78. # Your mixed precision operations here
  79. pass
  80. ```
  81. """
  82. if TORCH_1_13:
  83. return torch.amp.autocast(device, enabled=enabled)
  84. else:
  85. return torch.cuda.amp.autocast(enabled)
  86. def get_cpu_info():
  87. """Return a string with system CPU information, i.e. 'Apple M2'."""
  88. from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
  89. if "cpu_info" not in PERSISTENT_CACHE:
  90. try:
  91. import cpuinfo # pip install py-cpuinfo
  92. k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference
  93. info = cpuinfo.get_cpu_info() # info dict
  94. string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
  95. PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
  96. except Exception:
  97. pass
  98. return PERSISTENT_CACHE.get("cpu_info", "unknown")
  99. def get_gpu_info(index):
  100. """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
  101. properties = torch.cuda.get_device_properties(index)
  102. return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
  103. def select_device(device="", batch=0, newline=False, verbose=True):
  104. """
  105. Selects the appropriate PyTorch device based on the provided arguments.
  106. The function takes a string specifying the device or a torch.device object and returns a torch.device object
  107. representing the selected device. The function also validates the number of available devices and raises an
  108. exception if the requested device(s) are not available.
  109. Args:
  110. device (str | torch.device, optional): Device string or torch.device object.
  111. Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
  112. the first available GPU, or CPU if no GPU is available.
  113. batch (int, optional): Batch size being used in your model. Defaults to 0.
  114. newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False.
  115. verbose (bool, optional): If True, logs the device information. Defaults to True.
  116. Returns:
  117. (torch.device): Selected device.
  118. Raises:
  119. ValueError: If the specified device is not available or if the batch size is not a multiple of the number of
  120. devices when using multiple GPUs.
  121. Examples:
  122. >>> select_device("cuda:0")
  123. device(type='cuda', index=0)
  124. >>> select_device("cpu")
  125. device(type='cpu')
  126. Note:
  127. Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
  128. """
  129. if isinstance(device, torch.device) or str(device).startswith("tpu"):
  130. return device
  131. s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
  132. device = str(device).lower()
  133. for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
  134. device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
  135. cpu = device == "cpu"
  136. mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS)
  137. if cpu or mps:
  138. os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
  139. elif device: # non-cpu device requested
  140. if device == "cuda":
  141. device = "0"
  142. if "," in device:
  143. device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1"
  144. visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  145. os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
  146. if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
  147. LOGGER.info(s)
  148. install = (
  149. "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
  150. "CUDA devices are seen by torch.\n"
  151. if torch.cuda.device_count() == 0
  152. else ""
  153. )
  154. raise ValueError(
  155. f"Invalid CUDA 'device={device}' requested."
  156. f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
  157. f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
  158. f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
  159. f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
  160. f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
  161. f"{install}"
  162. )
  163. if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
  164. devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"]
  165. n = len(devices) # device count
  166. if n > 1: # multi-GPU
  167. if batch < 1:
  168. raise ValueError(
  169. "AutoBatch with batch<1 not supported for Multi-GPU training, "
  170. "please specify a valid batch size, i.e. batch=16."
  171. )
  172. if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count
  173. raise ValueError(
  174. f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
  175. f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
  176. )
  177. space = " " * (len(s) + 1)
  178. for i, d in enumerate(devices):
  179. s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB
  180. arg = "cuda:0"
  181. elif mps and TORCH_2_0 and torch.backends.mps.is_available():
  182. # Prefer MPS if available
  183. s += f"MPS ({get_cpu_info()})\n"
  184. arg = "mps"
  185. else: # revert to CPU
  186. s += f"CPU ({get_cpu_info()})\n"
  187. arg = "cpu"
  188. if arg in {"cpu", "mps"}:
  189. torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training
  190. if verbose:
  191. LOGGER.info(s if newline else s.rstrip())
  192. return torch.device(arg)
  193. def time_sync():
  194. """PyTorch-accurate time."""
  195. if torch.cuda.is_available():
  196. torch.cuda.synchronize()
  197. return time.time()
  198. def fuse_conv_and_bn(conv, bn):
  199. """Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
  200. fusedconv = (
  201. nn.Conv2d(
  202. conv.in_channels,
  203. conv.out_channels,
  204. kernel_size=conv.kernel_size,
  205. stride=conv.stride,
  206. padding=conv.padding,
  207. dilation=conv.dilation,
  208. groups=conv.groups,
  209. bias=True,
  210. )
  211. .requires_grad_(False)
  212. .to(conv.weight.device)
  213. )
  214. # Prepare filters
  215. w_conv = conv.weight.view(conv.out_channels, -1)
  216. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  217. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  218. # Prepare spatial bias
  219. b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias
  220. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  221. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  222. return fusedconv
  223. def fuse_deconv_and_bn(deconv, bn):
  224. """Fuse ConvTranspose2d() and BatchNorm2d() layers."""
  225. fuseddconv = (
  226. nn.ConvTranspose2d(
  227. deconv.in_channels,
  228. deconv.out_channels,
  229. kernel_size=deconv.kernel_size,
  230. stride=deconv.stride,
  231. padding=deconv.padding,
  232. output_padding=deconv.output_padding,
  233. dilation=deconv.dilation,
  234. groups=deconv.groups,
  235. bias=True,
  236. )
  237. .requires_grad_(False)
  238. .to(deconv.weight.device)
  239. )
  240. # Prepare filters
  241. w_deconv = deconv.weight.view(deconv.out_channels, -1)
  242. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  243. fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
  244. # Prepare spatial bias
  245. b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias
  246. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  247. fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  248. return fuseddconv
  249. def model_info(model, detailed=False, verbose=True, imgsz=640):
  250. """Print and return detailed model information layer by layer."""
  251. if not verbose:
  252. return
  253. n_p = get_num_params(model) # number of parameters
  254. n_g = get_num_gradients(model) # number of gradients
  255. n_l = len(list(model.modules())) # number of layers
  256. if detailed:
  257. LOGGER.info(f"{'layer':>5}{'name':>40}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}")
  258. for i, (name, p) in enumerate(model.named_parameters()):
  259. name = name.replace("module_list.", "")
  260. LOGGER.info(
  261. f"{i:>5g}{name:>40s}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20s}"
  262. f"{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype):>15s}"
  263. )
  264. flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
  265. fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
  266. fs = f", {flops:.1f} GFLOPs" if flops else ""
  267. yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
  268. model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
  269. LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")
  270. return n_l, n_p, n_g, flops
  271. def get_num_params(model):
  272. """Return the total number of parameters in a YOLO model."""
  273. return sum(x.numel() for x in model.parameters())
  274. def get_num_gradients(model):
  275. """Return the total number of parameters with gradients in a YOLO model."""
  276. return sum(x.numel() for x in model.parameters() if x.requires_grad)
  277. def model_info_for_loggers(trainer):
  278. """
  279. Return model info dict with useful model information.
  280. Example:
  281. YOLOv8n info for loggers
  282. ```python
  283. results = {
  284. "model/parameters": 3151904,
  285. "model/GFLOPs": 8.746,
  286. "model/speed_ONNX(ms)": 41.244,
  287. "model/speed_TensorRT(ms)": 3.211,
  288. "model/speed_PyTorch(ms)": 18.755,
  289. }
  290. ```
  291. """
  292. if trainer.args.profile: # profile ONNX and TensorRT times
  293. from ultralytics.utils.benchmarks import ProfileModels
  294. results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
  295. results.pop("model/name")
  296. else: # only return PyTorch times from most recent validation
  297. results = {
  298. "model/parameters": get_num_params(trainer.model),
  299. "model/GFLOPs": round(get_flops(trainer.model), 3),
  300. }
  301. results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
  302. return results
  303. def get_flops(model, imgsz=640):
  304. """Return a YOLO model's FLOPs."""
  305. try:
  306. model = de_parallel(model)
  307. p = next(model.parameters())
  308. if not isinstance(imgsz, list):
  309. imgsz = [imgsz, imgsz] # expand if int/float
  310. try:
  311. # Use stride size for input tensor
  312. stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
  313. im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
  314. flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
  315. return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
  316. except Exception:
  317. # Use actual image size for input tensor (i.e. required for RTDETR models)
  318. im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
  319. return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
  320. except Exception:
  321. return 0.0
  322. def get_flops_with_torch_profiler(model, imgsz=640):
  323. """Compute model FLOPs (thop package alternative, but 2-10x slower unfortunately)."""
  324. if not TORCH_2_0: # torch profiler implemented in torch>=2.0
  325. return 0.0
  326. model = de_parallel(model)
  327. p = next(model.parameters())
  328. if not isinstance(imgsz, list):
  329. imgsz = [imgsz, imgsz] # expand if int/float
  330. try:
  331. # Use stride size for input tensor
  332. stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
  333. im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
  334. with torch.profiler.profile(with_flops=True) as prof:
  335. model(im)
  336. flops = sum(x.flops for x in prof.key_averages()) / 1e9
  337. flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
  338. except Exception:
  339. # Use actual image size for input tensor (i.e. required for RTDETR models)
  340. im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
  341. with torch.profiler.profile(with_flops=True) as prof:
  342. model(im)
  343. flops = sum(x.flops for x in prof.key_averages()) / 1e9
  344. return flops
  345. def initialize_weights(model):
  346. """Initialize model weights to random values."""
  347. for m in model.modules():
  348. t = type(m)
  349. if t is nn.Conv2d:
  350. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  351. elif t is nn.BatchNorm2d:
  352. m.eps = 1e-3
  353. m.momentum = 0.03
  354. elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
  355. m.inplace = True
  356. def scale_img(img, ratio=1.0, same_shape=False, gs=32):
  357. """Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple."""
  358. if ratio == 1.0:
  359. return img
  360. h, w = img.shape[2:]
  361. s = (int(h * ratio), int(w * ratio)) # new size
  362. img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize
  363. if not same_shape: # pad/crop img
  364. h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
  365. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  366. def copy_attr(a, b, include=(), exclude=()):
  367. """Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
  368. for k, v in b.__dict__.items():
  369. if (len(include) and k not in include) or k.startswith("_") or k in exclude:
  370. continue
  371. else:
  372. setattr(a, k, v)
  373. def get_latest_opset():
  374. """Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity."""
  375. if TORCH_1_13:
  376. # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
  377. return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
  378. # Otherwise for PyTorch<=1.12 return the corresponding predefined opset
  379. version = torch.onnx.producer_version.rsplit(".", 1)[0] # i.e. '2.3'
  380. return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)
  381. def intersect_dicts(da, db, exclude=()):
  382. """Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""
  383. return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
  384. def is_parallel(model):
  385. """Returns True if model is of type DP or DDP."""
  386. return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
  387. def de_parallel(model):
  388. """De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""
  389. return model.module if is_parallel(model) else model
  390. def one_cycle(y1=0.0, y2=1.0, steps=100):
  391. """Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
  392. return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
  393. def init_seeds(seed=0, deterministic=False):
  394. """Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html."""
  395. random.seed(seed)
  396. np.random.seed(seed)
  397. torch.manual_seed(seed)
  398. torch.cuda.manual_seed(seed)
  399. torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
  400. # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
  401. if deterministic:
  402. if TORCH_2_0:
  403. torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
  404. torch.backends.cudnn.deterministic = True
  405. os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
  406. os.environ["PYTHONHASHSEED"] = str(seed)
  407. else:
  408. LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.")
  409. else:
  410. torch.use_deterministic_algorithms(False)
  411. torch.backends.cudnn.deterministic = False
  412. class ModelEMA:
  413. """
  414. Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
  415. average of everything in the model state_dict (parameters and buffers).
  416. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  417. To disable EMA set the `enabled` attribute to `False`.
  418. """
  419. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  420. """Initialize EMA for 'model' with given arguments."""
  421. self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
  422. self.updates = updates # number of EMA updates
  423. self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
  424. for p in self.ema.parameters():
  425. p.requires_grad_(False)
  426. self.enabled = True
  427. def update(self, model):
  428. """Update EMA parameters."""
  429. if self.enabled:
  430. self.updates += 1
  431. d = self.decay(self.updates)
  432. msd = de_parallel(model).state_dict() # model state_dict
  433. for k, v in self.ema.state_dict().items():
  434. if v.dtype.is_floating_point: # true for FP16 and FP32
  435. v *= d
  436. v += (1 - d) * msd[k].detach()
  437. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
  438. def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
  439. """Updates attributes and saves stripped model with optimizer removed."""
  440. if self.enabled:
  441. copy_attr(self.ema, model, include, exclude)
  442. def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
  443. """
  444. Strip optimizer from 'f' to finalize training, optionally save as 's'.
  445. Args:
  446. f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
  447. s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
  448. updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
  449. Returns:
  450. (dict): The combined checkpoint dictionary.
  451. Example:
  452. ```python
  453. from pathlib import Path
  454. from ultralytics.utils.torch_utils import strip_optimizer
  455. for f in Path("path/to/model/checkpoints").rglob("*.pt"):
  456. strip_optimizer(f)
  457. ```
  458. Note:
  459. Use `ultralytics.nn.torch_safe_load` for missing modules with `x = torch_safe_load(f)[0]`
  460. """
  461. try:
  462. x = torch.load(f, map_location=torch.device("cpu"))
  463. assert isinstance(x, dict), "checkpoint is not a Python dictionary"
  464. assert "model" in x, "'model' missing from checkpoint"
  465. except Exception as e:
  466. LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
  467. return {}
  468. metadata = {
  469. "date": datetime.now().isoformat(),
  470. "version": __version__,
  471. "license": "AGPL-3.0 License (https://ultralytics.com/license)",
  472. "docs": "https://docs.ultralytics.com",
  473. }
  474. # Update model
  475. if x.get("ema"):
  476. x["model"] = x["ema"] # replace model with EMA
  477. if hasattr(x["model"], "args"):
  478. x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
  479. if hasattr(x["model"], "criterion"):
  480. x["model"].criterion = None # strip loss criterion
  481. x["model"].half() # to FP16
  482. for p in x["model"].parameters():
  483. p.requires_grad = False
  484. # Update other keys
  485. args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
  486. for k in "optimizer", "best_fitness", "ema", "updates": # keys
  487. x[k] = None
  488. x["epoch"] = -1
  489. x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
  490. # x['model'].args = x['train_args']
  491. # Save
  492. combined = {**metadata, **x, **(updates or {})}
  493. torch.save(combined, s or f) # combine dicts (prefer to the right)
  494. mb = os.path.getsize(s or f) / 1e6 # file size
  495. LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
  496. return combined
  497. def convert_optimizer_state_dict_to_fp16(state_dict):
  498. """
  499. Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
  500. This method aims to reduce storage size without altering 'param_groups' as they contain non-tensor data.
  501. """
  502. for state in state_dict["state"].values():
  503. for k, v in state.items():
  504. if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
  505. state[k] = v.half()
  506. return state_dict
  507. @contextmanager
  508. def cuda_memory_usage(device=None):
  509. """
  510. Monitor and manage CUDA memory usage.
  511. This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.
  512. It then yields a dictionary containing memory usage information, which can be updated by the caller.
  513. Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.
  514. Args:
  515. device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None.
  516. Yields:
  517. (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
  518. """
  519. cuda_info = dict(memory=0)
  520. if torch.cuda.is_available():
  521. torch.cuda.empty_cache()
  522. try:
  523. yield cuda_info
  524. finally:
  525. cuda_info["memory"] = torch.cuda.memory_reserved(device)
  526. else:
  527. yield cuda_info
  528. def profile(input, ops, n=10, device=None, max_num_obj=0):
  529. """
  530. Ultralytics speed, memory and FLOPs profiler.
  531. Example:
  532. ```python
  533. from ultralytics.utils.torch_utils import profile
  534. input = torch.randn(16, 3, 640, 640)
  535. m1 = lambda x: x * torch.sigmoid(x)
  536. m2 = nn.SiLU()
  537. profile(input, [m1, m2], n=100) # profile over 100 iterations
  538. ```
  539. """
  540. results = []
  541. if not isinstance(device, torch.device):
  542. device = select_device(device)
  543. LOGGER.info(
  544. f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
  545. f"{'input':>24s}{'output':>24s}"
  546. )
  547. gc.collect() # attempt to free unused memory
  548. torch.cuda.empty_cache()
  549. for x in input if isinstance(input, list) else [input]:
  550. x = x.to(device)
  551. x.requires_grad = True
  552. for m in ops if isinstance(ops, list) else [ops]:
  553. m = m.to(device) if hasattr(m, "to") else m # device
  554. m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
  555. tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
  556. try:
  557. flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 # GFLOPs
  558. except Exception:
  559. flops = 0
  560. try:
  561. mem = 0
  562. for _ in range(n):
  563. with cuda_memory_usage(device) as cuda_info:
  564. t[0] = time_sync()
  565. y = m(x)
  566. t[1] = time_sync()
  567. try:
  568. (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
  569. t[2] = time_sync()
  570. except Exception: # no backward method
  571. # print(e) # for debug
  572. t[2] = float("nan")
  573. mem += cuda_info["memory"] / 1e9 # (GB)
  574. tf += (t[1] - t[0]) * 1000 / n # ms per op forward
  575. tb += (t[2] - t[1]) * 1000 / n # ms per op backward
  576. if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)
  577. with cuda_memory_usage(device) as cuda_info:
  578. torch.randn(
  579. x.shape[0],
  580. max_num_obj,
  581. int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),
  582. device=device,
  583. dtype=torch.float32,
  584. )
  585. mem += cuda_info["memory"] / 1e9 # (GB)
  586. s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
  587. p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
  588. LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
  589. results.append([p, flops, mem, tf, tb, s_in, s_out])
  590. except Exception as e:
  591. LOGGER.info(e)
  592. results.append(None)
  593. finally:
  594. gc.collect() # attempt to free unused memory
  595. torch.cuda.empty_cache()
  596. return results
  597. class EarlyStopping:
  598. """Early stopping class that stops training when a specified number of epochs have passed without improvement."""
  599. def __init__(self, patience=50):
  600. """
  601. Initialize early stopping object.
  602. Args:
  603. patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
  604. """
  605. self.best_fitness = 0.0 # i.e. mAP
  606. self.best_epoch = 0
  607. self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop
  608. self.possible_stop = False # possible stop may occur next epoch
  609. def __call__(self, epoch, fitness):
  610. """
  611. Check whether to stop training.
  612. Args:
  613. epoch (int): Current epoch of training
  614. fitness (float): Fitness value of current epoch
  615. Returns:
  616. (bool): True if training should stop, False otherwise
  617. """
  618. if fitness is None: # check if fitness=None (happens when val=False)
  619. return False
  620. if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
  621. self.best_epoch = epoch
  622. self.best_fitness = fitness
  623. delta = epoch - self.best_epoch # epochs without improvement
  624. self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
  625. stop = delta >= self.patience # stop training if patience exceeded
  626. if stop:
  627. prefix = colorstr("EarlyStopping: ")
  628. LOGGER.info(
  629. f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
  630. f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
  631. f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
  632. f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
  633. )
  634. return stop
  635. class FXModel(nn.Module):
  636. """
  637. A custom model class for torch.fx compatibility.
  638. This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph manipulation.
  639. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper copying.
  640. Args:
  641. model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
  642. """
  643. def __init__(self, model):
  644. """
  645. Initialize the FXModel.
  646. Args:
  647. model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
  648. """
  649. super().__init__()
  650. copy_attr(self, model)
  651. # Explicitly set `model` since `copy_attr` somehow does not copy it.
  652. self.model = model.model
  653. def forward(self, x):
  654. """
  655. Forward pass through the model.
  656. This method performs the forward pass through the model, handling the dependencies between layers and saving intermediate outputs.
  657. Args:
  658. x (torch.Tensor): The input tensor to the model.
  659. Returns:
  660. (torch.Tensor): The output tensor from the model.
  661. """
  662. y = [] # outputs
  663. for m in self.model:
  664. if m.f != -1: # if not from previous layer
  665. # from earlier layers
  666. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
  667. x = m(x) # run
  668. y.append(x) # save output
  669. return x