d_data_spliter.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. """
  2. 该脚本将源目录下同名的 .json 和 .diff 文件配对,
  3. 并按如下结构将其整理至目标目录:
  4. 输出结构如下:
  5. output_dir/
  6. ├── images/
  7. │ ├── train/
  8. │ └── val/
  9. └── labels/
  10. ├── train/
  11. └── val/
  12. 其中:
  13. - images/train 和 val 存放的是 `.diff` 文件
  14. - labels/train 和 val 存放的是 `.json` 文件
  15. """
  16. import os
  17. import shutil
  18. import random
  19. from pathlib import Path
  20. def organize_data(
  21. src_dir,
  22. dest_dir,
  23. image_extensions=['.tiff'],
  24. label_extensions=['.json'],
  25. val_ratio=0.2
  26. ):
  27. src_dir = Path(src_dir)
  28. dest_dir = Path(dest_dir)
  29. image_dir = dest_dir / 'images'
  30. label_dir = dest_dir / 'labels'
  31. # 创建文件夹结构
  32. for split in ['train', 'val']:
  33. (image_dir / split).mkdir(parents=True, exist_ok=True)
  34. (label_dir / split).mkdir(parents=True, exist_ok=True)
  35. # 获取所有文件
  36. files = list(src_dir.glob('*'))
  37. name_to_files = {}
  38. # 分组:同名文件归为一组
  39. for f in files:
  40. stem = f.stem
  41. name_to_files.setdefault(stem, []).append(f)
  42. # 筛选出同时包含 label 和 image 的样本
  43. paired_samples = []
  44. for name, file_group in name_to_files.items():
  45. image_file = next((f for f in file_group if f.suffix in image_extensions), None)
  46. label_file = next((f for f in file_group if f.suffix in label_extensions), None)
  47. if image_file and label_file:
  48. paired_samples.append((image_file, label_file))
  49. else:
  50. print(f"⚠️ Skipping unpaired files for: {name}")
  51. # 打乱并划分
  52. random.shuffle(paired_samples)
  53. split_idx = int(len(paired_samples) * (1 - val_ratio))
  54. train_samples = paired_samples[:split_idx]
  55. val_samples = paired_samples[split_idx:]
  56. # 拷贝函数
  57. def copy_samples(samples, split):
  58. for img, lbl in samples:
  59. shutil.copy(img, image_dir / split / img.name)
  60. shutil.copy(lbl, label_dir / split / lbl.name)
  61. # 执行拷贝
  62. copy_samples(train_samples, 'train')
  63. copy_samples(val_samples, 'val')
  64. print(f"\n✅ Done! Processed {len(paired_samples)} pairs.")
  65. print(f"Train: {len(train_samples)}, Val: {len(val_samples)}")
  66. if __name__ == "__main__":
  67. # 输入输出目录(可修改)
  68. source_dir = r"/home/zhaoyinghan/py_ws/data/circle/huayan"
  69. parent_dir = os.path.dirname(source_dir)
  70. output_dir = os.path.join(parent_dir, "a_dataset")
  71. # 后缀名列表,方便以后扩展其他格式
  72. image_exts = ['.tiff','.jpg','.png']
  73. label_exts = ['.json']
  74. organize_data(source_dir, output_dir, image_exts, label_exts)