a_4point.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import os
  2. import json
  3. import shutil
  4. import numpy as np
  5. import cv2
  6. from scipy.optimize import leastsq
  7. from tqdm import tqdm
  8. def calc_R(x, y, xc, yc):
  9. return np.sqrt((x - xc) ** 2 + (y - yc) ** 2)
  10. def f(c, x, y):
  11. Ri = calc_R(x, y, *c)
  12. return Ri - Ri.mean()
  13. def fit_arc_and_create_mask(points, shape, point_step_deg=0.5, radius=2):
  14. x_data = points[:, 0]
  15. y_data = points[:, 1]
  16. center_estimate = np.mean(x_data), np.mean(y_data)
  17. try:
  18. center, _ = leastsq(f, center_estimate, args=(x_data, y_data))
  19. xc, yc = center
  20. Ri = calc_R(x_data, y_data, xc, yc)
  21. R = Ri.mean()
  22. h, w = shape
  23. if R <= 0 or not np.isfinite(R) or not np.isfinite(xc) or not np.isfinite(yc):
  24. raise ValueError("拟合圆非法")
  25. # 生成完整圆的点
  26. circle_angles = np.linspace(0, 2 * np.pi, int(360 / point_step_deg), endpoint=False)
  27. full_circle_points = np.stack([
  28. xc + R * np.cos(circle_angles),
  29. yc + R * np.sin(circle_angles)
  30. ], axis=-1).astype(np.float32)
  31. # 只保留图像内部的点
  32. in_bounds_mask = (
  33. (full_circle_points[:, 0] >= 0) & (full_circle_points[:, 0] < w) &
  34. (full_circle_points[:, 1] >= 0) & (full_circle_points[:, 1] < h)
  35. )
  36. clipped_points = full_circle_points[in_bounds_mask]
  37. if len(clipped_points) < 3:
  38. print(f"拟合失败使用的三个点为:\n{points}")
  39. raise ValueError("拟合圆点全部在图像外")
  40. return None, clipped_points.tolist()
  41. except Exception as e:
  42. print(f"⚠️ 圆拟合失败:{e}")
  43. return None, []
  44. def build_shapes_from_points(points, label="circle"):
  45. points_np = np.array(points)
  46. if points_np.shape[0] < 3:
  47. return []
  48. xs = points_np[:, 0]
  49. ys = points_np[:, 1]
  50. xmin, xmax = xs.min(), xs.max()
  51. ymin, ymax = ys.min(), ys.max()
  52. eps = 1e-3
  53. # 边界点:每类只能找一个
  54. def find_one(points, cond):
  55. for pt in points:
  56. if cond(pt):
  57. return pt
  58. return None
  59. pt_xmin = find_one(points, lambda pt: abs(pt[0] - xmin) < eps)
  60. pt_xmax = find_one(points, lambda pt: abs(pt[0] - xmax) < eps)
  61. pt_ymin = find_one(points, lambda pt: abs(pt[1] - ymin) < eps)
  62. pt_ymax = find_one(points, lambda pt: abs(pt[1] - ymax) < eps)
  63. boundary_points = []
  64. for pt in [pt_xmin, pt_xmax, pt_ymin, pt_ymax]:
  65. if pt is not None and pt not in boundary_points:
  66. boundary_points.append(pt)
  67. if len(boundary_points) != 4:
  68. print("⚠️ 边界点不足 4 个,跳过该 shape")
  69. return []
  70. shape = {
  71. "label": label,
  72. "points": boundary_points,
  73. "shape_type": "polygon",
  74. "flags": {},
  75. "xmin": int(xmin),
  76. "ymin": int(ymin),
  77. "xmax": int(xmax),
  78. "ymax": int(ymax)
  79. }
  80. return [shape]
  81. def process_folder_labelme(input_dir, output_dir, point_step_deg=0.5, radius=2):
  82. os.makedirs(output_dir, exist_ok=True)
  83. for file_name in tqdm(os.listdir(input_dir)):
  84. if not file_name.endswith(".json"):
  85. continue
  86. json_path = os.path.join(input_dir, file_name)
  87. image_path = json_path.replace(".json", ".jpg")
  88. if not os.path.exists(image_path):
  89. print(f"图像不存在,跳过:{image_path}")
  90. continue
  91. with open(json_path, "r") as f:
  92. label_data = json.load(f)
  93. arc_points_raw = [item['points'][0] for item in label_data['shapes']
  94. if item.get('label') == 'circle' and len(item.get('points', [])) == 1]
  95. if len(arc_points_raw) < 3:
  96. print(f"{file_name} 中 circle 点数不足 3,跳过")
  97. continue
  98. image = cv2.imread(image_path)
  99. h, w = image.shape[:2]
  100. arc_shapes = []
  101. num_groups = len(arc_points_raw) // 3
  102. for i in range(num_groups):
  103. group_points = arc_points_raw[i*3:(i+1)*3]
  104. if len(group_points) < 3:
  105. print(f"{file_name} 第 {i+1} 组点数不足 3,跳过")
  106. continue
  107. points = np.array(group_points)
  108. try:
  109. _, arc_pts = fit_arc_and_create_mask(points, (h, w), point_step_deg, radius)
  110. shapes = build_shapes_from_points(arc_pts, label="circle")
  111. arc_shapes.extend(shapes)
  112. except Exception as e:
  113. print(f"{file_name} 第 {i+1} 组拟合失败,跳过。错误:{e}")
  114. continue
  115. if len(arc_shapes) == 0:
  116. print(f"{file_name} 没有成功拟合的 circle 区域,跳过")
  117. continue
  118. output_json = {
  119. "version": "5.0.1",
  120. "flags": {},
  121. "shapes": arc_shapes,
  122. "imagePath": os.path.basename(image_path),
  123. "imageHeight": h,
  124. "imageWidth": w
  125. }
  126. base_name = os.path.splitext(os.path.basename(image_path))[0]
  127. out_json_path = os.path.join(output_dir, base_name + ".json")
  128. out_img_path = os.path.join(output_dir, base_name + ".jpg")
  129. with open(out_json_path, "w") as f:
  130. json.dump(output_json, f, indent=4)
  131. shutil.copy(image_path, out_img_path)
  132. print(f"\n✅ 所有 circle 区域已保存为 labelme 格式(图像 + json)到目录:{output_dir}")
  133. # ===== 修改为你自己的输入输出路径 =====
  134. input_dir = r"\\192.168.50.222\share\zyh\data\rgb_453_source_json\dataset_for_guanban"
  135. output_dir = r"\\192.168.50.222\share\zyh\data\rgb_453_source_json\4point"
  136. point_step_deg = 0.5
  137. draw_radius = 2
  138. if __name__ == "__main__":
  139. process_folder_labelme(input_dir, output_dir, point_step_deg, draw_radius)