|
|
@@ -983,7 +983,6 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
-
|
|
|
if gt_lines_tensor.shape[0]>0 :
|
|
|
print(f'start to lines_point_pair_loss')
|
|
|
loss_line = lines_point_pair_loss(
|
|
|
@@ -997,10 +996,12 @@ class RoIHeads(nn.Module):
|
|
|
point_features, point_proposals, gt_points, point_pos_matched_idxs
|
|
|
)
|
|
|
|
|
|
- if not loss_line:
|
|
|
+ if loss_line is None:
|
|
|
+ print(f'loss_line is None111')
|
|
|
loss_line = torch.tensor(0.0, device=cs_features.device)
|
|
|
|
|
|
- if not loss_line_iou:
|
|
|
+ if loss_line_iou is None:
|
|
|
+ print(f'loss_line_iou is None111')
|
|
|
loss_line_iou = torch.tensor(0.0, device=cs_features.device)
|
|
|
|
|
|
loss_line = {"loss_line": loss_line}
|
|
|
@@ -1015,20 +1016,25 @@ class RoIHeads(nn.Module):
|
|
|
if targets is not None:
|
|
|
h, w = targets[0]["img_size"]
|
|
|
img_size = h
|
|
|
- gt_lines = [t["lines"] for t in targets]
|
|
|
- gt_points = [t["points"] for t in targets]
|
|
|
+ gt_lines = [t["lines"] for t in targets if "lines" in t]
|
|
|
+ gt_points = [t["points"] for t in targets if "points" in t]
|
|
|
|
|
|
- gt_lines_tensor = torch.cat(gt_lines)
|
|
|
- gt_points_tensor = torch.cat(gt_points)
|
|
|
+ gt_lines_tensor = torch.zeros(0, 0)
|
|
|
+ gt_points_tensor = torch.zeros(0, 0)
|
|
|
+ if len(gt_lines)>0:
|
|
|
+ gt_lines_tensor = torch.cat(gt_lines)
|
|
|
+ if len(gt_points)>0:
|
|
|
+ gt_points_tensor = torch.cat(gt_points)
|
|
|
|
|
|
- line_pos_matched_idxs = []
|
|
|
- point_pos_matched_idxs = []
|
|
|
+ # line_pos_matched_idxs = []
|
|
|
+ # point_pos_matched_idxs = []
|
|
|
|
|
|
|
|
|
if gt_lines_tensor.shape[0] > 0 and line_features is not None:
|
|
|
loss_line = lines_point_pair_loss(
|
|
|
line_features, line_proposals, gt_lines, line_pos_matched_idxs
|
|
|
)
|
|
|
+ print(f'compute_line_loss:{loss_line}')
|
|
|
loss_line_iou = line_iou_loss(line_features , line_proposals, gt_lines, line_pos_matched_idxs,
|
|
|
img_size)
|
|
|
|
|
|
@@ -1037,14 +1043,17 @@ class RoIHeads(nn.Module):
|
|
|
point_features, point_proposals, gt_points, point_pos_matched_idxs
|
|
|
)
|
|
|
|
|
|
- if not loss_line :
|
|
|
+ if loss_line is None:
|
|
|
+ print(f'loss_line is None')
|
|
|
loss_line=torch.tensor(0.0,device=cs_features.device)
|
|
|
|
|
|
- if not loss_line_iou :
|
|
|
+ if loss_line_iou is None:
|
|
|
+ print(f'loss_line_iou is None')
|
|
|
loss_line_iou=torch.tensor(0.0,device=cs_features.device)
|
|
|
|
|
|
- if not loss_point:
|
|
|
- loss_point=torch.tensor(0.0,device=cs_features.device)
|
|
|
+ # if loss_point is None:
|
|
|
+ # print(f'loss_point is None')
|
|
|
+ # loss_point=torch.tensor(0.0,device=cs_features.device)
|
|
|
|
|
|
loss_line = {"loss_line": loss_line}
|
|
|
loss_line_iou = {'loss_line_iou': loss_line_iou}
|
|
|
@@ -1057,6 +1066,9 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
|
|
|
else:
|
|
|
+ loss_point = {}
|
|
|
+ loss_line = {}
|
|
|
+ loss_line_iou = {}
|
|
|
if feature_logits is None or line_proposals is None:
|
|
|
raise ValueError(
|
|
|
"both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
|
|
@@ -1077,7 +1089,7 @@ class RoIHeads(nn.Module):
|
|
|
r["points_scores"] = ps
|
|
|
|
|
|
|
|
|
-
|
|
|
+ print(f'loss_line11111:{loss_line}')
|
|
|
losses.update(loss_line)
|
|
|
losses.update(loss_line_iou)
|
|
|
losses.update(loss_point)
|