|
|
@@ -295,7 +295,50 @@ def heatmaps_to_keypoints(maps, rois):
|
|
|
end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
|
|
|
|
|
|
return xy_preds.permute(0, 2, 1), end_scores
|
|
|
+def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
|
|
|
+ # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
|
|
|
+ N, K, H, W = line_logits.shape
|
|
|
+ if H != W:
|
|
|
+ raise ValueError(
|
|
|
+ f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
|
|
|
+ )
|
|
|
+ discretization_size = H
|
|
|
+ heatmaps = []
|
|
|
+ valid = []
|
|
|
+ for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
|
|
|
+ kp = gt_kp_in_image[midx]
|
|
|
+ heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
|
|
|
+ heatmaps.append(heatmaps_per_image.view(-1))
|
|
|
+ valid.append(valid_per_image.view(-1))
|
|
|
+
|
|
|
+ line_targets = torch.cat(heatmaps, dim=0)
|
|
|
+ valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
|
|
|
+ valid = torch.where(valid)[0]
|
|
|
+
|
|
|
+ # torch.mean (in binary_cross_entropy_with_logits) doesn't
|
|
|
+ # accept empty tensors, so handle it sepaartely
|
|
|
+ if line_targets.numel() == 0 or len(valid) == 0:
|
|
|
+ return line_logits.sum() * 0
|
|
|
+
|
|
|
+ line_logits = line_logits.view(N * K, H * W)
|
|
|
+
|
|
|
+ line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
|
|
|
+ return line_loss
|
|
|
+
|
|
|
+def line_inference(x, boxes):
|
|
|
+ # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
|
|
|
+ kp_probs = []
|
|
|
+ kp_scores = []
|
|
|
+
|
|
|
+ boxes_per_image = [box.size(0) for box in boxes]
|
|
|
+ x2 = x.split(boxes_per_image, dim=0)
|
|
|
|
|
|
+ for xx, bb in zip(x2, boxes):
|
|
|
+ kp_prob, scores = heatmaps_to_keypoints(xx, bb)
|
|
|
+ kp_probs.append(kp_prob)
|
|
|
+ kp_scores.append(scores)
|
|
|
+
|
|
|
+ return kp_probs, kp_scores
|
|
|
|
|
|
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
|
|
|
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
|
|
|
@@ -774,9 +817,12 @@ class RoIHeads(nn.Module):
|
|
|
if self.training:
|
|
|
proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
|
|
|
else:
|
|
|
- labels = None
|
|
|
- regression_targets = None
|
|
|
- matched_idxs = None
|
|
|
+ if targets is not None:
|
|
|
+ proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
|
|
|
+ else:
|
|
|
+ labels = None
|
|
|
+ regression_targets = None
|
|
|
+ matched_idxs = None
|
|
|
|
|
|
box_features = self.box_roi_pool(features, proposals, image_shapes)
|
|
|
box_features = self.box_head(box_features)
|
|
|
@@ -784,6 +830,7 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
result: List[Dict[str, torch.Tensor]] = []
|
|
|
losses = {}
|
|
|
+
|
|
|
if self.training:
|
|
|
if labels is None:
|
|
|
raise ValueError("labels cannot be None")
|
|
|
@@ -793,8 +840,12 @@ class RoIHeads(nn.Module):
|
|
|
loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
|
|
|
losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
|
|
|
else:
|
|
|
- print(f'boxes postprocess')
|
|
|
- boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
|
|
|
+ if targets is not None:
|
|
|
+ loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
|
|
|
+ losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
|
|
|
+
|
|
|
+ boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals,
|
|
|
+ image_shapes)
|
|
|
num_images = len(boxes)
|
|
|
for i in range(num_images):
|
|
|
result.append(
|
|
|
@@ -809,6 +860,8 @@ class RoIHeads(nn.Module):
|
|
|
if self.has_line():
|
|
|
print(f'roi_heads forward has_line()!!!!')
|
|
|
line_proposals = [p["boxes"] for p in result]
|
|
|
+ print(f'line_proposals:{len(line_proposals)}')
|
|
|
+
|
|
|
if self.training:
|
|
|
# during training, only focus on positive boxes
|
|
|
num_images = len(proposals)
|
|
|
@@ -822,33 +875,54 @@ class RoIHeads(nn.Module):
|
|
|
line_proposals.append(proposals[img_id][pos])
|
|
|
pos_matched_idxs.append(matched_idxs[img_id][pos])
|
|
|
else:
|
|
|
- pos_matched_idxs = None
|
|
|
-
|
|
|
- line_features = self.keypoint_roi_pool(features, line_proposals, image_shapes)
|
|
|
- line_features = self.keypoint_head(line_features)
|
|
|
- line_logits = self.keypoint_predictor(line_features)
|
|
|
-
|
|
|
- loss_keypoint = {}
|
|
|
+ if targets is not None:
|
|
|
+ pos_matched_idxs = []
|
|
|
+ num_images = len(proposals)
|
|
|
+ if matched_idxs is None:
|
|
|
+ raise ValueError("if in trainning, matched_idxs should not be None")
|
|
|
+
|
|
|
+ for img_id in range(num_images):
|
|
|
+ pos = torch.where(labels[img_id] > 0)[0]
|
|
|
+ line_proposals.append(proposals[img_id][pos])
|
|
|
+ pos_matched_idxs.append(matched_idxs[img_id][pos])
|
|
|
+ else:
|
|
|
+ pos_matched_idxs = None
|
|
|
+
|
|
|
+ line_features = self.line_roi_pool(features, line_proposals, image_shapes)
|
|
|
+ line_features = self.line_head(line_features)
|
|
|
+ line_logits = self.line_predictor(line_features)
|
|
|
+
|
|
|
+ loss_line = {}
|
|
|
if self.training:
|
|
|
if targets is None or pos_matched_idxs is None:
|
|
|
raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
|
|
|
|
|
|
- gt_keypoints = [t["keypoints"] for t in targets]
|
|
|
- rcnn_loss_keypoint = keypointrcnn_loss(
|
|
|
- line_logits, line_proposals, gt_keypoints, pos_matched_idxs
|
|
|
+ gt_lines = [t["lines"] for t in targets]
|
|
|
+ rcnn_loss_line = lines_point_pair_loss(
|
|
|
+ line_logits, line_proposals, gt_lines, pos_matched_idxs
|
|
|
)
|
|
|
- loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
|
|
|
+ loss_line = {"loss_line": rcnn_loss_line}
|
|
|
else:
|
|
|
- if line_logits is None or line_proposals is None:
|
|
|
- raise ValueError(
|
|
|
- "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
|
|
+ if targets is not None:
|
|
|
+ gt_lines = [t["lines"] for t in targets]
|
|
|
+ rcnn_loss_lines = lines_point_pair_loss(
|
|
|
+ line_logits, line_proposals, gt_lines, pos_matched_idxs
|
|
|
)
|
|
|
+ loss_line = {"loss_line": rcnn_loss_lines}
|
|
|
+ else:
|
|
|
+ if line_logits is None or line_proposals is None:
|
|
|
+ raise ValueError(
|
|
|
+ "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
|
|
+ )
|
|
|
+
|
|
|
+ lines_probs, kp_scores = line_inference(line_logits, line_proposals)
|
|
|
+ for keypoint_prob, kps, r in zip(lines_probs, kp_scores, result):
|
|
|
+ r["lines"] = keypoint_prob
|
|
|
+ r["liness_scores"] = kps
|
|
|
+ losses.update(loss_line)
|
|
|
+
|
|
|
+
|
|
|
|
|
|
- keypoints_probs, kp_scores = keypointrcnn_inference(line_logits, line_proposals)
|
|
|
- for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
|
|
|
- r["keypoints"] = keypoint_prob
|
|
|
- r["keypoints_scores"] = kps
|
|
|
- losses.update(loss_keypoint)
|
|
|
if self.has_mask():
|
|
|
mask_proposals = [p["boxes"] for p in result]
|
|
|
if self.training:
|
|
|
@@ -909,9 +983,9 @@ class RoIHeads(nn.Module):
|
|
|
else:
|
|
|
pos_matched_idxs = None
|
|
|
|
|
|
- keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
|
|
|
- keypoint_features = self.keypoint_head(keypoint_features)
|
|
|
- keypoint_logits = self.keypoint_predictor(keypoint_features)
|
|
|
+ keypoint_features = self.line_roi_pool(features, keypoint_proposals, image_shapes)
|
|
|
+ keypoint_features = self.line_head(keypoint_features)
|
|
|
+ keypoint_logits = self.line_predictor(keypoint_features)
|
|
|
|
|
|
loss_keypoint = {}
|
|
|
if self.training:
|