瀏覽代碼

fix arc_loss

zyhhsss 4 月之前
父節點
當前提交
0fd48def10
共有 2 個文件被更改,包括 9 次插入2 次删除
  1. 1 1
      models/line_detect/heads/arc_heads.py
  2. 8 1
      models/line_detect/heads/head_losses.py

+ 1 - 1
models/line_detect/heads/arc_heads.py

@@ -17,7 +17,7 @@ class ArcHeads(nn.Sequential):
 
 
 class ArcPredictor(nn.Module):
-    def __init__(self, in_channels, out_channels=3 ):
+    def __init__(self, in_channels, out_channels=1 ):
         super().__init__()
         input_features = in_channels
         deconv_kernel = 4

+ 8 - 1
models/line_detect/heads/head_losses.py

@@ -515,6 +515,10 @@ def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
     gs_heatmaps = []
     # print(f'point_matched_idxs:{point_matched_idxs}')
     for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_, pos_matched_idxs):
+        # [
+        #   (Tensor(38, 4), Tensor(1, 57, 2), Tensor(38, 1)),
+        #   (Tensor(65, 4), Tensor(1, 74, 2), Tensor(65, 1))
+        # ]
         print(f'proposals_per_image:{proposals_per_image.shape}')
         kp = gt_kp_in_image[midx]
         # print(f'gt_kp_in_image:{gt_kp_in_image}')
@@ -529,11 +533,14 @@ def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
         line_logits = feature_logits[:, 0]
         print(f'single_point_logits:{line_logits.shape}')
 
-        line_loss = F.cross_entropy(line_logits, gs_heatmaps)
+        line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
+        # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
 
     else:
         line_loss=100
 
+    print("d")
+
     return line_loss
 
 def arc_points_to_heatmap(keypoints, rois, heatmap_size):