|
@@ -1128,7 +1128,12 @@ class RoIHeads(nn.Module):
|
|
|
# print(f'keypoint_features from roi_pool:{wirepoint_features.shape}')
|
|
|
outputs, wirepoint_features = self.wirepoint_head(wirepoint_features)
|
|
|
|
|
|
+
|
|
|
+
|
|
|
outputs = merge_features(outputs, wirepoint_proposals)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals)
|
|
|
|
|
|
print(f'outpust:{outputs.shape}')
|
|
@@ -1204,6 +1209,7 @@ class RoIHeads(nn.Module):
|
|
|
|
|
|
def merge_features(features, proposals):
|
|
|
print(f'features:{features.shape}')
|
|
|
+ print(f'proposals:{len(proposals)}')
|
|
|
def diagnose_input(features, proposals):
|
|
|
"""诊断输入数据"""
|
|
|
print("Input Diagnostics:")
|
|
@@ -1225,11 +1231,19 @@ def merge_features(features, proposals):
|
|
|
f"Proposals count ({proposals_count}) must match features batch size ({features_size})"
|
|
|
)
|
|
|
|
|
|
- def safe_max_reduction(features_per_img):
|
|
|
+ def safe_max_reduction(features_per_img,proposals):
|
|
|
+
|
|
|
+ print(f'proposal:{proposals.shape},features_per_img:{features_per_img.shape}')
|
|
|
"""安全的最大值压缩"""
|
|
|
if features_per_img.numel() == 0:
|
|
|
return torch.zeros_like(features_per_img).unsqueeze(0)
|
|
|
|
|
|
+ for feature_map,roi in zip(features_per_img,proposals):
|
|
|
+ print(f'feature_map:{feature_map.shape},roi:{roi}')
|
|
|
+ roi_off_x=roi[0]
|
|
|
+ roi_off_y=roi[1]
|
|
|
+
|
|
|
+
|
|
|
try:
|
|
|
# 沿着第0维求最大值,保持维度
|
|
|
max_features, _ = torch.max(features_per_img, dim=0, keepdim=True)
|
|
@@ -1257,8 +1271,10 @@ def merge_features(features, proposals):
|
|
|
|
|
|
# 每张图像特征压缩
|
|
|
features_imgs = []
|
|
|
- for features_per_img in split_features:
|
|
|
- compressed_features = safe_max_reduction(features_per_img)
|
|
|
+
|
|
|
+ print(f'split_features:{len(split_features)}')
|
|
|
+ for features_per_img,proposal in zip(split_features,proposals):
|
|
|
+ compressed_features = safe_max_reduction(features_per_img,proposal)
|
|
|
features_imgs.append(compressed_features)
|
|
|
|
|
|
# 合并特征
|