|
|
@@ -13,7 +13,7 @@ import libs.vision_libs.models.detection._utils as det_utils
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
|
|
|
- lines_point_pair_loss, features_align, line_inference
|
|
|
+ lines_point_pair_loss, features_align, line_inference, compute_arc_loss
|
|
|
|
|
|
|
|
|
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
|
|
|
@@ -531,6 +531,11 @@ class RoIHeads(nn.Module):
|
|
|
point_head=None,
|
|
|
point_predictor=None,
|
|
|
|
|
|
+ # arc parameters
|
|
|
+ arc_roi_pool=None,
|
|
|
+ arc_head=None,
|
|
|
+ arc_predictor=None,
|
|
|
+
|
|
|
# Mask
|
|
|
mask_roi_pool=None,
|
|
|
mask_head=None,
|
|
|
@@ -571,6 +576,10 @@ class RoIHeads(nn.Module):
|
|
|
self.point_head = point_head
|
|
|
self.point_predictor = point_predictor
|
|
|
|
|
|
+ self.arc_roi_pool = arc_roi_pool
|
|
|
+ self.arc_head = arc_head
|
|
|
+ self.arc_predictor = arc_predictor
|
|
|
+
|
|
|
|
|
|
|
|
|
self.mask_roi_pool = mask_roi_pool
|
|
|
@@ -627,6 +636,15 @@ class RoIHeads(nn.Module):
|
|
|
# return False
|
|
|
return True
|
|
|
|
|
|
+ def has_arc(self):
|
|
|
+ # if self.line_roi_pool is None:
|
|
|
+ # return False
|
|
|
+ if self.arc_head is None:
|
|
|
+ return False
|
|
|
+ # if self.line_predictor is None:
|
|
|
+ # return False
|
|
|
+ return True
|
|
|
+
|
|
|
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
|
|
|
# type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
|
|
|
matched_idxs = []
|
|
|
@@ -1137,6 +1155,128 @@ class RoIHeads(nn.Module):
|
|
|
losses.update(loss_point)
|
|
|
print(f'losses:{losses}')
|
|
|
|
|
|
+
|
|
|
+ if self.has_arc() and self.detect_arc:
|
|
|
+ print(f'roi_heads forward has_arc()!!!!')
|
|
|
+ # print(f'labels:{labels}')
|
|
|
+ arc_proposals = [p["boxes"] for p in result]
|
|
|
+ print(f'boxes_proposals:{len(arc_proposals)}')
|
|
|
+
|
|
|
+ # if line_proposals is None or len(line_proposals) == 0:
|
|
|
+ # # è¿å空ç¹å¾æè
è·³è¿è¯¥é¨å计ç®
|
|
|
+ # return torch.empty(0, C, H, W).to(features['0'].device)
|
|
|
+
|
|
|
+ if self.training:
|
|
|
+ # during training, only focus on positive boxes
|
|
|
+ num_images = len(proposals)
|
|
|
+ print(f'num_images:{num_images}')
|
|
|
+ arc_proposals = []
|
|
|
+ arc_pos_matched_idxs = []
|
|
|
+ if matched_idxs is None:
|
|
|
+ raise ValueError("if in trainning, matched_idxs should not be None")
|
|
|
+ for img_id in range(num_images):
|
|
|
+ arc_pos=torch.where(labels[img_id] ==3)[0]
|
|
|
+ arc_proposals.append(proposals[img_id][arc_pos])
|
|
|
+ arc_pos_matched_idxs.append(matched_idxs[img_id][arc_pos])
|
|
|
+ else:
|
|
|
+ if targets is not None:
|
|
|
+
|
|
|
+ num_images = len(proposals)
|
|
|
+ arc_proposals = []
|
|
|
+
|
|
|
+ arc_pos_matched_idxs = []
|
|
|
+ print(f'val num_images:{num_images}')
|
|
|
+ if matched_idxs is None:
|
|
|
+ raise ValueError("if in trainning, matched_idxs should not be None")
|
|
|
+
|
|
|
+ for img_id in range(num_images):
|
|
|
+ arc_pos = torch.where(labels[img_id] == 3)[0]
|
|
|
+ arc_proposals.append(proposals[img_id][arc_pos])
|
|
|
+ arc_pos_matched_idxs.append(matched_idxs[img_id][arc_pos])
|
|
|
+
|
|
|
+ else:
|
|
|
+ pos_matched_idxs = None
|
|
|
+
|
|
|
+ feature_logits = self.arc_forward1(features, image_shapes, arc_proposals)
|
|
|
+
|
|
|
+ loss_arc=None
|
|
|
+
|
|
|
+ if self.training:
|
|
|
+
|
|
|
+ if targets is None or arc_pos_matched_idxs is None:
|
|
|
+ raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
|
|
|
+
|
|
|
+ gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
|
|
|
+
|
|
|
+ print(f'gt_arcs:{gt_arcs[0].shape}')
|
|
|
+ h, w = targets[0]["img_size"]
|
|
|
+ img_size = h
|
|
|
+
|
|
|
+ gt_arcs_tensor = torch.zeros(0, 0)
|
|
|
+ if len(gt_arcs) > 0:
|
|
|
+ gt_arcs_tensor = torch.cat(gt_arcs)
|
|
|
+ print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
|
|
|
+
|
|
|
+ if gt_arcs_tensor.shape[0] > 0:
|
|
|
+ print(f'start to compute point_loss')
|
|
|
+
|
|
|
+ loss_arc=compute_arc_loss(feature_logits,arc_proposals,gt_arcs,arc_pos_matched_idxs)
|
|
|
+
|
|
|
+ if loss_arc is None:
|
|
|
+ print(f'loss_arc is None111')
|
|
|
+ loss_arc = torch.tensor(0.0, device=device)
|
|
|
+
|
|
|
+ loss_arc = {"loss_arc": loss_arc}
|
|
|
+
|
|
|
+ else:
|
|
|
+ if targets is not None:
|
|
|
+ h, w = targets[0]["img_size"]
|
|
|
+ img_size = h
|
|
|
+ gt_arcs = [t["arc_mask"] for t in targets if "arc_mask" in t]
|
|
|
+
|
|
|
+ print(f'gt_arcs:{gt_arcs[0].shape}')
|
|
|
+ h, w = targets[0]["img_size"]
|
|
|
+ img_size = h
|
|
|
+
|
|
|
+ gt_arcs_tensor = torch.zeros(0, 0)
|
|
|
+ if len(gt_arcs) > 0:
|
|
|
+ gt_arcs_tensor = torch.cat(gt_arcs)
|
|
|
+ print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
|
|
|
+
|
|
|
+ if gt_arcs_tensor.shape[0] > 0:
|
|
|
+ print(f'start to compute point_loss')
|
|
|
+
|
|
|
+ loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
|
|
|
+
|
|
|
+ if loss_arc is None:
|
|
|
+ print(f'loss_arc is None111')
|
|
|
+ loss_arc = torch.tensor(0.0, device=device)
|
|
|
+
|
|
|
+ loss_arc = {"loss_arc": loss_arc}
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ else:
|
|
|
+ loss_arc = {}
|
|
|
+ if feature_logits is None or arc_proposals is None:
|
|
|
+ raise ValueError(
|
|
|
+ "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
|
|
+ )
|
|
|
+
|
|
|
+ if feature_logits is not None:
|
|
|
+
|
|
|
+ arcs_probs, arcs_scores = arc_inference(feature_logits,arc_proposals)
|
|
|
+ for keypoint_prob, kps, r in zip(arcs_probs, arcs_scores, result):
|
|
|
+ r["arcs"] = keypoint_prob
|
|
|
+ r["arcs_scores"] = kps
|
|
|
+
|
|
|
+ print(f'loss_point:{loss_point}')
|
|
|
+ losses.update(loss_point)
|
|
|
+ print(f'losses:{losses}')
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
if self.has_mask():
|
|
|
mask_proposals = [p["boxes"] for p in result]
|
|
|
if self.training:
|
|
|
@@ -1312,3 +1452,27 @@ class RoIHeads(nn.Module):
|
|
|
if roi_features is not None:
|
|
|
print(f'roi_features from align:{roi_features.shape}')
|
|
|
return roi_features
|
|
|
+
|
|
|
+
|
|
|
+ def arc_forward1(self, features, image_shapes, proposals):
|
|
|
+ print(f'point_proposals:{len(proposals)}')
|
|
|
+ # cs_features= features['0']
|
|
|
+ print(f'features-0:{features['0'].shape}')
|
|
|
+ # cs_features = self.channel_compress(features['0'])
|
|
|
+ cs_features=features['0']
|
|
|
+ # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
|
|
|
+ #
|
|
|
+ # if len(filtered_proposals) > 0:
|
|
|
+ # filtered_proposals_tensor = torch.cat(filtered_proposals)
|
|
|
+ # print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
|
|
|
+ # proposals=filtered_proposals
|
|
|
+ # point_proposals_tensor = torch.cat(proposals)
|
|
|
+ # print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
|
|
|
+
|
|
|
+ feature_logits = self.arc_predictor(cs_features)
|
|
|
+ print(f'feature_logits from line_head:{feature_logits.shape}')
|
|
|
+
|
|
|
+ roi_features = features_align(feature_logits, proposals, image_shapes)
|
|
|
+ if roi_features is not None:
|
|
|
+ print(f'roi_features from align:{roi_features.shape}')
|
|
|
+ return roi_features
|