ins_losses.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. class DiceLoss(nn.Module):
  5. def __init__(self, smooth=1.):
  6. super(DiceLoss, self).__init__()
  7. self.smooth = smooth
  8. def forward(self, logits, targets):
  9. probs = torch.sigmoid(logits)
  10. probs = probs.view(-1)
  11. targets = targets.view(-1).float()
  12. intersection = (probs * targets).sum()
  13. dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
  14. return 1. - dice
  15. bce_loss = nn.BCEWithLogitsLoss()
  16. dice_loss = DiceLoss()
  17. def combined_loss(preds, targets, alpha=0.5):
  18. bce = bce_loss(preds, targets)
  19. d = dice_loss(preds, targets)
  20. return alpha * bce + (1 - alpha) * d
  21. def align_masks(keypoints, rois, heatmap_size):
  22. print(f'rois:{rois.shape}')
  23. print(f'heatmap_size:{heatmap_size}')
  24. print(f'keypoints.shape:{keypoints.shape}')
  25. # batch_size, num_keypoints, _ = keypoints.shape
  26. t_h, t_w = keypoints.shape[-2:]
  27. scale=heatmap_size/t_w
  28. print(f'scale:{scale}')
  29. x = keypoints[..., 0]*scale
  30. y = keypoints[..., 1]*scale
  31. x = x.unsqueeze(1)
  32. y = y.unsqueeze(1)
  33. num_points=x.shape[2]
  34. print(f'num_points:{num_points}')
  35. mask_4d = keypoints.unsqueeze(1).float()
  36. resized_mask = F.interpolate(
  37. mask_4d,
  38. size = (heatmap_size, heatmap_size),
  39. mode = 'bilinear',
  40. align_corners = False
  41. ).squeeze(1) # [B,heatmap_size,heatmap_size]
  42. # plt.imshow(resized_mask[0].cpu())
  43. # plt.show()
  44. print(f'resized_mask:{resized_mask.shape}')
  45. return resized_mask
  46. def compute_ins_loss(feature_logits, proposals, gt_, pos_matched_idxs):
  47. print(f'compute_arc_loss:{feature_logits.shape}')
  48. N, K, H, W = feature_logits.shape
  49. len_proposals = len(proposals)
  50. empty_count = 0
  51. non_empty_count = 0
  52. for prop in proposals:
  53. if prop.shape[0] == 0:
  54. empty_count += 1
  55. else:
  56. non_empty_count += 1
  57. print(f"Empty proposals count: {empty_count}")
  58. print(f"Non-empty proposals count: {non_empty_count}")
  59. print(f'starte to compute_point_loss')
  60. print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}')
  61. if H != W:
  62. raise ValueError(
  63. f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  64. )
  65. discretization_size = H
  66. gs_heatmaps = []
  67. # print(f'point_matched_idxs:{point_matched_idxs}')
  68. print(f'gt_masks:{gt_[0].shape}')
  69. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs):
  70. # [
  71. # (Tensor(38, 4), Tensor(1, 57, 2), Tensor(38, 1)),
  72. # (Tensor(65, 4), Tensor(1, 74, 2), Tensor(65, 1))
  73. # ]
  74. print(f'proposals_per_image:{proposals_per_image.shape}')
  75. kp = gt_kp_in_image[midx]
  76. t_h, t_w = kp.shape[-2:]
  77. print(f't_h:{t_h}, t_w:{t_w}')
  78. print(f'gt_kp_in_image:{gt_kp_in_image.shape}')
  79. if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
  80. gs_heatmaps_per_img = align_masks(kp, proposals_per_image, discretization_size)
  81. gs_heatmaps.append(gs_heatmaps_per_img)
  82. if len(gs_heatmaps)>0:
  83. gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
  84. print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}')
  85. line_logits = feature_logits.squeeze(1)
  86. print(f'mask shape:{line_logits.shape}')
  87. # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
  88. # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
  89. line_loss=combined_loss(line_logits, gs_heatmaps)
  90. else:
  91. line_loss=100
  92. print("d")
  93. return line_loss