| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- import os
- import torch
- import numpy as np
- import matplotlib.pyplot as plt
- from datetime import datetime
- def visualize_feature_map(feature_logits,
- save_root=r"/home/zhaoyinghan/py_ws/code/pokou/MultiVisionModels/models/line_detect/train_results/feature",
- max_saved=5):
- """
- Visualize 5 feature maps stacked together. Saves result inside a timestamp folder.
- Keeps only the last `max_saved` folders.
- Args:
- feature_logits: Tensor [B, 1, H, W]
- save_root: root folder to save images
- max_saved: maximum number of timestamp folders to keep
- """
- with torch.no_grad():
- os.makedirs(save_root, exist_ok=True)
- # -----------------------
- # timestamp folder
- # -----------------------
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- save_dir = os.path.join(save_root, timestamp)
- os.makedirs(save_dir, exist_ok=True)
- # -----------------------
- # cleanup old folders
- # -----------------------
- all_dirs = [d for d in os.listdir(save_root) if os.path.isdir(os.path.join(save_root, d))]
- all_dirs.sort() # °´Ãû×ÖÅÅÐò£¬Ê±¼ä´Á˳Ðò
- if len(all_dirs) > max_saved:
- for d in all_dirs[:-max_saved]:
- try:
- full_path = os.path.join(save_root, d)
- import shutil
- shutil.rmtree(full_path)
- print(f"[visualize_feature_map] Removed old folder: {full_path}")
- except Exception as e:
- print(f"[visualize_feature_map] Failed to remove {full_path}: {e}")
- # -----------------------
- # take first 5 feature maps
- # -----------------------
- B = feature_logits.size(0)
- num = min(5, B)
- maps = feature_logits[:num, 0, ...].detach().cpu() # shape (num, H, W)
- # -----------------------
- # stack & normalize
- # -----------------------
- stacked = maps.sum(dim=0).numpy()
- stacked = (stacked - stacked.min()) / (stacked.max() - stacked.min() + 1e-6)
- # -----------------------
- # save image
- # -----------------------
- save_path = os.path.join(save_dir, "stacked_feature.png")
- plt.figure(figsize=(6, 6))
- plt.imshow(stacked, cmap="viridis")
- plt.axis("off")
- plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
- plt.close()
- print(f"[visualize_feature_map] Saved stacked feature ¡ú {save_path}")
|