| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 数据集划分脚本
- 将datasets_merged中的图片和标签按照8:1:1比例划分为train/test/val数据集
- """
- import os
- import shutil
- import random
- from pathlib import Path
- import argparse
- from typing import List, Tuple
- def get_paired_files(source_dir: str) -> List[Tuple[str, str]]:
- """
- 获取配对的图片和标签文件
-
- Args:
- source_dir: 源数据目录
-
- Returns:
- 配对文件列表 [(image_path, label_path), ...]
- """
- source_path = Path(source_dir)
-
- # 获取所有图片文件
- image_files = list(source_path.glob("*.jpg"))
- paired_files = []
-
- for img_file in image_files:
- # 构造对应的标签文件路径
- label_file = source_path / f"{img_file.stem}.txt"
-
- if label_file.exists():
- paired_files.append((str(img_file), str(label_file)))
- else:
- print(f"警告: 图片 {img_file.name} 没有对应的标签文件")
-
- return paired_files
- def create_directory_structure(output_dir: str):
- """
- 创建YOLO数据集目录结构
-
- Args:
- output_dir: 输出目录
- """
- output_path = Path(output_dir)
-
- # 创建主要目录结构
- directories = [
- "images/train",
- "images/test",
- "images/val",
- "labels/train",
- "labels/test",
- "labels/val"
- ]
-
- for directory in directories:
- dir_path = output_path / directory
- dir_path.mkdir(parents=True, exist_ok=True)
- print(f"创建目录: {dir_path}")
- def split_dataset(paired_files: List[Tuple[str, str]],
- train_ratio: float = 0.8,
- test_ratio: float = 0.1,
- val_ratio: float = 0.1) -> Tuple[List, List, List]:
- """
- 按比例划分数据集
-
- Args:
- paired_files: 配对文件列表
- train_ratio: 训练集比例
- test_ratio: 测试集比例
- val_ratio: 验证集比例
-
- Returns:
- (train_files, test_files, val_files)
- """
- # 验证比例
- total_ratio = train_ratio + test_ratio + val_ratio
- if abs(total_ratio - 1.0) > 1e-6:
- raise ValueError(f"比例总和必须为1.0,当前为: {total_ratio}")
-
- # 随机打乱数据
- random.shuffle(paired_files)
-
- total_files = len(paired_files)
- train_count = int(total_files * train_ratio)
- test_count = int(total_files * test_ratio)
-
- # 划分数据集
- train_files = paired_files[:train_count]
- test_files = paired_files[train_count:train_count + test_count]
- val_files = paired_files[train_count + test_count:]
-
- print(f"数据集划分统计:")
- print(f" 总文件数: {total_files}")
- print(f" 训练集: {len(train_files)} ({len(train_files)/total_files*100:.1f}%)")
- print(f" 测试集: {len(test_files)} ({len(test_files)/total_files*100:.1f}%)")
- print(f" 验证集: {len(val_files)} ({len(val_files)/total_files*100:.1f}%)")
-
- return train_files, test_files, val_files
- def copy_files(file_list: List[Tuple[str, str]],
- output_dir: str,
- subset_name: str):
- """
- 复制文件到目标目录
-
- Args:
- file_list: 文件列表
- output_dir: 输出目录
- subset_name: 子集名称 (train/test/val)
- """
- output_path = Path(output_dir)
-
- for img_path, label_path in file_list:
- img_file = Path(img_path)
- label_file = Path(label_path)
-
- # 目标路径
- target_img_dir = output_path / "images" / subset_name
- target_label_dir = output_path / "labels" / subset_name
-
- target_img_path = target_img_dir / img_file.name
- target_label_path = target_label_dir / label_file.name
-
- # 复制文件
- try:
- shutil.copy2(img_path, target_img_path)
- shutil.copy2(label_path, target_label_path)
- except Exception as e:
- print(f"复制文件失败: {img_file.name} - {e}")
- def _read_class_names(classes_file: str):
- """读取类别文件,支持以下格式:
- - 每行一个类名:"fire"
- - 带显式ID:"0 fire" 或 "fire 0" 或 "0,fire"
- - 行内注释:以#开始的内容忽略
- 返回:仅类名组成的列表
- """
- names = []
- with open(classes_file, 'r', encoding='utf-8') as f:
- for raw in f:
- line = raw.strip()
- if not line:
- continue
- # 去除行内注释
- if '#' in line:
- line = line.split('#', 1)[0].strip()
- if not line:
- continue
- # 逗号分隔:"id,name"
- if ',' in line:
- parts = [p.strip() for p in line.split(',') if p.strip()]
- if len(parts) == 2:
- if parts[0].isdigit():
- names.append(parts[1])
- continue
- tokens = [t for t in line.split() if t]
- if not tokens:
- continue
- if tokens[0].isdigit():
- # "ID 类名(可能包含空格)"
- names.append(' '.join(tokens[1:]))
- elif tokens[-1].isdigit():
- # "类名(可能包含空格) ID"
- names.append(' '.join(tokens[:-1]))
- else:
- names.append(line)
- return names
- def create_dataset_yaml(output_dir: str, classes_file: str = None):
- """
- 创建YOLO数据集配置文件
-
- Args:
- output_dir: 输出目录
- classes_file: 类别文件路径
- """
- output_path = Path(output_dir)
- yaml_path = output_path / "dataset.yaml"
-
- # 读取类别信息
- if classes_file and os.path.exists(classes_file):
- classes = _read_class_names(classes_file)
- else:
- # 默认类别(若未提供classes文件)
- classes = ['fire', 'dust', 'move_machine', 'open_machine', 'close_machine']
-
- # 生成YAML内容
- yaml_content = f"""# YOLO数据集配置文件
- # 数据集路径 (相对于此文件的路径)
- path: {output_path.absolute()}
- train: images/train
- val: images/val
- test: images/test
- # 类别数量
- nc: {len(classes)}
- # 类别名称
- names: {classes}
- """
-
- with open(yaml_path, 'w', encoding='utf-8') as f:
- f.write(yaml_content)
-
- print(f"创建数据集配置文件: {yaml_path}")
- def main():
- parser = argparse.ArgumentParser(description='YOLO数据集划分工具')
- parser.add_argument('source_dir', help='源数据目录路径')
- parser.add_argument('-o', '--output', default='./yolo_dataset',
- help='输出目录路径 (默认: ./yolo_dataset)')
- parser.add_argument('-c', '--classes', help='类别文件路径')
- parser.add_argument('--train-ratio', type=float, default=0.8,
- help='训练集比例 (默认: 0.8)')
- parser.add_argument('--test-ratio', type=float, default=0.1,
- help='测试集比例 (默认: 0.1)')
- parser.add_argument('--val-ratio', type=float, default=0.1,
- help='验证集比例 (默认: 0.1)')
- parser.add_argument('--seed', type=int, default=42,
- help='随机种子 (默认: 42)')
- parser.add_argument('--dry-run', action='store_true',
- help='仅显示划分统计,不实际复制文件')
-
- args = parser.parse_args()
-
- # 设置随机种子
- random.seed(args.seed)
-
- # 验证源目录
- if not os.path.exists(args.source_dir):
- print(f"错误: 源目录不存在: {args.source_dir}")
- return
-
- print(f"开始处理数据集...")
- print(f"源目录: {args.source_dir}")
- print(f"输出目录: {args.output}")
- print(f"划分比例: 训练集{args.train_ratio} : 测试集{args.test_ratio} : 验证集{args.val_ratio}")
- print(f"随机种子: {args.seed}")
- print("-" * 50)
-
- # 获取配对文件
- print("1. 扫描配对文件...")
- paired_files = get_paired_files(args.source_dir)
-
- if not paired_files:
- print("错误: 没有找到配对的图片和标签文件")
- return
-
- print(f"找到 {len(paired_files)} 对配对文件")
-
- # 划分数据集
- print("\n2. 划分数据集...")
- train_files, test_files, val_files = split_dataset(
- paired_files, args.train_ratio, args.test_ratio, args.val_ratio
- )
-
- if args.dry_run:
- print("\n[试运行模式] 仅显示统计信息,不实际复制文件")
- return
-
- # 创建目录结构
- print("\n3. 创建目录结构...")
- create_directory_structure(args.output)
-
- # 复制文件
- print("\n4. 复制文件...")
- print("复制训练集文件...")
- copy_files(train_files, args.output, "train")
-
- print("复制测试集文件...")
- copy_files(test_files, args.output, "test")
-
- print("复制验证集文件...")
- copy_files(val_files, args.output, "val")
-
- # 创建配置文件
- print("\n5. 创建数据集配置文件...")
- create_dataset_yaml(args.output, args.classes)
-
- print("\n" + "=" * 50)
- print("数据集划分完成!")
- print(f"输出目录: {os.path.abspath(args.output)}")
- print("\n目录结构:")
- print("yolo_dataset/")
- print("├── images/")
- print("│ ├── train/")
- print("│ ├── test/")
- print("│ └── val/")
- print("├── labels/")
- print("│ ├── train/")
- print("│ ├── test/")
- print("│ └── val/")
- print("└── dataset.yaml")
- if __name__ == "__main__":
- main()
|