|
|
@@ -13,7 +13,7 @@ import libs.vision_libs.models.detection._utils as det_utils
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
|
|
|
- lines_point_pair_loss, features_align, line_inference, compute_arc_loss, arc_inference, compute_circle_loss, \
|
|
|
+ lines_point_pair_loss, features_align, line_inference, compute_mask_loss, arc_inference, compute_circle_loss, \
|
|
|
circle_inference
|
|
|
|
|
|
|
|
|
@@ -135,6 +135,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
|
|
|
|
|
|
|
|
|
|
|
|
+
|
|
|
def keypoints_to_heatmap(keypoints, rois, heatmap_size):
|
|
|
# type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
|
|
|
offset_x = rois[:, 0]
|
|
|
@@ -1253,7 +1254,7 @@ class RoIHeads(nn.Module):
|
|
|
# if gt_arcs_tensor.shape[0] > 0:
|
|
|
# print(f'start to compute point_loss')
|
|
|
if len(gt_arcs) > 0 and feature_logits is not None:
|
|
|
- loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
|
|
|
+ loss_arc = compute_mask_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
|
|
|
|
|
|
if loss_arc is None:
|
|
|
print(f'loss_arc is None111')
|
|
|
@@ -1281,7 +1282,7 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
if len(gt_arcs) > 0 and feature_logits is not None:
|
|
|
print(f'start to compute arc_loss')
|
|
|
- loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
|
|
|
+ loss_arc = compute_mask_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
|
|
|
|
|
|
|
|
|
if loss_arc is None:
|
|
|
@@ -1377,6 +1378,7 @@ class RoIHeads(nn.Module):
|
|
|
raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
|
|
|
|
|
|
gt_circles = [t["circle_masks"] for t in targets if "circle_masks" in t]
|
|
|
+ gt_labels = [t["labels"] for t in targets]
|
|
|
|
|
|
print(f'gt_circle:{gt_circles[0].shape}')
|
|
|
h, w = targets[0]["img_size"]
|
|
|
@@ -1390,7 +1392,7 @@ class RoIHeads(nn.Module):
|
|
|
if gt_circles_tensor.shape[0] > 0:
|
|
|
print(f'start to compute circle_loss')
|
|
|
|
|
|
- loss_circle = compute_arc_loss(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
|
|
|
+ loss_circle = compute_mask_loss(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
|
|
|
|
|
|
# loss_circle_extra=compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
|
|
|
|
|
|
@@ -1410,7 +1412,7 @@ class RoIHeads(nn.Module):
|
|
|
h, w = targets[0]["img_size"]
|
|
|
img_size = h
|
|
|
gt_circles = [t["circle_masks"] for t in targets if "circle_masks" in t]
|
|
|
-
|
|
|
+ gt_labels = [t["labels"] for t in targets]
|
|
|
gt_circles_tensor = torch.zeros(0, 0)
|
|
|
if len(gt_circles) > 0:
|
|
|
gt_circles_tensor = torch.cat(gt_circles)
|
|
|
@@ -1419,7 +1421,7 @@ class RoIHeads(nn.Module):
|
|
|
if gt_circles_tensor.shape[0] > 0:
|
|
|
print(f'start to compute circle_loss')
|
|
|
|
|
|
- loss_circle = maskrcnn_loss(feature_logits, circle_proposals, gt_circles,
|
|
|
+ loss_circle = compute_mask_loss(feature_logits, circle_proposals, gt_circles,
|
|
|
circle_pos_matched_idxs)
|
|
|
|
|
|
# loss_circle_extra = compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles,circle_pos_matched_idxs)
|