split_dataset.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 数据集划分脚本
  5. 将datasets_merged中的图片和标签按照8:1:1比例划分为train/test/val数据集
  6. """
  7. import os
  8. import shutil
  9. import random
  10. from pathlib import Path
  11. import argparse
  12. from typing import List, Tuple
  13. def get_paired_files(source_dir: str) -> List[Tuple[str, str]]:
  14. """
  15. 获取配对的图片和标签文件
  16. Args:
  17. source_dir: 源数据目录
  18. Returns:
  19. 配对文件列表 [(image_path, label_path), ...]
  20. """
  21. source_path = Path(source_dir)
  22. # 获取所有图片文件
  23. image_files = list(source_path.glob("*.jpg"))
  24. paired_files = []
  25. for img_file in image_files:
  26. # 构造对应的标签文件路径
  27. label_file = source_path / f"{img_file.stem}.txt"
  28. if label_file.exists():
  29. paired_files.append((str(img_file), str(label_file)))
  30. else:
  31. print(f"警告: 图片 {img_file.name} 没有对应的标签文件")
  32. return paired_files
  33. def create_directory_structure(output_dir: str):
  34. """
  35. 创建YOLO数据集目录结构
  36. Args:
  37. output_dir: 输出目录
  38. """
  39. output_path = Path(output_dir)
  40. # 创建主要目录结构
  41. directories = [
  42. "images/train",
  43. "images/test",
  44. "images/val",
  45. "labels/train",
  46. "labels/test",
  47. "labels/val"
  48. ]
  49. for directory in directories:
  50. dir_path = output_path / directory
  51. dir_path.mkdir(parents=True, exist_ok=True)
  52. print(f"创建目录: {dir_path}")
  53. def split_dataset(paired_files: List[Tuple[str, str]],
  54. train_ratio: float = 0.8,
  55. test_ratio: float = 0.1,
  56. val_ratio: float = 0.1) -> Tuple[List, List, List]:
  57. """
  58. 按比例划分数据集
  59. Args:
  60. paired_files: 配对文件列表
  61. train_ratio: 训练集比例
  62. test_ratio: 测试集比例
  63. val_ratio: 验证集比例
  64. Returns:
  65. (train_files, test_files, val_files)
  66. """
  67. # 验证比例
  68. total_ratio = train_ratio + test_ratio + val_ratio
  69. if abs(total_ratio - 1.0) > 1e-6:
  70. raise ValueError(f"比例总和必须为1.0,当前为: {total_ratio}")
  71. # 随机打乱数据
  72. random.shuffle(paired_files)
  73. total_files = len(paired_files)
  74. train_count = int(total_files * train_ratio)
  75. test_count = int(total_files * test_ratio)
  76. # 划分数据集
  77. train_files = paired_files[:train_count]
  78. test_files = paired_files[train_count:train_count + test_count]
  79. val_files = paired_files[train_count + test_count:]
  80. print(f"数据集划分统计:")
  81. print(f" 总文件数: {total_files}")
  82. print(f" 训练集: {len(train_files)} ({len(train_files)/total_files*100:.1f}%)")
  83. print(f" 测试集: {len(test_files)} ({len(test_files)/total_files*100:.1f}%)")
  84. print(f" 验证集: {len(val_files)} ({len(val_files)/total_files*100:.1f}%)")
  85. return train_files, test_files, val_files
  86. def copy_files(file_list: List[Tuple[str, str]],
  87. output_dir: str,
  88. subset_name: str):
  89. """
  90. 复制文件到目标目录
  91. Args:
  92. file_list: 文件列表
  93. output_dir: 输出目录
  94. subset_name: 子集名称 (train/test/val)
  95. """
  96. output_path = Path(output_dir)
  97. for img_path, label_path in file_list:
  98. img_file = Path(img_path)
  99. label_file = Path(label_path)
  100. # 目标路径
  101. target_img_dir = output_path / "images" / subset_name
  102. target_label_dir = output_path / "labels" / subset_name
  103. target_img_path = target_img_dir / img_file.name
  104. target_label_path = target_label_dir / label_file.name
  105. # 复制文件
  106. try:
  107. shutil.copy2(img_path, target_img_path)
  108. shutil.copy2(label_path, target_label_path)
  109. except Exception as e:
  110. print(f"复制文件失败: {img_file.name} - {e}")
  111. def _read_class_names(classes_file: str):
  112. """读取类别文件,支持以下格式:
  113. - 每行一个类名:"fire"
  114. - 带显式ID:"0 fire" 或 "fire 0" 或 "0,fire"
  115. - 行内注释:以#开始的内容忽略
  116. 返回:仅类名组成的列表
  117. """
  118. names = []
  119. with open(classes_file, 'r', encoding='utf-8') as f:
  120. for raw in f:
  121. line = raw.strip()
  122. if not line:
  123. continue
  124. # 去除行内注释
  125. if '#' in line:
  126. line = line.split('#', 1)[0].strip()
  127. if not line:
  128. continue
  129. # 逗号分隔:"id,name"
  130. if ',' in line:
  131. parts = [p.strip() for p in line.split(',') if p.strip()]
  132. if len(parts) == 2:
  133. if parts[0].isdigit():
  134. names.append(parts[1])
  135. continue
  136. tokens = [t for t in line.split() if t]
  137. if not tokens:
  138. continue
  139. if tokens[0].isdigit():
  140. # "ID 类名(可能包含空格)"
  141. names.append(' '.join(tokens[1:]))
  142. elif tokens[-1].isdigit():
  143. # "类名(可能包含空格) ID"
  144. names.append(' '.join(tokens[:-1]))
  145. else:
  146. names.append(line)
  147. return names
  148. def create_dataset_yaml(output_dir: str, classes_file: str = None):
  149. """
  150. 创建YOLO数据集配置文件
  151. Args:
  152. output_dir: 输出目录
  153. classes_file: 类别文件路径
  154. """
  155. output_path = Path(output_dir)
  156. yaml_path = output_path / "dataset.yaml"
  157. # 读取类别信息
  158. if classes_file and os.path.exists(classes_file):
  159. classes = _read_class_names(classes_file)
  160. else:
  161. # 默认类别(若未提供classes文件)
  162. classes = ['fire', 'dust', 'move_machine', 'open_machine', 'close_machine']
  163. # 生成YAML内容
  164. yaml_content = f"""# YOLO数据集配置文件
  165. # 数据集路径 (相对于此文件的路径)
  166. path: {output_path.absolute()}
  167. train: images/train
  168. val: images/val
  169. test: images/test
  170. # 类别数量
  171. nc: {len(classes)}
  172. # 类别名称
  173. names: {classes}
  174. """
  175. with open(yaml_path, 'w', encoding='utf-8') as f:
  176. f.write(yaml_content)
  177. print(f"创建数据集配置文件: {yaml_path}")
  178. def main():
  179. parser = argparse.ArgumentParser(description='YOLO数据集划分工具')
  180. parser.add_argument('source_dir', help='源数据目录路径')
  181. parser.add_argument('-o', '--output', default='./yolo_dataset',
  182. help='输出目录路径 (默认: ./yolo_dataset)')
  183. parser.add_argument('-c', '--classes', help='类别文件路径')
  184. parser.add_argument('--train-ratio', type=float, default=0.8,
  185. help='训练集比例 (默认: 0.8)')
  186. parser.add_argument('--test-ratio', type=float, default=0.1,
  187. help='测试集比例 (默认: 0.1)')
  188. parser.add_argument('--val-ratio', type=float, default=0.1,
  189. help='验证集比例 (默认: 0.1)')
  190. parser.add_argument('--seed', type=int, default=42,
  191. help='随机种子 (默认: 42)')
  192. parser.add_argument('--dry-run', action='store_true',
  193. help='仅显示划分统计,不实际复制文件')
  194. args = parser.parse_args()
  195. # 设置随机种子
  196. random.seed(args.seed)
  197. # 验证源目录
  198. if not os.path.exists(args.source_dir):
  199. print(f"错误: 源目录不存在: {args.source_dir}")
  200. return
  201. print(f"开始处理数据集...")
  202. print(f"源目录: {args.source_dir}")
  203. print(f"输出目录: {args.output}")
  204. print(f"划分比例: 训练集{args.train_ratio} : 测试集{args.test_ratio} : 验证集{args.val_ratio}")
  205. print(f"随机种子: {args.seed}")
  206. print("-" * 50)
  207. # 获取配对文件
  208. print("1. 扫描配对文件...")
  209. paired_files = get_paired_files(args.source_dir)
  210. if not paired_files:
  211. print("错误: 没有找到配对的图片和标签文件")
  212. return
  213. print(f"找到 {len(paired_files)} 对配对文件")
  214. # 划分数据集
  215. print("\n2. 划分数据集...")
  216. train_files, test_files, val_files = split_dataset(
  217. paired_files, args.train_ratio, args.test_ratio, args.val_ratio
  218. )
  219. if args.dry_run:
  220. print("\n[试运行模式] 仅显示统计信息,不实际复制文件")
  221. return
  222. # 创建目录结构
  223. print("\n3. 创建目录结构...")
  224. create_directory_structure(args.output)
  225. # 复制文件
  226. print("\n4. 复制文件...")
  227. print("复制训练集文件...")
  228. copy_files(train_files, args.output, "train")
  229. print("复制测试集文件...")
  230. copy_files(test_files, args.output, "test")
  231. print("复制验证集文件...")
  232. copy_files(val_files, args.output, "val")
  233. # 创建配置文件
  234. print("\n5. 创建数据集配置文件...")
  235. create_dataset_yaml(args.output, args.classes)
  236. print("\n" + "=" * 50)
  237. print("数据集划分完成!")
  238. print(f"输出目录: {os.path.abspath(args.output)}")
  239. print("\n目录结构:")
  240. print("yolo_dataset/")
  241. print("├── images/")
  242. print("│ ├── train/")
  243. print("│ ├── test/")
  244. print("│ └── val/")
  245. print("├── labels/")
  246. print("│ ├── train/")
  247. print("│ ├── test/")
  248. print("│ └── val/")
  249. print("└── dataset.yaml")
  250. if __name__ == "__main__":
  251. main()