|
|
@@ -136,11 +136,13 @@ def line_points_to_heatmap(keypoints, rois, heatmap_size):
|
|
|
all_roi_heatmap = []
|
|
|
for roi, heatmap in zip(rois, gs):
|
|
|
# print(f'heatmap:{heatmap.shape}')
|
|
|
+ # show_heatmap(heatmap,'target')
|
|
|
heatmap = heatmap.unsqueeze(0)
|
|
|
+
|
|
|
x1, y1, x2, y2 = map(int, roi)
|
|
|
roi_heatmap = torch.zeros_like(heatmap)
|
|
|
roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
|
|
|
- # show_heatmap(roi_heatmap,'roi_heatmap')
|
|
|
+ # show_heatmap(roi_heatmap[0],'roi_heatmap')
|
|
|
all_roi_heatmap.append(roi_heatmap)
|
|
|
|
|
|
all_roi_heatmap = torch.cat(all_roi_heatmap)
|
|
|
@@ -403,7 +405,7 @@ def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
|
|
|
# line_logits = line_logits.view(N * K, H * W)
|
|
|
# print(f'line_logits[valid]:{line_logits[valid].shape}')
|
|
|
print(f'loss1 line_logits:{line_logits.shape}')
|
|
|
- line_logits = line_logits[:,2,:,:]
|
|
|
+ line_logits = line_logits[:,1,:,:]
|
|
|
print(f'loss2 line_logits:{line_logits.shape}')
|
|
|
|
|
|
# line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
|