| 
					
				 | 
			
			
				@@ -1128,7 +1128,12 @@ class RoIHeads(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # print(f'keypoint_features from roi_pool:{wirepoint_features.shape}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             outputs, wirepoint_features = self.wirepoint_head(wirepoint_features) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             outputs = merge_features(outputs, wirepoint_proposals) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             wirepoint_features = merge_features(wirepoint_features, wirepoint_proposals) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             print(f'outpust:{outputs.shape}') 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -1204,6 +1209,7 @@ class RoIHeads(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def merge_features(features, proposals): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     print(f'features:{features.shape}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print(f'proposals:{len(proposals)}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def diagnose_input(features, proposals): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """诊断输入数据""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print("Input Diagnostics:") 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -1225,11 +1231,19 @@ def merge_features(features, proposals): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 f"Proposals count ({proposals_count}) must match features batch size ({features_size})" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def safe_max_reduction(features_per_img): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def safe_max_reduction(features_per_img,proposals): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print(f'proposal:{proposals.shape},features_per_img:{features_per_img.shape}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """安全的最大值压缩""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if features_per_img.numel() == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             return torch.zeros_like(features_per_img).unsqueeze(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for feature_map,roi in zip(features_per_img,proposals): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            print(f'feature_map:{feature_map.shape},roi:{roi}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            roi_off_x=roi[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            roi_off_y=roi[1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # 沿着第0维求最大值,保持维度 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             max_features, _ = torch.max(features_per_img, dim=0, keepdim=True) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -1257,8 +1271,10 @@ def merge_features(features, proposals): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # 每张图像特征压缩 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         features_imgs = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for features_per_img in split_features: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            compressed_features = safe_max_reduction(features_per_img) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print(f'split_features:{len(split_features)}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for features_per_img,proposal in zip(split_features,proposals): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            compressed_features = safe_max_reduction(features_per_img,proposal) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             features_imgs.append(compressed_features) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # 合并特征 
			 |