4manycircle.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import os
  2. import json
  3. import shutil
  4. import numpy as np
  5. import cv2
  6. from scipy.optimize import leastsq
  7. from tqdm import tqdm
  8. def calc_R(x, y, xc, yc):
  9. return np.sqrt((x - xc) ** 2 + (y - yc) ** 2)
  10. def f(c, x, y):
  11. Ri = calc_R(x, y, *c)
  12. return Ri - Ri.mean()
  13. def fit_arc_and_create_mask(points, shape, point_step_deg=0.5, radius=2):
  14. x_data = points[:, 0]
  15. y_data = points[:, 1]
  16. center_estimate = np.mean(x_data), np.mean(y_data)
  17. try:
  18. center, _ = leastsq(f, center_estimate, args=(x_data, y_data))
  19. xc, yc = center
  20. Ri = calc_R(x_data, y_data, xc, yc)
  21. R = Ri.mean()
  22. h, w = shape
  23. if R <= 0 or not np.isfinite(R) or not np.isfinite(xc) or not np.isfinite(yc):
  24. raise ValueError("拟合圆非法")
  25. # 生成完整圆的点
  26. circle_angles = np.linspace(0, 2 * np.pi, int(360 / point_step_deg), endpoint=False)
  27. full_circle_points = np.stack([
  28. xc + R * np.cos(circle_angles),
  29. yc + R * np.sin(circle_angles)
  30. ], axis=-1).astype(np.float32)
  31. # 只保留图像内部的点
  32. in_bounds_mask = (
  33. (full_circle_points[:, 0] >= 0) & (full_circle_points[:, 0] < w) &
  34. (full_circle_points[:, 1] >= 0) & (full_circle_points[:, 1] < h)
  35. )
  36. clipped_points = full_circle_points[in_bounds_mask]
  37. if len(clipped_points) < 3:
  38. print(f"拟合失败使用的三个点为:\n{points}")
  39. raise ValueError("拟合圆点全部在图像外")
  40. return None, clipped_points.tolist()
  41. except Exception as e:
  42. print(f"⚠️ 圆拟合失败:{e}")
  43. return None, []
  44. def build_shapes_from_points(points, label="arc"):
  45. points_np = np.array(points)
  46. if points_np.shape[0] < 3:
  47. return []
  48. xs = points_np[:, 0]
  49. ys = points_np[:, 1]
  50. shape = {
  51. "label": label,
  52. "points": points,
  53. "shape_type": "polygon",
  54. "flags": {},
  55. "xmin": int(xs.min()),
  56. "ymin": int(ys.min()),
  57. "xmax": int(xs.max()),
  58. "ymax": int(ys.max())
  59. }
  60. return [shape]
  61. def process_folder_labelme(input_dir, output_dir, point_step_deg=0.5, radius=2):
  62. os.makedirs(output_dir, exist_ok=True)
  63. for file_name in tqdm(os.listdir(input_dir)):
  64. if not file_name.endswith(".json"):
  65. continue
  66. json_path = os.path.join(input_dir, file_name)
  67. image_path = json_path.replace(".json", ".jpg")
  68. if not os.path.exists(image_path):
  69. print(f"图像不存在,跳过:{image_path}")
  70. continue
  71. with open(json_path, "r") as f:
  72. label_data = json.load(f)
  73. arc_points_raw = [item['points'][0] for item in label_data['shapes']
  74. if item.get('label') == 'arc' and len(item.get('points', [])) == 1]
  75. if len(arc_points_raw) < 3:
  76. print(f"{file_name} 中 arc 点数不足 3,跳过")
  77. continue
  78. image = cv2.imread(image_path)
  79. h, w = image.shape[:2]
  80. arc_shapes = []
  81. num_groups = len(arc_points_raw) // 3
  82. for i in range(num_groups):
  83. group_points = arc_points_raw[i*3:(i+1)*3]
  84. if len(group_points) < 3:
  85. print(f"{file_name} 第 {i+1} 组点数不足 3,跳过")
  86. continue
  87. points = np.array(group_points)
  88. try:
  89. _, arc_pts = fit_arc_and_create_mask(points, (h, w), point_step_deg, radius)
  90. shapes = build_shapes_from_points(arc_pts, label="arc")
  91. arc_shapes.extend(shapes)
  92. except Exception as e:
  93. print(f"{file_name} 第 {i+1} 组拟合失败,跳过。错误:{e}")
  94. continue
  95. if len(arc_shapes) == 0:
  96. print(f"{file_name} 没有成功拟合的 arc 区域,跳过")
  97. continue
  98. output_json = {
  99. "version": "5.0.1",
  100. "flags": {},
  101. "shapes": arc_shapes,
  102. "imagePath": os.path.basename(image_path),
  103. "imageHeight": h,
  104. "imageWidth": w
  105. }
  106. base_name = os.path.splitext(os.path.basename(image_path))[0]
  107. out_json_path = os.path.join(output_dir, base_name + ".json")
  108. out_img_path = os.path.join(output_dir, base_name + ".jpg")
  109. with open(out_json_path, "w") as f:
  110. json.dump(output_json, f, indent=4)
  111. shutil.copy(image_path, out_img_path)
  112. print(f"\n✅ 所有 arc 区域已保存为 labelme 格式(图像 + json)到目录:{output_dir}")
  113. # ===== 修改为你自己的输入输出路径 =====
  114. input_dir = r"G:\python_ws_g\data\test\jsonjpg_weizhuang12"
  115. output_dir = r"G:\python_ws_g\data\test\circle"
  116. point_step_deg = 0.5
  117. draw_radius = 2
  118. if __name__ == "__main__":
  119. process_folder_labelme(input_dir, output_dir, point_step_deg, draw_radius)