import os import shutil import random from pathlib import Path def repartition_yolo_dataset_safe( dataset_root: str, train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, seed: int = 42, image_exts: tuple = ('.jpg', '.jpeg', '.png'), class_names: list = None, backup: bool = True ): dataset_root = Path(dataset_root) images_root = dataset_root / 'images' labels_root = dataset_root / 'labels' splits = ['train', 'val', 'test'] # 检查原始结构是否存在 for s in splits: if not (images_root / s).exists() or not (labels_root / s).exists(): raise FileNotFoundError(f"缺少 {s} 目录,请确保数据集结构完整") # === 1. 备份原始数据 === backup_dir = dataset_root / 'backup_original' if backup and not backup_dir.exists(): print("📁 正在备份原始数据...") shutil.copytree(images_root, backup_dir / 'images') shutil.copytree(labels_root, backup_dir / 'labels') print("✅ 备份完成") # === 2. 收集所有图片路径(来自 train/val/test)=== all_images = [] for split in splits: split_img_dir = images_root / split for f in split_img_dir.iterdir(): if f.is_file() and f.suffix.lower() in image_exts: all_images.append((split, f)) # 保存 (来源split, 文件路径) if not all_images: raise ValueError("未找到任何图片!") print(f"🔍 找到 {len(all_images)} 张图片") # === 3. 随机打乱并划分 === random.seed(seed) random.shuffle(all_images) total = len(all_images) train_end = int(total * train_ratio) val_end = train_end + int(total * val_ratio) new_splits = { 'train': all_images[:train_end], 'val': all_images[train_end:val_end], 'test': all_images[val_end:] } # === 4. 创建新的目标目录(加 _new 后缀避免冲突)=== new_images = dataset_root / 'images_new' new_labels = dataset_root / 'labels_new' for split in splits: (new_images / split).mkdir(parents=True, exist_ok=True) (new_labels / split).mkdir(parents=True, exist_ok=True) # === 5. 复制文件(从原位置 → 新位置)=== missing_labels = 0 for split_name, files in new_splits.items(): print(f"📦 {split_name}: {len(files)} 张") for orig_split, img_path in files: # 复制图片 dst_img = new_images / split_name / img_path.name shutil.copy2(img_path, dst_img) # 复制标签(标签在 labels/orig_split/ 下) label_path = labels_root / orig_split / (img_path.stem + '.txt') dst_label = new_labels / split_name / (img_path.stem + '.txt') if label_path.exists(): shutil.copy2(label_path, dst_label) else: print(f"⚠️ 缺失标签: {label_path}") missing_labels += 1 # === 6. 原子性替换:删除旧 images/labels,重命名 new 为正式名 === shutil.rmtree(images_root) shutil.rmtree(labels_root) shutil.move(str(new_images), str(images_root)) shutil.move(str(new_labels), str(labels_root)) # === 7. 生成 dataset.yaml(可选)=== if class_names: yaml_content = f"""path: {dataset_root.resolve()} train: images/train val: images/val test: images/test names: {class_names} """ with open(dataset_root / 'dataset.yaml', 'w', encoding='utf-8') as f: f.write(yaml_content) print(f"\n📄 已生成 dataset.yaml") print(f"\n✅ 重划分成功!") print(f"训练: {len(new_splits['train'])}, 验证: {len(new_splits['val'])}, 测试: {len(new_splits['test'])}") if missing_labels: print(f"⚠️ 共 {missing_labels} 个标签缺失") if __name__ == "__main__": DATASET_ROOT = r"20251210/20251210" #CLASS_NAMES = ["dust"] # 替换为你的类别 repartition_yolo_dataset_safe( dataset_root=DATASET_ROOT, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42, #class_names=CLASS_NAMES, backup=True )