import os import torch import numpy as np import matplotlib.pyplot as plt from datetime import datetime def visualize_feature_map(feature_logits, save_root=r"/home/zhaoyinghan/py_ws/code/pokou/MultiVisionModels/models/line_detect/train_results/feature", max_saved=5): """ Visualize 5 feature maps stacked together. Saves result inside a timestamp folder. Keeps only the last `max_saved` folders. Args: feature_logits: Tensor [B, 1, H, W] save_root: root folder to save images max_saved: maximum number of timestamp folders to keep """ with torch.no_grad(): os.makedirs(save_root, exist_ok=True) # ----------------------- # timestamp folder # ----------------------- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") save_dir = os.path.join(save_root, timestamp) os.makedirs(save_dir, exist_ok=True) # ----------------------- # cleanup old folders # ----------------------- all_dirs = [d for d in os.listdir(save_root) if os.path.isdir(os.path.join(save_root, d))] all_dirs.sort() # °´Ãû×ÖÅÅÐò£¬Ê±¼ä´Á˳Ðò if len(all_dirs) > max_saved: for d in all_dirs[:-max_saved]: try: full_path = os.path.join(save_root, d) import shutil shutil.rmtree(full_path) print(f"[visualize_feature_map] Removed old folder: {full_path}") except Exception as e: print(f"[visualize_feature_map] Failed to remove {full_path}: {e}") # ----------------------- # take first 5 feature maps # ----------------------- B = feature_logits.size(0) num = min(5, B) maps = feature_logits[:num, 0, ...].detach().cpu() # shape (num, H, W) # ----------------------- # stack & normalize # ----------------------- stacked = maps.sum(dim=0).numpy() stacked = (stacked - stacked.min()) / (stacked.max() - stacked.min() + 1e-6) # ----------------------- # save image # ----------------------- save_path = os.path.join(save_dir, "stacked_feature.png") plt.figure(figsize=(6, 6)) plt.imshow(stacked, cmap="viridis") plt.axis("off") plt.savefig(save_path, bbox_inches="tight", pad_inches=0) plt.close() print(f"[visualize_feature_map] Saved stacked feature ¡ú {save_path}")