|
@@ -859,13 +859,13 @@ class RoIHeads(nn.Module):
|
|
|
pos = torch.where(labels[img_id] > 0)[0]
|
|
pos = torch.where(labels[img_id] > 0)[0]
|
|
|
|
|
|
|
|
line_pos=torch.where(labels[img_id] ==2)[0]
|
|
line_pos=torch.where(labels[img_id] ==2)[0]
|
|
|
- point_pos=torch.where(labels[img_id] ==1)[0]
|
|
|
|
|
|
|
+ # point_pos=torch.where(labels[img_id] ==1)[0]
|
|
|
|
|
|
|
|
line_proposals.append(proposals[img_id][line_pos])
|
|
line_proposals.append(proposals[img_id][line_pos])
|
|
|
- point_proposals.append(proposals[img_id][point_pos])
|
|
|
|
|
|
|
+ # point_proposals.append(proposals[img_id][point_pos])
|
|
|
|
|
|
|
|
line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
|
|
line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
|
|
|
- point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
|
|
|
|
|
|
|
+ # point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
|
|
|
|
|
|
|
|
# pos_matched_idxs.append(matched_idxs[img_id][pos])
|
|
# pos_matched_idxs.append(matched_idxs[img_id][pos])
|
|
|
else:
|
|
else:
|
|
@@ -874,47 +874,33 @@ class RoIHeads(nn.Module):
|
|
|
pos_matched_idxs = []
|
|
pos_matched_idxs = []
|
|
|
num_images = len(proposals)
|
|
num_images = len(proposals)
|
|
|
line_proposals = []
|
|
line_proposals = []
|
|
|
- point_proposals=[]
|
|
|
|
|
- arc_proposals=[]
|
|
|
|
|
|
|
+
|
|
|
|
|
|
|
|
line_pos_matched_idxs = []
|
|
line_pos_matched_idxs = []
|
|
|
- point_pos_matched_idxs = []
|
|
|
|
|
print(f'val num_images:{num_images}')
|
|
print(f'val num_images:{num_images}')
|
|
|
if matched_idxs is None:
|
|
if matched_idxs is None:
|
|
|
raise ValueError("if in trainning, matched_idxs should not be None")
|
|
raise ValueError("if in trainning, matched_idxs should not be None")
|
|
|
|
|
|
|
|
for img_id in range(num_images):
|
|
for img_id in range(num_images):
|
|
|
- pos = torch.where(labels[img_id] > 0)[0]
|
|
|
|
|
- # line_proposals.append(proposals[img_id][pos])
|
|
|
|
|
- # pos_matched_idxs.append(matched_idxs[img_id][pos])
|
|
|
|
|
|
|
+ # pos = torch.where(labels[img_id] > 0)[0]
|
|
|
|
|
|
|
|
line_pos = torch.where(labels[img_id] == 2)[0]
|
|
line_pos = torch.where(labels[img_id] == 2)[0]
|
|
|
- point_pos = torch.where(labels[img_id] == 1)[0]
|
|
|
|
|
|
|
|
|
|
line_proposals.append(proposals[img_id][line_pos])
|
|
line_proposals.append(proposals[img_id][line_pos])
|
|
|
- point_proposals.append(proposals[img_id][point_pos])
|
|
|
|
|
|
|
|
|
|
line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
|
|
line_pos_matched_idxs.append(matched_idxs[img_id][line_pos])
|
|
|
- point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
|
|
|
|
|
|
|
|
|
|
else:
|
|
else:
|
|
|
pos_matched_idxs = None
|
|
pos_matched_idxs = None
|
|
|
|
|
|
|
|
print(f'line_proposals:{len(line_proposals)}')
|
|
print(f'line_proposals:{len(line_proposals)}')
|
|
|
|
|
|
|
|
- cs_features= features['0']
|
|
|
|
|
-
|
|
|
|
|
|
|
+ # cs_features= features['0']
|
|
|
|
|
+ cs_features = self.channel_compress(features['0'])
|
|
|
|
|
|
|
|
- all_proposals=line_proposals+point_proposals
|
|
|
|
|
- # print(f'point_proposals:{point_proposals}')
|
|
|
|
|
- # print(f'all_proposals:{all_proposals}')
|
|
|
|
|
- for p in point_proposals:
|
|
|
|
|
- print(f'point_proposal:{p.shape}')
|
|
|
|
|
|
|
|
|
|
- for ap in all_proposals:
|
|
|
|
|
- print(f'ap_proposal:{ap.shape}')
|
|
|
|
|
|
|
|
|
|
- filtered_proposals = [proposal for proposal in all_proposals if proposal.shape[0] > 0]
|
|
|
|
|
|
|
+ filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
|
|
|
if len(filtered_proposals) > 0:
|
|
if len(filtered_proposals) > 0:
|
|
|
filtered_proposals_tensor=torch.cat(filtered_proposals)
|
|
filtered_proposals_tensor=torch.cat(filtered_proposals)
|
|
|
print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
|
|
print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
|
|
@@ -923,35 +909,17 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
|
|
print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
|
|
print(f'line_proposals_tensor:{line_proposals_tensor.shape}')
|
|
|
|
|
|
|
|
|
|
+ roi_features = features_align(cs_features, line_proposals, image_shapes)
|
|
|
|
|
|
|
|
- point_proposals_tensor=torch.cat(point_proposals)
|
|
|
|
|
- print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
|
|
|
|
|
-
|
|
|
|
|
- # line_features=None
|
|
|
|
|
-
|
|
|
|
|
- feature_logits = self.line_predictor(cs_features)
|
|
|
|
|
- print(f'feature_logits from line_predictor:{feature_logits.shape}')
|
|
|
|
|
-
|
|
|
|
|
- point_features = features_align(feature_logits, point_proposals, image_shapes)
|
|
|
|
|
|
|
+ if roi_features is not None:
|
|
|
|
|
+ print(f'line_features from align:{roi_features.shape}')
|
|
|
|
|
|
|
|
|
|
+ feature_logits = self.line_head(roi_features)
|
|
|
|
|
+ print(f'feature_logits from line_head:{feature_logits.shape}')
|
|
|
|
|
|
|
|
|
|
|
|
|
- line_features = features_align(feature_logits, line_proposals, image_shapes)
|
|
|
|
|
-
|
|
|
|
|
- if line_features is not None:
|
|
|
|
|
- print(f'line_features from align:{line_features.shape}')
|
|
|
|
|
-
|
|
|
|
|
- if point_features is not None:
|
|
|
|
|
- print(f'feature_logits features_align:{point_features.shape}')
|
|
|
|
|
- # feature_logits=point_features
|
|
|
|
|
-
|
|
|
|
|
- # line_logits = combine_features
|
|
|
|
|
- # print(f'line_logits:{line_logits.shape}')
|
|
|
|
|
-
|
|
|
|
|
loss_line = None
|
|
loss_line = None
|
|
|
loss_line_iou =None
|
|
loss_line_iou =None
|
|
|
- loss_point = None
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
if self.training:
|
|
if self.training:
|
|
|
|
|
|
|
@@ -959,11 +927,6 @@ class RoIHeads(nn.Module):
|
|
|
raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
|
|
raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
|
|
|
|
|
|
|
|
gt_lines = [t["lines"] for t in targets if "lines" in t]
|
|
gt_lines = [t["lines"] for t in targets if "lines" in t]
|
|
|
- gt_points = [t["points"] for t in targets if "points" in t]
|
|
|
|
|
- #
|
|
|
|
|
- # line_pos_matched_idxs = []
|
|
|
|
|
- # point_pos_matched_idxs = []
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -972,29 +935,19 @@ class RoIHeads(nn.Module):
|
|
|
img_size = h
|
|
img_size = h
|
|
|
|
|
|
|
|
gt_lines_tensor=torch.zeros(0,0)
|
|
gt_lines_tensor=torch.zeros(0,0)
|
|
|
- gt_points_tensor=torch.zeros(0,0)
|
|
|
|
|
if len(gt_lines)>0:
|
|
if len(gt_lines)>0:
|
|
|
gt_lines_tensor = torch.cat(gt_lines)
|
|
gt_lines_tensor = torch.cat(gt_lines)
|
|
|
print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
|
|
print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
|
|
|
|
|
|
|
|
- if len(gt_points)>0:
|
|
|
|
|
- gt_points_tensor = torch.cat(gt_points)
|
|
|
|
|
- print(f'gt_points_tensor:{gt_points_tensor.shape}')
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
if gt_lines_tensor.shape[0]>0 :
|
|
if gt_lines_tensor.shape[0]>0 :
|
|
|
print(f'start to lines_point_pair_loss')
|
|
print(f'start to lines_point_pair_loss')
|
|
|
loss_line = lines_point_pair_loss(
|
|
loss_line = lines_point_pair_loss(
|
|
|
- line_features, line_proposals, gt_lines, line_pos_matched_idxs
|
|
|
|
|
|
|
+ feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
|
|
|
)
|
|
)
|
|
|
- loss_line_iou = line_iou_loss(line_features, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
|
|
|
|
|
|
|
+ loss_line_iou = line_iou_loss(feature_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
|
|
|
|
|
+
|
|
|
|
|
|
|
|
- if gt_points_tensor.shape[0]>0 :
|
|
|
|
|
- print(f'start to compute_point_loss ')
|
|
|
|
|
- loss_point = compute_point_loss(
|
|
|
|
|
- point_features, point_proposals, gt_points, point_pos_matched_idxs
|
|
|
|
|
- )
|
|
|
|
|
|
|
|
|
|
if loss_line is None:
|
|
if loss_line is None:
|
|
|
print(f'loss_line is None111')
|
|
print(f'loss_line is None111')
|
|
@@ -1007,41 +960,27 @@ class RoIHeads(nn.Module):
|
|
|
loss_line = {"loss_line": loss_line}
|
|
loss_line = {"loss_line": loss_line}
|
|
|
loss_line_iou = {'loss_line_iou': loss_line_iou}
|
|
loss_line_iou = {'loss_line_iou': loss_line_iou}
|
|
|
|
|
|
|
|
- if loss_point is None:
|
|
|
|
|
- loss_point = {"loss_point": torch.tensor(0.0,device=feature_logits.device)}
|
|
|
|
|
- else:
|
|
|
|
|
- loss_point = {"loss_point": loss_point}
|
|
|
|
|
-
|
|
|
|
|
else:
|
|
else:
|
|
|
if targets is not None:
|
|
if targets is not None:
|
|
|
h, w = targets[0]["img_size"]
|
|
h, w = targets[0]["img_size"]
|
|
|
img_size = h
|
|
img_size = h
|
|
|
gt_lines = [t["lines"] for t in targets if "lines" in t]
|
|
gt_lines = [t["lines"] for t in targets if "lines" in t]
|
|
|
- gt_points = [t["points"] for t in targets if "points" in t]
|
|
|
|
|
|
|
|
|
|
gt_lines_tensor = torch.zeros(0, 0)
|
|
gt_lines_tensor = torch.zeros(0, 0)
|
|
|
- gt_points_tensor = torch.zeros(0, 0)
|
|
|
|
|
if len(gt_lines)>0:
|
|
if len(gt_lines)>0:
|
|
|
gt_lines_tensor = torch.cat(gt_lines)
|
|
gt_lines_tensor = torch.cat(gt_lines)
|
|
|
- if len(gt_points)>0:
|
|
|
|
|
- gt_points_tensor = torch.cat(gt_points)
|
|
|
|
|
|
|
|
|
|
- # line_pos_matched_idxs = []
|
|
|
|
|
- # point_pos_matched_idxs = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- if gt_lines_tensor.shape[0] > 0 and line_features is not None:
|
|
|
|
|
|
|
+ if gt_lines_tensor.shape[0] > 0 and feature_logits is not None:
|
|
|
loss_line = lines_point_pair_loss(
|
|
loss_line = lines_point_pair_loss(
|
|
|
- line_features, line_proposals, gt_lines, line_pos_matched_idxs
|
|
|
|
|
|
|
+ feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
|
|
|
)
|
|
)
|
|
|
print(f'compute_line_loss:{loss_line}')
|
|
print(f'compute_line_loss:{loss_line}')
|
|
|
- loss_line_iou = line_iou_loss(line_features , line_proposals, gt_lines, line_pos_matched_idxs,
|
|
|
|
|
|
|
+ loss_line_iou = line_iou_loss(feature_logits , line_proposals, gt_lines, line_pos_matched_idxs,
|
|
|
img_size)
|
|
img_size)
|
|
|
|
|
|
|
|
- if gt_points_tensor.shape[0] > 0 and point_features is not None:
|
|
|
|
|
- loss_point = compute_point_loss(
|
|
|
|
|
- point_features, point_proposals, gt_points, point_pos_matched_idxs
|
|
|
|
|
- )
|
|
|
|
|
|
|
+
|
|
|
|
|
|
|
|
if loss_line is None:
|
|
if loss_line is None:
|
|
|
print(f'loss_line is None')
|
|
print(f'loss_line is None')
|
|
@@ -1051,18 +990,10 @@ class RoIHeads(nn.Module):
|
|
|
print(f'loss_line_iou is None')
|
|
print(f'loss_line_iou is None')
|
|
|
loss_line_iou=torch.tensor(0.0,device=cs_features.device)
|
|
loss_line_iou=torch.tensor(0.0,device=cs_features.device)
|
|
|
|
|
|
|
|
- # if loss_point is None:
|
|
|
|
|
- # print(f'loss_point is None')
|
|
|
|
|
- # loss_point=torch.tensor(0.0,device=cs_features.device)
|
|
|
|
|
|
|
|
|
|
loss_line = {"loss_line": loss_line}
|
|
loss_line = {"loss_line": loss_line}
|
|
|
loss_line_iou = {'loss_line_iou': loss_line_iou}
|
|
loss_line_iou = {'loss_line_iou': loss_line_iou}
|
|
|
|
|
|
|
|
- if loss_point is None:
|
|
|
|
|
- loss_point = {"loss_point": torch.tensor(0.0, device=feature_logits.device)}
|
|
|
|
|
- else:
|
|
|
|
|
- loss_point = {"loss_point": loss_point}
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
else:
|
|
@@ -1074,25 +1005,18 @@ class RoIHeads(nn.Module):
|
|
|
"both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
|
"both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- if line_features is not None:
|
|
|
|
|
- lines_probs, lines_scores = line_inference(line_features,line_proposals)
|
|
|
|
|
|
|
+ if feature_logits is not None:
|
|
|
|
|
+ lines_probs, lines_scores = line_inference(feature_logits,line_proposals)
|
|
|
for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
|
|
for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
|
|
|
r["lines"] = keypoint_prob
|
|
r["lines"] = keypoint_prob
|
|
|
r["liness_scores"] = kps
|
|
r["liness_scores"] = kps
|
|
|
|
|
|
|
|
- if point_features is not None:
|
|
|
|
|
- point_probs, points_scores=point_inference(point_features, point_proposals, )
|
|
|
|
|
- for points, ps, r in zip(point_probs,points_scores, result):
|
|
|
|
|
- print(f'points_prob :{points.shape}')
|
|
|
|
|
|
|
|
|
|
- r["points"] = points
|
|
|
|
|
- r["points_scores"] = ps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f'loss_line11111:{loss_line}')
|
|
print(f'loss_line11111:{loss_line}')
|
|
|
losses.update(loss_line)
|
|
losses.update(loss_line)
|
|
|
losses.update(loss_line_iou)
|
|
losses.update(loss_line_iou)
|
|
|
- losses.update(loss_point)
|
|
|
|
|
print(f'losses:{losses}')
|
|
print(f'losses:{losses}')
|
|
|
|
|
|
|
|
if self.has_mask():
|
|
if self.has_mask():
|