show_mask.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import os
  2. import torch
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from glob import glob
  6. def save_full_mask(
  7. mask: torch.Tensor,
  8. name="mask",
  9. out_dir="output",
  10. save_png=True,
  11. save_txt=True,
  12. save_npy=False,
  13. max_files=100,
  14. save_all_zero=False,
  15. zero_eps=1e-6,
  16. # New parameters
  17. image=None, # PIL.Image or numpy array (H,W,3)
  18. show_on_image=False, # Whether to save overlay on original image
  19. alpha=0.4, # Overlay transparency
  20. force_save=False # Force save even if mask is all zeros
  21. ):
  22. """
  23. Save a full mask tensor safely, optionally with overlay on an image.
  24. Supports batch masks (shape: [N,H,W]) and saves each slice separately.
  25. Force save if force_save=True even if the mask is all zeros.
  26. """
  27. # ---------------- Convert batch masks ----------------
  28. if mask.dim() == 3: # (N, H, W)
  29. for idx in range(mask.shape[0]):
  30. save_full_mask(
  31. mask[idx],
  32. name=f"{name}_{idx}",
  33. out_dir=out_dir,
  34. save_png=save_png,
  35. save_txt=save_txt,
  36. save_npy=save_npy,
  37. max_files=max_files,
  38. save_all_zero=save_all_zero,
  39. zero_eps=zero_eps,
  40. image=image,
  41. show_on_image=show_on_image,
  42. alpha=alpha,
  43. force_save=force_save
  44. )
  45. return
  46. elif mask.dim() > 3:
  47. # Flatten batch/channel dimensions safely, take first image
  48. mask = mask.view(-1, mask.shape[-2], mask.shape[-1])[0]
  49. mask_cpu = mask.detach().cpu()
  50. # ---------------- Check if mask is all zeros ----------------
  51. def is_all_zero(t: torch.Tensor):
  52. if t.dtype.is_floating_point:
  53. return torch.all(torch.abs(t) < zero_eps)
  54. else:
  55. return torch.all(t == 0)
  56. all_zero = is_all_zero(mask_cpu)
  57. if all_zero and not save_all_zero and not force_save:
  58. print(f"{name} all zeros ¡ú Not saved")
  59. return
  60. elif all_zero and force_save:
  61. print(f"{name} all zeros ¡ú force saving")
  62. # ---------------- Create output directory ----------------
  63. os.makedirs(out_dir, exist_ok=True)
  64. # ---------------- Determine file index ----------------
  65. pattern = os.path.join(out_dir, f"{name}_*.png")
  66. existing_files = sorted(glob(pattern))
  67. if existing_files:
  68. existing_idx = []
  69. for f in existing_files:
  70. try:
  71. existing_idx.append(int(os.path.basename(f).split("_")[-1].split(".")[0]))
  72. except:
  73. pass
  74. next_idx = max(existing_idx) + 1
  75. else:
  76. next_idx = 0
  77. if next_idx >= max_files:
  78. next_idx = next_idx % max_files
  79. file_idx_str = f"{next_idx:03d}"
  80. # ---------------- Save mask itself ----------------
  81. if save_png:
  82. png_path = os.path.join(out_dir, f"{name}_{file_idx_str}.png")
  83. plt.figure()
  84. plt.imshow(mask_cpu, cmap="gray")
  85. plt.colorbar()
  86. plt.title(name)
  87. plt.savefig(png_path, dpi=300)
  88. plt.close()
  89. print(f"Saved PNG -> {png_path}")
  90. if save_txt:
  91. txt_path = os.path.join(out_dir, f"{name}_{file_idx_str}.txt")
  92. np.savetxt(txt_path, mask_cpu.numpy(), fmt="%.3f")
  93. print(f"Saved TXT -> {txt_path}")
  94. if save_npy:
  95. npy_path = os.path.join(out_dir, f"{name}_{file_idx_str}.npy")
  96. np.save(npy_path, mask_cpu.numpy())
  97. print(f"Saved NPY -> {npy_path}")
  98. # ---------------- Save overlay if requested ----------------
  99. if show_on_image:
  100. if image is None:
  101. print("show_on_image=True but image not provided, skipping overlay")
  102. return
  103. if all_zero and not force_save:
  104. print(f"{name} all zeros ¡ú skip overlay")
  105. return
  106. # Convert image to numpy array
  107. if hasattr(image, "size"): # PIL.Image
  108. img_np = np.array(image)
  109. else: # numpy array
  110. img_np = image
  111. mask_np = mask_cpu.numpy()
  112. if mask_np.dtype != np.uint8:
  113. mask_np = (mask_np > 0).astype(np.uint8)
  114. # Red overlay
  115. colored_mask = np.zeros((*mask_np.shape, 3), dtype=np.uint8)
  116. colored_mask[..., 0] = mask_np * 255 # Red channel
  117. overlay_path = os.path.join(out_dir, f"{name}_{file_idx_str}_overlay.png")
  118. plt.figure(figsize=(8, 8))
  119. plt.imshow(img_np)
  120. plt.imshow(colored_mask, alpha=alpha)
  121. plt.title(f"{name} overlay")
  122. plt.axis('off')
  123. plt.savefig(overlay_path, dpi=300)
  124. plt.close()
  125. print(f"Saved overlay -> {overlay_path}")