|
|
@@ -2,6 +2,30 @@ import torch
|
|
|
from matplotlib import pyplot as plt
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
+from torch import nn
|
|
|
+
|
|
|
+
|
|
|
+class DiceLoss(nn.Module):
|
|
|
+ def __init__(self, smooth=1.):
|
|
|
+ super(DiceLoss, self).__init__()
|
|
|
+ self.smooth = smooth
|
|
|
+
|
|
|
+ def forward(self, logits, targets):
|
|
|
+ probs = torch.sigmoid(logits)
|
|
|
+ probs = probs.view(-1)
|
|
|
+ targets = targets.view(-1).float()
|
|
|
+
|
|
|
+ intersection = (probs * targets).sum()
|
|
|
+ dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
|
|
|
+ return 1. - dice
|
|
|
+bce_loss = nn.BCEWithLogitsLoss()
|
|
|
+dice_loss = DiceLoss()
|
|
|
+
|
|
|
+
|
|
|
+def combined_loss(preds, targets, alpha=0.5):
|
|
|
+ bce = bce_loss(preds, targets)
|
|
|
+ d = dice_loss(preds, targets)
|
|
|
+ return alpha * bce + (1 - alpha) * d
|
|
|
|
|
|
def features_align(features, proposals, img_size):
|
|
|
print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}')
|
|
|
@@ -528,13 +552,15 @@ def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
|
|
|
|
|
|
if len(gs_heatmaps)>0:
|
|
|
gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
|
|
|
- print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.squeeze(1).shape}')
|
|
|
+ print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}')
|
|
|
|
|
|
- line_logits = feature_logits[:, 0]
|
|
|
+ line_logits = feature_logits.squeeze(1)
|
|
|
print(f'single_point_logits:{line_logits.shape}')
|
|
|
|
|
|
- line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
|
|
|
+ # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
|
|
|
+
|
|
|
# line_loss = F.cross_entropy(line_logits, gs_heatmaps)
|
|
|
+ line_loss=combined_loss(line_logits, gs_heatmaps)
|
|
|
|
|
|
else:
|
|
|
line_loss=100
|