|
@@ -4,32 +4,61 @@ import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.pyplot as plt
|
|
|
from glob import glob
|
|
from glob import glob
|
|
|
|
|
|
|
|
-def save_full_mask(mask: torch.Tensor, name="mask", out_dir="output",
|
|
|
|
|
- save_png=True, save_txt=True, save_npy=False,
|
|
|
|
|
- max_files=100,
|
|
|
|
|
- save_all_zero=False):
|
|
|
|
|
-
|
|
|
|
|
|
|
+def save_full_mask(
|
|
|
|
|
+ mask: torch.Tensor,
|
|
|
|
|
+ name="mask",
|
|
|
|
|
+ out_dir="output",
|
|
|
|
|
+ save_png=True,
|
|
|
|
|
+ save_txt=True,
|
|
|
|
|
+ save_npy=False,
|
|
|
|
|
+ max_files=100,
|
|
|
|
|
+ save_all_zero=False,
|
|
|
|
|
+ zero_eps=1e-6,
|
|
|
|
|
+ # ÐÂÔö²ÎÊý
|
|
|
|
|
+ image=None, # PIL.Image »ò numpy(H,W,3)
|
|
|
|
|
+ show_on_image=False, # ÊÇ·ñ±£´æµþ¼ÓµÄͼ
|
|
|
|
|
+ alpha=0.4 # mask͸Ã÷¶È
|
|
|
|
|
+):
|
|
|
|
|
+ """
|
|
|
|
|
+ ±£´æÕûÕÅÌØÕ÷ͼ£¬¿ÉÑ¡ÔñÉú³Éµþ¼ÓÔͼºóµÄ¿ÉÊÓ»¯Îļþ
|
|
|
|
|
+ """
|
|
|
|
|
|
|
|
|
|
+ # È¥µô batch/channel ¶àÓàά¶È
|
|
|
if mask.dim() == 4:
|
|
if mask.dim() == 4:
|
|
|
mask = mask[0, 0]
|
|
mask = mask[0, 0]
|
|
|
elif mask.dim() == 3:
|
|
elif mask.dim() == 3:
|
|
|
mask = mask[0]
|
|
mask = mask[0]
|
|
|
|
|
|
|
|
- if not save_all_zero and torch.all(mask == 0):
|
|
|
|
|
- print(f"? {name} È« 0£¬ÒÑÌø¹ý±£´æ")
|
|
|
|
|
- return
|
|
|
|
|
|
|
+ # ---------------- È« 0 ÅжϺ¯Êý ----------------
|
|
|
|
|
+ def is_all_zero(t: torch.Tensor):
|
|
|
|
|
+ if t.dtype.is_floating_point:
|
|
|
|
|
+ return torch.all(torch.abs(t) < zero_eps)
|
|
|
|
|
+ else:
|
|
|
|
|
+ return torch.all(t == 0)
|
|
|
|
|
+
|
|
|
|
|
+ # È« 0 Çé¿ö´¦Àí
|
|
|
|
|
+ if is_all_zero(mask):
|
|
|
|
|
+ print(f"? {name} È« 0")
|
|
|
|
|
+ if not save_all_zero:
|
|
|
|
|
+ print(f"¡ú δ±£´æ£¨save_all_zero=False£©")
|
|
|
|
|
+ return
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(f"¡ú °´ÉèÖüÌÐø±£´æ£¨save_all_zero=True£©")
|
|
|
|
|
|
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
mask_cpu = mask.detach().cpu()
|
|
mask_cpu = mask.detach().cpu()
|
|
|
|
|
|
|
|
- print(f"\nSaving full mask: {name}")
|
|
|
|
|
- print(f"shape = {tuple(mask.shape)}, dtype = {mask.dtype}, device = {mask.device}")
|
|
|
|
|
-
|
|
|
|
|
|
|
+ # ------------------- ·ÖÅäÐòºÅ -------------------
|
|
|
pattern = os.path.join(out_dir, f"{name}_*.png")
|
|
pattern = os.path.join(out_dir, f"{name}_*.png")
|
|
|
existing_files = sorted(glob(pattern))
|
|
existing_files = sorted(glob(pattern))
|
|
|
if existing_files:
|
|
if existing_files:
|
|
|
- existing_idxs = [int(os.path.basename(f).split("_")[-1].split(".")[0]) for f in existing_files]
|
|
|
|
|
- next_idx = max(existing_idxs) + 1
|
|
|
|
|
|
|
+ existing_idx = []
|
|
|
|
|
+ for f in existing_files:
|
|
|
|
|
+ try:
|
|
|
|
|
+ existing_idx.append(int(os.path.basename(f).split("_")[-1].split(".")[0]))
|
|
|
|
|
+ except:
|
|
|
|
|
+ pass
|
|
|
|
|
+ next_idx = max(existing_idx) + 1
|
|
|
else:
|
|
else:
|
|
|
next_idx = 0
|
|
next_idx = 0
|
|
|
|
|
|
|
@@ -38,6 +67,11 @@ def save_full_mask(mask: torch.Tensor, name="mask", out_dir="output",
|
|
|
|
|
|
|
|
file_idx_str = f"{next_idx:03d}"
|
|
file_idx_str = f"{next_idx:03d}"
|
|
|
|
|
|
|
|
|
|
+ # ==========================================================
|
|
|
|
|
+ # ¢Ù ±£´æ mask ±¾Éí
|
|
|
|
|
+ # ==========================================================
|
|
|
|
|
+
|
|
|
|
|
+ # ¡ª¡ª¡ª ±£´æ PNG ¡ª¡ª¡ª
|
|
|
if save_png:
|
|
if save_png:
|
|
|
png_path = os.path.join(out_dir, f"{name}_{file_idx_str}.png")
|
|
png_path = os.path.join(out_dir, f"{name}_{file_idx_str}.png")
|
|
|
plt.figure()
|
|
plt.figure()
|
|
@@ -46,14 +80,58 @@ def save_full_mask(mask: torch.Tensor, name="mask", out_dir="output",
|
|
|
plt.title(name)
|
|
plt.title(name)
|
|
|
plt.savefig(png_path, dpi=300)
|
|
plt.savefig(png_path, dpi=300)
|
|
|
plt.close()
|
|
plt.close()
|
|
|
- print(f"? saved PNG -> {png_path}")
|
|
|
|
|
|
|
+ print(f"?? saved PNG -> {png_path}")
|
|
|
|
|
|
|
|
|
|
+ # ¡ª¡ª¡ª ±£´æ TXT ¡ª¡ª¡ª
|
|
|
if save_txt:
|
|
if save_txt:
|
|
|
txt_path = os.path.join(out_dir, f"{name}_{file_idx_str}.txt")
|
|
txt_path = os.path.join(out_dir, f"{name}_{file_idx_str}.txt")
|
|
|
np.savetxt(txt_path, mask_cpu.numpy(), fmt="%.3f")
|
|
np.savetxt(txt_path, mask_cpu.numpy(), fmt="%.3f")
|
|
|
- print(f"? saved TXT -> {txt_path}")
|
|
|
|
|
|
|
+ print(f"?? saved TXT -> {txt_path}")
|
|
|
|
|
|
|
|
|
|
+ # ¡ª¡ª¡ª ±£´æ NPY ¡ª¡ª¡ª
|
|
|
if save_npy:
|
|
if save_npy:
|
|
|
npy_path = os.path.join(out_dir, f"{name}_{file_idx_str}.npy")
|
|
npy_path = os.path.join(out_dir, f"{name}_{file_idx_str}.npy")
|
|
|
np.save(npy_path, mask_cpu.numpy())
|
|
np.save(npy_path, mask_cpu.numpy())
|
|
|
- print(f"? saved NPY -> {npy_path}")
|
|
|
|
|
|
|
+ print(f"?? saved NPY -> {npy_path}")
|
|
|
|
|
+
|
|
|
|
|
+ # ==========================================================
|
|
|
|
|
+ # ¢Ú ±£´æµþ¼ÓЧ¹û£¨²»ÏÔʾ£¬Ö»±£´æ£©
|
|
|
|
|
+ # ==========================================================
|
|
|
|
|
+ if show_on_image:
|
|
|
|
|
+ if image is None:
|
|
|
|
|
+ print("? show_on_image=True µ«Î´´«Èë image£¬Ìø¹ýµþ¼Ó±£´æ")
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ # Èç¹ûÈ« 0£¬¾Í²»×öµþ¼Ó±£´æ
|
|
|
|
|
+ if is_all_zero(mask):
|
|
|
|
|
+ print(f"? {name} È« 0 ¡ú ²»Éú³Éµþ¼Ó overlay")
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ # image ¡ú numpy
|
|
|
|
|
+ if hasattr(image, "size"): # PIL.Image
|
|
|
|
|
+ img_np = np.array(image)
|
|
|
|
|
+ else: # numpy array
|
|
|
|
|
+ img_np = image
|
|
|
|
|
+
|
|
|
|
|
+ mask_np = mask_cpu.numpy()
|
|
|
|
|
+ if mask_np.dtype != np.uint8:
|
|
|
|
|
+ mask_np = (mask_np > 0).astype(np.uint8)
|
|
|
|
|
+
|
|
|
|
|
+ # ºìÉ«°ë͸Ã÷µþ¼Ó
|
|
|
|
|
+ colored_mask = np.zeros((*mask_np.shape, 3), dtype=np.uint8)
|
|
|
|
|
+ colored_mask[..., 0] = mask_np * 255 # R channel
|
|
|
|
|
+
|
|
|
|
|
+ overlay_path = os.path.join(out_dir, f"{name}_{file_idx_str}_overlay.png")
|
|
|
|
|
+
|
|
|
|
|
+ plt.figure(figsize=(8, 8))
|
|
|
|
|
+ plt.imshow(img_np)
|
|
|
|
|
+ plt.imshow(colored_mask, alpha=alpha)
|
|
|
|
|
+ plt.title(f"{name} overlay")
|
|
|
|
|
+ plt.title(f"{name} overlay")
|
|
|
|
|
+ plt.axis('off')
|
|
|
|
|
+
|
|
|
|
|
+ # ? ²»µ÷Óà plt.show()£¬Ö±½Ó±£´æ
|
|
|
|
|
+ plt.savefig(overlay_path, dpi=300)
|
|
|
|
|
+ plt.close()
|
|
|
|
|
+
|
|
|
|
|
+ print(f"?? saved overlay -> {overlay_path}")
|