| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- """
- 该脚本将源目录下同名的 .json 和 .diff 文件配对,
- 并按如下结构将其整理至目标目录:
- 输出结构如下:
- output_dir/
- ├── images/
- │ ├── train/
- │ └── val/
- └── labels/
- ├── train/
- └── val/
- 其中:
- - images/train 和 val 存放的是 `.diff` 文件
- - labels/train 和 val 存放的是 `.json` 文件
- """
- import os
- import shutil
- import random
- from pathlib import Path
- def organize_data(
- src_dir,
- dest_dir,
- image_extensions=['.tiff'],
- label_extensions=['.json'],
- val_ratio=0.2
- ):
- src_dir = Path(src_dir)
- dest_dir = Path(dest_dir)
- image_dir = dest_dir / 'images'
- label_dir = dest_dir / 'labels'
- # 创建文件夹结构
- for split in ['train', 'val']:
- (image_dir / split).mkdir(parents=True, exist_ok=True)
- (label_dir / split).mkdir(parents=True, exist_ok=True)
- # 获取所有文件
- files = list(src_dir.glob('*'))
- name_to_files = {}
- # 分组:同名文件归为一组
- for f in files:
- stem = f.stem
- name_to_files.setdefault(stem, []).append(f)
- # 筛选出同时包含 label 和 image 的样本
- paired_samples = []
- for name, file_group in name_to_files.items():
- image_file = next((f for f in file_group if f.suffix in image_extensions), None)
- label_file = next((f for f in file_group if f.suffix in label_extensions), None)
- if image_file and label_file:
- paired_samples.append((image_file, label_file))
- else:
- print(f"⚠️ Skipping unpaired files for: {name}")
- # 打乱并划分
- random.shuffle(paired_samples)
- split_idx = int(len(paired_samples) * (1 - val_ratio))
- train_samples = paired_samples[:split_idx]
- val_samples = paired_samples[split_idx:]
- # 拷贝函数
- def copy_samples(samples, split):
- for img, lbl in samples:
- shutil.copy(img, image_dir / split / img.name)
- shutil.copy(lbl, label_dir / split / lbl.name)
- # 执行拷贝
- copy_samples(train_samples, 'train')
- copy_samples(val_samples, 'val')
- print(f"\n✅ Done! Processed {len(paired_samples)} pairs.")
- print(f"Train: {len(train_samples)}, Val: {len(val_samples)}")
- if __name__ == "__main__":
- # 输入输出目录(可修改)
- source_dir = r"/home/zhaoyinghan/py_ws/data/circle/huayan"
- parent_dir = os.path.dirname(source_dir)
- output_dir = os.path.join(parent_dir, "a_dataset")
- # 后缀名列表,方便以后扩展其他格式
- image_exts = ['.tiff','.jpg','.png']
- label_exts = ['.json']
- organize_data(source_dir, output_dir, image_exts, label_exts)
|