import os import torch import numpy as np import matplotlib.pyplot as plt from glob import glob def save_full_mask( mask: torch.Tensor, name="mask", out_dir="output", save_png=True, save_txt=True, save_npy=False, max_files=100, save_all_zero=False, zero_eps=1e-6, # ÐÂÔö²ÎÊý image=None, # PIL.Image »ò numpy(H,W,3) show_on_image=False, # ÊÇ·ñ±£´æµþ¼ÓµÄͼ alpha=0.4 # mask͸Ã÷¶È ): """ ±£´æÕûÕÅÌØÕ÷ͼ£¬¿ÉÑ¡ÔñÉú³Éµþ¼ÓԭͼºóµÄ¿ÉÊÓ»¯Îļþ """ # È¥µô batch/channel ¶àÓàά¶È if mask.dim() == 4: mask = mask[0, 0] elif mask.dim() == 3: mask = mask[0] # ---------------- È« 0 ÅжϺ¯Êý ---------------- def is_all_zero(t: torch.Tensor): if t.dtype.is_floating_point: return torch.all(torch.abs(t) < zero_eps) else: return torch.all(t == 0) # È« 0 Çé¿ö´¦Àí if is_all_zero(mask): print(f"? {name} È« 0") if not save_all_zero: print(f"¡ú δ±£´æ£¨save_all_zero=False£©") return else: print(f"¡ú °´ÉèÖüÌÐø±£´æ£¨save_all_zero=True£©") os.makedirs(out_dir, exist_ok=True) mask_cpu = mask.detach().cpu() # ------------------- ·ÖÅäÐòºÅ ------------------- pattern = os.path.join(out_dir, f"{name}_*.png") existing_files = sorted(glob(pattern)) if existing_files: existing_idx = [] for f in existing_files: try: existing_idx.append(int(os.path.basename(f).split("_")[-1].split(".")[0])) except: pass next_idx = max(existing_idx) + 1 else: next_idx = 0 if next_idx >= max_files: next_idx = next_idx % max_files file_idx_str = f"{next_idx:03d}" # ========================================================== # ¢Ù ±£´æ mask ±¾Éí # ========================================================== # ¡ª¡ª¡ª ±£´æ PNG ¡ª¡ª¡ª if save_png: png_path = os.path.join(out_dir, f"{name}_{file_idx_str}.png") plt.figure() plt.imshow(mask_cpu, cmap="gray") plt.colorbar() plt.title(name) plt.savefig(png_path, dpi=300) plt.close() print(f"?? saved PNG -> {png_path}") # ¡ª¡ª¡ª ±£´æ TXT ¡ª¡ª¡ª if save_txt: txt_path = os.path.join(out_dir, f"{name}_{file_idx_str}.txt") np.savetxt(txt_path, mask_cpu.numpy(), fmt="%.3f") print(f"?? saved TXT -> {txt_path}") # ¡ª¡ª¡ª ±£´æ NPY ¡ª¡ª¡ª if save_npy: npy_path = os.path.join(out_dir, f"{name}_{file_idx_str}.npy") np.save(npy_path, mask_cpu.numpy()) print(f"?? saved NPY -> {npy_path}") # ========================================================== # ¢Ú ±£´æµþ¼ÓЧ¹û£¨²»ÏÔʾ£¬Ö»±£´æ£© # ========================================================== if show_on_image: if image is None: print("? show_on_image=True µ«Î´´«Èë image£¬Ìø¹ýµþ¼Ó±£´æ") return # Èç¹ûÈ« 0£¬¾Í²»×öµþ¼Ó±£´æ if is_all_zero(mask): print(f"? {name} È« 0 ¡ú ²»Éú³Éµþ¼Ó overlay") return # image ¡ú numpy if hasattr(image, "size"): # PIL.Image img_np = np.array(image) else: # numpy array img_np = image mask_np = mask_cpu.numpy() if mask_np.dtype != np.uint8: mask_np = (mask_np > 0).astype(np.uint8) # ºìÉ«°ë͸Ã÷µþ¼Ó colored_mask = np.zeros((*mask_np.shape, 3), dtype=np.uint8) colored_mask[..., 0] = mask_np * 255 # R channel overlay_path = os.path.join(out_dir, f"{name}_{file_idx_str}_overlay.png") plt.figure(figsize=(8, 8)) plt.imshow(img_np) plt.imshow(colored_mask, alpha=alpha) plt.title(f"{name} overlay") plt.title(f"{name} overlay") plt.axis('off') # ? ²»µ÷Óà plt.show()£¬Ö±½Ó±£´æ plt.savefig(overlay_path, dpi=300) plt.close() print(f"?? saved overlay -> {overlay_path}")