|
|
@@ -411,8 +411,8 @@ def heatmaps_to_lines(maps, rois):
|
|
|
line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
|
|
|
line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
|
|
|
|
|
|
- line_maps=maps[:,1]
|
|
|
-
|
|
|
+ # line_maps=maps[:,1]
|
|
|
+ line_maps = maps.squeeze(1)
|
|
|
|
|
|
for i in range(len(rois)):
|
|
|
line_roi_map = line_maps[i].unsqueeze(0)
|
|
|
@@ -503,7 +503,8 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
|
|
|
# line_logits = line_logits.view(N * K, H * W)
|
|
|
# print(f'line_logits[valid]:{line_logits[valid].shape}')
|
|
|
print(f'loss1 line_logits:{line_logits.shape}')
|
|
|
- line_logits = line_logits[:,1,:,:]
|
|
|
+ # line_logits = line_logits[:,1,:,:]
|
|
|
+ line_logits = line_logits.squeeze(1)
|
|
|
print(f'loss2 line_logits:{line_logits.shape}')
|
|
|
|
|
|
# line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
|