json_to_yolo.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. JSON标签到YOLO格式转换脚本
  5. 支持多种常见的JSON标注格式转换为YOLO格式
  6. 功能特性:
  7. - 支持LabelMe、COCO、YOLO等多种JSON格式
  8. - 矩形标注转换为YOLO边界框格式 (class_id x_center y_center width height)
  9. - 多边形标注保留所有点位信息 (class_id x1 y1 x2 y2 ... xn yn)
  10. - 自动归一化坐标到[0,1]范围
  11. - 支持自定义类别映射文件
  12. """
  13. import os
  14. import json
  15. import glob
  16. from pathlib import Path
  17. import argparse
  18. def convert_bbox_to_yolo(bbox, img_width, img_height, format_type="xywh"):
  19. """
  20. 将边界框坐标转换为YOLO格式(归一化的中心点坐标和宽高)
  21. Args:
  22. bbox: 边界框坐标
  23. img_width: 图片宽度
  24. img_height: 图片高度
  25. format_type: 输入格式类型 ("xywh", "xyxy", "coco")
  26. Returns:
  27. tuple: (center_x, center_y, width, height) 归一化坐标
  28. """
  29. if format_type == "xyxy":
  30. # 格式: [x_min, y_min, x_max, y_max]
  31. x_min, y_min, x_max, y_max = bbox
  32. width = x_max - x_min
  33. height = y_max - y_min
  34. center_x = x_min + width / 2
  35. center_y = y_min + height / 2
  36. elif format_type == "xywh":
  37. # 格式: [x, y, width, height] (左上角坐标)
  38. x, y, width, height = bbox
  39. center_x = x + width / 2
  40. center_y = y + height / 2
  41. elif format_type == "coco":
  42. # COCO格式: [x, y, width, height] (左上角坐标)
  43. x, y, width, height = bbox
  44. center_x = x + width / 2
  45. center_y = y + height / 2
  46. else:
  47. raise ValueError(f"不支持的格式类型: {format_type}")
  48. # 归一化坐标
  49. center_x_norm = center_x / img_width
  50. center_y_norm = center_y / img_height
  51. width_norm = width / img_width
  52. height_norm = height / img_height
  53. return center_x_norm, center_y_norm, width_norm, height_norm
  54. def convert_polygon_to_yolo(points, img_width, img_height):
  55. """
  56. 将多边形点位转换为YOLO格式(归一化坐标)
  57. Args:
  58. points: 多边形点位列表 [[x1, y1], [x2, y2], ...]
  59. img_width: 图片宽度
  60. img_height: 图片高度
  61. Returns:
  62. list: 归一化的点位坐标 [x1_norm, y1_norm, x2_norm, y2_norm, ...]
  63. """
  64. normalized_points = []
  65. for point in points:
  66. x, y = point
  67. # 归一化坐标
  68. x_norm = x / img_width
  69. y_norm = y / img_height
  70. normalized_points.extend([x_norm, y_norm])
  71. return normalized_points
  72. def parse_labelme_json(json_data):
  73. """
  74. 解析LabelMe格式的JSON文件
  75. Args:
  76. json_data: JSON数据
  77. Returns:
  78. list: 包含(class_name, bbox)的列表
  79. """
  80. annotations = []
  81. img_width = json_data.get('imageWidth', 0)
  82. img_height = json_data.get('imageHeight', 0)
  83. if img_width == 0 or img_height == 0:
  84. raise ValueError("JSON文件中缺少图片尺寸信息")
  85. for shape in json_data.get('shapes', []):
  86. label = shape.get('label', '')
  87. shape_type = shape.get('shape_type', 'rectangle')
  88. points = shape.get('points', [])
  89. if shape_type == 'rectangle' and len(points) == 2:
  90. # 矩形格式: [[x1, y1], [x2, y2]]
  91. x1, y1 = points[0]
  92. x2, y2 = points[1]
  93. # 确保坐标顺序正确
  94. x_min = min(x1, x2)
  95. y_min = min(y1, y2)
  96. x_max = max(x1, x2)
  97. y_max = max(y1, y2)
  98. bbox = [x_min, y_min, x_max, y_max]
  99. annotations.append((label, bbox, "xyxy", img_width, img_height))
  100. elif shape_type == 'polygon' and len(points) >= 3:
  101. # 多边形格式: 保留所有点位信息
  102. annotations.append((label, points, "polygon", img_width, img_height))
  103. return annotations
  104. def parse_coco_json(json_data):
  105. """
  106. 解析COCO格式的JSON文件
  107. Args:
  108. json_data: JSON数据
  109. Returns:
  110. dict: 按图片ID分组的标注信息
  111. """
  112. # 构建类别映射
  113. categories = {cat['id']: cat['name'] for cat in json_data.get('categories', [])}
  114. # 构建图片信息映射
  115. images = {img['id']: img for img in json_data.get('images', [])}
  116. # 按图片分组标注
  117. annotations_by_image = {}
  118. for ann in json_data.get('annotations', []):
  119. image_id = ann['image_id']
  120. category_id = ann['category_id']
  121. bbox = ann['bbox'] # COCO格式: [x, y, width, height]
  122. if image_id not in annotations_by_image:
  123. annotations_by_image[image_id] = []
  124. if image_id in images:
  125. img_info = images[image_id]
  126. img_width = img_info['width']
  127. img_height = img_info['height']
  128. class_name = categories.get(category_id, f'class_{category_id}')
  129. annotations_by_image[image_id].append((
  130. class_name, bbox, "coco", img_width, img_height, img_info['file_name']
  131. ))
  132. return annotations_by_image
  133. def parse_yolo_json(json_data):
  134. """
  135. 解析自定义YOLO JSON格式
  136. Args:
  137. json_data: JSON数据
  138. Returns:
  139. list: 包含(class_name, bbox)的列表
  140. """
  141. annotations = []
  142. img_width = json_data.get('image_width', json_data.get('width', 0))
  143. img_height = json_data.get('image_height', json_data.get('height', 0))
  144. if img_width == 0 or img_height == 0:
  145. raise ValueError("JSON文件中缺少图片尺寸信息")
  146. for obj in json_data.get('objects', json_data.get('annotations', [])):
  147. class_name = obj.get('class', obj.get('category', obj.get('label', '')))
  148. # 支持多种边界框格式
  149. if 'bbox' in obj:
  150. bbox = obj['bbox']
  151. bbox_format = obj.get('bbox_format', 'xywh')
  152. elif 'bounding_box' in obj:
  153. bbox = obj['bounding_box']
  154. bbox_format = obj.get('bbox_format', 'xywh')
  155. elif all(k in obj for k in ['x', 'y', 'width', 'height']):
  156. bbox = [obj['x'], obj['y'], obj['width'], obj['height']]
  157. bbox_format = 'xywh'
  158. elif all(k in obj for k in ['x_min', 'y_min', 'x_max', 'y_max']):
  159. bbox = [obj['x_min'], obj['y_min'], obj['x_max'], obj['y_max']]
  160. bbox_format = 'xyxy'
  161. else:
  162. print(f"警告: 无法解析对象的边界框格式: {obj}")
  163. continue
  164. annotations.append((class_name, bbox, bbox_format, img_width, img_height))
  165. return annotations
  166. def convert_json_to_yolo(json_file_path, output_dir, class_mapping=None, json_format="auto"):
  167. """
  168. 将JSON标注文件转换为YOLO格式
  169. Args:
  170. json_file_path: JSON文件路径
  171. output_dir: 输出目录
  172. class_mapping: 类别名称到ID的映射字典
  173. json_format: JSON格式类型 ("auto", "labelme", "coco", "yolo")
  174. """
  175. with open(json_file_path, 'r', encoding='utf-8') as f:
  176. json_data = json.load(f)
  177. # 自动检测JSON格式
  178. if json_format == "auto":
  179. if 'shapes' in json_data and 'imageWidth' in json_data:
  180. json_format = "labelme"
  181. elif 'categories' in json_data and 'annotations' in json_data and 'images' in json_data:
  182. json_format = "coco"
  183. else:
  184. json_format = "yolo"
  185. print(f"检测到JSON格式: {json_format}")
  186. # 解析JSON数据
  187. if json_format == "labelme":
  188. annotations = parse_labelme_json(json_data)
  189. # 为LabelMe格式生成单个txt文件
  190. base_name = Path(json_file_path).stem
  191. output_file = os.path.join(output_dir, f"{base_name}.txt")
  192. with open(output_file, 'w', encoding='utf-8') as f:
  193. for class_name, data, data_format, img_width, img_height in annotations:
  194. # 获取类别ID
  195. if class_mapping and class_name in class_mapping:
  196. class_id = class_mapping[class_name]
  197. else:
  198. class_id = 0 # 默认类别ID
  199. if data_format == "polygon":
  200. # 处理多边形点位
  201. normalized_points = convert_polygon_to_yolo(data, img_width, img_height)
  202. # 写入YOLO格式的多边形标注
  203. points_str = ' '.join([f"{coord:.6f}" for coord in normalized_points])
  204. f.write(f"{class_id} {points_str}\n")
  205. else:
  206. # 处理边界框
  207. center_x, center_y, width, height = convert_bbox_to_yolo(
  208. data, img_width, img_height, data_format
  209. )
  210. # 写入YOLO格式的边界框标注
  211. f.write(f"{class_id} {center_x:.6f} {center_y:.6f} {width:.6f} {height:.6f}\n")
  212. print(f"已生成: {output_file}")
  213. elif json_format == "coco":
  214. annotations_by_image = parse_coco_json(json_data)
  215. for image_id, annotations in annotations_by_image.items():
  216. if not annotations:
  217. continue
  218. # 使用第一个标注的文件名信息
  219. file_name = annotations[0][5] # file_name
  220. base_name = Path(file_name).stem
  221. output_file = os.path.join(output_dir, f"{base_name}.txt")
  222. with open(output_file, 'w', encoding='utf-8') as f:
  223. for class_name, bbox, bbox_format, img_width, img_height, _ in annotations:
  224. # 获取类别ID
  225. if class_mapping and class_name in class_mapping:
  226. class_id = class_mapping[class_name]
  227. else:
  228. class_id = 0 # 默认类别ID
  229. # 转换为YOLO格式
  230. center_x, center_y, width, height = convert_bbox_to_yolo(
  231. bbox, img_width, img_height, bbox_format
  232. )
  233. # 写入YOLO格式
  234. f.write(f"{class_id} {center_x:.6f} {center_y:.6f} {width:.6f} {height:.6f}\n")
  235. print(f"已生成: {output_file}")
  236. elif json_format == "yolo":
  237. annotations = parse_yolo_json(json_data)
  238. base_name = Path(json_file_path).stem
  239. output_file = os.path.join(output_dir, f"{base_name}.txt")
  240. with open(output_file, 'w', encoding='utf-8') as f:
  241. for class_name, bbox, bbox_format, img_width, img_height in annotations:
  242. # 获取类别ID
  243. if class_mapping and class_name in class_mapping:
  244. class_id = class_mapping[class_name]
  245. else:
  246. class_id = 0 # 默认类别ID
  247. # 转换为YOLO格式
  248. center_x, center_y, width, height = convert_bbox_to_yolo(
  249. bbox, img_width, img_height, bbox_format
  250. )
  251. # 写入YOLO格式
  252. f.write(f"{class_id} {center_x:.6f} {center_y:.6f} {width:.6f} {height:.6f}\n")
  253. print(f"已生成: {output_file}")
  254. def load_class_mapping(mapping_file):
  255. """
  256. 从文件加载类别映射
  257. Args:
  258. mapping_file: 映射文件路径 (支持txt和json格式)
  259. Returns:
  260. dict: 类别名称到ID的映射
  261. """
  262. if not os.path.exists(mapping_file):
  263. return None
  264. mapping = {}
  265. if mapping_file.endswith('.json'):
  266. with open(mapping_file, 'r', encoding='utf-8') as f:
  267. mapping = json.load(f)
  268. else:
  269. # txt格式兼容:
  270. # 1) "类名"(行号作为ID)
  271. # 2) "ID 类名" 或 "ID,类名"(显式ID与类名)
  272. # 3) "类名 ID"(显式ID在末尾)
  273. # 会自动忽略行首/行尾的空白与注释(# 开始的内容)
  274. with open(mapping_file, 'r', encoding='utf-8') as f:
  275. for i, raw in enumerate(f):
  276. line = raw.strip()
  277. if not line:
  278. continue
  279. # 去除行内注释
  280. if '#' in line:
  281. line = line.split('#', 1)[0].strip()
  282. if not line:
  283. continue
  284. cls_name = None
  285. cls_id = None
  286. # 尝试按逗号分隔(例如:"0,fire")
  287. if ',' in line:
  288. parts = [p.strip() for p in line.split(',') if p.strip()]
  289. if len(parts) == 2 and parts[0].isdigit():
  290. cls_id = int(parts[0])
  291. cls_name = parts[1]
  292. # 若未解析到,尝试按空白分隔(例如:"0 fire" 或 "fire 0" 或 "fire")
  293. if cls_name is None:
  294. tokens = [t for t in line.split() if t]
  295. if len(tokens) == 1:
  296. # 仅类名:按行号作为ID
  297. cls_name = tokens[0]
  298. cls_id = i
  299. elif len(tokens) >= 2:
  300. # 两段或以上:尝试识别前后是否为ID
  301. if tokens[0].isdigit():
  302. # "ID 类名(可能包含空格)"
  303. cls_id = int(tokens[0])
  304. cls_name = ' '.join(tokens[1:])
  305. elif tokens[-1].isdigit():
  306. # "类名(可能包含空格) ID"
  307. cls_id = int(tokens[-1])
  308. cls_name = ' '.join(tokens[:-1])
  309. else:
  310. # 都不是数字,则将整行视为类名,按行号作为ID
  311. cls_name = ' '.join(tokens)
  312. cls_id = i
  313. if cls_name:
  314. mapping[cls_name] = cls_id
  315. return mapping
  316. def main():
  317. parser = argparse.ArgumentParser(description='JSON标签到YOLO格式转换工具')
  318. parser.add_argument('input_path', help='输入JSON文件或包含JSON文件的目录')
  319. parser.add_argument('-o', '--output', default='./20251124/yolo_labels', help='输出目录 (默认: ./yolo_labels)')
  320. parser.add_argument('-c', '--classes', help='类别映射文件 (txt或json格式)')
  321. parser.add_argument('-f', '--format', choices=['auto', 'labelme', 'coco', 'yolo'],
  322. default='auto', help='JSON格式类型 (默认: auto)')
  323. parser.add_argument('--test', action='store_true', help='测试模式,仅显示解析结果不生成文件')
  324. args = parser.parse_args()
  325. # 创建输出目录
  326. output_dir = args.output
  327. if not args.test:
  328. os.makedirs(output_dir, exist_ok=True)
  329. # 加载类别映射
  330. class_mapping = None
  331. if args.classes:
  332. class_mapping = load_class_mapping(args.classes)
  333. if class_mapping:
  334. print(f"已加载类别映射: {class_mapping}")
  335. else:
  336. print(f"警告: 无法加载类别映射文件: {args.classes}")
  337. # 处理输入路径
  338. input_path = args.input_path
  339. if os.path.isfile(input_path):
  340. # 单个文件
  341. json_files = [input_path]
  342. elif os.path.isdir(input_path):
  343. # 目录中的所有JSON文件
  344. json_files = glob.glob(os.path.join(input_path, "*.json"))
  345. else:
  346. print(f"错误: 输入路径不存在: {input_path}")
  347. return
  348. if not json_files:
  349. print(f"错误: 在 {input_path} 中没有找到JSON文件")
  350. return
  351. print(f"找到 {len(json_files)} 个JSON文件")
  352. # 转换文件
  353. success_count = 0
  354. error_count = 0
  355. for json_file in json_files:
  356. try:
  357. print(f"\n处理文件: {json_file}")
  358. if args.test:
  359. # 测试模式:仅解析和显示信息
  360. with open(json_file, 'r', encoding='utf-8') as f:
  361. json_data = json.load(f)
  362. print(f" JSON键: {list(json_data.keys())}")
  363. if 'shapes' in json_data:
  364. print(f" LabelMe格式,包含 {len(json_data['shapes'])} 个标注")
  365. elif 'annotations' in json_data:
  366. print(f" COCO格式,包含 {len(json_data['annotations'])} 个标注")
  367. else:
  368. print(f" 自定义格式")
  369. else:
  370. convert_json_to_yolo(json_file, output_dir, class_mapping, args.format)
  371. success_count += 1
  372. except Exception as e:
  373. print(f" 错误: {e}")
  374. error_count += 1
  375. print(f"\n转换完成:")
  376. print(f" 成功: {success_count} 个文件")
  377. print(f" 失败: {error_count} 个文件")
  378. if not args.test and success_count > 0:
  379. print(f" 输出目录: {output_dir}")
  380. if __name__ == "__main__":
  381. # 如果没有命令行参数,使用交互模式
  382. import sys
  383. if len(sys.argv) == 1:
  384. print("JSON标签到YOLO格式转换工具")
  385. print("=" * 50)
  386. # 交互式输入
  387. input_path = input("请输入JSON文件或目录路径: ").strip()
  388. if not input_path:
  389. print("错误: 必须提供输入路径")
  390. sys.exit(1)
  391. output_dir = input("请输入输出目录 (默认: ./yolo_labels): ").strip()
  392. if not output_dir:
  393. output_dir = "./yolo_labels"
  394. classes_file = input("请输入类别映射文件路径 (可选): ").strip()
  395. json_format = input("请输入JSON格式 (auto/labelme/coco/yolo, 默认: auto): ").strip()
  396. if not json_format:
  397. json_format = "auto"
  398. test_mode = input("是否启用测试模式?(y/N): ").strip().lower() == 'y'
  399. # 模拟命令行参数
  400. sys.argv = ['json_to_yolo.py', input_path, '-o', output_dir, '-f', json_format]
  401. if classes_file:
  402. sys.argv.extend(['-c', classes_file])
  403. if test_mode:
  404. sys.argv.append('--test')
  405. main()