import torch import torch.nn.functional as F from torch import nn class DiceLoss(nn.Module): def __init__(self, smooth=1.): super(DiceLoss, self).__init__() self.smooth = smooth def forward(self, logits, targets): probs = torch.sigmoid(logits) probs = probs.view(-1) targets = targets.view(-1).float() intersection = (probs * targets).sum() dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth) return 1. - dice bce_loss = nn.BCEWithLogitsLoss() dice_loss = DiceLoss() def combined_loss(preds, targets, alpha=0.5): bce = bce_loss(preds, targets) d = dice_loss(preds, targets) return alpha * bce + (1 - alpha) * d def align_masks(keypoints, rois, heatmap_size): print(f'rois:{rois.shape}') print(f'heatmap_size:{heatmap_size}') print(f'keypoints.shape:{keypoints.shape}') # batch_size, num_keypoints, _ = keypoints.shape t_h, t_w = keypoints.shape[-2:] scale=heatmap_size/t_w print(f'scale:{scale}') x = keypoints[..., 0]*scale y = keypoints[..., 1]*scale x = x.unsqueeze(1) y = y.unsqueeze(1) num_points=x.shape[2] print(f'num_points:{num_points}') mask_4d = keypoints.unsqueeze(1).float() resized_mask = F.interpolate( mask_4d, size = (heatmap_size, heatmap_size), mode = 'bilinear', align_corners = False ).squeeze(1) # [B,heatmap_size,heatmap_size] # plt.imshow(resized_mask[0].cpu()) # plt.show() print(f'resized_mask:{resized_mask.shape}') return resized_mask def compute_ins_loss(feature_logits, proposals, gt_, pos_matched_idxs): print(f'compute_arc_loss:{feature_logits.shape}') N, K, H, W = feature_logits.shape len_proposals = len(proposals) empty_count = 0 non_empty_count = 0 for prop in proposals: if prop.shape[0] == 0: empty_count += 1 else: non_empty_count += 1 print(f"Empty proposals count: {empty_count}") print(f"Non-empty proposals count: {non_empty_count}") print(f'starte to compute_point_loss') print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}') if H != W: raise ValueError( f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}" ) discretization_size = H gs_heatmaps = [] # print(f'point_matched_idxs:{point_matched_idxs}') print(f'gt_masks:{gt_[0].shape}') for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs): # [ # (Tensor(38, 4), Tensor(1, 57, 2), Tensor(38, 1)), # (Tensor(65, 4), Tensor(1, 74, 2), Tensor(65, 1)) # ] print(f'proposals_per_image:{proposals_per_image.shape}') kp = gt_kp_in_image[midx] t_h, t_w = kp.shape[-2:] print(f't_h:{t_h}, t_w:{t_w}') print(f'gt_kp_in_image:{gt_kp_in_image.shape}') if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0: gs_heatmaps_per_img = align_masks(kp, proposals_per_image, discretization_size) gs_heatmaps.append(gs_heatmaps_per_img) if len(gs_heatmaps)>0: gs_heatmaps = torch.cat(gs_heatmaps, dim=0) print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}') line_logits = feature_logits.squeeze(1) print(f'mask shape:{line_logits.shape}') # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps) # line_loss = F.cross_entropy(line_logits, gs_heatmaps) line_loss=combined_loss(line_logits, gs_heatmaps) else: line_loss=100 print("d") return line_loss