loss.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from ultralytics.utils.loss import FocalLoss, VarifocalLoss
  6. from ultralytics.utils.metrics import bbox_iou
  7. from .ops import HungarianMatcher
  8. class DETRLoss(nn.Module):
  9. """
  10. DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the
  11. DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary
  12. losses.
  13. Attributes:
  14. nc (int): The number of classes.
  15. loss_gain (dict): Coefficients for different loss components.
  16. aux_loss (bool): Whether to compute auxiliary losses.
  17. use_fl (bool): Use FocalLoss or not.
  18. use_vfl (bool): Use VarifocalLoss or not.
  19. use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch.
  20. uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True.
  21. matcher (HungarianMatcher): Object to compute matching cost and indices.
  22. fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None.
  23. vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None.
  24. device (torch.device): Device on which tensors are stored.
  25. """
  26. def __init__(
  27. self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
  28. ):
  29. """
  30. Initialize DETR loss function with customizable components and gains.
  31. Uses default loss_gain if not provided. Initializes HungarianMatcher with
  32. preset cost gains. Supports auxiliary losses and various loss types.
  33. Args:
  34. nc (int): Number of classes.
  35. loss_gain (dict): Coefficients for different loss components.
  36. aux_loss (bool): Use auxiliary losses from each decoder layer.
  37. use_fl (bool): Use FocalLoss.
  38. use_vfl (bool): Use VarifocalLoss.
  39. use_uni_match (bool): Use fixed layer for auxiliary branch label assignment.
  40. uni_match_ind (int): Index of fixed layer for uni_match.
  41. """
  42. super().__init__()
  43. if loss_gain is None:
  44. loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
  45. self.nc = nc
  46. self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
  47. self.loss_gain = loss_gain
  48. self.aux_loss = aux_loss
  49. self.fl = FocalLoss() if use_fl else None
  50. self.vfl = VarifocalLoss() if use_vfl else None
  51. self.use_uni_match = use_uni_match
  52. self.uni_match_ind = uni_match_ind
  53. self.device = None
  54. def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
  55. """Computes the classification loss based on predictions, target values, and ground truth scores."""
  56. # Logits: [b, query, num_classes], gt_class: list[[n, 1]]
  57. name_class = f"loss_class{postfix}"
  58. bs, nq = pred_scores.shape[:2]
  59. # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
  60. one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
  61. one_hot.scatter_(2, targets.unsqueeze(-1), 1)
  62. one_hot = one_hot[..., :-1]
  63. gt_scores = gt_scores.view(bs, nq, 1) * one_hot
  64. if self.fl:
  65. if num_gts and self.vfl:
  66. loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
  67. else:
  68. loss_cls = self.fl(pred_scores, one_hot.float())
  69. loss_cls /= max(num_gts, 1) / nq
  70. else:
  71. loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
  72. return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
  73. def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
  74. """Computes bounding box and GIoU losses for predicted and ground truth bounding boxes."""
  75. # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
  76. name_bbox = f"loss_bbox{postfix}"
  77. name_giou = f"loss_giou{postfix}"
  78. loss = {}
  79. if len(gt_bboxes) == 0:
  80. loss[name_bbox] = torch.tensor(0.0, device=self.device)
  81. loss[name_giou] = torch.tensor(0.0, device=self.device)
  82. return loss
  83. loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
  84. loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
  85. loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
  86. loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
  87. return {k: v.squeeze() for k, v in loss.items()}
  88. # This function is for future RT-DETR Segment models
  89. # def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
  90. # # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
  91. # name_mask = f'loss_mask{postfix}'
  92. # name_dice = f'loss_dice{postfix}'
  93. #
  94. # loss = {}
  95. # if sum(len(a) for a in gt_mask) == 0:
  96. # loss[name_mask] = torch.tensor(0., device=self.device)
  97. # loss[name_dice] = torch.tensor(0., device=self.device)
  98. # return loss
  99. #
  100. # num_gts = len(gt_mask)
  101. # src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
  102. # src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
  103. # # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
  104. # loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
  105. # torch.tensor([num_gts], dtype=torch.float32))
  106. # loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
  107. # return loss
  108. # This function is for future RT-DETR Segment models
  109. # @staticmethod
  110. # def _dice_loss(inputs, targets, num_gts):
  111. # inputs = F.sigmoid(inputs).flatten(1)
  112. # targets = targets.flatten(1)
  113. # numerator = 2 * (inputs * targets).sum(1)
  114. # denominator = inputs.sum(-1) + targets.sum(-1)
  115. # loss = 1 - (numerator + 1) / (denominator + 1)
  116. # return loss.sum() / num_gts
  117. def _get_loss_aux(
  118. self,
  119. pred_bboxes,
  120. pred_scores,
  121. gt_bboxes,
  122. gt_cls,
  123. gt_groups,
  124. match_indices=None,
  125. postfix="",
  126. masks=None,
  127. gt_mask=None,
  128. ):
  129. """Get auxiliary losses."""
  130. # NOTE: loss class, bbox, giou, mask, dice
  131. loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
  132. if match_indices is None and self.use_uni_match:
  133. match_indices = self.matcher(
  134. pred_bboxes[self.uni_match_ind],
  135. pred_scores[self.uni_match_ind],
  136. gt_bboxes,
  137. gt_cls,
  138. gt_groups,
  139. masks=masks[self.uni_match_ind] if masks is not None else None,
  140. gt_mask=gt_mask,
  141. )
  142. for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
  143. aux_masks = masks[i] if masks is not None else None
  144. loss_ = self._get_loss(
  145. aux_bboxes,
  146. aux_scores,
  147. gt_bboxes,
  148. gt_cls,
  149. gt_groups,
  150. masks=aux_masks,
  151. gt_mask=gt_mask,
  152. postfix=postfix,
  153. match_indices=match_indices,
  154. )
  155. loss[0] += loss_[f"loss_class{postfix}"]
  156. loss[1] += loss_[f"loss_bbox{postfix}"]
  157. loss[2] += loss_[f"loss_giou{postfix}"]
  158. # if masks is not None and gt_mask is not None:
  159. # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
  160. # loss[3] += loss_[f'loss_mask{postfix}']
  161. # loss[4] += loss_[f'loss_dice{postfix}']
  162. loss = {
  163. f"loss_class_aux{postfix}": loss[0],
  164. f"loss_bbox_aux{postfix}": loss[1],
  165. f"loss_giou_aux{postfix}": loss[2],
  166. }
  167. # if masks is not None and gt_mask is not None:
  168. # loss[f'loss_mask_aux{postfix}'] = loss[3]
  169. # loss[f'loss_dice_aux{postfix}'] = loss[4]
  170. return loss
  171. @staticmethod
  172. def _get_index(match_indices):
  173. """Returns batch indices, source indices, and destination indices from provided match indices."""
  174. batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
  175. src_idx = torch.cat([src for (src, _) in match_indices])
  176. dst_idx = torch.cat([dst for (_, dst) in match_indices])
  177. return (batch_idx, src_idx), dst_idx
  178. def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
  179. """Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
  180. pred_assigned = torch.cat(
  181. [
  182. t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
  183. for t, (i, _) in zip(pred_bboxes, match_indices)
  184. ]
  185. )
  186. gt_assigned = torch.cat(
  187. [
  188. t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
  189. for t, (_, j) in zip(gt_bboxes, match_indices)
  190. ]
  191. )
  192. return pred_assigned, gt_assigned
  193. def _get_loss(
  194. self,
  195. pred_bboxes,
  196. pred_scores,
  197. gt_bboxes,
  198. gt_cls,
  199. gt_groups,
  200. masks=None,
  201. gt_mask=None,
  202. postfix="",
  203. match_indices=None,
  204. ):
  205. """Get losses."""
  206. if match_indices is None:
  207. match_indices = self.matcher(
  208. pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
  209. )
  210. idx, gt_idx = self._get_index(match_indices)
  211. pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
  212. bs, nq = pred_scores.shape[:2]
  213. targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
  214. targets[idx] = gt_cls[gt_idx]
  215. gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
  216. if len(gt_bboxes):
  217. gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
  218. return {
  219. **self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix),
  220. **self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix),
  221. # **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
  222. }
  223. def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
  224. """
  225. Calculate loss for predicted bounding boxes and scores.
  226. Args:
  227. pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
  228. pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
  229. batch (dict): Batch information containing:
  230. cls (torch.Tensor): Ground truth classes, shape [num_gts].
  231. bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
  232. gt_groups (List[int]): Number of ground truths for each image in the batch.
  233. postfix (str): Postfix for loss names.
  234. **kwargs (Any): Additional arguments, may include 'match_indices'.
  235. Returns:
  236. (dict): Computed losses, including main and auxiliary (if enabled).
  237. Note:
  238. Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
  239. self.aux_loss is True.
  240. """
  241. self.device = pred_bboxes.device
  242. match_indices = kwargs.get("match_indices", None)
  243. gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
  244. total_loss = self._get_loss(
  245. pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
  246. )
  247. if self.aux_loss:
  248. total_loss.update(
  249. self._get_loss_aux(
  250. pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
  251. )
  252. )
  253. return total_loss
  254. class RTDETRDetectionLoss(DETRLoss):
  255. """
  256. Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
  257. This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
  258. an additional denoising training loss when provided with denoising metadata.
  259. """
  260. def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
  261. """
  262. Forward pass to compute the detection loss.
  263. Args:
  264. preds (tuple): Predicted bounding boxes and scores.
  265. batch (dict): Batch data containing ground truth information.
  266. dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None.
  267. dn_scores (torch.Tensor, optional): Denoising scores. Default is None.
  268. dn_meta (dict, optional): Metadata for denoising. Default is None.
  269. Returns:
  270. (dict): Dictionary containing the total loss and, if applicable, the denoising loss.
  271. """
  272. pred_bboxes, pred_scores = preds
  273. total_loss = super().forward(pred_bboxes, pred_scores, batch)
  274. # Check for denoising metadata to compute denoising training loss
  275. if dn_meta is not None:
  276. dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
  277. assert len(batch["gt_groups"]) == len(dn_pos_idx)
  278. # Get the match indices for denoising
  279. match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
  280. # Compute the denoising training loss
  281. dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
  282. total_loss.update(dn_loss)
  283. else:
  284. # If no denoising metadata is provided, set denoising loss to zero
  285. total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
  286. return total_loss
  287. @staticmethod
  288. def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
  289. """
  290. Get the match indices for denoising.
  291. Args:
  292. dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.
  293. dn_num_group (int): Number of denoising groups.
  294. gt_groups (List[int]): List of integers representing the number of ground truths for each image.
  295. Returns:
  296. (List[tuple]): List of tuples containing matched indices for denoising.
  297. """
  298. dn_match_indices = []
  299. idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
  300. for i, num_gt in enumerate(gt_groups):
  301. if num_gt > 0:
  302. gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
  303. gt_idx = gt_idx.repeat(dn_num_group)
  304. assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
  305. f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
  306. dn_match_indices.append((dn_pos_idx[i], gt_idx))
  307. else:
  308. dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
  309. return dn_match_indices