head.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. """Model head modules."""
  3. import copy
  4. import math
  5. import torch
  6. import torch.nn as nn
  7. from torch.nn.init import constant_, xavier_uniform_
  8. from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
  9. from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto
  10. from .conv import Conv, DWConv
  11. from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
  12. from .utils import bias_init_with_prob, linear_init
  13. __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect"
  14. class Detect(nn.Module):
  15. """YOLO Detect head for detection models."""
  16. dynamic = False # force grid reconstruction
  17. export = False # export mode
  18. format = None # export format
  19. end2end = False # end2end
  20. max_det = 300 # max_det
  21. shape = None
  22. anchors = torch.empty(0) # init
  23. strides = torch.empty(0) # init
  24. legacy = False # backward compatibility for v3/v5/v8/v9 models
  25. def __init__(self, nc=80, ch=()):
  26. """Initializes the YOLO detection layer with specified number of classes and channels."""
  27. super().__init__()
  28. self.nc = nc # number of classes
  29. self.nl = len(ch) # number of detection layers
  30. self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
  31. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  32. self.stride = torch.zeros(self.nl) # strides computed during build
  33. c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
  34. self.cv2 = nn.ModuleList(
  35. nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
  36. )
  37. self.cv3 = (
  38. nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
  39. if self.legacy
  40. else nn.ModuleList(
  41. nn.Sequential(
  42. nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
  43. nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
  44. nn.Conv2d(c3, self.nc, 1),
  45. )
  46. for x in ch
  47. )
  48. )
  49. self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
  50. if self.end2end:
  51. self.one2one_cv2 = copy.deepcopy(self.cv2)
  52. self.one2one_cv3 = copy.deepcopy(self.cv3)
  53. def forward(self, x):
  54. """Concatenates and returns predicted bounding boxes and class probabilities."""
  55. if self.end2end:
  56. return self.forward_end2end(x)
  57. for i in range(self.nl):
  58. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  59. if self.training: # Training path
  60. return x
  61. y = self._inference(x)
  62. return y if self.export else (y, x)
  63. def forward_end2end(self, x):
  64. """
  65. Performs forward pass of the v10Detect module.
  66. Args:
  67. x (tensor): Input tensor.
  68. Returns:
  69. (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
  70. If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
  71. """
  72. x_detach = [xi.detach() for xi in x]
  73. one2one = [
  74. torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
  75. ]
  76. for i in range(self.nl):
  77. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  78. if self.training: # Training path
  79. return {"one2many": x, "one2one": one2one}
  80. y = self._inference(one2one)
  81. y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
  82. return y if self.export else (y, {"one2many": x, "one2one": one2one})
  83. def _inference(self, x):
  84. """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
  85. # Inference path
  86. shape = x[0].shape # BCHW
  87. x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
  88. if self.format != "imx" and (self.dynamic or self.shape != shape):
  89. self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
  90. self.shape = shape
  91. if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
  92. box = x_cat[:, : self.reg_max * 4]
  93. cls = x_cat[:, self.reg_max * 4 :]
  94. else:
  95. box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
  96. if self.export and self.format in {"tflite", "edgetpu"}:
  97. # Precompute normalization factor to increase numerical stability
  98. # See https://github.com/ultralytics/ultralytics/issues/7371
  99. grid_h = shape[2]
  100. grid_w = shape[3]
  101. grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
  102. norm = self.strides / (self.stride[0] * grid_size)
  103. dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
  104. elif self.export and self.format == "imx":
  105. dbox = self.decode_bboxes(
  106. self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False
  107. )
  108. return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
  109. else:
  110. dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
  111. return torch.cat((dbox, cls.sigmoid()), 1)
  112. def bias_init(self):
  113. """Initialize Detect() biases, WARNING: requires stride availability."""
  114. m = self # self.model[-1] # Detect() module
  115. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  116. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  117. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  118. a[-1].bias.data[:] = 1.0 # box
  119. b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  120. if self.end2end:
  121. for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
  122. a[-1].bias.data[:] = 1.0 # box
  123. b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  124. def decode_bboxes(self, bboxes, anchors, xywh=True):
  125. """Decode bounding boxes."""
  126. return dist2bbox(bboxes, anchors, xywh=xywh and (not self.end2end), dim=1)
  127. @staticmethod
  128. def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
  129. """
  130. Post-processes YOLO model predictions.
  131. Args:
  132. preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
  133. format [x, y, w, h, class_probs].
  134. max_det (int): Maximum detections per image.
  135. nc (int, optional): Number of classes. Default: 80.
  136. Returns:
  137. (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
  138. dimension format [x, y, w, h, max_class_prob, class_index].
  139. """
  140. batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
  141. boxes, scores = preds.split([4, nc], dim=-1)
  142. index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
  143. boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
  144. scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
  145. scores, index = scores.flatten(1).topk(min(max_det, anchors))
  146. i = torch.arange(batch_size)[..., None] # batch indices
  147. return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
  148. class Segment(Detect):
  149. """YOLO Segment head for segmentation models."""
  150. def __init__(self, nc=80, nm=32, npr=256, ch=()):
  151. """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
  152. super().__init__(nc, ch)
  153. self.nm = nm # number of masks
  154. self.npr = npr # number of protos
  155. self.proto = Proto(ch[0], self.npr, self.nm) # protos
  156. c4 = max(ch[0] // 4, self.nm)
  157. self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
  158. def forward(self, x):
  159. """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
  160. p = self.proto(x[0]) # mask protos
  161. bs = p.shape[0] # batch size
  162. mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
  163. x = Detect.forward(self, x)
  164. if self.training:
  165. return x, mc, p
  166. return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
  167. class OBB(Detect):
  168. """YOLO OBB detection head for detection with rotation models."""
  169. def __init__(self, nc=80, ne=1, ch=()):
  170. """Initialize OBB with number of classes `nc` and layer channels `ch`."""
  171. super().__init__(nc, ch)
  172. self.ne = ne # number of extra parameters
  173. c4 = max(ch[0] // 4, self.ne)
  174. self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
  175. def forward(self, x):
  176. """Concatenates and returns predicted bounding boxes and class probabilities."""
  177. bs = x[0].shape[0] # batch size
  178. angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
  179. # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
  180. angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
  181. # angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
  182. if not self.training:
  183. self.angle = angle
  184. x = Detect.forward(self, x)
  185. if self.training:
  186. return x, angle
  187. return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
  188. def decode_bboxes(self, bboxes, anchors):
  189. """Decode rotated bounding boxes."""
  190. return dist2rbox(bboxes, self.angle, anchors, dim=1)
  191. class Pose(Detect):
  192. """YOLO Pose head for keypoints models."""
  193. def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
  194. """Initialize YOLO network with default parameters and Convolutional Layers."""
  195. super().__init__(nc, ch)
  196. self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
  197. self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
  198. c4 = max(ch[0] // 4, self.nk)
  199. self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
  200. def forward(self, x):
  201. """Perform forward pass through YOLO model and return predictions."""
  202. bs = x[0].shape[0] # batch size
  203. kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
  204. x = Detect.forward(self, x)
  205. if self.training:
  206. return x, kpt
  207. pred_kpt = self.kpts_decode(bs, kpt)
  208. return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
  209. def kpts_decode(self, bs, kpts):
  210. """Decodes keypoints."""
  211. ndim = self.kpt_shape[1]
  212. if self.export:
  213. if self.format in {
  214. "tflite",
  215. "edgetpu",
  216. }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
  217. # Precompute normalization factor to increase numerical stability
  218. y = kpts.view(bs, *self.kpt_shape, -1)
  219. grid_h, grid_w = self.shape[2], self.shape[3]
  220. grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
  221. norm = self.strides / (self.stride[0] * grid_size)
  222. a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
  223. else:
  224. # NCNN fix
  225. y = kpts.view(bs, *self.kpt_shape, -1)
  226. a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
  227. if ndim == 3:
  228. a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
  229. return a.view(bs, self.nk, -1)
  230. else:
  231. y = kpts.clone()
  232. if ndim == 3:
  233. y[:, 2::3] = y[:, 2::3].sigmoid() # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)
  234. y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
  235. y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
  236. return y
  237. class Classify(nn.Module):
  238. """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
  239. export = False # export mode
  240. def __init__(self, c1, c2, k=1, s=1, p=None, g=1):
  241. """Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape."""
  242. super().__init__()
  243. c_ = 1280 # efficientnet_b0 size
  244. self.conv = Conv(c1, c_, k, s, p, g)
  245. self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
  246. self.drop = nn.Dropout(p=0.0, inplace=True)
  247. self.linear = nn.Linear(c_, c2) # to x(b,c2)
  248. def forward(self, x):
  249. """Performs a forward pass of the YOLO model on input image data."""
  250. if isinstance(x, list):
  251. x = torch.cat(x, 1)
  252. x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
  253. if self.training:
  254. return x
  255. y = x.softmax(1) # get final output
  256. return y if self.export else (y, x)
  257. class WorldDetect(Detect):
  258. """Head for integrating YOLO detection models with semantic understanding from text embeddings."""
  259. def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
  260. """Initialize YOLO detection layer with nc classes and layer channels ch."""
  261. super().__init__(nc, ch)
  262. c3 = max(ch[0], min(self.nc, 100))
  263. self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
  264. self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
  265. def forward(self, x, text):
  266. """Concatenates and returns predicted bounding boxes and class probabilities."""
  267. for i in range(self.nl):
  268. x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
  269. if self.training:
  270. return x
  271. # Inference path
  272. shape = x[0].shape # BCHW
  273. x_cat = torch.cat([xi.view(shape[0], self.nc + self.reg_max * 4, -1) for xi in x], 2)
  274. if self.dynamic or self.shape != shape:
  275. self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
  276. self.shape = shape
  277. if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
  278. box = x_cat[:, : self.reg_max * 4]
  279. cls = x_cat[:, self.reg_max * 4 :]
  280. else:
  281. box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
  282. if self.export and self.format in {"tflite", "edgetpu"}:
  283. # Precompute normalization factor to increase numerical stability
  284. # See https://github.com/ultralytics/ultralytics/issues/7371
  285. grid_h = shape[2]
  286. grid_w = shape[3]
  287. grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
  288. norm = self.strides / (self.stride[0] * grid_size)
  289. dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
  290. else:
  291. dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
  292. y = torch.cat((dbox, cls.sigmoid()), 1)
  293. return y if self.export else (y, x)
  294. def bias_init(self):
  295. """Initialize Detect() biases, WARNING: requires stride availability."""
  296. m = self # self.model[-1] # Detect() module
  297. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  298. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  299. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  300. a[-1].bias.data[:] = 1.0 # box
  301. # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  302. class RTDETRDecoder(nn.Module):
  303. """
  304. Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
  305. This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes
  306. and class labels for objects in an image. It integrates features from multiple layers and runs through a series of
  307. Transformer decoder layers to output the final predictions.
  308. """
  309. export = False # export mode
  310. def __init__(
  311. self,
  312. nc=80,
  313. ch=(512, 1024, 2048),
  314. hd=256, # hidden dim
  315. nq=300, # num queries
  316. ndp=4, # num decoder points
  317. nh=8, # num head
  318. ndl=6, # num decoder layers
  319. d_ffn=1024, # dim of feedforward
  320. dropout=0.0,
  321. act=nn.ReLU(),
  322. eval_idx=-1,
  323. # Training args
  324. nd=100, # num denoising
  325. label_noise_ratio=0.5,
  326. box_noise_scale=1.0,
  327. learnt_init_query=False,
  328. ):
  329. """
  330. Initializes the RTDETRDecoder module with the given parameters.
  331. Args:
  332. nc (int): Number of classes. Default is 80.
  333. ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048).
  334. hd (int): Dimension of hidden layers. Default is 256.
  335. nq (int): Number of query points. Default is 300.
  336. ndp (int): Number of decoder points. Default is 4.
  337. nh (int): Number of heads in multi-head attention. Default is 8.
  338. ndl (int): Number of decoder layers. Default is 6.
  339. d_ffn (int): Dimension of the feed-forward networks. Default is 1024.
  340. dropout (float): Dropout rate. Default is 0.
  341. act (nn.Module): Activation function. Default is nn.ReLU.
  342. eval_idx (int): Evaluation index. Default is -1.
  343. nd (int): Number of denoising. Default is 100.
  344. label_noise_ratio (float): Label noise ratio. Default is 0.5.
  345. box_noise_scale (float): Box noise scale. Default is 1.0.
  346. learnt_init_query (bool): Whether to learn initial query embeddings. Default is False.
  347. """
  348. super().__init__()
  349. self.hidden_dim = hd
  350. self.nhead = nh
  351. self.nl = len(ch) # num level
  352. self.nc = nc
  353. self.num_queries = nq
  354. self.num_decoder_layers = ndl
  355. # Backbone feature projection
  356. self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
  357. # NOTE: simplified version but it's not consistent with .pt weights.
  358. # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)
  359. # Transformer module
  360. decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
  361. self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)
  362. # Denoising part
  363. self.denoising_class_embed = nn.Embedding(nc, hd)
  364. self.num_denoising = nd
  365. self.label_noise_ratio = label_noise_ratio
  366. self.box_noise_scale = box_noise_scale
  367. # Decoder embedding
  368. self.learnt_init_query = learnt_init_query
  369. if learnt_init_query:
  370. self.tgt_embed = nn.Embedding(nq, hd)
  371. self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)
  372. # Encoder head
  373. self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
  374. self.enc_score_head = nn.Linear(hd, nc)
  375. self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)
  376. # Decoder head
  377. self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
  378. self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])
  379. self._reset_parameters()
  380. def forward(self, x, batch=None):
  381. """Runs the forward pass of the module, returning bounding box and classification scores for the input."""
  382. from ultralytics.models.utils.ops import get_cdn_group
  383. # Input projection and embedding
  384. feats, shapes = self._get_encoder_input(x)
  385. # Prepare denoising training
  386. dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group(
  387. batch,
  388. self.nc,
  389. self.num_queries,
  390. self.denoising_class_embed.weight,
  391. self.num_denoising,
  392. self.label_noise_ratio,
  393. self.box_noise_scale,
  394. self.training,
  395. )
  396. embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
  397. # Decoder
  398. dec_bboxes, dec_scores = self.decoder(
  399. embed,
  400. refer_bbox,
  401. feats,
  402. shapes,
  403. self.dec_bbox_head,
  404. self.dec_score_head,
  405. self.query_pos_head,
  406. attn_mask=attn_mask,
  407. )
  408. x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
  409. if self.training:
  410. return x
  411. # (bs, 300, 4+nc)
  412. y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
  413. return y if self.export else (y, x)
  414. def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
  415. """Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
  416. anchors = []
  417. for i, (h, w) in enumerate(shapes):
  418. sy = torch.arange(end=h, dtype=dtype, device=device)
  419. sx = torch.arange(end=w, dtype=dtype, device=device)
  420. grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
  421. grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
  422. valid_WH = torch.tensor([w, h], dtype=dtype, device=device)
  423. grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
  424. wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i)
  425. anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
  426. anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
  427. valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
  428. anchors = torch.log(anchors / (1 - anchors))
  429. anchors = anchors.masked_fill(~valid_mask, float("inf"))
  430. return anchors, valid_mask
  431. def _get_encoder_input(self, x):
  432. """Processes and returns encoder inputs by getting projection features from input and concatenating them."""
  433. # Get projection features
  434. x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
  435. # Get encoder inputs
  436. feats = []
  437. shapes = []
  438. for feat in x:
  439. h, w = feat.shape[2:]
  440. # [b, c, h, w] -> [b, h*w, c]
  441. feats.append(feat.flatten(2).permute(0, 2, 1))
  442. # [nl, 2]
  443. shapes.append([h, w])
  444. # [b, h*w, c]
  445. feats = torch.cat(feats, 1)
  446. return feats, shapes
  447. def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
  448. """Generates and prepares the input required for the decoder from the provided features and shapes."""
  449. bs = feats.shape[0]
  450. # Prepare input for decoder
  451. anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
  452. features = self.enc_output(valid_mask * feats) # bs, h*w, 256
  453. enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
  454. # Query selection
  455. # (bs, num_queries)
  456. topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
  457. # (bs, num_queries)
  458. batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
  459. # (bs, num_queries, 256)
  460. top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
  461. # (bs, num_queries, 4)
  462. top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)
  463. # Dynamic anchors + static content
  464. refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors
  465. enc_bboxes = refer_bbox.sigmoid()
  466. if dn_bbox is not None:
  467. refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
  468. enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
  469. embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features
  470. if self.training:
  471. refer_bbox = refer_bbox.detach()
  472. if not self.learnt_init_query:
  473. embeddings = embeddings.detach()
  474. if dn_embed is not None:
  475. embeddings = torch.cat([dn_embed, embeddings], 1)
  476. return embeddings, refer_bbox, enc_bboxes, enc_scores
  477. # TODO
  478. def _reset_parameters(self):
  479. """Initializes or resets the parameters of the model's various components with predefined weights and biases."""
  480. # Class and bbox head init
  481. bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
  482. # NOTE: the weight initialization in `linear_init` would cause NaN when training with custom datasets.
  483. # linear_init(self.enc_score_head)
  484. constant_(self.enc_score_head.bias, bias_cls)
  485. constant_(self.enc_bbox_head.layers[-1].weight, 0.0)
  486. constant_(self.enc_bbox_head.layers[-1].bias, 0.0)
  487. for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
  488. # linear_init(cls_)
  489. constant_(cls_.bias, bias_cls)
  490. constant_(reg_.layers[-1].weight, 0.0)
  491. constant_(reg_.layers[-1].bias, 0.0)
  492. linear_init(self.enc_output[0])
  493. xavier_uniform_(self.enc_output[0].weight)
  494. if self.learnt_init_query:
  495. xavier_uniform_(self.tgt_embed.weight)
  496. xavier_uniform_(self.query_pos_head.layers[0].weight)
  497. xavier_uniform_(self.query_pos_head.layers[1].weight)
  498. for layer in self.input_proj:
  499. xavier_uniform_(layer[0].weight)
  500. class v10Detect(Detect):
  501. """
  502. v10 Detection head from https://arxiv.org/pdf/2405.14458.
  503. Args:
  504. nc (int): Number of classes.
  505. ch (tuple): Tuple of channel sizes.
  506. Attributes:
  507. max_det (int): Maximum number of detections.
  508. Methods:
  509. __init__(self, nc=80, ch=()): Initializes the v10Detect object.
  510. forward(self, x): Performs forward pass of the v10Detect module.
  511. bias_init(self): Initializes biases of the Detect module.
  512. """
  513. end2end = True
  514. def __init__(self, nc=80, ch=()):
  515. """Initializes the v10Detect object with the specified number of classes and input channels."""
  516. super().__init__(nc, ch)
  517. c3 = max(ch[0], min(self.nc, 100)) # channels
  518. # Light cls head
  519. self.cv3 = nn.ModuleList(
  520. nn.Sequential(
  521. nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)),
  522. nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)),
  523. nn.Conv2d(c3, self.nc, 1),
  524. )
  525. for x in ch
  526. )
  527. self.one2one_cv3 = copy.deepcopy(self.cv3)