split12.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import os
  2. import shutil
  3. import random
  4. from pathlib import Path
  5. def repartition_yolo_dataset_safe(
  6. dataset_root: str,
  7. train_ratio: float = 0.8,
  8. val_ratio: float = 0.1,
  9. test_ratio: float = 0.1,
  10. seed: int = 42,
  11. image_exts: tuple = ('.jpg', '.jpeg', '.png'),
  12. class_names: list = None,
  13. backup: bool = True
  14. ):
  15. dataset_root = Path(dataset_root)
  16. images_root = dataset_root / 'images'
  17. labels_root = dataset_root / 'labels'
  18. splits = ['train', 'val', 'test']
  19. # 检查原始结构是否存在
  20. for s in splits:
  21. if not (images_root / s).exists() or not (labels_root / s).exists():
  22. raise FileNotFoundError(f"缺少 {s} 目录,请确保数据集结构完整")
  23. # === 1. 备份原始数据 ===
  24. backup_dir = dataset_root / 'backup_original'
  25. if backup and not backup_dir.exists():
  26. print("📁 正在备份原始数据...")
  27. shutil.copytree(images_root, backup_dir / 'images')
  28. shutil.copytree(labels_root, backup_dir / 'labels')
  29. print("✅ 备份完成")
  30. # === 2. 收集所有图片路径(来自 train/val/test)===
  31. all_images = []
  32. for split in splits:
  33. split_img_dir = images_root / split
  34. for f in split_img_dir.iterdir():
  35. if f.is_file() and f.suffix.lower() in image_exts:
  36. all_images.append((split, f)) # 保存 (来源split, 文件路径)
  37. if not all_images:
  38. raise ValueError("未找到任何图片!")
  39. print(f"🔍 找到 {len(all_images)} 张图片")
  40. # === 3. 随机打乱并划分 ===
  41. random.seed(seed)
  42. random.shuffle(all_images)
  43. total = len(all_images)
  44. train_end = int(total * train_ratio)
  45. val_end = train_end + int(total * val_ratio)
  46. new_splits = {
  47. 'train': all_images[:train_end],
  48. 'val': all_images[train_end:val_end],
  49. 'test': all_images[val_end:]
  50. }
  51. # === 4. 创建新的目标目录(加 _new 后缀避免冲突)===
  52. new_images = dataset_root / 'images_new'
  53. new_labels = dataset_root / 'labels_new'
  54. for split in splits:
  55. (new_images / split).mkdir(parents=True, exist_ok=True)
  56. (new_labels / split).mkdir(parents=True, exist_ok=True)
  57. # === 5. 复制文件(从原位置 → 新位置)===
  58. missing_labels = 0
  59. for split_name, files in new_splits.items():
  60. print(f"📦 {split_name}: {len(files)} 张")
  61. for orig_split, img_path in files:
  62. # 复制图片
  63. dst_img = new_images / split_name / img_path.name
  64. shutil.copy2(img_path, dst_img)
  65. # 复制标签(标签在 labels/orig_split/ 下)
  66. label_path = labels_root / orig_split / (img_path.stem + '.txt')
  67. dst_label = new_labels / split_name / (img_path.stem + '.txt')
  68. if label_path.exists():
  69. shutil.copy2(label_path, dst_label)
  70. else:
  71. print(f"⚠️ 缺失标签: {label_path}")
  72. missing_labels += 1
  73. # === 6. 原子性替换:删除旧 images/labels,重命名 new 为正式名 ===
  74. shutil.rmtree(images_root)
  75. shutil.rmtree(labels_root)
  76. shutil.move(str(new_images), str(images_root))
  77. shutil.move(str(new_labels), str(labels_root))
  78. # === 7. 生成 dataset.yaml(可选)===
  79. if class_names:
  80. yaml_content = f"""path: {dataset_root.resolve()}
  81. train: images/train
  82. val: images/val
  83. test: images/test
  84. names: {class_names}
  85. """
  86. with open(dataset_root / 'dataset.yaml', 'w', encoding='utf-8') as f:
  87. f.write(yaml_content)
  88. print(f"\n📄 已生成 dataset.yaml")
  89. print(f"\n✅ 重划分成功!")
  90. print(f"训练: {len(new_splits['train'])}, 验证: {len(new_splits['val'])}, 测试: {len(new_splits['test'])}")
  91. if missing_labels:
  92. print(f"⚠️ 共 {missing_labels} 个标签缺失")
  93. if __name__ == "__main__":
  94. DATASET_ROOT = r"20251210/20251210"
  95. #CLASS_NAMES = ["dust"] # 替换为你的类别
  96. repartition_yolo_dataset_safe(
  97. dataset_root=DATASET_ROOT,
  98. train_ratio=0.8,
  99. val_ratio=0.1,
  100. test_ratio=0.1,
  101. seed=42,
  102. #class_names=CLASS_NAMES,
  103. backup=True
  104. )