|
@@ -14,7 +14,7 @@ from collections import OrderedDict
|
|
|
|
|
|
|
|
from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
|
|
from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
|
|
|
lines_point_pair_loss, features_align, line_inference, compute_arc_loss, arc_inference, compute_circle_loss, \
|
|
lines_point_pair_loss, features_align, line_inference, compute_arc_loss, arc_inference, compute_circle_loss, \
|
|
|
- circle_inference
|
|
|
|
|
|
|
+ circle_inference, compute_circle_extra_losses
|
|
|
|
|
|
|
|
|
|
|
|
|
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
|
|
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
|
|
@@ -1348,6 +1348,7 @@ class RoIHeads(nn.Module):
|
|
|
feature_logits = self.circle_forward1(features, image_shapes, circle_proposals)
|
|
feature_logits = self.circle_forward1(features, image_shapes, circle_proposals)
|
|
|
|
|
|
|
|
loss_circle = None
|
|
loss_circle = None
|
|
|
|
|
+ loss_circle_extra=None
|
|
|
|
|
|
|
|
if self.training:
|
|
if self.training:
|
|
|
|
|
|
|
@@ -1369,12 +1370,18 @@ class RoIHeads(nn.Module):
|
|
|
print(f'start to compute circle_loss')
|
|
print(f'start to compute circle_loss')
|
|
|
|
|
|
|
|
loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
|
|
loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
|
|
|
|
|
+ loss_circle_extra=compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
|
|
|
|
|
|
|
|
if loss_circle is None:
|
|
if loss_circle is None:
|
|
|
print(f'loss_circle is None111')
|
|
print(f'loss_circle is None111')
|
|
|
loss_circle = torch.tensor(0.0, device=device)
|
|
loss_circle = torch.tensor(0.0, device=device)
|
|
|
|
|
|
|
|
|
|
+ if loss_circle_extra is None:
|
|
|
|
|
+ print(f'loss_circle_extra is None111')
|
|
|
|
|
+ loss_circle_extra = torch.tensor(0.0, device=device)
|
|
|
|
|
+
|
|
|
loss_circle = {"loss_circle": loss_circle}
|
|
loss_circle = {"loss_circle": loss_circle}
|
|
|
|
|
+ loss_circle_extra = {"loss_circle_extra": loss_circle_extra}
|
|
|
|
|
|
|
|
else:
|
|
else:
|
|
|
if targets is not None:
|
|
if targets is not None:
|
|
@@ -1393,16 +1400,25 @@ class RoIHeads(nn.Module):
|
|
|
loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles,
|
|
loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles,
|
|
|
circle_pos_matched_idxs)
|
|
circle_pos_matched_idxs)
|
|
|
|
|
|
|
|
|
|
+ loss_circle_extra = compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles,
|
|
|
|
|
+ circle_pos_matched_idxs)
|
|
|
|
|
+
|
|
|
if loss_circle is None:
|
|
if loss_circle is None:
|
|
|
print(f'loss_circle is None111')
|
|
print(f'loss_circle is None111')
|
|
|
loss_circle = torch.tensor(0.0, device=device)
|
|
loss_circle = torch.tensor(0.0, device=device)
|
|
|
|
|
|
|
|
|
|
+ if loss_circle_extra is None:
|
|
|
|
|
+ print(f'loss_circle_extra is None111')
|
|
|
|
|
+ loss_circle_extra = torch.tensor(0.0, device=device)
|
|
|
|
|
+
|
|
|
loss_circle = {"loss_circle": loss_circle}
|
|
loss_circle = {"loss_circle": loss_circle}
|
|
|
|
|
+ loss_circle_extra = {"loss_circle_extra": loss_circle_extra}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
else:
|
|
|
loss_circle = {}
|
|
loss_circle = {}
|
|
|
|
|
+ loss_circle_extra = {}
|
|
|
if feature_logits is None or circle_proposals is None:
|
|
if feature_logits is None or circle_proposals is None:
|
|
|
raise ValueError(
|
|
raise ValueError(
|
|
|
"both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
|
"both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
|
@@ -1416,7 +1432,9 @@ class RoIHeads(nn.Module):
|
|
|
r["circles_scores"] = kps
|
|
r["circles_scores"] = kps
|
|
|
|
|
|
|
|
print(f'loss_circle:{loss_circle}')
|
|
print(f'loss_circle:{loss_circle}')
|
|
|
|
|
+ print(f'loss_circle_extra:{loss_circle_extra}')
|
|
|
losses.update(loss_circle)
|
|
losses.update(loss_circle)
|
|
|
|
|
+ losses.update(loss_circle_extra)
|
|
|
print(f'losses:{losses}')
|
|
print(f'losses:{losses}')
|
|
|
|
|
|
|
|
|
|
|