csv_show.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import os
  2. import json
  3. import cv2
  4. import numpy as np
  5. import math
  6. # === 配置 ===
  7. input_folder = r"/data/share/zyh/master_dataset/dataset_net/pokou_251115_251121/resize" # 输入 JSON 和图片
  8. output_folder = os.path.join(os.path.dirname(input_folder), "show")
  9. os.makedirs(output_folder, exist_ok=True)
  10. # === 工具函数:生成椭圆轨迹点 ===
  11. def ellipse_points(cx, cy, rx, ry, theta, num=100):
  12. """
  13. 返回椭圆轨迹上的 num 个点
  14. theta 单位为弧度
  15. """
  16. t = np.linspace(0, 2 * np.pi, num)
  17. x = rx * np.cos(t)
  18. y = ry * np.sin(t)
  19. # 旋转
  20. xr = x * np.cos(theta) - y * np.sin(theta)
  21. yr = x * np.sin(theta) + y * np.cos(theta)
  22. # 平移
  23. xr += cx
  24. yr += cy
  25. return np.stack([xr, yr], axis=1).astype(np.int32)
  26. # === 遍历输出文件夹 ===
  27. for file in os.listdir(input_folder):
  28. if not file.endswith(".json"):
  29. continue
  30. json_path = os.path.join(input_folder, file)
  31. img_name = file.replace(".json", ".jpg")
  32. img_path = os.path.join(input_folder, img_name)
  33. if not os.path.exists(img_path):
  34. print(f"?? Image not found: {img_name}")
  35. continue
  36. img = cv2.imread(img_path)
  37. if img is None:
  38. continue
  39. with open(json_path, "r", encoding="utf-8") as jf:
  40. data = json.load(jf)
  41. # === 绘制 shapes ===
  42. for shape in data.get("shapes", []):
  43. if shape.get("label") == "arc":
  44. pts = np.array(shape.get("points", []), dtype=np.int32)
  45. ends = np.array(shape.get("ends", []), dtype=np.int32)
  46. params = shape.get("params", [0, 0, 0, 0, 0])
  47. cx, cy, rx, ry, theta = params
  48. # 绘制三点
  49. for p in pts:
  50. cv2.circle(img, (int(p[0]), int(p[1])), 5, (0, 0, 255), -1) # 红色
  51. # 绘制端点
  52. for p in ends:
  53. cv2.circle(img, (int(p[0]), int(p[1])), 7, (0, 255, 0), -1) # 绿色
  54. # 绘制椭圆轨迹
  55. ep = ellipse_points(cx, cy, rx, ry, theta)
  56. for i in range(len(ep) - 1):
  57. cv2.line(img, tuple(ep[i]), tuple(ep[i + 1]), (255, 0, 0), 2) # 蓝色轨迹
  58. elif shape.get("shape_type") == "line":
  59. pts = shape.get("points", [])
  60. if len(pts) >= 2:
  61. cv2.line(img, tuple(map(int, pts[0])), tuple(map(int, pts[1])), (0, 255, 255), 2) # 黄色线
  62. # === 保存结果 ===
  63. save_path = os.path.join(output_folder, img_name)
  64. cv2.imwrite(save_path, img)
  65. print(f"Saved visualization for {img_name} -> {save_path}")
  66. print("\nAll done! Visualizations saved in:", output_folder)