admin 1 месяц назад
Родитель
Сommit
c01fe0262d
1 измененных файлов с 14 добавлено и 2 удалено
  1. 14 2
      models/line_detect/loi_heads.py

+ 14 - 2
models/line_detect/loi_heads.py

@@ -1448,11 +1448,23 @@ class RoIHeads(nn.Module):
                         if feature_logits is not None:
 
                             circles_probs, circles_scores = circle_inference(feature_logits, circle_proposals)
-                            for keypoint_prob, kps, r,f in zip(circles_probs, circles_scores, result,feature_logits):
+                            # print(f'circles_probs:{circles_probs.shape}, circles_scores:{circles_scores.shape}')
+                            proposals_per_image = [box.size(0) for box in  circle_proposals]
+                            print(f'circle_proposals_per_image:{proposals_per_image}')
+                            feature_logits_props=[]
+                            start_idx = 0
+                            for num_p in proposals_per_image:
+                                current_features = feature_logits[start_idx:start_idx + num_p]
+                                merged_feature = torch.sum(current_features, dim=0, keepdim=True)
+                                feature_logits_props.append(merged_feature)
+
+
+
+                            for keypoint_prob, kps, r,f in zip(circles_probs, circles_scores, result,feature_logits_props):
                                 r["circles"] = keypoint_prob
                                 r["circles_scores"] = kps
                                 print(f'circles feature map:{f.shape}')
-                                r["features"] = f
+                                r["features"] = f.squeeze(0)
 
                 print(f'loss_circle:{loss_circle}')
                 print(f'loss_circle_extra:{loss_circle_extra}')