3weizhuang.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import os
  2. import json
  3. import shutil
  4. import numpy as np
  5. from tqdm import tqdm
  6. def random_12_points(h, w, num_points=12, margin=20):
  7. xs = np.random.uniform(margin, w - margin, num_points)
  8. ys = np.random.uniform(margin, h - margin, num_points)
  9. points = np.stack([xs, ys], axis=-1).astype(np.float32)
  10. return points.tolist()
  11. def build_point_shape(pt, label="arc"):
  12. return {
  13. "label": label,
  14. "points": [pt],
  15. "shape_type": "point",
  16. "flags": {},
  17. "attributes": {},
  18. "group_id": None,
  19. "description": "",
  20. "difficult": False,
  21. "kie_linking": []
  22. }
  23. def process_json_folder(input_dir, output_dir, num_points=12, margin=20):
  24. os.makedirs(output_dir, exist_ok=True)
  25. for file_name in tqdm(os.listdir(input_dir)):
  26. if not file_name.endswith(".json"):
  27. continue
  28. json_path = os.path.join(input_dir, file_name)
  29. with open(json_path, "r", encoding="utf-8") as f:
  30. data = json.load(f)
  31. arc_points_raw = [item['points'][0] for item in data['shapes']
  32. if item.get('label') == 'arc' and item.get('shape_type') == 'point']
  33. if len(arc_points_raw) < 3:
  34. print(f"{file_name} 中 arc 点数不足 3,跳过")
  35. continue
  36. num_groups = len(arc_points_raw) // 3
  37. new_shapes = []
  38. h = data.get("imageHeight", 2000)
  39. w = data.get("imageWidth", 2000)
  40. for i in range(num_groups):
  41. # 这里不使用拟合,直接随机生成12个点
  42. try:
  43. circle_pts = random_12_points(h, w, num_points=num_points, margin=margin)
  44. for pt in circle_pts:
  45. shape = build_point_shape(pt, label="arc")
  46. new_shapes.append(shape)
  47. except Exception as e:
  48. print(f"{file_name} 第 {i+1} 组生成失败,跳过。错误:{e}")
  49. continue
  50. if not new_shapes:
  51. continue
  52. new_json = {
  53. "version": data.get("version", "5.0.1"),
  54. "flags": {},
  55. "shapes": new_shapes,
  56. "imagePath": data["imagePath"],
  57. "imageHeight": h,
  58. "imageWidth": w,
  59. "imageData": None
  60. }
  61. base_name = os.path.splitext(file_name)[0]
  62. out_json_path = os.path.join(output_dir, base_name + ".json")
  63. out_img_path = os.path.join(output_dir, data["imagePath"])
  64. with open(out_json_path, "w", encoding="utf-8") as f:
  65. json.dump(new_json, f, indent=4)
  66. old_img_path = os.path.join(input_dir, data["imagePath"])
  67. if os.path.exists(old_img_path):
  68. shutil.copy(old_img_path, out_img_path)
  69. print(f"\n✅ 已处理完成,输出保存至:{output_dir}")
  70. # ========= 路径设置 =========
  71. input_dir = r"G:\python_ws_g\data\test\jsonjpg"
  72. output_dir = r"G:\python_ws_g\data\test\jsonjpg_weizhuang12"
  73. if __name__ == "__main__":
  74. process_json_folder(input_dir, output_dir, num_points=12, margin=20)