tasks.py 49 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import contextlib
  3. import pickle
  4. import re
  5. import types
  6. from copy import deepcopy
  7. from pathlib import Path
  8. import thop
  9. import torch
  10. import torch.nn as nn
  11. from ultralytics.nn.modules import (
  12. AIFI,
  13. C1,
  14. C2,
  15. C2PSA,
  16. C3,
  17. C3TR,
  18. ELAN1,
  19. OBB,
  20. PSA,
  21. SPP,
  22. SPPELAN,
  23. SPPF,
  24. AConv,
  25. ADown,
  26. Bottleneck,
  27. BottleneckCSP,
  28. C2f,
  29. C2fAttn,
  30. C2fCIB,
  31. C2fPSA,
  32. C3Ghost,
  33. C3k2,
  34. C3x,
  35. CBFuse,
  36. CBLinear,
  37. Classify,
  38. Concat,
  39. Conv,
  40. Conv2,
  41. DSConv,
  42. ConvTranspose,
  43. Detect,
  44. DWConv,
  45. DWConvTranspose2d,
  46. DSC3k2,
  47. Focus,
  48. GhostBottleneck,
  49. GhostConv,
  50. HGBlock,
  51. HGStem,
  52. ImagePoolingAttn,
  53. Index,
  54. Pose,
  55. RepC3,
  56. RepConv,
  57. RepNCSPELAN4,
  58. RepVGGDW,
  59. ResNetLayer,
  60. RTDETRDecoder,
  61. SCDown,
  62. Segment,
  63. TorchVision,
  64. WorldDetect,
  65. v10Detect,
  66. A2C2f,
  67. HyperACE,
  68. DownsampleConv,
  69. FullPAD_Tunnel,
  70. DSC3k2
  71. )
  72. from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
  73. from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
  74. from ultralytics.utils.loss import (
  75. E2EDetectLoss,
  76. v8ClassificationLoss,
  77. v8DetectionLoss,
  78. v8OBBLoss,
  79. v8PoseLoss,
  80. v8SegmentationLoss,
  81. )
  82. from ultralytics.utils.ops import make_divisible
  83. from ultralytics.utils.plotting import feature_visualization
  84. from ultralytics.utils.torch_utils import (
  85. fuse_conv_and_bn,
  86. fuse_deconv_and_bn,
  87. initialize_weights,
  88. intersect_dicts,
  89. model_info,
  90. scale_img,
  91. time_sync,
  92. )
  93. class BaseModel(nn.Module):
  94. """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""
  95. def forward(self, x, *args, **kwargs):
  96. """
  97. Perform forward pass of the model for either training or inference.
  98. If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
  99. Args:
  100. x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
  101. *args (Any): Variable length argument list.
  102. **kwargs (Any): Arbitrary keyword arguments.
  103. Returns:
  104. (torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
  105. """
  106. if isinstance(x, dict): # for cases of training and validating while training.
  107. return self.loss(x, *args, **kwargs)
  108. return self.predict(x, *args, **kwargs)
  109. def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
  110. """
  111. Perform a forward pass through the network.
  112. Args:
  113. x (torch.Tensor): The input tensor to the model.
  114. profile (bool): Print the computation time of each layer if True, defaults to False.
  115. visualize (bool): Save the feature maps of the model if True, defaults to False.
  116. augment (bool): Augment image during prediction, defaults to False.
  117. embed (list, optional): A list of feature vectors/embeddings to return.
  118. Returns:
  119. (torch.Tensor): The last output of the model.
  120. """
  121. if augment:
  122. return self._predict_augment(x)
  123. return self._predict_once(x, profile, visualize, embed)
  124. def _predict_once(self, x, profile=False, visualize=False, embed=None):
  125. """
  126. Perform a forward pass through the network.
  127. Args:
  128. x (torch.Tensor): The input tensor to the model.
  129. profile (bool): Print the computation time of each layer if True, defaults to False.
  130. visualize (bool): Save the feature maps of the model if True, defaults to False.
  131. embed (list, optional): A list of feature vectors/embeddings to return.
  132. Returns:
  133. (torch.Tensor): The last output of the model.
  134. """
  135. y, dt, embeddings = [], [], [] # outputs
  136. for m in self.model:
  137. if m.f != -1: # if not from previous layer
  138. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  139. if profile:
  140. self._profile_one_layer(m, x, dt)
  141. x = m(x) # run
  142. y.append(x if m.i in self.save else None) # save output
  143. if visualize:
  144. feature_visualization(x, m.type, m.i, save_dir=visualize)
  145. if embed and m.i in embed:
  146. embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
  147. if m.i == max(embed):
  148. return torch.unbind(torch.cat(embeddings, 1), dim=0)
  149. return x
  150. def _predict_augment(self, x):
  151. """Perform augmentations on input image x and return augmented inference."""
  152. LOGGER.warning(
  153. f"WARNING ⚠️ {self.__class__.__name__} does not support 'augment=True' prediction. "
  154. f"Reverting to single-scale prediction."
  155. )
  156. return self._predict_once(x)
  157. def _profile_one_layer(self, m, x, dt):
  158. """
  159. Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results to
  160. the provided list.
  161. Args:
  162. m (nn.Module): The layer to be profiled.
  163. x (torch.Tensor): The input data to the layer.
  164. dt (list): A list to store the computation time of the layer.
  165. Returns:
  166. None
  167. """
  168. c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
  169. flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
  170. t = time_sync()
  171. for _ in range(10):
  172. m(x.copy() if c else x)
  173. dt.append((time_sync() - t) * 100)
  174. if m == self.model[0]:
  175. LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
  176. LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}")
  177. if c:
  178. LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
  179. def fuse(self, verbose=True):
  180. """
  181. Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
  182. computation efficiency.
  183. Returns:
  184. (nn.Module): The fused model is returned.
  185. """
  186. if not self.is_fused():
  187. for m in self.model.modules():
  188. if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
  189. if isinstance(m, Conv2):
  190. m.fuse_convs()
  191. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  192. delattr(m, "bn") # remove batchnorm
  193. m.forward = m.forward_fuse # update forward
  194. if isinstance(m, ConvTranspose) and hasattr(m, "bn"):
  195. m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
  196. delattr(m, "bn") # remove batchnorm
  197. m.forward = m.forward_fuse # update forward
  198. if isinstance(m, RepConv):
  199. m.fuse_convs()
  200. m.forward = m.forward_fuse # update forward
  201. if isinstance(m, RepVGGDW):
  202. m.fuse()
  203. m.forward = m.forward_fuse
  204. self.info(verbose=verbose)
  205. return self
  206. def is_fused(self, thresh=10):
  207. """
  208. Check if the model has less than a certain threshold of BatchNorm layers.
  209. Args:
  210. thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
  211. Returns:
  212. (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
  213. """
  214. bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
  215. return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
  216. def info(self, detailed=False, verbose=True, imgsz=640):
  217. """
  218. Prints model information.
  219. Args:
  220. detailed (bool): if True, prints out detailed information about the model. Defaults to False
  221. verbose (bool): if True, prints out the model information. Defaults to False
  222. imgsz (int): the size of the image that the model will be trained on. Defaults to 640
  223. """
  224. return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
  225. def _apply(self, fn):
  226. """
  227. Applies a function to all the tensors in the model that are not parameters or registered buffers.
  228. Args:
  229. fn (function): the function to apply to the model
  230. Returns:
  231. (BaseModel): An updated BaseModel object.
  232. """
  233. self = super()._apply(fn)
  234. m = self.model[-1] # Detect()
  235. if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
  236. m.stride = fn(m.stride)
  237. m.anchors = fn(m.anchors)
  238. m.strides = fn(m.strides)
  239. return self
  240. def load(self, weights, verbose=True):
  241. """
  242. Load the weights into the model.
  243. Args:
  244. weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
  245. verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
  246. """
  247. model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
  248. csd = model.float().state_dict() # checkpoint state_dict as FP32
  249. csd = intersect_dicts(csd, self.state_dict()) # intersect
  250. self.load_state_dict(csd, strict=False) # load
  251. if verbose:
  252. LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
  253. def loss(self, batch, preds=None):
  254. """
  255. Compute loss.
  256. Args:
  257. batch (dict): Batch to compute loss on
  258. preds (torch.Tensor | List[torch.Tensor]): Predictions.
  259. """
  260. if getattr(self, "criterion", None) is None:
  261. self.criterion = self.init_criterion()
  262. preds = self.forward(batch["img"]) if preds is None else preds
  263. return self.criterion(preds, batch)
  264. def init_criterion(self):
  265. """Initialize the loss criterion for the BaseModel."""
  266. raise NotImplementedError("compute_loss() needs to be implemented by task heads")
  267. class DetectionModel(BaseModel):
  268. """YOLOv8 detection model."""
  269. def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes
  270. """Initialize the YOLOv8 detection model with the given config and parameters."""
  271. super().__init__()
  272. self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
  273. if self.yaml["backbone"][0][2] == "Silence":
  274. LOGGER.warning(
  275. "WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of nn.Identity. "
  276. "Please delete local *.pt file and re-download the latest model checkpoint."
  277. )
  278. self.yaml["backbone"][0][2] = "nn.Identity"
  279. # Define model
  280. ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
  281. if nc and nc != self.yaml["nc"]:
  282. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  283. self.yaml["nc"] = nc # override YAML value
  284. self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
  285. self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
  286. self.inplace = self.yaml.get("inplace", True)
  287. self.end2end = getattr(self.model[-1], "end2end", False)
  288. # Build strides
  289. m = self.model[-1] # Detect()
  290. if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
  291. s = 256 # 2x min stride
  292. m.inplace = self.inplace
  293. def _forward(x):
  294. """Performs a forward pass through the model, handling different Detect subclass types accordingly."""
  295. if self.end2end:
  296. return self.forward(x)["one2many"]
  297. return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
  298. m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
  299. self.stride = m.stride
  300. m.bias_init() # only run once
  301. else:
  302. self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR
  303. # Init weights, biases
  304. initialize_weights(self)
  305. if verbose:
  306. self.info()
  307. LOGGER.info("")
  308. def _predict_augment(self, x):
  309. """Perform augmentations on input image x and return augmented inference and train outputs."""
  310. if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel":
  311. LOGGER.warning("WARNING ⚠️ Model does not support 'augment=True', reverting to single-scale prediction.")
  312. return self._predict_once(x)
  313. img_size = x.shape[-2:] # height, width
  314. s = [1, 0.83, 0.67] # scales
  315. f = [None, 3, None] # flips (2-ud, 3-lr)
  316. y = [] # outputs
  317. for si, fi in zip(s, f):
  318. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  319. yi = super().predict(xi)[0] # forward
  320. yi = self._descale_pred(yi, fi, si, img_size)
  321. y.append(yi)
  322. y = self._clip_augmented(y) # clip augmented tails
  323. return torch.cat(y, -1), None # augmented inference, train
  324. @staticmethod
  325. def _descale_pred(p, flips, scale, img_size, dim=1):
  326. """De-scale predictions following augmented inference (inverse operation)."""
  327. p[:, :4] /= scale # de-scale
  328. x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
  329. if flips == 2:
  330. y = img_size[0] - y # de-flip ud
  331. elif flips == 3:
  332. x = img_size[1] - x # de-flip lr
  333. return torch.cat((x, y, wh, cls), dim)
  334. def _clip_augmented(self, y):
  335. """Clip YOLO augmented inference tails."""
  336. nl = self.model[-1].nl # number of detection layers (P3-P5)
  337. g = sum(4**x for x in range(nl)) # grid points
  338. e = 1 # exclude layer count
  339. i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices
  340. y[0] = y[0][..., :-i] # large
  341. i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
  342. y[-1] = y[-1][..., i:] # small
  343. return y
  344. def init_criterion(self):
  345. """Initialize the loss criterion for the DetectionModel."""
  346. return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)
  347. class OBBModel(DetectionModel):
  348. """YOLOv8 Oriented Bounding Box (OBB) model."""
  349. def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True):
  350. """Initialize YOLOv8 OBB model with given config and parameters."""
  351. super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
  352. def init_criterion(self):
  353. """Initialize the loss criterion for the model."""
  354. return v8OBBLoss(self)
  355. class SegmentationModel(DetectionModel):
  356. """YOLOv8 segmentation model."""
  357. def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True):
  358. """Initialize YOLOv8 segmentation model with given config and parameters."""
  359. super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
  360. def init_criterion(self):
  361. """Initialize the loss criterion for the SegmentationModel."""
  362. return v8SegmentationLoss(self)
  363. class PoseModel(DetectionModel):
  364. """YOLOv8 pose model."""
  365. def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
  366. """Initialize YOLOv8 Pose model."""
  367. if not isinstance(cfg, dict):
  368. cfg = yaml_model_load(cfg) # load model YAML
  369. if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
  370. LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
  371. cfg["kpt_shape"] = data_kpt_shape
  372. super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
  373. def init_criterion(self):
  374. """Initialize the loss criterion for the PoseModel."""
  375. return v8PoseLoss(self)
  376. class ClassificationModel(BaseModel):
  377. """YOLOv8 classification model."""
  378. def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True):
  379. """Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
  380. super().__init__()
  381. self._from_yaml(cfg, ch, nc, verbose)
  382. def _from_yaml(self, cfg, ch, nc, verbose):
  383. """Set YOLOv8 model configurations and define the model architecture."""
  384. self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
  385. # Define model
  386. ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
  387. if nc and nc != self.yaml["nc"]:
  388. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  389. self.yaml["nc"] = nc # override YAML value
  390. elif not nc and not self.yaml.get("nc", None):
  391. raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
  392. self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
  393. self.stride = torch.Tensor([1]) # no stride constraints
  394. self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
  395. self.info()
  396. @staticmethod
  397. def reshape_outputs(model, nc):
  398. """Update a TorchVision classification model to class count 'n' if required."""
  399. name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
  400. if isinstance(m, Classify): # YOLO Classify() head
  401. if m.linear.out_features != nc:
  402. m.linear = nn.Linear(m.linear.in_features, nc)
  403. elif isinstance(m, nn.Linear): # ResNet, EfficientNet
  404. if m.out_features != nc:
  405. setattr(model, name, nn.Linear(m.in_features, nc))
  406. elif isinstance(m, nn.Sequential):
  407. types = [type(x) for x in m]
  408. if nn.Linear in types:
  409. i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index
  410. if m[i].out_features != nc:
  411. m[i] = nn.Linear(m[i].in_features, nc)
  412. elif nn.Conv2d in types:
  413. i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index
  414. if m[i].out_channels != nc:
  415. m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
  416. def init_criterion(self):
  417. """Initialize the loss criterion for the ClassificationModel."""
  418. return v8ClassificationLoss()
  419. class RTDETRDetectionModel(DetectionModel):
  420. """
  421. RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class.
  422. This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both
  423. the training and inference processes. RTDETR is an object detection and tracking model that extends from the
  424. DetectionModel base class.
  425. Attributes:
  426. cfg (str): The configuration file path or preset string. Default is 'rtdetr-l.yaml'.
  427. ch (int): Number of input channels. Default is 3 (RGB).
  428. nc (int, optional): Number of classes for object detection. Default is None.
  429. verbose (bool): Specifies if summary statistics are shown during initialization. Default is True.
  430. Methods:
  431. init_criterion: Initializes the criterion used for loss calculation.
  432. loss: Computes and returns the loss during training.
  433. predict: Performs a forward pass through the network and returns the output.
  434. """
  435. def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
  436. """
  437. Initialize the RTDETRDetectionModel.
  438. Args:
  439. cfg (str): Configuration file name or path.
  440. ch (int): Number of input channels.
  441. nc (int, optional): Number of classes. Defaults to None.
  442. verbose (bool, optional): Print additional information during initialization. Defaults to True.
  443. """
  444. super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
  445. def init_criterion(self):
  446. """Initialize the loss criterion for the RTDETRDetectionModel."""
  447. from ultralytics.models.utils.loss import RTDETRDetectionLoss
  448. return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
  449. def loss(self, batch, preds=None):
  450. """
  451. Compute the loss for the given batch of data.
  452. Args:
  453. batch (dict): Dictionary containing image and label data.
  454. preds (torch.Tensor, optional): Precomputed model predictions. Defaults to None.
  455. Returns:
  456. (tuple): A tuple containing the total loss and main three losses in a tensor.
  457. """
  458. if not hasattr(self, "criterion"):
  459. self.criterion = self.init_criterion()
  460. img = batch["img"]
  461. # NOTE: preprocess gt_bbox and gt_labels to list.
  462. bs = len(img)
  463. batch_idx = batch["batch_idx"]
  464. gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
  465. targets = {
  466. "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),
  467. "bboxes": batch["bboxes"].to(device=img.device),
  468. "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),
  469. "gt_groups": gt_groups,
  470. }
  471. preds = self.predict(img, batch=targets) if preds is None else preds
  472. dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
  473. if dn_meta is None:
  474. dn_bboxes, dn_scores = None, None
  475. else:
  476. dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)
  477. dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)
  478. dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
  479. dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
  480. loss = self.criterion(
  481. (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta
  482. )
  483. # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
  484. return sum(loss.values()), torch.as_tensor(
  485. [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device
  486. )
  487. def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
  488. """
  489. Perform a forward pass through the model.
  490. Args:
  491. x (torch.Tensor): The input tensor.
  492. profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
  493. visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
  494. batch (dict, optional): Ground truth data for evaluation. Defaults to None.
  495. augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
  496. embed (list, optional): A list of feature vectors/embeddings to return.
  497. Returns:
  498. (torch.Tensor): Model's output tensor.
  499. """
  500. y, dt, embeddings = [], [], [] # outputs
  501. for m in self.model[:-1]: # except the head part
  502. if m.f != -1: # if not from previous layer
  503. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  504. if profile:
  505. self._profile_one_layer(m, x, dt)
  506. x = m(x) # run
  507. y.append(x if m.i in self.save else None) # save output
  508. if visualize:
  509. feature_visualization(x, m.type, m.i, save_dir=visualize)
  510. if embed and m.i in embed:
  511. embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
  512. if m.i == max(embed):
  513. return torch.unbind(torch.cat(embeddings, 1), dim=0)
  514. head = self.model[-1]
  515. x = head([y[j] for j in head.f], batch) # head inference
  516. return x
  517. class WorldModel(DetectionModel):
  518. """YOLOv8 World Model."""
  519. def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
  520. """Initialize YOLOv8 world model with given config and parameters."""
  521. self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
  522. self.clip_model = None # CLIP model placeholder
  523. super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
  524. def set_classes(self, text, batch=80, cache_clip_model=True):
  525. """Set classes in advance so that model could do offline-inference without clip model."""
  526. try:
  527. import clip
  528. except ImportError:
  529. check_requirements("git+https://github.com/ultralytics/CLIP.git")
  530. import clip
  531. if (
  532. not getattr(self, "clip_model", None) and cache_clip_model
  533. ): # for backwards compatibility of models lacking clip_model attribute
  534. self.clip_model = clip.load("ViT-B/32")[0]
  535. model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
  536. device = next(model.parameters()).device
  537. text_token = clip.tokenize(text).to(device)
  538. txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
  539. txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
  540. txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
  541. self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
  542. self.model[-1].nc = len(text)
  543. def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
  544. """
  545. Perform a forward pass through the model.
  546. Args:
  547. x (torch.Tensor): The input tensor.
  548. profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
  549. visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
  550. txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
  551. augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
  552. embed (list, optional): A list of feature vectors/embeddings to return.
  553. Returns:
  554. (torch.Tensor): Model's output tensor.
  555. """
  556. txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
  557. if len(txt_feats) != len(x):
  558. txt_feats = txt_feats.repeat(len(x), 1, 1)
  559. ori_txt_feats = txt_feats.clone()
  560. y, dt, embeddings = [], [], [] # outputs
  561. for m in self.model: # except the head part
  562. if m.f != -1: # if not from previous layer
  563. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  564. if profile:
  565. self._profile_one_layer(m, x, dt)
  566. if isinstance(m, C2fAttn):
  567. x = m(x, txt_feats)
  568. elif isinstance(m, WorldDetect):
  569. x = m(x, ori_txt_feats)
  570. elif isinstance(m, ImagePoolingAttn):
  571. txt_feats = m(x, txt_feats)
  572. else:
  573. x = m(x) # run
  574. y.append(x if m.i in self.save else None) # save output
  575. if visualize:
  576. feature_visualization(x, m.type, m.i, save_dir=visualize)
  577. if embed and m.i in embed:
  578. embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
  579. if m.i == max(embed):
  580. return torch.unbind(torch.cat(embeddings, 1), dim=0)
  581. return x
  582. def loss(self, batch, preds=None):
  583. """
  584. Compute loss.
  585. Args:
  586. batch (dict): Batch to compute loss on.
  587. preds (torch.Tensor | List[torch.Tensor]): Predictions.
  588. """
  589. if not hasattr(self, "criterion"):
  590. self.criterion = self.init_criterion()
  591. if preds is None:
  592. preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
  593. return self.criterion(preds, batch)
  594. class Ensemble(nn.ModuleList):
  595. """Ensemble of models."""
  596. def __init__(self):
  597. """Initialize an ensemble of models."""
  598. super().__init__()
  599. def forward(self, x, augment=False, profile=False, visualize=False):
  600. """Function generates the YOLO network's final layer."""
  601. y = [module(x, augment, profile, visualize)[0] for module in self]
  602. # y = torch.stack(y).max(0)[0] # max ensemble
  603. # y = torch.stack(y).mean(0) # mean ensemble
  604. y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C)
  605. return y, None # inference, train output
  606. # Functions ------------------------------------------------------------------------------------------------------------
  607. @contextlib.contextmanager
  608. def temporary_modules(modules=None, attributes=None):
  609. """
  610. Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
  611. This function can be used to change the module paths during runtime. It's useful when refactoring code,
  612. where you've moved a module from one location to another, but you still want to support the old import
  613. paths for backwards compatibility.
  614. Args:
  615. modules (dict, optional): A dictionary mapping old module paths to new module paths.
  616. attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.
  617. Example:
  618. ```python
  619. with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
  620. import old.module # this will now import new.module
  621. from old.module import attribute # this will now import new.module.attribute
  622. ```
  623. Note:
  624. The changes are only in effect inside the context manager and are undone once the context manager exits.
  625. Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
  626. applications or libraries. Use this function with caution.
  627. """
  628. if modules is None:
  629. modules = {}
  630. if attributes is None:
  631. attributes = {}
  632. import sys
  633. from importlib import import_module
  634. try:
  635. # Set attributes in sys.modules under their old name
  636. for old, new in attributes.items():
  637. old_module, old_attr = old.rsplit(".", 1)
  638. new_module, new_attr = new.rsplit(".", 1)
  639. setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))
  640. # Set modules in sys.modules under their old name
  641. for old, new in modules.items():
  642. sys.modules[old] = import_module(new)
  643. yield
  644. finally:
  645. # Remove the temporary module paths
  646. for old in modules:
  647. if old in sys.modules:
  648. del sys.modules[old]
  649. class SafeClass:
  650. """A placeholder class to replace unknown classes during unpickling."""
  651. def __init__(self, *args, **kwargs):
  652. """Initialize SafeClass instance, ignoring all arguments."""
  653. pass
  654. def __call__(self, *args, **kwargs):
  655. """Run SafeClass instance, ignoring all arguments."""
  656. pass
  657. class SafeUnpickler(pickle.Unpickler):
  658. """Custom Unpickler that replaces unknown classes with SafeClass."""
  659. def find_class(self, module, name):
  660. """Attempt to find a class, returning SafeClass if not among safe modules."""
  661. safe_modules = (
  662. "torch",
  663. "collections",
  664. "collections.abc",
  665. "builtins",
  666. "math",
  667. "numpy",
  668. # Add other modules considered safe
  669. )
  670. if module in safe_modules:
  671. return super().find_class(module, name)
  672. else:
  673. return SafeClass
  674. def torch_safe_load(weight, safe_only=False):
  675. """
  676. Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
  677. error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
  678. After installation, the function again attempts to load the model using torch.load().
  679. Args:
  680. weight (str): The file path of the PyTorch model.
  681. safe_only (bool): If True, replace unknown classes with SafeClass during loading.
  682. Example:
  683. ```python
  684. from ultralytics.nn.tasks import torch_safe_load
  685. ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
  686. ```
  687. Returns:
  688. ckpt (dict): The loaded model checkpoint.
  689. file (str): The loaded filename
  690. """
  691. from ultralytics.utils.downloads import attempt_download_asset
  692. check_suffix(file=weight, suffix=".pt")
  693. file = attempt_download_asset(weight) # search online if missing locally
  694. try:
  695. with temporary_modules(
  696. modules={
  697. "ultralytics.yolo.utils": "ultralytics.utils",
  698. "ultralytics.yolo.v8": "ultralytics.models.yolo",
  699. "ultralytics.yolo.data": "ultralytics.data",
  700. },
  701. attributes={
  702. "ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e
  703. "ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10
  704. "ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10
  705. },
  706. ):
  707. if safe_only:
  708. # Load via custom pickle module
  709. safe_pickle = types.ModuleType("safe_pickle")
  710. safe_pickle.Unpickler = SafeUnpickler
  711. safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
  712. with open(file, "rb") as f:
  713. ckpt = torch.load(f, pickle_module=safe_pickle)
  714. else:
  715. ckpt = torch.load(file, map_location="cpu")
  716. except ModuleNotFoundError as e: # e.name is missing module name
  717. if e.name == "models":
  718. raise TypeError(
  719. emojis(
  720. f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained "
  721. f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
  722. f"YOLOv8 at https://github.com/ultralytics/ultralytics."
  723. f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
  724. f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
  725. )
  726. ) from e
  727. LOGGER.warning(
  728. f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in Ultralytics requirements."
  729. f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
  730. f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
  731. f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
  732. )
  733. check_requirements(e.name) # install missing module
  734. ckpt = torch.load(file, map_location="cpu")
  735. if not isinstance(ckpt, dict):
  736. # File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt")
  737. LOGGER.warning(
  738. f"WARNING ⚠️ The file '{weight}' appears to be improperly saved or formatted. "
  739. f"For optimal results, use model.save('filename.pt') to correctly save YOLO models."
  740. )
  741. ckpt = {"model": ckpt.model}
  742. return ckpt, file
  743. def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
  744. """Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
  745. ensemble = Ensemble()
  746. for w in weights if isinstance(weights, list) else [weights]:
  747. ckpt, w = torch_safe_load(w) # load ckpt
  748. args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
  749. model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
  750. # Model compatibility updates
  751. model.args = args # attach args to model
  752. model.pt_path = w # attach *.pt file path to model
  753. model.task = guess_model_task(model)
  754. if not hasattr(model, "stride"):
  755. model.stride = torch.tensor([32.0])
  756. # Append
  757. ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode
  758. # Module updates
  759. for m in ensemble.modules():
  760. if hasattr(m, "inplace"):
  761. m.inplace = inplace
  762. elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
  763. m.recompute_scale_factor = None # torch 1.11.0 compatibility
  764. # Return model
  765. if len(ensemble) == 1:
  766. return ensemble[-1]
  767. # Return ensemble
  768. LOGGER.info(f"Ensemble created with {weights}\n")
  769. for k in "names", "nc", "yaml":
  770. setattr(ensemble, k, getattr(ensemble[0], k))
  771. ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride
  772. assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
  773. return ensemble
  774. def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
  775. """Loads a single model weights."""
  776. ckpt, weight = torch_safe_load(weight) # load ckpt
  777. args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
  778. model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
  779. # Model compatibility updates
  780. model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
  781. model.pt_path = weight # attach *.pt file path to model
  782. model.task = guess_model_task(model)
  783. if not hasattr(model, "stride"):
  784. model.stride = torch.tensor([32.0])
  785. model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode
  786. # Module updates
  787. for m in model.modules():
  788. if hasattr(m, "inplace"):
  789. m.inplace = inplace
  790. elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
  791. m.recompute_scale_factor = None # torch 1.11.0 compatibility
  792. # Return model and ckpt
  793. return model, ckpt
  794. def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
  795. """Parse a YOLO model.yaml dictionary into a PyTorch model."""
  796. import ast
  797. # Args
  798. legacy = True # backward compatibility for v3/v5/v8/v9 models
  799. max_channels = float("inf")
  800. nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
  801. depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
  802. if scales:
  803. scale = d.get("scale")
  804. if not scale:
  805. scale = tuple(scales.keys())[0]
  806. LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
  807. depth, width, max_channels = scales[scale]
  808. if act:
  809. Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
  810. if verbose:
  811. LOGGER.info(f"{colorstr('activation:')} {act}") # print
  812. if verbose:
  813. LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
  814. ch = [ch]
  815. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  816. for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
  817. m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
  818. for j, a in enumerate(args):
  819. if isinstance(a, str):
  820. with contextlib.suppress(ValueError):
  821. args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
  822. n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
  823. if m in {
  824. Classify,
  825. Conv,
  826. ConvTranspose,
  827. GhostConv,
  828. Bottleneck,
  829. GhostBottleneck,
  830. SPP,
  831. SPPF,
  832. C2fPSA,
  833. C2PSA,
  834. DWConv,
  835. Focus,
  836. BottleneckCSP,
  837. C1,
  838. C2,
  839. C2f,
  840. C3k2,
  841. RepNCSPELAN4,
  842. ELAN1,
  843. ADown,
  844. AConv,
  845. SPPELAN,
  846. C2fAttn,
  847. C3,
  848. C3TR,
  849. C3Ghost,
  850. nn.ConvTranspose2d,
  851. DWConvTranspose2d,
  852. C3x,
  853. RepC3,
  854. PSA,
  855. SCDown,
  856. C2fCIB,
  857. A2C2f,
  858. DSC3k2,
  859. DSConv
  860. }:
  861. c1, c2 = ch[f], args[0]
  862. if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
  863. c2 = make_divisible(min(c2, max_channels) * width, 8)
  864. if m is C2fAttn:
  865. args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels
  866. args[2] = int(
  867. max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
  868. ) # num heads
  869. args = [c1, c2, *args[1:]]
  870. if m in {
  871. BottleneckCSP,
  872. C1,
  873. C2,
  874. C2f,
  875. C3k2,
  876. C2fAttn,
  877. C3,
  878. C3TR,
  879. C3Ghost,
  880. C3x,
  881. RepC3,
  882. C2fPSA,
  883. C2fCIB,
  884. C2PSA,
  885. A2C2f,
  886. DSC3k2
  887. }:
  888. args.insert(2, n) # number of repeats
  889. n = 1
  890. if m in {C3k2, DSC3k2}: # for P/U sizes
  891. legacy = False
  892. if scale in "lx":
  893. args[3] = True
  894. if m is A2C2f:
  895. legacy = False
  896. if scale in "lx": # for L/X sizes
  897. args.append(True)
  898. args.append(1.5)
  899. elif m is AIFI:
  900. args = [ch[f], *args]
  901. elif m in {HGStem, HGBlock}:
  902. c1, cm, c2 = ch[f], args[0], args[1]
  903. args = [c1, cm, c2, *args[2:]]
  904. if m is HGBlock:
  905. args.insert(4, n) # number of repeats
  906. n = 1
  907. elif m is ResNetLayer:
  908. c2 = args[1] if args[3] else args[1] * 4
  909. elif m is nn.BatchNorm2d:
  910. args = [ch[f]]
  911. elif m is Concat:
  912. c2 = sum(ch[x] for x in f)
  913. elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
  914. args.append([ch[x] for x in f])
  915. if m is Segment:
  916. args[2] = make_divisible(min(args[2], max_channels) * width, 8)
  917. if m in {Detect, Segment, Pose, OBB}:
  918. m.legacy = legacy
  919. elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
  920. args.insert(1, [ch[x] for x in f])
  921. elif m in {CBLinear, TorchVision, Index}:
  922. c2 = args[0]
  923. c1 = ch[f]
  924. args = [c1, c2, *args[1:]]
  925. elif m is CBFuse:
  926. c2 = ch[f[-1]]
  927. elif m is HyperACE:
  928. legacy = False
  929. c1 = ch[f[1]]
  930. c2 = args[0]
  931. c2 = make_divisible(min(c2, max_channels) * width, 8)
  932. he = args[1]
  933. if scale in "n":
  934. he = int(args[1] * 0.5)
  935. elif scale in "x":
  936. he = int(args[1] * 1.5)
  937. args = [c1, c2, n, he, *args[2:]]
  938. n = 1
  939. if scale in "lx": # for L/X sizes
  940. args.append(False)
  941. elif m is DownsampleConv:
  942. c1 = ch[f]
  943. c2 = c1 * 2
  944. args = [c1]
  945. if scale in "lx": # for L/X sizes
  946. args.append(False)
  947. c2 =c1
  948. elif m is FullPAD_Tunnel:
  949. c2 = ch[f[0]]
  950. else:
  951. c2 = ch[f]
  952. m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  953. t = str(m)[8:-2].replace("__main__.", "") # module type
  954. m_.np = sum(x.numel() for x in m_.parameters()) # number params
  955. m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
  956. if verbose:
  957. LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print
  958. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  959. layers.append(m_)
  960. if i == 0:
  961. ch = []
  962. ch.append(c2)
  963. return nn.Sequential(*layers), sorted(save)
  964. def yaml_model_load(path):
  965. """Load a YOLOv8 model from a YAML file."""
  966. path = Path(path)
  967. if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
  968. new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
  969. LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
  970. path = path.with_name(new_stem + path.suffix)
  971. unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
  972. yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
  973. d = yaml_load(yaml_file) # model dict
  974. d["scale"] = guess_model_scale(path)
  975. d["yaml_file"] = str(path)
  976. return d
  977. def guess_model_scale(model_path):
  978. """
  979. Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale. The function
  980. uses regular expression matching to find the pattern of the model scale in the YAML file name, which is denoted by
  981. n, s, m, l, or x. The function returns the size character of the model scale as a string.
  982. Args:
  983. model_path (str | Path): The path to the YOLO model's YAML file.
  984. Returns:
  985. (str): The size character of the model's scale, which can be n, s, m, l, or x.
  986. """
  987. try:
  988. return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # noqa, returns n, s, m, l, or x
  989. except AttributeError:
  990. return ""
  991. def guess_model_task(model):
  992. """
  993. Guess the task of a PyTorch model from its architecture or configuration.
  994. Args:
  995. model (nn.Module | dict): PyTorch model or model configuration in YAML format.
  996. Returns:
  997. (str): Task of the model ('detect', 'segment', 'classify', 'pose').
  998. Raises:
  999. SyntaxError: If the task of the model could not be determined.
  1000. """
  1001. def cfg2task(cfg):
  1002. """Guess from YAML dictionary."""
  1003. m = cfg["head"][-1][-2].lower() # output module name
  1004. if m in {"classify", "classifier", "cls", "fc"}:
  1005. return "classify"
  1006. if "detect" in m:
  1007. return "detect"
  1008. if m == "segment":
  1009. return "segment"
  1010. if m == "pose":
  1011. return "pose"
  1012. if m == "obb":
  1013. return "obb"
  1014. # Guess from model cfg
  1015. if isinstance(model, dict):
  1016. with contextlib.suppress(Exception):
  1017. return cfg2task(model)
  1018. # Guess from PyTorch model
  1019. if isinstance(model, nn.Module): # PyTorch model
  1020. for x in "model.args", "model.model.args", "model.model.model.args":
  1021. with contextlib.suppress(Exception):
  1022. return eval(x)["task"]
  1023. for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
  1024. with contextlib.suppress(Exception):
  1025. return cfg2task(eval(x))
  1026. for m in model.modules():
  1027. if isinstance(m, Segment):
  1028. return "segment"
  1029. elif isinstance(m, Classify):
  1030. return "classify"
  1031. elif isinstance(m, Pose):
  1032. return "pose"
  1033. elif isinstance(m, OBB):
  1034. return "obb"
  1035. elif isinstance(m, (Detect, WorldDetect, v10Detect)):
  1036. return "detect"
  1037. # Guess from model filename
  1038. if isinstance(model, (str, Path)):
  1039. model = Path(model)
  1040. if "-seg" in model.stem or "segment" in model.parts:
  1041. return "segment"
  1042. elif "-cls" in model.stem or "classify" in model.parts:
  1043. return "classify"
  1044. elif "-pose" in model.stem or "pose" in model.parts:
  1045. return "pose"
  1046. elif "-obb" in model.stem or "obb" in model.parts:
  1047. return "obb"
  1048. elif "detect" in model.parts:
  1049. return "detect"
  1050. # Unable to determine task from model
  1051. LOGGER.warning(
  1052. "WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
  1053. "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'."
  1054. )
  1055. return "detect" # assume detect