split_dataset.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import os
  2. import random
  3. import shutil
  4. from sklearn.model_selection import train_test_split
  5. # 定义路径
  6. data_dir = 'D:\python\PycharmProjects\data_20250223\lcnn新T型板十字板增强后(复件)' # 替换为你的数据文件夹路径
  7. output_dir = 'D:\python\PycharmProjects\data_20250223\lcnn_20250223' # 替换为你想要保存输出的文件夹路径
  8. # 创建输出目录结构
  9. images_train_dir = os.path.join(output_dir, 'images', 'train')
  10. images_val_dir = os.path.join(output_dir, 'images', 'val')
  11. labels_train_dir = os.path.join(output_dir, 'labels', 'train')
  12. labels_val_dir = os.path.join(output_dir, 'labels', 'val')
  13. os.makedirs(images_train_dir, exist_ok=True)
  14. os.makedirs(images_val_dir, exist_ok=True)
  15. os.makedirs(labels_train_dir, exist_ok=True)
  16. os.makedirs(labels_val_dir, exist_ok=True)
  17. # 获取所有图片文件名和对应的json文件名
  18. image_files = [f for f in os.listdir(data_dir) if f.endswith('.jpg')]
  19. json_files = {f.replace('.json', ''): f for f in os.listdir(data_dir) if f.endswith('.json')}
  20. # 提取图片名称(不包含扩展名)以便匹配json文件
  21. image_names = [os.path.splitext(f)[0] for f in image_files]
  22. # 按照9:1的比例划分数据集
  23. train_names, val_names = train_test_split(image_names, test_size=0.1, random_state=42)
  24. # 复制文件到相应目录
  25. for name in train_names:
  26. image_file = name + '.jpg'
  27. json_file = json_files[name]
  28. # 复制图片文件
  29. shutil.copy(os.path.join(data_dir, image_file), os.path.join(images_train_dir, image_file))
  30. # 复制json文件
  31. shutil.copy(os.path.join(data_dir, json_file), os.path.join(labels_train_dir, json_file))
  32. for name in val_names:
  33. image_file = name + '.jpg'
  34. json_file = json_files[name]
  35. # 复制图片文件
  36. shutil.copy(os.path.join(data_dir, image_file), os.path.join(images_val_dir, image_file))
  37. # 复制json文件
  38. shutil.copy(os.path.join(data_dir, json_file), os.path.join(labels_val_dir, json_file))
  39. # 示例调用
  40. if __name__ == "__main__":
  41. # 确保输出目录和子目录存在
  42. os.makedirs(images_train_dir, exist_ok=True)
  43. os.makedirs(images_val_dir, exist_ok=True)
  44. os.makedirs(labels_train_dir, exist_ok=True)
  45. os.makedirs(labels_val_dir, exist_ok=True)
  46. # 执行划分和文件复制
  47. train_names, val_names = train_test_split(image_names, test_size=0.1, random_state=42)
  48. for name in train_names:
  49. image_file = name + '.jpg'
  50. json_file = json_files[name]
  51. # 复制图片文件
  52. shutil.copy(os.path.join(data_dir, image_file), os.path.join(images_train_dir, image_file))
  53. # 复制json文件
  54. shutil.copy(os.path.join(data_dir, json_file), os.path.join(labels_train_dir, json_file))
  55. for name in val_names:
  56. image_file = name + '.jpg'
  57. json_file = json_files[name]
  58. # 复制图片文件
  59. shutil.copy(os.path.join(data_dir, image_file), os.path.join(images_val_dir, image_file))
  60. # 复制json文件
  61. shutil.copy(os.path.join(data_dir, json_file), os.path.join(labels_val_dir, json_file))