| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- import os
- import torch
- import numpy as np
- import matplotlib.pyplot as plt
- from glob import glob
- def save_full_mask(mask: torch.Tensor, name="mask", out_dir="output",
- save_png=True, save_txt=True, save_npy=False,
- max_files=100,
- save_all_zero=False):
- if mask.dim() == 4:
- mask = mask[0, 0]
- elif mask.dim() == 3:
- mask = mask[0]
- if not save_all_zero and torch.all(mask == 0):
- print(f"? {name} È« 0£¬ÒÑÌø¹ý±£´æ")
- return
- os.makedirs(out_dir, exist_ok=True)
- mask_cpu = mask.detach().cpu()
- print(f"\nSaving full mask: {name}")
- print(f"shape = {tuple(mask.shape)}, dtype = {mask.dtype}, device = {mask.device}")
- pattern = os.path.join(out_dir, f"{name}_*.png")
- existing_files = sorted(glob(pattern))
- if existing_files:
- existing_idxs = [int(os.path.basename(f).split("_")[-1].split(".")[0]) for f in existing_files]
- next_idx = max(existing_idxs) + 1
- else:
- next_idx = 0
- if next_idx >= max_files:
- next_idx = next_idx % max_files
- file_idx_str = f"{next_idx:03d}"
- if save_png:
- png_path = os.path.join(out_dir, f"{name}_{file_idx_str}.png")
- plt.figure()
- plt.imshow(mask_cpu, cmap="gray")
- plt.colorbar()
- plt.title(name)
- plt.savefig(png_path, dpi=300)
- plt.close()
- print(f"? saved PNG -> {png_path}")
- if save_txt:
- txt_path = os.path.join(out_dir, f"{name}_{file_idx_str}.txt")
- np.savetxt(txt_path, mask_cpu.numpy(), fmt="%.3f")
- print(f"? saved TXT -> {txt_path}")
- if save_npy:
- npy_path = os.path.join(out_dir, f"{name}_{file_idx_str}.npy")
- np.save(npy_path, mask_cpu.numpy())
- print(f"? saved NPY -> {npy_path}")
|