Просмотр исходного кода

add circle direct and length losses

admin 4 месяцев назад
Родитель
Сommit
721f34601a
1 измененных файлов с 19 добавлено и 1 удалено
  1. 19 1
      models/line_detect/loi_heads.py

+ 19 - 1
models/line_detect/loi_heads.py

@@ -14,7 +14,7 @@ from collections import OrderedDict
 
 from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
     lines_point_pair_loss, features_align, line_inference, compute_arc_loss, arc_inference, compute_circle_loss, \
-    circle_inference
+    circle_inference, compute_circle_extra_losses
 
 
 def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
@@ -1348,6 +1348,7 @@ class RoIHeads(nn.Module):
             feature_logits = self.circle_forward1(features, image_shapes, circle_proposals)
 
             loss_circle = None
+            loss_circle_extra=None
 
             if self.training:
 
@@ -1369,12 +1370,18 @@ class RoIHeads(nn.Module):
                     print(f'start to compute circle_loss')
 
                     loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
+                    loss_circle_extra=compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles, circle_pos_matched_idxs)
 
                 if loss_circle is None:
                     print(f'loss_circle is None111')
                     loss_circle = torch.tensor(0.0, device=device)
 
+                if loss_circle_extra is None:
+                    print(f'loss_circle_extra is None111')
+                    loss_circle_extra = torch.tensor(0.0, device=device)
+
                 loss_circle = {"loss_circle": loss_circle}
+                loss_circle_extra = {"loss_circle_extra": loss_circle_extra}
 
             else:
                 if targets is not None:
@@ -1393,16 +1400,25 @@ class RoIHeads(nn.Module):
                         loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles,
                                                         circle_pos_matched_idxs)
 
+                        loss_circle_extra = compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles,
+                                                                         circle_pos_matched_idxs)
+
                     if loss_circle is None:
                         print(f'loss_circle is None111')
                         loss_circle = torch.tensor(0.0, device=device)
 
+                    if loss_circle_extra is None:
+                        print(f'loss_circle_extra is None111')
+                        loss_circle_extra = torch.tensor(0.0, device=device)
+
                     loss_circle = {"loss_circle": loss_circle}
+                    loss_circle_extra = {"loss_circle_extra": loss_circle_extra}
 
 
 
                 else:
                     loss_circle = {}
+                    loss_circle_extra = {}
                     if feature_logits is None or circle_proposals is None:
                         raise ValueError(
                             "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
@@ -1416,7 +1432,9 @@ class RoIHeads(nn.Module):
                             r["circles_scores"] = kps
 
             print(f'loss_circle:{loss_circle}')
+            print(f'loss_circle_extra:{loss_circle_extra}')
             losses.update(loss_circle)
+            losses.update(loss_circle_extra)
             print(f'losses:{losses}')