autolabel.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import os
  2. from ultralytics import YOLO
  3. def auto_label(image_dir, output_dir, model_path='model/best.pt', conf=0.5, save_segmentation=True):
  4. """
  5. 使用 YOLO 模型自动标注图片并生成 TXT 标签文件。
  6. """
  7. # 检查模型文件是否存在
  8. if not os.path.exists(model_path):
  9. print(f"错误: 模型文件不存在 -> {model_path}")
  10. # 尝试使用根目录下的 yolov8n-seg.pt 作为备选
  11. fallback_model = 'yolov8n-seg.pt'
  12. if os.path.exists(fallback_model):
  13. print(f"尝试使用备选模型 -> {fallback_model}")
  14. model_path = fallback_model
  15. else:
  16. return
  17. # 加载模型
  18. print(f"正在加载模型: {model_path} ...")
  19. model = YOLO(model_path)
  20. # 确保输出目录存在
  21. if not os.path.exists(output_dir):
  22. os.makedirs(output_dir)
  23. print(f"创建输出目录: {output_dir}")
  24. # 支持的图片格式
  25. img_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
  26. # 获取图片列表
  27. image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(img_extensions)]
  28. total_files = len(image_files)
  29. print(f"找到 {total_files} 张图片,开始自动标注...")
  30. count = 0
  31. for idx, img_file in enumerate(image_files):
  32. img_path = os.path.join(image_dir, img_file)
  33. txt_filename = os.path.splitext(img_file)[0] + ".txt"
  34. txt_path = os.path.join(output_dir, txt_filename)
  35. # 推理
  36. # stream=True 可以节省内存,但对于单张处理差异不大
  37. results = model(img_path, conf=conf, verbose=False)
  38. for result in results:
  39. with open(txt_path, 'w') as f:
  40. # 优先尝试保存分割掩码
  41. if save_segmentation and result.masks is not None:
  42. # result.masks.xyn 获取归一化的多边形坐标片段
  43. if hasattr(result.masks, 'xyn'):
  44. for i, seg in enumerate(result.masks.xyn):
  45. cls = int(result.boxes.cls[i]) # 对应的类别
  46. # 格式: class x1 y1 x2 y2 ...
  47. # flatten() 将数组展平,tolist() 转为列表
  48. coords = " ".join([f"{p[0]:.6f} {p[1]:.6f}" for p in seg])
  49. f.write(f"{cls} {coords}\n")
  50. else:
  51. pass # 兼容性处理
  52. # 如果没有分割结果或未启用分割,保存检测框
  53. elif result.boxes is not None:
  54. for box in result.boxes:
  55. cls = int(box.cls)
  56. # xywhn: x_center, y_center, width, height (normalized)
  57. x, y, w, h = box.xywhn[0].tolist()
  58. f.write(f"{cls} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")
  59. count += 1
  60. if count % 10 == 0:
  61. print(f"进度: {count}/{total_files}")
  62. print(f"完成!已生成 {count} 个标签文件。保存位置: {output_dir}")
  63. if __name__ == "__main__":
  64. # --- 配置区域 (请在此修改路径) ---
  65. # 待标注图片文件夹 (例如: d:\data\20251204)
  66. IMAGE_DIR = r'd:\data\20251204'
  67. # 标签保存文件夹 (例如: d:\data\20251204\labels_auto)
  68. OUTPUT_DIR = r'd:\data\20251204\labels_auto'
  69. # 模型路径 (例如: d:\data\model\best.pt)
  70. MODEL_PATH = r'd:\data\best.pt'
  71. # 置信度阈值
  72. CONF_THRESHOLD = 0.5
  73. # 是否保存分割点 (True=保存多边形, False=只保存矩形框)
  74. SAVE_SEGMENTATION = True
  75. # --------------------------------
  76. # 检查输入目录是否存在,防止报错
  77. if not os.path.exists(IMAGE_DIR):
  78. print(f"错误: 图片目录不存在 -> {IMAGE_DIR}")
  79. print("请打开脚本修改 IMAGE_DIR 为正确的路径。")
  80. else:
  81. auto_label(IMAGE_DIR, OUTPUT_DIR, MODEL_PATH, CONF_THRESHOLD, SAVE_SEGMENTATION)