s.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. import random
  3. import shutil
  4. import time
  5. # ================= 🔧 配置区域 =================
  6. # 请确保这个路径和你截图里的路径一模一样
  7. root_path = r"D:\data\20251210\20251210"
  8. # 划分比例 8:1:1
  9. split_ratios = {"train": 0.8, "val": 0.1, "test": 0.1}
  10. # 后缀配置
  11. valid_image_exts = ['.jpg', '.jpeg', '.png', '.bmp']
  12. valid_label_exts = ['.txt', '.xml', '.json']
  13. # ===============================================
  14. def main():
  15. print("==========================================")
  16. print(" 正在启动 8:1:1 随机划分程序")
  17. print("==========================================\n")
  18. images_root = os.path.join(root_path, "images")
  19. labels_root = os.path.join(root_path, "labels")
  20. # 1. 检查根目录
  21. if not os.path.exists(images_root):
  22. print(f"❌ 错误:找不到 images 文件夹!\n 试图寻找路径: {images_root}")
  23. return
  24. if not os.path.exists(labels_root):
  25. print(f"❌ 错误:找不到 labels 文件夹!\n 试图寻找路径: {labels_root}")
  26. return
  27. print(f"✅ 路径检查通过: {root_path}")
  28. # 2. 收集文件
  29. all_pairs = []
  30. # 扫描 images 下的所有文件夹(包括 train, test, val 或者其他)
  31. # 只要是在 images 目录下的子文件夹,都会被扫描
  32. sub_dirs = [d for d in os.listdir(images_root) if os.path.isdir(os.path.join(images_root, d))]
  33. # 如果 images 下没有子文件夹,可能是图片直接放在了 images 根目录下?
  34. # 为了兼容,如果下面没有文件夹,就扫描 images 本身
  35. if not sub_dirs:
  36. sub_dirs = ["."] # 代表当前目录
  37. print("⚠️ 提示:images 下没有子文件夹,将扫描 images 根目录...")
  38. print(f"📂 正在扫描以下文件夹: {sub_dirs}")
  39. for sub in sub_dirs:
  40. # 处理路径:如果是 "." 则不拼接子目录
  41. sub_img_dir = images_root if sub == "." else os.path.join(images_root, sub)
  42. sub_lbl_dir = labels_root if sub == "." else os.path.join(labels_root, sub)
  43. if not os.path.exists(sub_lbl_dir):
  44. # 如果对应的 label 文件夹不存在,跳过
  45. continue
  46. files = os.listdir(sub_img_dir)
  47. count_folder = 0
  48. for f in files:
  49. stem, ext = os.path.splitext(f)
  50. if ext.lower() in valid_image_exts:
  51. img_path = os.path.join(sub_img_dir, f)
  52. # 找标签
  53. lbl_path = None
  54. for lbl_ext in valid_label_exts:
  55. potential_lbl = os.path.join(sub_lbl_dir, stem + lbl_ext)
  56. if os.path.exists(potential_lbl):
  57. lbl_path = potential_lbl
  58. break
  59. if lbl_path:
  60. all_pairs.append({'img': img_path, 'lbl': lbl_path})
  61. count_folder += 1
  62. print(f" -> 在 [{sub}] 中找到 {count_folder} 对数据")
  63. total = len(all_pairs)
  64. print(f"\n📦 总共收集到: {total} 组数据")
  65. if total == 0:
  66. print("❌ 没有找到任何匹配的图片和标签,请检查文件名是否对应(例如 001.jpg 是否有 001.txt)。")
  67. return
  68. # 3. 打乱与计算
  69. print("🎲 正在打乱顺序...")
  70. random.shuffle(all_pairs)
  71. n_train = int(total * split_ratios["train"])
  72. n_val = int(total * split_ratios["val"])
  73. n_test = total - n_train - n_val
  74. print(f"📊 划分数量 -> Train: {n_train}, Val: {n_val}, Test: {n_test}")
  75. split_data = {
  76. "train": all_pairs[:n_train],
  77. "val": all_pairs[n_train : n_train + n_val],
  78. "test": all_pairs[n_train + n_val :]
  79. }
  80. # 4. 移动文件
  81. print("🚚 开始移动文件...")
  82. for split_name, items in split_data.items():
  83. # 目标目录
  84. dst_img_dir = os.path.join(images_root, split_name)
  85. dst_lbl_dir = os.path.join(labels_root, split_name)
  86. os.makedirs(dst_img_dir, exist_ok=True)
  87. os.makedirs(dst_lbl_dir, exist_ok=True)
  88. processed = 0
  89. for item in items:
  90. src_img, src_lbl = item['img'], item['lbl']
  91. # 只有当源路径不在目标路径时才移动
  92. if os.path.dirname(src_img) != dst_img_dir:
  93. try:
  94. shutil.move(src_img, os.path.join(dst_img_dir, os.path.basename(src_img)))
  95. shutil.move(src_lbl, os.path.join(dst_lbl_dir, os.path.basename(src_lbl)))
  96. except Exception as e:
  97. print(f" [Error] 移动失败: {os.path.basename(src_img)} -> {e}")
  98. processed += 1
  99. # 每处理 500 个打印一次进度,防止看着像死机
  100. if processed % 500 == 0:
  101. print(f" [{split_name}] 已处理 {processed} / {len(items)}")
  102. print(f" ✅ {split_name} 完成。")
  103. print("\n🎉🎉🎉 全部处理完毕! 🎉🎉🎉")
  104. if __name__ == "__main__":
  105. try:
  106. main()
  107. except Exception as e:
  108. print(f"\n❌ 发生严重错误: {e}")
  109. # 这里的 input 是为了防止窗口直接关闭
  110. print("\n--------------------------------")
  111. input("程序运行结束,请按回车键关闭窗口...")