|
@@ -525,6 +525,12 @@ class RoIHeads(nn.Module):
|
|
|
line_roi_pool=None,
|
|
line_roi_pool=None,
|
|
|
line_head=None,
|
|
line_head=None,
|
|
|
line_predictor=None,
|
|
line_predictor=None,
|
|
|
|
|
+
|
|
|
|
|
+ # point parameters
|
|
|
|
|
+ point_roi_pool=None,
|
|
|
|
|
+ point_head=None,
|
|
|
|
|
+ point_predictor=None,
|
|
|
|
|
+
|
|
|
# Mask
|
|
# Mask
|
|
|
mask_roi_pool=None,
|
|
mask_roi_pool=None,
|
|
|
mask_head=None,
|
|
mask_head=None,
|
|
@@ -532,6 +538,10 @@ class RoIHeads(nn.Module):
|
|
|
keypoint_roi_pool=None,
|
|
keypoint_roi_pool=None,
|
|
|
keypoint_head=None,
|
|
keypoint_head=None,
|
|
|
keypoint_predictor=None,
|
|
keypoint_predictor=None,
|
|
|
|
|
+
|
|
|
|
|
+ detect_point=True,
|
|
|
|
|
+ detect_line=True,
|
|
|
|
|
+ detect_arc=False,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
@@ -557,6 +567,12 @@ class RoIHeads(nn.Module):
|
|
|
self.line_head = line_head
|
|
self.line_head = line_head
|
|
|
self.line_predictor = line_predictor
|
|
self.line_predictor = line_predictor
|
|
|
|
|
|
|
|
|
|
+ self.point_roi_pool = point_roi_pool
|
|
|
|
|
+ self.point_head = point_head
|
|
|
|
|
+ self.point_predictor = point_predictor
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
self.mask_roi_pool = mask_roi_pool
|
|
self.mask_roi_pool = mask_roi_pool
|
|
|
self.mask_head = mask_head
|
|
self.mask_head = mask_head
|
|
|
self.mask_predictor = mask_predictor
|
|
self.mask_predictor = mask_predictor
|
|
@@ -565,6 +581,10 @@ class RoIHeads(nn.Module):
|
|
|
self.keypoint_head = keypoint_head
|
|
self.keypoint_head = keypoint_head
|
|
|
self.keypoint_predictor = keypoint_predictor
|
|
self.keypoint_predictor = keypoint_predictor
|
|
|
|
|
|
|
|
|
|
+ self.detect_point =detect_point
|
|
|
|
|
+ self.detect_line =detect_line
|
|
|
|
|
+ self.detect_arc =detect_arc
|
|
|
|
|
+
|
|
|
self.channel_compress = nn.Sequential(
|
|
self.channel_compress = nn.Sequential(
|
|
|
nn.Conv2d(256, 8, kernel_size=1),
|
|
nn.Conv2d(256, 8, kernel_size=1),
|
|
|
nn.BatchNorm2d(8),
|
|
nn.BatchNorm2d(8),
|
|
@@ -598,6 +618,15 @@ class RoIHeads(nn.Module):
|
|
|
# return False
|
|
# return False
|
|
|
return True
|
|
return True
|
|
|
|
|
|
|
|
|
|
+ def has_point(self):
|
|
|
|
|
+ # if self.line_roi_pool is None:
|
|
|
|
|
+ # return False
|
|
|
|
|
+ if self.point_head is None:
|
|
|
|
|
+ return False
|
|
|
|
|
+ # if self.line_predictor is None:
|
|
|
|
|
+ # return False
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
|
|
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
|
|
|
# type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
|
|
# type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
|
|
|
matched_idxs = []
|
|
matched_idxs = []
|
|
@@ -831,7 +860,7 @@ class RoIHeads(nn.Module):
|
|
|
}
|
|
}
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- if self.has_line():
|
|
|
|
|
|
|
+ if self.has_line() and self.detect_line:
|
|
|
print(f'roi_heads forward has_line()!!!!')
|
|
print(f'roi_heads forward has_line()!!!!')
|
|
|
# print(f'labels:{labels}')
|
|
# print(f'labels:{labels}')
|
|
|
line_proposals = [p["boxes"] for p in result]
|
|
line_proposals = [p["boxes"] for p in result]
|
|
@@ -894,7 +923,7 @@ class RoIHeads(nn.Module):
|
|
|
else:
|
|
else:
|
|
|
pos_matched_idxs = None
|
|
pos_matched_idxs = None
|
|
|
|
|
|
|
|
- feature_logits = self.head_forward3(features, image_shapes, line_proposals)
|
|
|
|
|
|
|
+ feature_logits = self.line_forward3(features, image_shapes, line_proposals)
|
|
|
|
|
|
|
|
loss_line = None
|
|
loss_line = None
|
|
|
loss_line_iou =None
|
|
loss_line_iou =None
|
|
@@ -984,7 +1013,7 @@ class RoIHeads(nn.Module):
|
|
|
lines_probs, lines_scores = line_inference(feature_logits,line_proposals)
|
|
lines_probs, lines_scores = line_inference(feature_logits,line_proposals)
|
|
|
for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
|
|
for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
|
|
|
r["lines"] = keypoint_prob
|
|
r["lines"] = keypoint_prob
|
|
|
- r["liness_scores"] = kps
|
|
|
|
|
|
|
+ r["lines_scores"] = kps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -993,6 +1022,120 @@ class RoIHeads(nn.Module):
|
|
|
losses.update(loss_line)
|
|
losses.update(loss_line)
|
|
|
losses.update(loss_line_iou)
|
|
losses.update(loss_line_iou)
|
|
|
print(f'losses:{losses}')
|
|
print(f'losses:{losses}')
|
|
|
|
|
+ if self.has_point() and self.detect_point:
|
|
|
|
|
+ print(f'roi_heads forward has_point()!!!!')
|
|
|
|
|
+ # print(f'labels:{labels}')
|
|
|
|
|
+ point_proposals = [p["boxes"] for p in result]
|
|
|
|
|
+ print(f'boxes_proposals:{len(point_proposals)}')
|
|
|
|
|
+
|
|
|
|
|
+ # if line_proposals is None or len(line_proposals) == 0:
|
|
|
|
|
+ # # è¿å空ç¹å¾æè
è·³è¿è¯¥é¨å计ç®
|
|
|
|
|
+ # return torch.empty(0, C, H, W).to(features['0'].device)
|
|
|
|
|
+
|
|
|
|
|
+ if self.training:
|
|
|
|
|
+ # during training, only focus on positive boxes
|
|
|
|
|
+ num_images = len(proposals)
|
|
|
|
|
+ print(f'num_images:{num_images}')
|
|
|
|
|
+ point_proposals = []
|
|
|
|
|
+ point_pos_matched_idxs = []
|
|
|
|
|
+ if matched_idxs is None:
|
|
|
|
|
+ raise ValueError("if in trainning, matched_idxs should not be None")
|
|
|
|
|
+ for img_id in range(num_images):
|
|
|
|
|
+ point_pos=torch.where(labels[img_id] ==1)[0]
|
|
|
|
|
+ point_proposals.append(proposals[img_id][point_pos])
|
|
|
|
|
+ point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
|
|
|
|
|
+ else:
|
|
|
|
|
+ if targets is not None:
|
|
|
|
|
+
|
|
|
|
|
+ num_images = len(proposals)
|
|
|
|
|
+ point_proposals = []
|
|
|
|
|
+
|
|
|
|
|
+ point_pos_matched_idxs = []
|
|
|
|
|
+ print(f'val num_images:{num_images}')
|
|
|
|
|
+ if matched_idxs is None:
|
|
|
|
|
+ raise ValueError("if in trainning, matched_idxs should not be None")
|
|
|
|
|
+
|
|
|
|
|
+ for img_id in range(num_images):
|
|
|
|
|
+ point_pos = torch.where(labels[img_id] == 1)[0]
|
|
|
|
|
+ point_proposals.append(proposals[img_id][point_pos])
|
|
|
|
|
+ point_pos_matched_idxs.append(matched_idxs[img_id][point_pos])
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ pos_matched_idxs = None
|
|
|
|
|
+
|
|
|
|
|
+ feature_logits = self.point_forward1(features, image_shapes, point_proposals)
|
|
|
|
|
+
|
|
|
|
|
+ loss_point=None
|
|
|
|
|
+
|
|
|
|
|
+ if self.training:
|
|
|
|
|
+
|
|
|
|
|
+ if targets is None or point_pos_matched_idxs is None:
|
|
|
|
|
+ raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
|
|
|
|
|
+
|
|
|
|
|
+ gt_points = [t["points"] for t in targets if "points" in t]
|
|
|
|
|
+
|
|
|
|
|
+ print(f'gt_points:{gt_points[0].shape}')
|
|
|
|
|
+ h, w = targets[0]["img_size"]
|
|
|
|
|
+ img_size = h
|
|
|
|
|
+
|
|
|
|
|
+ gt_points_tensor = torch.zeros(0, 0)
|
|
|
|
|
+ if len(gt_points) > 0:
|
|
|
|
|
+ gt_points_tensor = torch.cat(gt_points)
|
|
|
|
|
+ print(f'gt_points_tensor:{gt_points_tensor.shape}')
|
|
|
|
|
+
|
|
|
|
|
+ if gt_points_tensor.shape[0] > 0:
|
|
|
|
|
+ print(f'start to compute point_loss')
|
|
|
|
|
+
|
|
|
|
|
+ loss_point=compute_point_loss(feature_logits,point_proposals,gt_points,point_pos_matched_idxs)
|
|
|
|
|
+
|
|
|
|
|
+ if loss_point is None:
|
|
|
|
|
+ print(f'loss_point is None111')
|
|
|
|
|
+ loss_point = torch.tensor(0.0, device=device)
|
|
|
|
|
+
|
|
|
|
|
+ loss_point = {"loss_point": loss_point}
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ if targets is not None:
|
|
|
|
|
+ h, w = targets[0]["img_size"]
|
|
|
|
|
+ img_size = h
|
|
|
|
|
+ gt_points = [t["points"] for t in targets if "points" in t]
|
|
|
|
|
+
|
|
|
|
|
+ gt_points_tensor = torch.zeros(0, 0)
|
|
|
|
|
+ if len(gt_points) > 0:
|
|
|
|
|
+ gt_points_tensor = torch.cat(gt_points)
|
|
|
|
|
+ print(f'gt_points_tensor:{gt_points_tensor.shape}')
|
|
|
|
|
+
|
|
|
|
|
+ if gt_points_tensor.shape[0] > 0:
|
|
|
|
|
+ print(f'start to compute point_loss')
|
|
|
|
|
+
|
|
|
|
|
+ loss_point = compute_point_loss(feature_logits, point_proposals, gt_points,
|
|
|
|
|
+ point_pos_matched_idxs, img_size)
|
|
|
|
|
+
|
|
|
|
|
+ if loss_point is None:
|
|
|
|
|
+ print(f'loss_point is None111')
|
|
|
|
|
+ loss_point = torch.tensor(0.0, device=device)
|
|
|
|
|
+
|
|
|
|
|
+ loss_point = {"loss_point": loss_point}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ loss_point = {}
|
|
|
|
|
+ if feature_logits is None or point_proposals is None:
|
|
|
|
|
+ raise ValueError(
|
|
|
|
|
+ "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if feature_logits is not None:
|
|
|
|
|
+
|
|
|
|
|
+ points_probs, points_scores = point_inference(feature_logits,point_proposals)
|
|
|
|
|
+ for keypoint_prob, kps, r in zip(points_probs, points_scores, result):
|
|
|
|
|
+ r["points"] = keypoint_prob
|
|
|
|
|
+ r["points_scores"] = kps
|
|
|
|
|
+
|
|
|
|
|
+ print(f'loss_point:{loss_point}')
|
|
|
|
|
+ losses.update(loss_point)
|
|
|
|
|
+ print(f'losses:{losses}')
|
|
|
|
|
|
|
|
if self.has_mask():
|
|
if self.has_mask():
|
|
|
mask_proposals = [p["boxes"] for p in result]
|
|
mask_proposals = [p["boxes"] for p in result]
|
|
@@ -1083,7 +1226,7 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
|
|
return result, losses
|
|
return result, losses
|
|
|
|
|
|
|
|
- def head_forward1(self, features, image_shapes, line_proposals):
|
|
|
|
|
|
|
+ def line_forward1(self, features, image_shapes, line_proposals):
|
|
|
print(f'line_proposals:{len(line_proposals)}')
|
|
print(f'line_proposals:{len(line_proposals)}')
|
|
|
# cs_features= features['0']
|
|
# cs_features= features['0']
|
|
|
print(f'features-0:{features['0'].shape}')
|
|
print(f'features-0:{features['0'].shape}')
|
|
@@ -1101,7 +1244,7 @@ class RoIHeads(nn.Module):
|
|
|
print(f'feature_logits from line_head:{feature_logits.shape}')
|
|
print(f'feature_logits from line_head:{feature_logits.shape}')
|
|
|
return feature_logits
|
|
return feature_logits
|
|
|
|
|
|
|
|
- def head_forward2(self, features, image_shapes, line_proposals):
|
|
|
|
|
|
|
+ def line_forward2(self, features, image_shapes, line_proposals):
|
|
|
print(f'line_proposals:{len(line_proposals)}')
|
|
print(f'line_proposals:{len(line_proposals)}')
|
|
|
# cs_features= features['0']
|
|
# cs_features= features['0']
|
|
|
print(f'features-0:{features['0'].shape}')
|
|
print(f'features-0:{features['0'].shape}')
|
|
@@ -1124,7 +1267,7 @@ class RoIHeads(nn.Module):
|
|
|
print(f'roi_features from align:{roi_features.shape}')
|
|
print(f'roi_features from align:{roi_features.shape}')
|
|
|
return roi_features
|
|
return roi_features
|
|
|
|
|
|
|
|
- def head_forward3(self, features, image_shapes, line_proposals):
|
|
|
|
|
|
|
+ def line_forward3(self, features, image_shapes, line_proposals):
|
|
|
print(f'line_proposals:{len(line_proposals)}')
|
|
print(f'line_proposals:{len(line_proposals)}')
|
|
|
# cs_features= features['0']
|
|
# cs_features= features['0']
|
|
|
print(f'features-0:{features['0'].shape}')
|
|
print(f'features-0:{features['0'].shape}')
|
|
@@ -1146,3 +1289,26 @@ class RoIHeads(nn.Module):
|
|
|
if roi_features is not None:
|
|
if roi_features is not None:
|
|
|
print(f'roi_features from align:{roi_features.shape}')
|
|
print(f'roi_features from align:{roi_features.shape}')
|
|
|
return roi_features
|
|
return roi_features
|
|
|
|
|
+
|
|
|
|
|
+ def point_forward1(self, features, image_shapes, proposals):
|
|
|
|
|
+ print(f'point_proposals:{len(proposals)}')
|
|
|
|
|
+ # cs_features= features['0']
|
|
|
|
|
+ print(f'features-0:{features['0'].shape}')
|
|
|
|
|
+ # cs_features = self.channel_compress(features['0'])
|
|
|
|
|
+ cs_features=features['0']
|
|
|
|
|
+ filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
|
|
|
|
|
+
|
|
|
|
|
+ if len(filtered_proposals) > 0:
|
|
|
|
|
+ filtered_proposals_tensor = torch.cat(filtered_proposals)
|
|
|
|
|
+ print(f'filtered_proposals_tensor:{filtered_proposals_tensor.shape}')
|
|
|
|
|
+ proposals=filtered_proposals
|
|
|
|
|
+ point_proposals_tensor = torch.cat(proposals)
|
|
|
|
|
+ print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
|
|
|
|
|
+
|
|
|
|
|
+ feature_logits = self.point_predictor(cs_features)
|
|
|
|
|
+ print(f'feature_logits from line_head:{feature_logits.shape}')
|
|
|
|
|
+
|
|
|
|
|
+ roi_features = features_align(cs_features, proposals, image_shapes)
|
|
|
|
|
+ if roi_features is not None:
|
|
|
|
|
+ print(f'roi_features from align:{roi_features.shape}')
|
|
|
|
|
+ return roi_features
|