show_mask.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import os
  2. import torch
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from glob import glob
  6. def save_full_mask(mask: torch.Tensor, name="mask", out_dir="output",
  7. save_png=True, save_txt=True, save_npy=False,
  8. max_files=100,
  9. save_all_zero=False):
  10. if mask.dim() == 4:
  11. mask = mask[0, 0]
  12. elif mask.dim() == 3:
  13. mask = mask[0]
  14. if not save_all_zero and torch.all(mask == 0):
  15. print(f"? {name} È« 0£¬ÒÑÌø¹ý±£´æ")
  16. return
  17. os.makedirs(out_dir, exist_ok=True)
  18. mask_cpu = mask.detach().cpu()
  19. print(f"\nSaving full mask: {name}")
  20. print(f"shape = {tuple(mask.shape)}, dtype = {mask.dtype}, device = {mask.device}")
  21. pattern = os.path.join(out_dir, f"{name}_*.png")
  22. existing_files = sorted(glob(pattern))
  23. if existing_files:
  24. existing_idxs = [int(os.path.basename(f).split("_")[-1].split(".")[0]) for f in existing_files]
  25. next_idx = max(existing_idxs) + 1
  26. else:
  27. next_idx = 0
  28. if next_idx >= max_files:
  29. next_idx = next_idx % max_files
  30. file_idx_str = f"{next_idx:03d}"
  31. if save_png:
  32. png_path = os.path.join(out_dir, f"{name}_{file_idx_str}.png")
  33. plt.figure()
  34. plt.imshow(mask_cpu, cmap="gray")
  35. plt.colorbar()
  36. plt.title(name)
  37. plt.savefig(png_path, dpi=300)
  38. plt.close()
  39. print(f"? saved PNG -> {png_path}")
  40. if save_txt:
  41. txt_path = os.path.join(out_dir, f"{name}_{file_idx_str}.txt")
  42. np.savetxt(txt_path, mask_cpu.numpy(), fmt="%.3f")
  43. print(f"? saved TXT -> {txt_path}")
  44. if save_npy:
  45. npy_path = os.path.join(out_dir, f"{name}_{file_idx_str}.npy")
  46. np.save(npy_path, mask_cpu.numpy())
  47. print(f"? saved NPY -> {npy_path}")