import os import torch import numpy as np import matplotlib.pyplot as plt from glob import glob def save_full_mask( mask: torch.Tensor, name="mask", out_dir="output", save_png=True, save_txt=True, save_npy=False, max_files=100, save_all_zero=False, zero_eps=1e-6, # New parameters image=None, # PIL.Image or numpy array (H,W,3) show_on_image=False, # Whether to save overlay on original image alpha=0.4, # Overlay transparency force_save=False # Force save even if mask is all zeros ): """ Save a full mask tensor safely, optionally with overlay on an image. Supports batch masks (shape: [N,H,W]) and saves each slice separately. Force save if force_save=True even if the mask is all zeros. """ # ---------------- Convert batch masks ---------------- if mask.dim() == 3: # (N, H, W) for idx in range(mask.shape[0]): save_full_mask( mask[idx], name=f"{name}_{idx}", out_dir=out_dir, save_png=save_png, save_txt=save_txt, save_npy=save_npy, max_files=max_files, save_all_zero=save_all_zero, zero_eps=zero_eps, image=image, show_on_image=show_on_image, alpha=alpha, force_save=force_save ) return elif mask.dim() > 3: # Flatten batch/channel dimensions safely, take first image mask = mask.view(-1, mask.shape[-2], mask.shape[-1])[0] mask_cpu = mask.detach().cpu() # ---------------- Check if mask is all zeros ---------------- def is_all_zero(t: torch.Tensor): if t.dtype.is_floating_point: return torch.all(torch.abs(t) < zero_eps) else: return torch.all(t == 0) all_zero = is_all_zero(mask_cpu) if all_zero and not save_all_zero and not force_save: print(f"{name} all zeros ¡ú Not saved") return elif all_zero and force_save: print(f"{name} all zeros ¡ú force saving") # ---------------- Create output directory ---------------- os.makedirs(out_dir, exist_ok=True) # ---------------- Determine file index ---------------- pattern = os.path.join(out_dir, f"{name}_*.png") existing_files = sorted(glob(pattern)) if existing_files: existing_idx = [] for f in existing_files: try: existing_idx.append(int(os.path.basename(f).split("_")[-1].split(".")[0])) except: pass next_idx = max(existing_idx) + 1 else: next_idx = 0 if next_idx >= max_files: next_idx = next_idx % max_files file_idx_str = f"{next_idx:03d}" # ---------------- Save mask itself ---------------- if save_png: png_path = os.path.join(out_dir, f"{name}_{file_idx_str}.png") plt.figure() plt.imshow(mask_cpu, cmap="gray") plt.colorbar() plt.title(name) plt.savefig(png_path, dpi=300) plt.close() print(f"Saved PNG -> {png_path}") if save_txt: txt_path = os.path.join(out_dir, f"{name}_{file_idx_str}.txt") np.savetxt(txt_path, mask_cpu.numpy(), fmt="%.3f") print(f"Saved TXT -> {txt_path}") if save_npy: npy_path = os.path.join(out_dir, f"{name}_{file_idx_str}.npy") np.save(npy_path, mask_cpu.numpy()) print(f"Saved NPY -> {npy_path}") # ---------------- Save overlay if requested ---------------- if show_on_image: if image is None: print("show_on_image=True but image not provided, skipping overlay") return if all_zero and not force_save: print(f"{name} all zeros ¡ú skip overlay") return # Convert image to numpy array if hasattr(image, "size"): # PIL.Image img_np = np.array(image) else: # numpy array img_np = image mask_np = mask_cpu.numpy() if mask_np.dtype != np.uint8: mask_np = (mask_np > 0).astype(np.uint8) # Red overlay colored_mask = np.zeros((*mask_np.shape, 3), dtype=np.uint8) colored_mask[..., 0] = mask_np * 255 # Red channel overlay_path = os.path.join(out_dir, f"{name}_{file_idx_str}_overlay.png") plt.figure(figsize=(8, 8)) plt.imshow(img_np) plt.imshow(colored_mask, alpha=alpha) plt.title(f"{name} overlay") plt.axis('off') plt.savefig(overlay_path, dpi=300) plt.close() print(f"Saved overlay -> {overlay_path}")