import os import torch import numpy as np import matplotlib.pyplot as plt from glob import glob def save_full_mask(mask: torch.Tensor, name="mask", out_dir="output", save_png=True, save_txt=True, save_npy=False, max_files=100, save_all_zero=False): if mask.dim() == 4: mask = mask[0, 0] elif mask.dim() == 3: mask = mask[0] if not save_all_zero and torch.all(mask == 0): print(f"? {name} È« 0£¬ÒÑÌø¹ý±£´æ") return os.makedirs(out_dir, exist_ok=True) mask_cpu = mask.detach().cpu() print(f"\nSaving full mask: {name}") print(f"shape = {tuple(mask.shape)}, dtype = {mask.dtype}, device = {mask.device}") pattern = os.path.join(out_dir, f"{name}_*.png") existing_files = sorted(glob(pattern)) if existing_files: existing_idxs = [int(os.path.basename(f).split("_")[-1].split(".")[0]) for f in existing_files] next_idx = max(existing_idxs) + 1 else: next_idx = 0 if next_idx >= max_files: next_idx = next_idx % max_files file_idx_str = f"{next_idx:03d}" if save_png: png_path = os.path.join(out_dir, f"{name}_{file_idx_str}.png") plt.figure() plt.imshow(mask_cpu, cmap="gray") plt.colorbar() plt.title(name) plt.savefig(png_path, dpi=300) plt.close() print(f"? saved PNG -> {png_path}") if save_txt: txt_path = os.path.join(out_dir, f"{name}_{file_idx_str}.txt") np.savetxt(txt_path, mask_cpu.numpy(), fmt="%.3f") print(f"? saved TXT -> {txt_path}") if save_npy: npy_path = os.path.join(out_dir, f"{name}_{file_idx_str}.npy") np.save(npy_path, mask_cpu.numpy()) print(f"? saved NPY -> {npy_path}")