Prechádzať zdrojové kódy

修复line val阶段报错bug

RenLiqiang 5 mesiacov pred
rodič
commit
08f665db7f
1 zmenil súbory, kde vykonal 6 pridanie a 2 odobranie
  1. 6 2
      models/line_detect/roi_heads.py

+ 6 - 2
models/line_detect/roi_heads.py

@@ -432,8 +432,8 @@ def heatmaps_to_keypoints(maps, rois):
 def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
     N, K, H, W = line_logits.shape
-    batch_size=len(proposals)
-    print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape}')
+    len_proposals=len(proposals)
+    print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape},len_proposals:{len_proposals}')
     if H != W:
         raise ValueError(
             f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
@@ -1023,6 +1023,7 @@ class RoIHeads(nn.Module):
             if self.training:
                 # during training, only focus on positive boxes
                 num_images = len(proposals)
+                print(f'num_images:{num_images}')
                 line_proposals = []
                 pos_matched_idxs = []
                 if matched_idxs is None:
@@ -1034,8 +1035,11 @@ class RoIHeads(nn.Module):
                     pos_matched_idxs.append(matched_idxs[img_id][pos])
             else:
                 if targets is not None:
+
                     pos_matched_idxs = []
                     num_images = len(proposals)
+                    line_proposals = []
+                    print(f'val num_images:{num_images}')
                     if matched_idxs is None:
                         raise ValueError("if in trainning, matched_idxs should not be None")