show_feature.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os
  2. import torch
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from datetime import datetime
  6. def visualize_feature_map(feature_logits,
  7. save_root=r"/home/zhaoyinghan/py_ws/code/pokou/MultiVisionModels/models/line_detect/train_results/feature",
  8. max_saved=5):
  9. """
  10. Visualize 5 feature maps stacked together. Saves result inside a timestamp folder.
  11. Keeps only the last `max_saved` folders.
  12. Args:
  13. feature_logits: Tensor [B, 1, H, W]
  14. save_root: root folder to save images
  15. max_saved: maximum number of timestamp folders to keep
  16. """
  17. with torch.no_grad():
  18. os.makedirs(save_root, exist_ok=True)
  19. # -----------------------
  20. # timestamp folder
  21. # -----------------------
  22. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  23. save_dir = os.path.join(save_root, timestamp)
  24. os.makedirs(save_dir, exist_ok=True)
  25. # -----------------------
  26. # cleanup old folders
  27. # -----------------------
  28. all_dirs = [d for d in os.listdir(save_root) if os.path.isdir(os.path.join(save_root, d))]
  29. all_dirs.sort() # °´Ãû×ÖÅÅÐò£¬Ê±¼ä´Á˳Ðò
  30. if len(all_dirs) > max_saved:
  31. for d in all_dirs[:-max_saved]:
  32. try:
  33. full_path = os.path.join(save_root, d)
  34. import shutil
  35. shutil.rmtree(full_path)
  36. print(f"[visualize_feature_map] Removed old folder: {full_path}")
  37. except Exception as e:
  38. print(f"[visualize_feature_map] Failed to remove {full_path}: {e}")
  39. # -----------------------
  40. # take first 5 feature maps
  41. # -----------------------
  42. B = feature_logits.size(0)
  43. num = min(5, B)
  44. maps = feature_logits[:num, 0, ...].detach().cpu() # shape (num, H, W)
  45. # -----------------------
  46. # stack & normalize
  47. # -----------------------
  48. stacked = maps.sum(dim=0).numpy()
  49. stacked = (stacked - stacked.min()) / (stacked.max() - stacked.min() + 1e-6)
  50. # -----------------------
  51. # save image
  52. # -----------------------
  53. save_path = os.path.join(save_dir, "stacked_feature.png")
  54. plt.figure(figsize=(6, 6))
  55. plt.imshow(stacked, cmap="viridis")
  56. plt.axis("off")
  57. plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
  58. plt.close()
  59. print(f"[visualize_feature_map] Saved stacked feature ¡ú {save_path}")