ops.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from scipy.optimize import linear_sum_assignment
  6. from ultralytics.utils.metrics import bbox_iou
  7. from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
  8. class HungarianMatcher(nn.Module):
  9. """
  10. A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an
  11. end-to-end fashion.
  12. HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost
  13. function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
  14. Attributes:
  15. cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
  16. use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
  17. with_mask (bool): Indicates whether the model makes mask predictions.
  18. num_sample_points (int): The number of sample points used in mask cost calculation.
  19. alpha (float): The alpha factor in Focal Loss calculation.
  20. gamma (float): The gamma factor in Focal Loss calculation.
  21. Methods:
  22. forward(pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): Computes the
  23. assignment between predictions and ground truths for a batch.
  24. _cost_mask(bs, num_gts, masks=None, gt_mask=None): Computes the mask cost and dice cost if masks are predicted.
  25. """
  26. def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
  27. """Initializes a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes."""
  28. super().__init__()
  29. if cost_gain is None:
  30. cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
  31. self.cost_gain = cost_gain
  32. self.use_fl = use_fl
  33. self.with_mask = with_mask
  34. self.num_sample_points = num_sample_points
  35. self.alpha = alpha
  36. self.gamma = gamma
  37. def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
  38. """
  39. Forward pass for HungarianMatcher. This function computes costs based on prediction and ground truth
  40. (classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching between
  41. predictions and ground truth based on these costs.
  42. Args:
  43. pred_bboxes (Tensor): Predicted bounding boxes with shape [batch_size, num_queries, 4].
  44. pred_scores (Tensor): Predicted scores with shape [batch_size, num_queries, num_classes].
  45. gt_cls (torch.Tensor): Ground truth classes with shape [num_gts, ].
  46. gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape [num_gts, 4].
  47. gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
  48. each image.
  49. masks (Tensor, optional): Predicted masks with shape [batch_size, num_queries, height, width].
  50. Defaults to None.
  51. gt_mask (List[Tensor], optional): List of ground truth masks, each with shape [num_masks, Height, Width].
  52. Defaults to None.
  53. Returns:
  54. (List[Tuple[Tensor, Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
  55. - index_i is the tensor of indices of the selected predictions (in order)
  56. - index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
  57. For each batch element, it holds:
  58. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  59. """
  60. bs, nq, nc = pred_scores.shape
  61. if sum(gt_groups) == 0:
  62. return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
  63. # We flatten to compute the cost matrices in a batch
  64. # [batch_size * num_queries, num_classes]
  65. pred_scores = pred_scores.detach().view(-1, nc)
  66. pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
  67. # [batch_size * num_queries, 4]
  68. pred_bboxes = pred_bboxes.detach().view(-1, 4)
  69. # Compute the classification cost
  70. pred_scores = pred_scores[:, gt_cls]
  71. if self.use_fl:
  72. neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
  73. pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
  74. cost_class = pos_cost_class - neg_cost_class
  75. else:
  76. cost_class = -pred_scores
  77. # Compute the L1 cost between boxes
  78. cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
  79. # Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
  80. cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
  81. # Final cost matrix
  82. C = (
  83. self.cost_gain["class"] * cost_class
  84. + self.cost_gain["bbox"] * cost_bbox
  85. + self.cost_gain["giou"] * cost_giou
  86. )
  87. # Compute the mask cost and dice cost
  88. if self.with_mask:
  89. C += self._cost_mask(bs, gt_groups, masks, gt_mask)
  90. # Set invalid values (NaNs and infinities) to 0 (fixes ValueError: matrix contains invalid numeric entries)
  91. C[C.isnan() | C.isinf()] = 0.0
  92. C = C.view(bs, nq, -1).cpu()
  93. indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
  94. gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
  95. return [
  96. (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
  97. for k, (i, j) in enumerate(indices)
  98. ]
  99. # This function is for future RT-DETR Segment models
  100. # def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
  101. # assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
  102. # # all masks share the same set of points for efficient matching
  103. # sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
  104. # sample_points = 2.0 * sample_points - 1.0
  105. #
  106. # out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
  107. # out_mask = out_mask.flatten(0, 1)
  108. #
  109. # tgt_mask = torch.cat(gt_mask).unsqueeze(1)
  110. # sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
  111. # tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
  112. #
  113. # with torch.amp.autocast("cuda", enabled=False):
  114. # # binary cross entropy cost
  115. # pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
  116. # neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
  117. # cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
  118. # cost_mask /= self.num_sample_points
  119. #
  120. # # dice cost
  121. # out_mask = F.sigmoid(out_mask)
  122. # numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
  123. # denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
  124. # cost_dice = 1 - (numerator + 1) / (denominator + 1)
  125. #
  126. # C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
  127. # return C
  128. def get_cdn_group(
  129. batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
  130. ):
  131. """
  132. Get contrastive denoising training group. This function creates a contrastive denoising training group with positive
  133. and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
  134. and returns the modified labels, bounding boxes, attention mask and meta information.
  135. Args:
  136. batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes'
  137. (torch.Tensor with shape [num_gts, 4]), 'gt_groups' (List(int)) which is a list of batch size length
  138. indicating the number of gts of each image.
  139. num_classes (int): Number of classes.
  140. num_queries (int): Number of queries.
  141. class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
  142. num_dn (int, optional): Number of denoising. Defaults to 100.
  143. cls_noise_ratio (float, optional): Noise ratio for class labels. Defaults to 0.5.
  144. box_noise_scale (float, optional): Noise scale for bounding box coordinates. Defaults to 1.0.
  145. training (bool, optional): If it's in training mode. Defaults to False.
  146. Returns:
  147. (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Dict]]): The modified class embeddings,
  148. bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
  149. is less than or equal to 0, the function returns None for all elements in the tuple.
  150. """
  151. if (not training) or num_dn <= 0:
  152. return None, None, None, None
  153. gt_groups = batch["gt_groups"]
  154. total_num = sum(gt_groups)
  155. max_nums = max(gt_groups)
  156. if max_nums == 0:
  157. return None, None, None, None
  158. num_group = num_dn // max_nums
  159. num_group = 1 if num_group == 0 else num_group
  160. # Pad gt to max_num of a batch
  161. bs = len(gt_groups)
  162. gt_cls = batch["cls"] # (bs*num, )
  163. gt_bbox = batch["bboxes"] # bs*num, 4
  164. b_idx = batch["batch_idx"]
  165. # Each group has positive and negative queries.
  166. dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
  167. dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
  168. dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
  169. # Positive and negative mask
  170. # (bs*num*num_group, ), the second total_num*num_group part as negative samples
  171. neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
  172. if cls_noise_ratio > 0:
  173. # Half of bbox prob
  174. mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
  175. idx = torch.nonzero(mask).squeeze(-1)
  176. # Randomly put a new one here
  177. new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
  178. dn_cls[idx] = new_label
  179. if box_noise_scale > 0:
  180. known_bbox = xywh2xyxy(dn_bbox)
  181. diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
  182. rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
  183. rand_part = torch.rand_like(dn_bbox)
  184. rand_part[neg_idx] += 1.0
  185. rand_part *= rand_sign
  186. known_bbox += rand_part * diff
  187. known_bbox.clip_(min=0.0, max=1.0)
  188. dn_bbox = xyxy2xywh(known_bbox)
  189. dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
  190. num_dn = int(max_nums * 2 * num_group) # total denoising queries
  191. # class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
  192. dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
  193. padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
  194. padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
  195. map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
  196. pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
  197. map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
  198. padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
  199. padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
  200. tgt_size = num_dn + num_queries
  201. attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
  202. # Match query cannot see the reconstruct
  203. attn_mask[num_dn:, :num_dn] = True
  204. # Reconstruct cannot see each other
  205. for i in range(num_group):
  206. if i == 0:
  207. attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
  208. if i == num_group - 1:
  209. attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
  210. else:
  211. attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
  212. attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
  213. dn_meta = {
  214. "dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
  215. "dn_num_group": num_group,
  216. "dn_num_split": [num_dn, num_queries],
  217. }
  218. return (
  219. padding_cls.to(class_embed.device),
  220. padding_bbox.to(class_embed.device),
  221. attn_mask.to(class_embed.device),
  222. dn_meta,
  223. )