Bladeren bron

add dice loss

lstrlq 4 maanden geleden
bovenliggende
commit
3fd7a2412a
3 gewijzigde bestanden met toevoegingen van 38 en 9 verwijderingen
  1. 29 3
      models/line_detect/heads/head_losses.py
  2. 1 1
      models/line_detect/line_detect.py
  3. 8 5
      models/line_detect/loi_heads.py

+ 29 - 3
models/line_detect/heads/head_losses.py

@@ -2,6 +2,30 @@ import torch
 from matplotlib import pyplot as plt
 
 import torch.nn.functional as F
+from torch import nn
+
+
+class DiceLoss(nn.Module):
+    def __init__(self, smooth=1.):
+        super(DiceLoss, self).__init__()
+        self.smooth = smooth
+
+    def forward(self, logits, targets):
+        probs = torch.sigmoid(logits)
+        probs = probs.view(-1)
+        targets = targets.view(-1).float()
+
+        intersection = (probs * targets).sum()
+        dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)
+        return 1. - dice
+bce_loss = nn.BCEWithLogitsLoss()
+dice_loss = DiceLoss()
+
+
+def combined_loss(preds, targets, alpha=0.5):
+    bce = bce_loss(preds, targets)
+    d = dice_loss(preds, targets)
+    return alpha * bce + (1 - alpha) * d
 
 def features_align(features, proposals, img_size):
     print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}')
@@ -528,13 +552,15 @@ def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
 
     if len(gs_heatmaps)>0:
         gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
-        print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.squeeze(1).shape}')
+        print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.shape}')
 
-        line_logits = feature_logits[:, 0]
+        line_logits = feature_logits.squeeze(1)
         print(f'single_point_logits:{line_logits.shape}')
 
-        line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
+        # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
+
         # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
+        line_loss=combined_loss(line_logits, gs_heatmaps)
 
     else:
         line_loss=100

+ 1 - 1
models/line_detect/line_detect.py

@@ -194,7 +194,7 @@ class LineDetect(BaseDetectionNet):
             arc_head=ArcHeads(8,layers)
         if detect_arc and arc_predictor is None:
             layers = tuple(num_points for _ in range(8))
-            arc_predictor=ArcPredictor(in_channels=256)
+            arc_predictor=ArcPredictor(in_channels=256,out_channels=1)
 
 
 

+ 8 - 5
models/line_detect/loi_heads.py

@@ -1263,9 +1263,12 @@ class RoIHeads(nn.Module):
                 else:
                     loss_arc = {}
                     if feature_logits is None or arc_proposals is None:
-                        raise ValueError(
-                            "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
-                        )
+                        # raise ValueError(
+                        #     "both arc_feature_logits and arc_proposals should not be None when not in training mode"
+                        # )
+
+                        print(f'error :both arc_feature_logits and arc_proposals should not be None when not in training mode"')
+                        return None
 
                     if feature_logits is not None:
 
@@ -1460,7 +1463,7 @@ class RoIHeads(nn.Module):
 
 
     def arc_forward1(self, features, image_shapes, proposals):
-        print(f'point_proposals:{len(proposals)}')
+        print(f'arc_proposals:{len(proposals)}')
         # cs_features= features['0']
         # print(f'features-0:{features['0'].shape}')
         # cs_features = self.channel_compress(features['0'])
@@ -1475,7 +1478,7 @@ class RoIHeads(nn.Module):
         # print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
 
         feature_logits = self.arc_predictor(cs_features)
-        print(f'feature_logits from line_head:{feature_logits.shape}')
+        print(f'feature_logits from arc_head:{feature_logits.shape}')
 
         roi_features = features_align(feature_logits, proposals, image_shapes)
         if roi_features is not None: