te.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import numpy as np
  2. import cv2
  3. import torch
  4. def arc_to_mask(xc, yc, a, b, theta, phi1, phi2, H, W, line_width=1):
  5. """
  6. Generate a binary mask of an elliptical arc.
  7. Args:
  8. xc, yc (float): 椭圆中心
  9. a, b (float): 长半轴、短半轴 (a >= b)
  10. theta (float): 椭圆旋转角度(**弧度**,逆时针,相对于 x 轴)
  11. phi1, phi2 (float): 起始和终止参数角(**弧度**,在 [0, 2π) 内)
  12. H, W (int): 输出 mask 的高度和宽度
  13. line_width (int): 弧线宽度(像素)
  14. Returns:
  15. mask (Tensor): [H, W], dtype=torch.uint8, 0/255
  16. """
  17. # 确保 phi1 -> phi2 是正向(可处理跨 2π 的情况)
  18. if phi2 < phi1:
  19. phi2 += 2 * np.pi
  20. # 生成参数角(足够密集,避免断线)
  21. num_points = max(int(200 * abs(phi2 - phi1) / (2 * np.pi)), 10)
  22. phi = np.linspace(phi1, phi2, num_points)
  23. # 椭圆参数方程(先在未旋转坐标系下计算)
  24. x_local = a * np.cos(phi)
  25. y_local = b * np.sin(phi)
  26. # 应用旋转和平移
  27. cos_t = np.cos(theta)
  28. sin_t = np.sin(theta)
  29. x_rot = x_local * cos_t - y_local * sin_t + xc
  30. y_rot = x_local * cos_t + y_local * sin_t + yc
  31. # 转为整数坐标(OpenCV 需要 int32)
  32. points = np.stack([x_rot, y_rot], axis=1).astype(np.int32)
  33. # 创建空白图像
  34. img = np.zeros((H, W), dtype=np.uint8)
  35. # 绘制折线(antialias=False 更适合 mask)
  36. cv2.polylines(img, [points], isClosed=False, color=255, thickness=line_width, lineType=cv2.LINE_AA)
  37. return torch.from_numpy(img).byte() # [H, W], values: 0 or 255
  38. # 椭圆参数
  39. xc, yc = 100.0, 100.0
  40. a, b = 80.0, 40.0
  41. theta = np.radians(30) # 30度 → 弧度
  42. phi1 = np.radians(0) # 从右侧开始
  43. phi2 = np.radians(180) # 到左侧结束(上半椭圆)
  44. H, W = 200, 200
  45. mask = arc_to_mask(xc, yc, a, b, theta, phi1, phi2, H, W, line_width=2)
  46. # print(mask.shape) # torch.Size([200, 200])
  47. # print(mask.dtype) # torch.uint8
  48. # 可视化(调试用)
  49. import matplotlib.pyplot as plt
  50. plt.imshow(mask.numpy(), cmap='gray')
  51. plt.title("Elliptical Arc Mask")
  52. plt.show()