|
@@ -518,16 +518,21 @@ def compute_arc_loss(feature_logits, proposals, gt_, pos_matched_idxs):
|
|
|
print(f'proposals_per_image:{proposals_per_image.shape}')
|
|
print(f'proposals_per_image:{proposals_per_image.shape}')
|
|
|
kp = gt_kp_in_image[midx]
|
|
kp = gt_kp_in_image[midx]
|
|
|
# print(f'gt_kp_in_image:{gt_kp_in_image}')
|
|
# print(f'gt_kp_in_image:{gt_kp_in_image}')
|
|
|
- gs_heatmaps_per_img = arc_points_to_heatmap(kp, proposals_per_image, discretization_size)
|
|
|
|
|
- gs_heatmaps.append(gs_heatmaps_per_img)
|
|
|
|
|
|
|
+ if proposals_per_image.shape[0] > 0 and gt_kp_in_image.shape[0] > 0:
|
|
|
|
|
+ gs_heatmaps_per_img = arc_points_to_heatmap(kp, proposals_per_image, discretization_size)
|
|
|
|
|
+ gs_heatmaps.append(gs_heatmaps_per_img)
|
|
|
|
|
|
|
|
- gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
|
|
|
|
|
- print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.squeeze(1).shape}')
|
|
|
|
|
|
|
+ if len(gs_heatmaps)>0:
|
|
|
|
|
+ gs_heatmaps = torch.cat(gs_heatmaps, dim=0)
|
|
|
|
|
+ print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{feature_logits.squeeze(1).shape}')
|
|
|
|
|
|
|
|
- line_logits = feature_logits[:, 0]
|
|
|
|
|
- print(f'single_point_logits:{line_logits.shape}')
|
|
|
|
|
|
|
+ line_logits = feature_logits[:, 0]
|
|
|
|
|
+ print(f'single_point_logits:{line_logits.shape}')
|
|
|
|
|
|
|
|
- line_loss = F.cross_entropy(line_logits, gs_heatmaps)
|
|
|
|
|
|
|
+ line_loss = F.cross_entropy(line_logits, gs_heatmaps)
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ line_loss=100
|
|
|
|
|
|
|
|
return line_loss
|
|
return line_loss
|
|
|
|
|
|
|
@@ -760,6 +765,7 @@ def arc_inference(x, arc_boxes):
|
|
|
print(f'arc_boxes:{len(arc_boxes)}')
|
|
print(f'arc_boxes:{len(arc_boxes)}')
|
|
|
|
|
|
|
|
boxes_per_image = [box.size(0) for box in arc_boxes]
|
|
boxes_per_image = [box.size(0) for box in arc_boxes]
|
|
|
|
|
+ print(f'arc boxes_per_image:{boxes_per_image}')
|
|
|
x2 = x.split(boxes_per_image, dim=0)
|
|
x2 = x.split(boxes_per_image, dim=0)
|
|
|
|
|
|
|
|
for xx, bb in zip(x2, arc_boxes):
|
|
for xx, bb in zip(x2, arc_boxes):
|