| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- import os
- import shutil
- import random
- from pathlib import Path
- def repartition_yolo_dataset_safe(
- dataset_root: str,
- train_ratio: float = 0.8,
- val_ratio: float = 0.1,
- test_ratio: float = 0.1,
- seed: int = 42,
- image_exts: tuple = ('.jpg', '.jpeg', '.png'),
- class_names: list = None,
- backup: bool = True
- ):
- dataset_root = Path(dataset_root)
- images_root = dataset_root / 'images'
- labels_root = dataset_root / 'labels'
- splits = ['train', 'val', 'test']
-
- # 检查原始结构是否存在
- for s in splits:
- if not (images_root / s).exists() or not (labels_root / s).exists():
- raise FileNotFoundError(f"缺少 {s} 目录,请确保数据集结构完整")
- # === 1. 备份原始数据 ===
- backup_dir = dataset_root / 'backup_original'
- if backup and not backup_dir.exists():
- print("📁 正在备份原始数据...")
- shutil.copytree(images_root, backup_dir / 'images')
- shutil.copytree(labels_root, backup_dir / 'labels')
- print("✅ 备份完成")
- # === 2. 收集所有图片路径(来自 train/val/test)===
- all_images = []
- for split in splits:
- split_img_dir = images_root / split
- for f in split_img_dir.iterdir():
- if f.is_file() and f.suffix.lower() in image_exts:
- all_images.append((split, f)) # 保存 (来源split, 文件路径)
- if not all_images:
- raise ValueError("未找到任何图片!")
- print(f"🔍 找到 {len(all_images)} 张图片")
- # === 3. 随机打乱并划分 ===
- random.seed(seed)
- random.shuffle(all_images)
- total = len(all_images)
- train_end = int(total * train_ratio)
- val_end = train_end + int(total * val_ratio)
- new_splits = {
- 'train': all_images[:train_end],
- 'val': all_images[train_end:val_end],
- 'test': all_images[val_end:]
- }
- # === 4. 创建新的目标目录(加 _new 后缀避免冲突)===
- new_images = dataset_root / 'images_new'
- new_labels = dataset_root / 'labels_new'
- for split in splits:
- (new_images / split).mkdir(parents=True, exist_ok=True)
- (new_labels / split).mkdir(parents=True, exist_ok=True)
- # === 5. 复制文件(从原位置 → 新位置)===
- missing_labels = 0
- for split_name, files in new_splits.items():
- print(f"📦 {split_name}: {len(files)} 张")
- for orig_split, img_path in files:
- # 复制图片
- dst_img = new_images / split_name / img_path.name
- shutil.copy2(img_path, dst_img)
- # 复制标签(标签在 labels/orig_split/ 下)
- label_path = labels_root / orig_split / (img_path.stem + '.txt')
- dst_label = new_labels / split_name / (img_path.stem + '.txt')
- if label_path.exists():
- shutil.copy2(label_path, dst_label)
- else:
- print(f"⚠️ 缺失标签: {label_path}")
- missing_labels += 1
- # === 6. 原子性替换:删除旧 images/labels,重命名 new 为正式名 ===
- shutil.rmtree(images_root)
- shutil.rmtree(labels_root)
- shutil.move(str(new_images), str(images_root))
- shutil.move(str(new_labels), str(labels_root))
- # === 7. 生成 dataset.yaml(可选)===
- if class_names:
- yaml_content = f"""path: {dataset_root.resolve()}
- train: images/train
- val: images/val
- test: images/test
- names: {class_names}
- """
- with open(dataset_root / 'dataset.yaml', 'w', encoding='utf-8') as f:
- f.write(yaml_content)
- print(f"\n📄 已生成 dataset.yaml")
- print(f"\n✅ 重划分成功!")
- print(f"训练: {len(new_splits['train'])}, 验证: {len(new_splits['val'])}, 测试: {len(new_splits['test'])}")
- if missing_labels:
- print(f"⚠️ 共 {missing_labels} 个标签缺失")
- if __name__ == "__main__":
- DATASET_ROOT = r"20251210/20251210"
- #CLASS_NAMES = ["dust"] # 替换为你的类别
- repartition_yolo_dataset_safe(
- dataset_root=DATASET_ROOT,
- train_ratio=0.8,
- val_ratio=0.1,
- test_ratio=0.1,
- seed=42,
- #class_names=CLASS_NAMES,
- backup=True
- )
|