#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 数据集划分脚本 将datasets_merged中的图片和标签按照8:1:1比例划分为train/test/val数据集 """ import os import shutil import random from pathlib import Path import argparse from typing import List, Tuple def get_paired_files(source_dir: str) -> List[Tuple[str, str]]: """ 获取配对的图片和标签文件 Args: source_dir: 源数据目录 Returns: 配对文件列表 [(image_path, label_path), ...] """ source_path = Path(source_dir) # 获取所有图片文件 image_files = list(source_path.glob("*.jpg")) paired_files = [] for img_file in image_files: # 构造对应的标签文件路径 label_file = source_path / f"{img_file.stem}.txt" if label_file.exists(): paired_files.append((str(img_file), str(label_file))) else: print(f"警告: 图片 {img_file.name} 没有对应的标签文件") return paired_files def create_directory_structure(output_dir: str): """ 创建YOLO数据集目录结构 Args: output_dir: 输出目录 """ output_path = Path(output_dir) # 创建主要目录结构 directories = [ "images/train", "images/test", "images/val", "labels/train", "labels/test", "labels/val" ] for directory in directories: dir_path = output_path / directory dir_path.mkdir(parents=True, exist_ok=True) print(f"创建目录: {dir_path}") def split_dataset(paired_files: List[Tuple[str, str]], train_ratio: float = 0.8, test_ratio: float = 0.1, val_ratio: float = 0.1) -> Tuple[List, List, List]: """ 按比例划分数据集 Args: paired_files: 配对文件列表 train_ratio: 训练集比例 test_ratio: 测试集比例 val_ratio: 验证集比例 Returns: (train_files, test_files, val_files) """ # 验证比例 total_ratio = train_ratio + test_ratio + val_ratio if abs(total_ratio - 1.0) > 1e-6: raise ValueError(f"比例总和必须为1.0,当前为: {total_ratio}") # 随机打乱数据 random.shuffle(paired_files) total_files = len(paired_files) train_count = int(total_files * train_ratio) test_count = int(total_files * test_ratio) # 划分数据集 train_files = paired_files[:train_count] test_files = paired_files[train_count:train_count + test_count] val_files = paired_files[train_count + test_count:] print(f"数据集划分统计:") print(f" 总文件数: {total_files}") print(f" 训练集: {len(train_files)} ({len(train_files)/total_files*100:.1f}%)") print(f" 测试集: {len(test_files)} ({len(test_files)/total_files*100:.1f}%)") print(f" 验证集: {len(val_files)} ({len(val_files)/total_files*100:.1f}%)") return train_files, test_files, val_files def copy_files(file_list: List[Tuple[str, str]], output_dir: str, subset_name: str): """ 复制文件到目标目录 Args: file_list: 文件列表 output_dir: 输出目录 subset_name: 子集名称 (train/test/val) """ output_path = Path(output_dir) for img_path, label_path in file_list: img_file = Path(img_path) label_file = Path(label_path) # 目标路径 target_img_dir = output_path / "images" / subset_name target_label_dir = output_path / "labels" / subset_name target_img_path = target_img_dir / img_file.name target_label_path = target_label_dir / label_file.name # 复制文件 try: shutil.copy2(img_path, target_img_path) shutil.copy2(label_path, target_label_path) except Exception as e: print(f"复制文件失败: {img_file.name} - {e}") def _read_class_names(classes_file: str): """读取类别文件,支持以下格式: - 每行一个类名:"fire" - 带显式ID:"0 fire" 或 "fire 0" 或 "0,fire" - 行内注释:以#开始的内容忽略 返回:仅类名组成的列表 """ names = [] with open(classes_file, 'r', encoding='utf-8') as f: for raw in f: line = raw.strip() if not line: continue # 去除行内注释 if '#' in line: line = line.split('#', 1)[0].strip() if not line: continue # 逗号分隔:"id,name" if ',' in line: parts = [p.strip() for p in line.split(',') if p.strip()] if len(parts) == 2: if parts[0].isdigit(): names.append(parts[1]) continue tokens = [t for t in line.split() if t] if not tokens: continue if tokens[0].isdigit(): # "ID 类名(可能包含空格)" names.append(' '.join(tokens[1:])) elif tokens[-1].isdigit(): # "类名(可能包含空格) ID" names.append(' '.join(tokens[:-1])) else: names.append(line) return names def create_dataset_yaml(output_dir: str, classes_file: str = None): """ 创建YOLO数据集配置文件 Args: output_dir: 输出目录 classes_file: 类别文件路径 """ output_path = Path(output_dir) yaml_path = output_path / "dataset.yaml" # 读取类别信息 if classes_file and os.path.exists(classes_file): classes = _read_class_names(classes_file) else: # 默认类别(若未提供classes文件) classes = ['fire', 'dust', 'move_machine', 'open_machine', 'close_machine'] # 生成YAML内容 yaml_content = f"""# YOLO数据集配置文件 # 数据集路径 (相对于此文件的路径) path: {output_path.absolute()} train: images/train val: images/val test: images/test # 类别数量 nc: {len(classes)} # 类别名称 names: {classes} """ with open(yaml_path, 'w', encoding='utf-8') as f: f.write(yaml_content) print(f"创建数据集配置文件: {yaml_path}") def main(): parser = argparse.ArgumentParser(description='YOLO数据集划分工具') parser.add_argument('source_dir', help='源数据目录路径') parser.add_argument('-o', '--output', default='./yolo_dataset', help='输出目录路径 (默认: ./yolo_dataset)') parser.add_argument('-c', '--classes', help='类别文件路径') parser.add_argument('--train-ratio', type=float, default=0.8, help='训练集比例 (默认: 0.8)') parser.add_argument('--test-ratio', type=float, default=0.1, help='测试集比例 (默认: 0.1)') parser.add_argument('--val-ratio', type=float, default=0.1, help='验证集比例 (默认: 0.1)') parser.add_argument('--seed', type=int, default=42, help='随机种子 (默认: 42)') parser.add_argument('--dry-run', action='store_true', help='仅显示划分统计,不实际复制文件') args = parser.parse_args() # 设置随机种子 random.seed(args.seed) # 验证源目录 if not os.path.exists(args.source_dir): print(f"错误: 源目录不存在: {args.source_dir}") return print(f"开始处理数据集...") print(f"源目录: {args.source_dir}") print(f"输出目录: {args.output}") print(f"划分比例: 训练集{args.train_ratio} : 测试集{args.test_ratio} : 验证集{args.val_ratio}") print(f"随机种子: {args.seed}") print("-" * 50) # 获取配对文件 print("1. 扫描配对文件...") paired_files = get_paired_files(args.source_dir) if not paired_files: print("错误: 没有找到配对的图片和标签文件") return print(f"找到 {len(paired_files)} 对配对文件") # 划分数据集 print("\n2. 划分数据集...") train_files, test_files, val_files = split_dataset( paired_files, args.train_ratio, args.test_ratio, args.val_ratio ) if args.dry_run: print("\n[试运行模式] 仅显示统计信息,不实际复制文件") return # 创建目录结构 print("\n3. 创建目录结构...") create_directory_structure(args.output) # 复制文件 print("\n4. 复制文件...") print("复制训练集文件...") copy_files(train_files, args.output, "train") print("复制测试集文件...") copy_files(test_files, args.output, "test") print("复制验证集文件...") copy_files(val_files, args.output, "val") # 创建配置文件 print("\n5. 创建数据集配置文件...") create_dataset_yaml(args.output, args.classes) print("\n" + "=" * 50) print("数据集划分完成!") print(f"输出目录: {os.path.abspath(args.output)}") print("\n目录结构:") print("yolo_dataset/") print("├── images/") print("│ ├── train/") print("│ ├── test/") print("│ └── val/") print("├── labels/") print("│ ├── train/") print("│ ├── test/") print("│ └── val/") print("└── dataset.yaml") if __name__ == "__main__": main()