csv_read.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import os
  2. import csv
  3. import json
  4. import shutil
  5. import math
  6. from typing import List, Union, Dict
  7. # === 文件夹配置 ===
  8. csv_folder = r"/data/share/zyh/master_dataset/pokou/merge/251121_251115/csv" # CSV 文件夹
  9. json_folder_json = r"/data/share/zyh/master_dataset/pokou/merge/251121_251115/json" # JSON 文件夹
  10. json_folder_img = r"/data/share/zyh/master_dataset/pokou/merge/251121_251115/image" # 图片文件夹
  11. output_folder = r"/data/share/zyh/master_dataset/dataset_net/pokou_251115_251121/to_dataset" # 输出文件夹
  12. os.makedirs(output_folder, exist_ok=True)
  13. # ==============================================================
  14. # 计算圆弧端点
  15. # ==============================================================
  16. def compute_arc_ends(points: List[List[float]]) -> List[List[float]]:
  17. if len(points) != 3:
  18. return [[0, 0], [0, 0]]
  19. p1, p2, p3 = points
  20. x1, y1 = p1
  21. x2, y2 = p2
  22. x3, y3 = p3
  23. A = 2 * (x2 - x1)
  24. B = 2 * (y2 - y1)
  25. C = x2**2 + y2**2 - x1**2 - y1**2
  26. D = 2 * (x3 - x2)
  27. E = 2 * (y3 - y2)
  28. F = x3**2 + y3**2 - x2**2 - y2**2
  29. denom = A * E - B * D
  30. if denom == 0:
  31. return [p1, p3]
  32. cx = (C * E - F * B) / denom
  33. cy = (A * F - D * C) / denom
  34. angles = [math.atan2(y - cy, x - cx) for x, y in points]
  35. def angle_diff(a1, a2):
  36. diff = (a2 - a1) % (2 * math.pi)
  37. if diff > math.pi:
  38. diff = 2 * math.pi - diff
  39. return diff
  40. pairs = [(0, 1), (0, 2), (1, 2)]
  41. max_diff = -1
  42. end_pair = (0, 1)
  43. for i, j in pairs:
  44. diff = angle_diff(angles[i], angles[j])
  45. if diff > max_diff:
  46. max_diff = diff
  47. end_pair = (i, j)
  48. return [points[end_pair[0]], points[end_pair[1]]]
  49. # ==============================================================
  50. # 根据点匹配到最近椭圆
  51. # ==============================================================
  52. def match_point_to_ellipse(point: List[float], ellipses: List[Dict]) -> int:
  53. x, y = point
  54. min_dist = float("inf")
  55. match_idx = -1
  56. for i, e in enumerate(ellipses):
  57. cx, cy = e["cx"], e["cy"]
  58. dist = math.hypot(x - cx, y - cy)
  59. if dist < min_dist:
  60. min_dist = dist
  61. match_idx = i
  62. return match_idx
  63. # ==============================================================
  64. # 从 CSV 读取椭圆参数映射
  65. # ==============================================================
  66. csv_ellipse_map = {} # filename -> list of ellipse params
  67. for csv_file in os.listdir(csv_folder):
  68. if not csv_file.endswith(".csv"):
  69. continue
  70. csv_path = os.path.join(csv_folder, csv_file)
  71. with open(csv_path, "r", encoding="utf-8-sig") as f:
  72. reader = csv.DictReader(f)
  73. for row in reader:
  74. filename = row["filename"].strip()
  75. shape_str = row["region_shape_attributes"]
  76. try:
  77. shape_data = json.loads(shape_str)
  78. except json.JSONDecodeError:
  79. shape_data = json.loads(shape_str.replace('""', '"'))
  80. if filename not in csv_ellipse_map:
  81. csv_ellipse_map[filename] = []
  82. csv_ellipse_map[filename].append(shape_data)
  83. # ==============================================================
  84. # 遍历 JSON 文件
  85. # ==============================================================
  86. for json_file in os.listdir(json_folder_json):
  87. if not json_file.endswith(".json"):
  88. continue
  89. json_path = os.path.join(json_folder_json, json_file)
  90. filename = json_file.replace(".json", ".jpg") # 图片的名字
  91. img_path = os.path.join(json_folder_img, filename) # 图片从独立文件夹读取
  92. # 图片存在性检查
  93. if not os.path.exists(img_path):
  94. print(f"[WARN] Image not found for: {filename}")
  95. continue
  96. # CSV 中必须有匹配的记录
  97. if filename not in csv_ellipse_map:
  98. print(f"[WARN] No CSV ellipse for: {filename}")
  99. continue
  100. # 读取 JSON
  101. with open(json_path, "r", encoding="utf-8") as jf:
  102. data = json.load(jf)
  103. if "shapes" not in data:
  104. data["shapes"] = []
  105. # 获取 JSON 中的单点 arc 标注
  106. arc_points = [
  107. s["points"][0]
  108. for s in data["shapes"]
  109. if s.get("label") == "arc" and "points" in s and len(s["points"]) == 1
  110. ]
  111. # 从 CSV 获取椭圆信息
  112. ellipses = csv_ellipse_map[filename]
  113. ellipse_point_map = {i: [] for i in range(len(ellipses))}
  114. # 将 arc 点匹配到最近的椭圆
  115. for pt in arc_points:
  116. idx = match_point_to_ellipse(pt, ellipses)
  117. ellipse_point_map[idx].append(pt)
  118. # 生成新的 arc shapes
  119. new_arc_shapes = []
  120. for idx, pts in ellipse_point_map.items():
  121. if len(pts) != 3:
  122. print(f"[WARN] {filename} ellipse {idx} has {len(pts)} points (expected 3)")
  123. ends = [[0, 0], [0, 0]]
  124. else:
  125. ends = compute_arc_ends(pts)
  126. e = ellipses[idx]
  127. arc_shape = {
  128. "label": "arc",
  129. "points": pts,
  130. "params": [
  131. e.get("cx", 0),
  132. e.get("cy", 0),
  133. e.get("rx", 0),
  134. e.get("ry", 0),
  135. e.get("theta", 0),
  136. ],
  137. "ends": ends,
  138. "group_id": None,
  139. "description": "",
  140. "difficult": False,
  141. "shape_type": "arc",
  142. "flags": {},
  143. "attributes": {},
  144. }
  145. new_arc_shapes.append(arc_shape)
  146. # 删除旧 arc,添加新 arc
  147. remaining = [s for s in data["shapes"] if s.get("label") != "arc"]
  148. data["shapes"] = remaining + new_arc_shapes
  149. # 输出 JSON
  150. output_json = os.path.join(output_folder, json_file)
  151. with open(output_json, "w", encoding="utf-8") as jf:
  152. json.dump(data, jf, ensure_ascii=False, indent=2)
  153. # 复制图片
  154. shutil.copy2(img_path, os.path.join(output_folder, filename))
  155. print(f"[OK] Saved merged data for: {filename}")
  156. print("\nAll done! Output in:", output_folder)