| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- import torch
- import torch.nn.functional as F
- from torch import nn
- class DiceLoss(nn.Module):
- def __init__(self, smooth=1.):
- super(DiceLoss, self).__init__()
- self.smooth = smooth
- def forward(self, logits, targets):
- probs = torch.sigmoid(logits)
- probs = probs.view(-1)
- targets = targets.view(-1).float()
- intersection = (probs * targets).sum()
- dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
- return 1. - dice
- bce_loss = nn.BCEWithLogitsLoss()
- dice_loss = DiceLoss()
- def combined_loss(preds, targets, alpha=0.5):
- bce = bce_loss(preds, targets)
- d = dice_loss(preds, targets)
- return alpha * bce + (1 - alpha) * d
- def align_masks(keypoints, rois, heatmap_size):
- print(f'rois:{rois.shape}')
- print(f'heatmap_size:{heatmap_size}')
- print(f'keypoints.shape:{keypoints.shape}')
- # batch_size, num_keypoints, _ = keypoints.shape
- t_h, t_w = keypoints.shape[-2:]
- scale=heatmap_size/t_w
- print(f'scale:{scale}')
- x = keypoints[..., 0]*scale
- y = keypoints[..., 1]*scale
- x = x.unsqueeze(1)
- y = y.unsqueeze(1)
- num_points=x.shape[2]
- print(f'num_points:{num_points}')
- mask_4d = keypoints.unsqueeze(1).float()
- resized_mask = F.interpolate(
- mask_4d,
- size = (heatmap_size, heatmap_size),
- mode = 'bilinear',
- align_corners = False
- ).squeeze(1) # [B,heatmap_size,heatmap_size]
- # plt.imshow(resized_mask[0].cpu())
- # plt.show()
- print(f'resized_mask:{resized_mask.shape}')
- return resized_mask
- def compute_ins_loss(feature_logits, proposals, gt_, pos_matched_idxs):
- print(f'compute_arc_loss:{feature_logits.shape}')
- N, K, H, W = feature_logits.shape
- len_proposals = len(proposals)
- empty_count = 0
- non_empty_count = 0
- for prop in proposals:
- if prop.shape[0] == 0:
- empty_count += 1
- else:
- non_empty_count += 1
- print(f"Empty proposals count: {empty_count}")
- print(f"Non-empty proposals count: {non_empty_count}")
- print(f'starte to compute_point_loss')
- print(f'compute_point_loss line_logits.shape:{feature_logits.shape},len_proposals:{len_proposals}')
- if H != W:
- raise ValueError(
- f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
- )
- discretization_size = H
- gs_heatmaps = []
- # print(f'point_matched_idxs:{point_matched_idxs}')
- print(f'gt_masks:{gt_[0].shape}')
- for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs):
- # [
- # (Tensor(38, 4), Tensor(1, 57, 2), Tensor(38, 1)),
- # (Tensor(65, 4), Tensor(1, 74, 2), Tensor(65, 1))
- # ]
- print(f'proposals_per_image:{proposals_per_image.shape}')
- kp = gt_kp_in_image[midx]
- t_h, t_w = kp.shape[-2:]
- print(f't_h:{t_h}, t_w:{t_w}')
- print(f'gt_kp_in_image:{gt_kp_in_image.shape}')
- if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
- gs_heatmaps_per_img = align_masks(kp, proposals_per_image, discretization_size)
- gs_heatmaps.append(gs_heatmaps_per_img)
- if len(gs_heatmaps)>0:
- gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
- print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}')
- line_logits = feature_logits.squeeze(1)
- print(f'mask shape:{line_logits.shape}')
- # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
- # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
- line_loss=combined_loss(line_logits, gs_heatmaps)
- else:
- line_loss=100
- print("d")
- return line_loss
|