""" 该脚本将源目录下同名的 .json 和 .diff 文件配对, 并按如下结构将其整理至目标目录: 输出结构如下: output_dir/ ├── images/ │ ├── train/ │ └── val/ └── labels/ ├── train/ └── val/ 其中: - images/train 和 val 存放的是 `.diff` 文件 - labels/train 和 val 存放的是 `.json` 文件 """ import os import shutil import random from pathlib import Path def organize_data( src_dir, dest_dir, image_extensions=['.tiff'], label_extensions=['.json'], val_ratio=0.2 ): src_dir = Path(src_dir) dest_dir = Path(dest_dir) image_dir = dest_dir / 'images' label_dir = dest_dir / 'labels' # 创建文件夹结构 for split in ['train', 'val']: (image_dir / split).mkdir(parents=True, exist_ok=True) (label_dir / split).mkdir(parents=True, exist_ok=True) # 获取所有文件 files = list(src_dir.glob('*')) name_to_files = {} # 分组:同名文件归为一组 for f in files: stem = f.stem name_to_files.setdefault(stem, []).append(f) # 筛选出同时包含 label 和 image 的样本 paired_samples = [] for name, file_group in name_to_files.items(): image_file = next((f for f in file_group if f.suffix in image_extensions), None) label_file = next((f for f in file_group if f.suffix in label_extensions), None) if image_file and label_file: paired_samples.append((image_file, label_file)) else: print(f"⚠️ Skipping unpaired files for: {name}") # 打乱并划分 random.shuffle(paired_samples) split_idx = int(len(paired_samples) * (1 - val_ratio)) train_samples = paired_samples[:split_idx] val_samples = paired_samples[split_idx:] # 拷贝函数 def copy_samples(samples, split): for img, lbl in samples: shutil.copy(img, image_dir / split / img.name) shutil.copy(lbl, label_dir / split / lbl.name) # 执行拷贝 copy_samples(train_samples, 'train') copy_samples(val_samples, 'val') print(f"\n✅ Done! Processed {len(paired_samples)} pairs.") print(f"Train: {len(train_samples)}, Val: {len(val_samples)}") if __name__ == "__main__": # 输入输出目录(可修改) source_dir = r"/home/zhaoyinghan/py_ws/data/circle/huayan" parent_dir = os.path.dirname(source_dir) output_dir = os.path.join(parent_dir, "a_dataset") # 后缀名列表,方便以后扩展其他格式 image_exts = ['.tiff','.jpg','.png'] label_exts = ['.json'] organize_data(source_dir, output_dir, image_exts, label_exts)