| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- import os
- import torch
- import numpy as np
- import matplotlib.pyplot as plt
- from glob import glob
- def save_full_mask(
- mask: torch.Tensor,
- name="mask",
- out_dir="output",
- save_png=True,
- save_txt=True,
- save_npy=False,
- max_files=100,
- save_all_zero=False,
- zero_eps=1e-6,
- # New parameters
- image=None, # PIL.Image or numpy array (H,W,3)
- show_on_image=False, # Whether to save overlay on original image
- alpha=0.4, # Overlay transparency
- force_save=False # Force save even if mask is all zeros
- ):
- """
- Save a full mask tensor safely, optionally with overlay on an image.
- Supports batch masks (shape: [N,H,W]) and saves each slice separately.
- Force save if force_save=True even if the mask is all zeros.
- """
- # ---------------- Convert batch masks ----------------
- if mask.dim() == 3: # (N, H, W)
- for idx in range(mask.shape[0]):
- save_full_mask(
- mask[idx],
- name=f"{name}_{idx}",
- out_dir=out_dir,
- save_png=save_png,
- save_txt=save_txt,
- save_npy=save_npy,
- max_files=max_files,
- save_all_zero=save_all_zero,
- zero_eps=zero_eps,
- image=image,
- show_on_image=show_on_image,
- alpha=alpha,
- force_save=force_save
- )
- return
- elif mask.dim() > 3:
- # Flatten batch/channel dimensions safely, take first image
- mask = mask.view(-1, mask.shape[-2], mask.shape[-1])[0]
- mask_cpu = mask.detach().cpu()
- # ---------------- Check if mask is all zeros ----------------
- def is_all_zero(t: torch.Tensor):
- if t.dtype.is_floating_point:
- return torch.all(torch.abs(t) < zero_eps)
- else:
- return torch.all(t == 0)
- all_zero = is_all_zero(mask_cpu)
- if all_zero and not save_all_zero and not force_save:
- print(f"{name} all zeros ¡ú Not saved")
- return
- elif all_zero and force_save:
- print(f"{name} all zeros ¡ú force saving")
- # ---------------- Create output directory ----------------
- os.makedirs(out_dir, exist_ok=True)
- # ---------------- Determine file index ----------------
- pattern = os.path.join(out_dir, f"{name}_*.png")
- existing_files = sorted(glob(pattern))
- if existing_files:
- existing_idx = []
- for f in existing_files:
- try:
- existing_idx.append(int(os.path.basename(f).split("_")[-1].split(".")[0]))
- except:
- pass
- next_idx = max(existing_idx) + 1
- else:
- next_idx = 0
- if next_idx >= max_files:
- next_idx = next_idx % max_files
- file_idx_str = f"{next_idx:03d}"
- # ---------------- Save mask itself ----------------
- if save_png:
- png_path = os.path.join(out_dir, f"{name}_{file_idx_str}.png")
- plt.figure()
- plt.imshow(mask_cpu, cmap="gray")
- plt.colorbar()
- plt.title(name)
- plt.savefig(png_path, dpi=300)
- plt.close()
- print(f"Saved PNG -> {png_path}")
- if save_txt:
- txt_path = os.path.join(out_dir, f"{name}_{file_idx_str}.txt")
- np.savetxt(txt_path, mask_cpu.numpy(), fmt="%.3f")
- print(f"Saved TXT -> {txt_path}")
- if save_npy:
- npy_path = os.path.join(out_dir, f"{name}_{file_idx_str}.npy")
- np.save(npy_path, mask_cpu.numpy())
- print(f"Saved NPY -> {npy_path}")
- # ---------------- Save overlay if requested ----------------
- if show_on_image:
- if image is None:
- print("show_on_image=True but image not provided, skipping overlay")
- return
- if all_zero and not force_save:
- print(f"{name} all zeros ¡ú skip overlay")
- return
- # Convert image to numpy array
- if hasattr(image, "size"): # PIL.Image
- img_np = np.array(image)
- else: # numpy array
- img_np = image
- mask_np = mask_cpu.numpy()
- if mask_np.dtype != np.uint8:
- mask_np = (mask_np > 0).astype(np.uint8)
- # Red overlay
- colored_mask = np.zeros((*mask_np.shape, 3), dtype=np.uint8)
- colored_mask[..., 0] = mask_np * 255 # Red channel
- overlay_path = os.path.join(out_dir, f"{name}_{file_idx_str}_overlay.png")
- plt.figure(figsize=(8, 8))
- plt.imshow(img_np)
- plt.imshow(colored_mask, alpha=alpha)
- plt.title(f"{name} overlay")
- plt.axis('off')
- plt.savefig(overlay_path, dpi=300)
- plt.close()
- print(f"Saved overlay -> {overlay_path}")
|