admin 1 miesiąc temu
rodzic
commit
797a449034
1 zmienionych plików z 11 dodań i 6 usunięć
  1. 11 6
      models/line_detect/loi_heads.py

+ 11 - 6
models/line_detect/loi_heads.py

@@ -76,6 +76,8 @@ def maskrcnn_inference(x, labels):
     """
     mask_prob = x.sigmoid()
 
+
+
     # select masks corresponding to the predicted classes
     num_masks = x.shape[0]
     boxes_per_image = [label.shape[0] for label in labels]
@@ -84,6 +86,8 @@ def maskrcnn_inference(x, labels):
     mask_prob = mask_prob[index, labels][:, None]
     mask_prob = mask_prob.split(boxes_per_image, dim=0)
 
+
+
     return mask_prob
 
 
@@ -1449,20 +1453,21 @@ class RoIHeads(nn.Module):
 
                         if feature_logits is not None:
 
-                            circles_probs, circles_scores,circle_points = arc_inference(feature_logits, circle_proposals,th=0)
+                            circles_probs, circles_scores, circle_points = arc_inference(feature_logits,
+                                                                                         circle_proposals, th=0)
                             # print(f'circles_probs:{circles_probs.shape}, circles_scores:{circles_scores.shape}')
-                            proposals_per_image = [box.size(0) for box in  circle_proposals]
+                            proposals_per_image = [box.size(0) for box in circle_proposals]
                             print(f'circle_proposals_per_image:{proposals_per_image}')
-                            feature_logits_props=[]
+                            feature_logits_props = []
                             start_idx = 0
                             for num_p in proposals_per_image:
                                 current_features = feature_logits[start_idx:start_idx + num_p]
                                 merged_feature = torch.sum(current_features, dim=0, keepdim=True)
                                 feature_logits_props.append(merged_feature)
+                                start_idx += num_p
 
-
-
-                            for keypoint_prob, kps, r,f in zip(circles_probs, circles_scores, result,feature_logits_props):
+                            for keypoint_prob, kps, r, f in zip(circles_probs, circles_scores, result,
+                                                                feature_logits_props):
                                 r["circles"] = keypoint_prob
                                 r["circles_scores"] = kps
                                 print(f'circles feature map:{f.shape}')