|
|
@@ -13,7 +13,7 @@ import libs.vision_libs.models.detection._utils as det_utils
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
|
|
|
- lines_point_pair_loss, features_align
|
|
|
+ lines_point_pair_loss, features_align, line_inference
|
|
|
|
|
|
|
|
|
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
|
|
|
@@ -902,12 +902,6 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
print(f'line_proposals:{len(line_proposals)}')
|
|
|
|
|
|
- # line_features = self.line_roi_pool(features, line_proposals, image_shapes)
|
|
|
-
|
|
|
- # print(f'line_features from line_roi_pool:{line_features.shape}')
|
|
|
- #(b,256,512,512)
|
|
|
- # cs_features = self.channel_compress(features['0'])
|
|
|
- #(b.8,512,512)
|
|
|
cs_features= features['0']
|
|
|
|
|
|
|
|
|
@@ -933,45 +927,74 @@ class RoIHeads(nn.Module):
|
|
|
point_proposals_tensor=torch.cat(point_proposals)
|
|
|
print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
|
|
|
|
|
|
- line_features=None
|
|
|
+ # line_features=None
|
|
|
|
|
|
feature_logits = self.line_predictor(cs_features)
|
|
|
print(f'feature_logits from line_predictor:{feature_logits.shape}')
|
|
|
|
|
|
point_features = features_align(feature_logits, point_proposals, image_shapes)
|
|
|
- print(f'feature_logits features_align:{point_features.shape}')
|
|
|
- feature_logits=point_features
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ line_features = features_align(feature_logits, line_proposals, image_shapes)
|
|
|
+
|
|
|
+ if line_features is not None:
|
|
|
+ print(f'line_features from align:{line_features.shape}')
|
|
|
+
|
|
|
+ if point_features is not None:
|
|
|
+ print(f'feature_logits features_align:{point_features.shape}')
|
|
|
+ # feature_logits=point_features
|
|
|
|
|
|
# line_logits = combine_features
|
|
|
# print(f'line_logits:{line_logits.shape}')
|
|
|
|
|
|
- loss_line = {}
|
|
|
- loss_line_iou = {}
|
|
|
- loss_point = {}
|
|
|
+ loss_line = None
|
|
|
+ loss_line_iou =None
|
|
|
+ loss_point = None
|
|
|
+
|
|
|
+
|
|
|
if self.training:
|
|
|
|
|
|
if targets is None or pos_matched_idxs is None:
|
|
|
raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
|
|
|
|
|
|
- gt_lines = [t["lines"] for t in targets]
|
|
|
- gt_points = [t["points"] for t in targets]
|
|
|
+ gt_lines = [t["lines"] for t in targets if "lines" in t]
|
|
|
+ gt_points = [t["points"] for t in targets if "points" in t]
|
|
|
+ #
|
|
|
+ # line_pos_matched_idxs = []
|
|
|
+ # point_pos_matched_idxs = []
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
print(f'gt_lines:{gt_lines[0].shape}')
|
|
|
h, w = targets[0]["img_size"]
|
|
|
img_size = h
|
|
|
|
|
|
- gt_lines_tensor=torch.cat(gt_lines)
|
|
|
- gt_points_tensor = torch.cat(gt_points)
|
|
|
- print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
|
|
|
- print(f'gt_points_tensor:{gt_points_tensor.shape}')
|
|
|
- if gt_lines_tensor.shape[0]>0 and line_features is not None:
|
|
|
+ gt_lines_tensor=torch.zeros(0,0)
|
|
|
+ gt_points_tensor=torch.zeros(0,0)
|
|
|
+ if len(gt_lines)>0:
|
|
|
+ gt_lines_tensor = torch.cat(gt_lines)
|
|
|
+ print(f'gt_lines_tensor:{gt_lines_tensor.shape}')
|
|
|
+
|
|
|
+ if len(gt_points)>0:
|
|
|
+ gt_points_tensor = torch.cat(gt_points)
|
|
|
+ print(f'gt_points_tensor:{gt_points_tensor.shape}')
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ if gt_lines_tensor.shape[0]>0 :
|
|
|
+ print(f'start to lines_point_pair_loss')
|
|
|
loss_line = lines_point_pair_loss(
|
|
|
- feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
|
|
|
+ line_features, line_proposals, gt_lines, line_pos_matched_idxs
|
|
|
)
|
|
|
- loss_line_iou = line_iou_loss(feature_logits, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
|
|
|
+ loss_line_iou = line_iou_loss(line_features, line_proposals, gt_lines, line_pos_matched_idxs, img_size)
|
|
|
|
|
|
- if gt_points_tensor.shape[0]>0 and point_features is not None:
|
|
|
+ if gt_points_tensor.shape[0]>0 :
|
|
|
+ print(f'start to compute_point_loss ')
|
|
|
loss_point = compute_point_loss(
|
|
|
- feature_logits, point_proposals, gt_points, point_pos_matched_idxs
|
|
|
+ point_features, point_proposals, gt_points, point_pos_matched_idxs
|
|
|
)
|
|
|
|
|
|
if not loss_line:
|
|
|
@@ -982,7 +1005,11 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
loss_line = {"loss_line": loss_line}
|
|
|
loss_line_iou = {'loss_line_iou': loss_line_iou}
|
|
|
- loss_point = {"loss_point": loss_point}
|
|
|
+
|
|
|
+ if loss_point is None:
|
|
|
+ loss_point = {"loss_point": torch.tensor(0.0,device=feature_logits.device)}
|
|
|
+ else:
|
|
|
+ loss_point = {"loss_point": loss_point}
|
|
|
|
|
|
else:
|
|
|
if targets is not None:
|
|
|
@@ -994,18 +1021,20 @@ class RoIHeads(nn.Module):
|
|
|
gt_lines_tensor = torch.cat(gt_lines)
|
|
|
gt_points_tensor = torch.cat(gt_points)
|
|
|
|
|
|
+ line_pos_matched_idxs = []
|
|
|
+ point_pos_matched_idxs = []
|
|
|
|
|
|
|
|
|
if gt_lines_tensor.shape[0] > 0 and line_features is not None:
|
|
|
loss_line = lines_point_pair_loss(
|
|
|
- feature_logits, line_proposals, gt_lines, line_pos_matched_idxs
|
|
|
+ line_features, line_proposals, gt_lines, line_pos_matched_idxs
|
|
|
)
|
|
|
- loss_line_iou = line_iou_loss(feature_logits, line_proposals, gt_lines, line_pos_matched_idxs,
|
|
|
+ loss_line_iou = line_iou_loss(line_features , line_proposals, gt_lines, line_pos_matched_idxs,
|
|
|
img_size)
|
|
|
|
|
|
if gt_points_tensor.shape[0] > 0 and point_features is not None:
|
|
|
loss_point = compute_point_loss(
|
|
|
- feature_logits, point_proposals, gt_points, point_pos_matched_idxs
|
|
|
+ point_features, point_proposals, gt_points, point_pos_matched_idxs
|
|
|
)
|
|
|
|
|
|
if not loss_line :
|
|
|
@@ -1019,7 +1048,11 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
loss_line = {"loss_line": loss_line}
|
|
|
loss_line_iou = {'loss_line_iou': loss_line_iou}
|
|
|
- loss_point={"loss_point":loss_point}
|
|
|
+
|
|
|
+ if loss_point is None:
|
|
|
+ loss_point = {"loss_point": torch.tensor(0.0, device=feature_logits.device)}
|
|
|
+ else:
|
|
|
+ loss_point = {"loss_point": loss_point}
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1029,14 +1062,14 @@ class RoIHeads(nn.Module):
|
|
|
"both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
|
|
)
|
|
|
|
|
|
- # if line_features is not None:
|
|
|
- # lines_probs, lines_scores = line_inference(combine_features,line_proposals)
|
|
|
- # for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
|
|
|
- # r["lines"] = keypoint_prob
|
|
|
- # r["liness_scores"] = kps
|
|
|
+ if line_features is not None:
|
|
|
+ lines_probs, lines_scores = line_inference(line_features,line_proposals)
|
|
|
+ for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
|
|
|
+ r["lines"] = keypoint_prob
|
|
|
+ r["liness_scores"] = kps
|
|
|
|
|
|
if point_features is not None:
|
|
|
- point_probs, points_scores=point_inference(feature_logits, point_proposals, )
|
|
|
+ point_probs, points_scores=point_inference(point_features, point_proposals, )
|
|
|
for points, ps, r in zip(point_probs,points_scores, result):
|
|
|
print(f'points_prob :{points.shape}')
|
|
|
|