mask.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import os
  2. import json
  3. import shutil
  4. import numpy as np
  5. import cv2
  6. from scipy.optimize import leastsq
  7. from tqdm import tqdm
  8. def calc_R(x, y, xc, yc):
  9. return np.sqrt((x - xc) ** 2 + (y - yc) ** 2)
  10. def f(c, x, y):
  11. Ri = calc_R(x, y, *c)
  12. return Ri - Ri.mean()
  13. def fit_arc_and_create_mask(points, shape, point_step_deg=0.5, radius=2):
  14. x_data = points[:, 0]
  15. y_data = points[:, 1]
  16. center_estimate = np.mean(x_data), np.mean(y_data)
  17. center, _ = leastsq(f, center_estimate, args=(x_data, y_data))
  18. xc, yc = center
  19. Ri = calc_R(x_data, y_data, xc, yc)
  20. R = Ri.mean()
  21. # 极角
  22. angles = np.arctan2(points[:, 1] - yc, points[:, 0] - xc)
  23. start_angle = np.min(angles)
  24. end_angle = np.max(angles)
  25. # 确保角度顺序正确
  26. if end_angle - start_angle > np.pi:
  27. start_angle, end_angle = end_angle, start_angle + 2 * np.pi
  28. arc_angles = np.arange(start_angle, end_angle, np.deg2rad(point_step_deg))
  29. arc_points = np.stack([
  30. xc + R * np.cos(arc_angles),
  31. yc + R * np.sin(arc_angles)
  32. ], axis=-1).astype(np.int32)
  33. # 创建 mask 并画出拟合的弧线(扩展为像素块)
  34. mask = np.zeros(shape, dtype=np.uint8)
  35. for pt in arc_points:
  36. cv2.circle(mask, tuple(pt), radius=radius, color=255, thickness=-1)
  37. return mask, arc_points.tolist()
  38. def build_shapes_from_mask(mask, label="arc"):
  39. contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  40. shapes = []
  41. for contour in contours:
  42. contour = contour.squeeze()
  43. if len(contour.shape) != 2 or contour.shape[0] < 3:
  44. continue # 不是合法 polygon
  45. points = contour.tolist()
  46. xs = [p[0] for p in points]
  47. ys = [p[1] for p in points]
  48. shape = {
  49. "label": label,
  50. "points": points,
  51. "shape_type": "polygon",
  52. "flags": {},
  53. "xmin": int(min(xs)),
  54. "ymin": int(min(ys)),
  55. "xmax": int(max(xs)),
  56. "ymax": int(max(ys))
  57. }
  58. shapes.append(shape)
  59. return shapes
  60. def process_folder_labelme(input_dir, output_dir, point_step_deg=0.5, radius=2):
  61. os.makedirs(output_dir, exist_ok=True)
  62. for file_name in tqdm(os.listdir(input_dir)):
  63. if not file_name.endswith(".json"):
  64. continue
  65. json_path = os.path.join(input_dir, file_name)
  66. image_path = json_path.replace(".json", ".jpg")
  67. if not os.path.exists(image_path):
  68. print(f"图像不存在,跳过:{image_path}")
  69. continue
  70. with open(json_path, "r") as f:
  71. label_data = json.load(f)
  72. points_list = [item['points'][0] for item in label_data['shapes'] if item['label'] == 'arc']
  73. if len(points_list) != 3:
  74. print(f"{file_name} 中 arc 点数不足 3,跳过")
  75. continue
  76. image = cv2.imread(image_path)
  77. h, w = image.shape[:2]
  78. points = np.array(points_list)
  79. mask, arc_points = fit_arc_and_create_mask(points, (h, w), point_step_deg, radius)
  80. shapes = build_shapes_from_mask(mask, label="arc")
  81. if len(shapes) == 0:
  82. print(f"{file_name} 拟合失败,跳过")
  83. continue
  84. output_json = {
  85. "version": "5.0.1",
  86. "flags": {},
  87. "shapes": shapes,
  88. "imagePath": os.path.basename(image_path),
  89. "imageHeight": h,
  90. "imageWidth": w
  91. }
  92. base_name = os.path.splitext(os.path.basename(image_path))[0]
  93. out_json_path = os.path.join(output_dir, base_name + ".json")
  94. out_img_path = os.path.join(output_dir, base_name + ".jpg")
  95. with open(out_json_path, "w") as f:
  96. json.dump(output_json, f, indent=4)
  97. shutil.copy(image_path, out_img_path)
  98. print(f"\n✅ 所有 arc 区域已保存为 labelme 格式(图像 + json)到目录:{output_dir}")
  99. # ===== 修改为你自己的输入输出路径 =====
  100. input_dir = r"G:\python_ws_g\data\arc"
  101. output_dir = r"G:\python_ws_g\data\arcout"
  102. point_step_deg = 0.2
  103. draw_radius = 3
  104. if __name__ == "__main__":
  105. process_folder_labelme(input_dir, output_dir, point_step_deg, draw_radius)