line_dataset.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. import cv2
  2. import imageio
  3. import numpy as np
  4. from skimage.draw import ellipse
  5. from torch.utils.data.dataset import T_co
  6. from libs.vision_libs.utils import draw_keypoints
  7. from models.base.base_dataset import BaseDataset
  8. import json
  9. import os
  10. import PIL
  11. import matplotlib as mpl
  12. from torchvision.utils import draw_bounding_boxes
  13. import torchvision.transforms.v2 as transforms
  14. import torch
  15. import matplotlib.pyplot as plt
  16. from models.base.transforms import get_transforms
  17. def validate_keypoints(keypoints, image_width, image_height):
  18. for kp in keypoints:
  19. x, y, v = kp
  20. if not (0 <= x < image_width and 0 <= y < image_height):
  21. raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
  22. """
  23. 直接读取xanlabel标注的数据集json格式
  24. """
  25. class LineDataset(BaseDataset):
  26. def __init__(self, dataset_path, data_type, transforms=None, augmentation=False, dataset_type=None, img_type='rgb',
  27. target_type='pixel'):
  28. super().__init__(dataset_path)
  29. self.data_path = dataset_path
  30. self.data_type = data_type
  31. print(f'data_path:{dataset_path}')
  32. self.transforms = transforms
  33. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  34. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  35. self.imgs = os.listdir(self.img_path)
  36. self.lbls = os.listdir(self.lbl_path)
  37. self.target_type = target_type
  38. self.img_type = img_type
  39. self.augmentation = augmentation
  40. print(f'augmentation:{augmentation}')
  41. # self.default_transform = DefaultTransform()
  42. def __getitem__(self, index) -> T_co:
  43. img_path = os.path.join(self.img_path, self.imgs[index])
  44. if self.data_type == 'tiff':
  45. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
  46. img = imageio.v3.imread(img_path)[:, :, 0]
  47. print(f'img shape:{img.shape}')
  48. w, h = img.shape[:2]
  49. img = img.reshape(w, h, 1)
  50. img_3channel = np.zeros((w, h, 3), dtype=img.dtype)
  51. img_3channel[:, :, 2] = img[:, :, 0]
  52. img = torch.from_numpy(img_3channel).permute(2, 1, 0)
  53. else:
  54. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  55. img = PIL.Image.open(img_path).convert('RGB')
  56. w, h = img.size
  57. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  58. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  59. self.transforms = get_transforms(augmention=self.augmentation)
  60. img, target = self.transforms(img, target)
  61. return img, target
  62. def __len__(self):
  63. return len(self.imgs)
  64. def read_target(self, item, lbl_path, shape, extra=None):
  65. # print(f'shape:{shape}')
  66. # print(f'lbl_path:{lbl_path}')
  67. with open(lbl_path, 'r') as file:
  68. lable_all = json.load(file)
  69. objs = lable_all["shapes"]
  70. point_pairs = objs[0]['points']
  71. # print(f'point_pairs:{point_pairs}')
  72. target = {}
  73. target["image_id"] = torch.tensor(item)
  74. boxes, lines, points, arc_mask, circle_4points, labels, arc_ends, arc_params = get_boxes_lines(objs, shape)
  75. if points is not None:
  76. target["points"] = points
  77. if lines is not None:
  78. a = torch.full((lines.shape[0],), 2).unsqueeze(1)
  79. lines = torch.cat((lines, a), dim=1)
  80. target["lines"] = lines.to(torch.float32).view(-1, 2, 3)
  81. # print(f'lines shape:{ target["lines"].shape}')
  82. if arc_mask is not None:
  83. target['arc_mask'] = arc_mask
  84. # print(f'arc_mask dataset')
  85. # else:
  86. # print(f'not arc_mask dataset')
  87. if arc_ends is not None:
  88. target['mask_ends'] = arc_ends
  89. target['mask_params'] = arc_params
  90. arc_angles = compute_arc_angles(arc_ends, arc_params)
  91. # print(arc_angles)
  92. # print(arc_params)
  93. arc_masks = []
  94. for i in range(len(arc_params)):
  95. arc7=arc_params[i] + arc_angles[i].tolist()
  96. arc_masks.append(arc_to_mask(arc7, shape, line_width=1))
  97. print(f'arc_masks:{torch.stack(arc_masks, dim=0).shape}')
  98. target['arc_masks'] = torch.stack(arc_masks, dim=0)
  99. if circle_4points is not None:
  100. target['circles'] = circle_4points
  101. circle_masks = generate_ellipse_mask(shape, points_to_ellipse(circle_4points))
  102. target['circle_masks'] = torch.tensor(circle_masks, dtype=torch.float32).unsqueeze(0)
  103. target["boxes"] = boxes
  104. target["labels"] = labels
  105. # target["boxes"], lines,target["points"], target["labels"] = get_boxes_lines(objs,shape)
  106. # print(f'lines:{lines}')
  107. # target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
  108. # print(f'target points:{target["points"]}')
  109. # target["lines"] = lines.to(torch.float32).view(-1,2,3)
  110. # print(f'')
  111. # print(f'lines:{target["lines"].shape}')
  112. target["img_size"] = shape
  113. # validate_keypoints(lines, shape[0], shape[1])
  114. return target
  115. def show(self, idx, show_type='all'):
  116. image, target = self.__getitem__(idx)
  117. cmap = plt.get_cmap("jet")
  118. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  119. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  120. sm.set_array([])
  121. # img_path = os.path.join(self.img_path, self.imgs[idx])
  122. # print(f'boxes:{target["boxes"]}')
  123. img = image
  124. if show_type == 'arc_masks':
  125. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
  126. colors="yellow", width=1)
  127. # arc = target['arc']
  128. arc_mask = target['arc_masks']
  129. # print(f'taget circle:{arc.shape}')
  130. print(f'target circle_masks:{arc_mask.shape}')
  131. plt.imshow(arc_mask.squeeze(0))
  132. plt.show()
  133. if show_type == 'circle_masks':
  134. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
  135. colors="yellow", width=1)
  136. circle = target['circles']
  137. circle_mask = target['circle_masks']
  138. print(f'taget circle:{circle.shape}')
  139. print(f'target circle_masks:{circle_mask.shape}')
  140. plt.imshow(circle_mask.squeeze(0))
  141. keypoint_img = draw_keypoints(boxed_image, circle, colors='red', width=3)
  142. # plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  143. plt.show()
  144. # if show_type=='lines':
  145. # keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['lines'],colors='red',width=3)
  146. # plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  147. # plt.show()
  148. if show_type == 'points':
  149. # print(f'points:{target['points'].shape}')
  150. keypoint_img = draw_keypoints((img * 255).to(torch.uint8), target['points'].unsqueeze(1), colors='red',
  151. width=3)
  152. plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  153. plt.show()
  154. if show_type == 'boxes':
  155. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
  156. colors="yellow", width=1)
  157. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  158. plt.show()
  159. def show_img(self, img_path):
  160. pass
  161. def draw_el(all):
  162. # 解析椭圆参数
  163. if isinstance(all, torch.Tensor):
  164. all = all.cpu().numpy()
  165. x, y, a, b, q, q1, q2 = all
  166. theta = np.radians(q)
  167. phi1 = np.radians(q1) # 第一个点的参数角
  168. phi2 = np.radians(q2) # 第二个点的参数角
  169. # 生成椭圆上的点
  170. phi = np.linspace(0, 2 * np.pi, 500)
  171. x_ellipse = x + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
  172. y_ellipse = y + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
  173. # 计算两个指定点的坐标
  174. def param_to_point(phi, xc, yc, a, b, theta):
  175. x = xc + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
  176. y = yc + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
  177. return x, y
  178. P1 = param_to_point(phi1, x, y, a, b, theta)
  179. P2 = param_to_point(phi2, x, y, a, b, theta)
  180. # 创建画布并显示背景图片(使用传入的background_img,shape为[H, W, C])
  181. plt.figure(figsize=(10, 10))
  182. # plt.imshow(background_img) # 直接显示背景图
  183. # 绘制椭圆及相关元素
  184. plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2)
  185. plt.plot(x, y, 'ko', markersize=8)
  186. plt.plot(P1[0], P1[1], 'ro', markersize=10)
  187. plt.plot(P2[0], P2[1], 'go', markersize=10)
  188. plt.show()
  189. def arc_to_mask(arc7, shape, line_width=1):
  190. """
  191. Generate a binary mask of an elliptical arc.
  192. Args:
  193. xc, yc (float): 椭圆中心
  194. a, b (float): 长半轴、短半轴 (a >= b)
  195. theta (float): 椭圆旋转角度(**弧度**,逆时针,相对于 x 轴)
  196. phi1, phi2 (float): 起始和终止参数角(**弧度**,在 [0, 2π) 内)
  197. H, W (int): 输出 mask 的高度和宽度
  198. line_width (int): 弧线宽度(像素)
  199. Returns:
  200. mask (Tensor): [H, W], dtype=torch.uint8, 0/255
  201. """
  202. # 确保 phi1 -> phi2 是正向(可处理跨 2π 的情况)
  203. xc, yc, a, b, theta, phi1, phi2 = arc7
  204. H, W = shape
  205. if phi2 < phi1:
  206. phi2 += 2 * np.pi
  207. # 生成参数角(足够密集,避免断线)
  208. num_points = max(int(200 * abs(phi2 - phi1) / (2 * np.pi)), 10)
  209. phi = np.linspace(phi1, phi2, num_points)
  210. # 椭圆参数方程(先在未旋转坐标系下计算)
  211. x_local = a * np.cos(phi)
  212. y_local = b * np.sin(phi)
  213. # 应用旋转和平移
  214. cos_t = np.cos(theta)
  215. sin_t = np.sin(theta)
  216. x_rot = x_local * cos_t - y_local * sin_t + xc
  217. y_rot = x_local * cos_t + y_local * sin_t + yc
  218. # 转为整数坐标(OpenCV 需要 int32)
  219. points = np.stack([x_rot, y_rot], axis=1).astype(np.int32)
  220. # 创建空白图像
  221. img = np.zeros((H, W), dtype=np.uint8)
  222. # 绘制折线(antialias=False 更适合 mask)
  223. cv2.polylines(img, [points], isClosed=False, color=255, thickness=line_width, lineType=cv2.LINE_AA)
  224. return torch.from_numpy(img).byte() # [H, W], values: 0 or 255
  225. def compute_arc_angles(gt_mask_ends, gt_mask_params):
  226. """
  227. 给定椭圆上的一个点,计算其对应的参数角 phi(弧度)。
  228. Parameters:
  229. point: tuple or array-like, (x, y)
  230. ellipse_param: tuple or array-like, (xc, yc, a, b, theta)
  231. Returns:
  232. phi: float, in [0, 2*pi)
  233. """
  234. results = []
  235. gt_mask_params_tensor = torch.tensor(gt_mask_params,
  236. dtype=gt_mask_ends.dtype,
  237. device=gt_mask_ends.device)
  238. for ends_img, params_img in zip(gt_mask_ends, gt_mask_params_tensor):
  239. # print(f'params_img:{params_img}')
  240. if torch.norm(params_img) < 1e-6: # L2 norm near zero
  241. results.append(torch.zeros(2, device=params_img.device, dtype=params_img.dtype))
  242. continue
  243. x, y = ends_img
  244. xc, yc, a, b, theta = params_img
  245. # 1. 平移到中心
  246. dx = x - xc
  247. dy = y - yc
  248. # 2. 逆旋转(旋转 -theta)
  249. cos_t = torch.cos(theta)
  250. sin_t = torch.sin(theta)
  251. X = dx * cos_t + dy * sin_t
  252. Y = -dx * sin_t + dy * cos_t
  253. # 3. 归一化到单位圆(除以 a, b)
  254. cos_phi = X / a
  255. sin_phi = Y / b
  256. # 4. 用 atan2 求角度(自动处理象限)
  257. phi = torch.atan2(sin_phi, cos_phi)
  258. # 5. 转换到 [0, 2π)
  259. phi = torch.where(phi < 0, phi + 2 * torch.pi, phi)
  260. results.append(phi)
  261. return results
  262. def points_to_ellipse(points):
  263. """
  264. 根据提供的四个点估计椭圆参数。
  265. :param points: Tensor of shape (4, 2) 表示椭圆上的四个点
  266. :return: 返回 (cx, cy, r1, r2, orientation) 其中 cx, cy 是中心坐标,r1, r2 分别是长轴和短轴半径,orientation 是椭圆的方向(弧度)
  267. """
  268. # 转换为numpy数组进行计算
  269. pts = points.numpy()
  270. pts = pts.reshape(-1, 2)
  271. center = np.mean(pts, axis=0)
  272. A = np.hstack(
  273. [pts[:, 0:1] ** 2, pts[:, 0:1] * pts[:, 1:2], pts[:, 1:2] ** 2, pts[:, :2], np.ones((pts.shape[0], 1))])
  274. b = np.ones(pts.shape[0])
  275. x = np.linalg.lstsq(A, b, rcond=None)[0]
  276. # 解析解参见 https://en.wikipedia.org/wiki/Ellipse#General_ellipse
  277. a, b, c, d, f, g = x.ravel()
  278. numerator = 2 * (a * f * f + c * d * d + g * b * b - 2 * b * d * f - a * c * g)
  279. denominator1 = (b * b - a * c) * ((c - a) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a))
  280. denominator2 = (b * b - a * c) * ((a - c) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a))
  281. major_axis = np.sqrt(numerator / denominator1)
  282. minor_axis = np.sqrt(numerator / denominator2)
  283. distances = np.linalg.norm(pts - center, axis=1)
  284. long_axis_length = np.max(distances) * 2
  285. short_axis_length = np.min(distances) * 2
  286. orientation = np.arctan2(pts[1, 1] - pts[0, 1], pts[1, 0] - pts[0, 0])
  287. return center[0], center[1], long_axis_length / 2, short_axis_length / 2, orientation
  288. def generate_ellipse_mask(shape, ellipse_params):
  289. """
  290. 在指定形状的图像上生成椭圆mask。
  291. :param shape: 输出mask的形状 (HxW)
  292. :param ellipse_params: 椭圆参数 (cx, cy, rx, ry, orientation)
  293. :return: 椭圆mask
  294. """
  295. cx, cy, rx, ry, orientation = ellipse_params
  296. img = np.zeros(shape, dtype=np.uint8)
  297. cx, cy, rx, ry = int(cx), int(cy), int(rx), int(ry)
  298. rr, cc = ellipse(cy, cx, ry, rx, shape)
  299. img[rr, cc] = 1
  300. return img
  301. def sort_points_clockwise(points):
  302. points = np.array(points)
  303. top_left_idx = np.lexsort((points[:, 0], points[:, 1]))[0]
  304. reference_point = points[top_left_idx]
  305. def angle_to_reference(point):
  306. return np.arctan2(point[1] - reference_point[1], point[0] - reference_point[0])
  307. angles = np.apply_along_axis(angle_to_reference, 1, points)
  308. angles[angles < 0] += 2 * np.pi
  309. sorted_indices = np.argsort(angles)
  310. sorted_points = points[sorted_indices]
  311. return sorted_points.tolist()
  312. def get_boxes_lines(objs, shape):
  313. boxes = []
  314. labels = []
  315. h, w = shape
  316. line_point_pairs = []
  317. points = []
  318. arc_mask = []
  319. arc_ends = []
  320. arc_params = []
  321. circle_4points = []
  322. for obj in objs:
  323. # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  324. # print(f"points:{obj['points']}")
  325. label = obj['label']
  326. if label == 'line' or label == 'dseam1':
  327. a, b = obj['points'][0], obj['points'][1]
  328. line_point_pairs.append(a)
  329. line_point_pairs.append(b)
  330. xmin = max(0, (min(a[0], b[0]) - 6))
  331. xmax = min(w, (max(a[0], b[0]) + 6))
  332. ymin = max(0, (min(a[1], b[1]) - 6))
  333. ymax = min(h, (max(a[1], b[1]) + 6))
  334. boxes.append([xmin, ymin, xmax, ymax])
  335. labels.append(torch.tensor(2))
  336. elif label == 'point':
  337. p = obj['points'][0]
  338. xmin = max(0, p[0] - 12)
  339. xmax = min(w, p[0] + 12)
  340. ymin = max(0, p[1] - 12)
  341. ymax = min(h, p[1] + 12)
  342. points.append(p)
  343. labels.append(torch.tensor(1))
  344. boxes.append([xmin, ymin, xmax, ymax])
  345. elif label == 'arc':
  346. arc_points = obj['points']
  347. params = obj['params']
  348. ends = obj['ends']
  349. arc_ends.append(ends)
  350. arc_params.append(params)
  351. xs = [p[0] for p in arc_points]
  352. ys = [p[1] for p in arc_points]
  353. xmin, xmax = min(xs), max(xs)
  354. ymin, ymax = min(ys), max(ys)
  355. boxes.append([xmin, ymin, xmax, ymax])
  356. labels.append(torch.tensor(3))
  357. elif label == 'circle':
  358. # print(f'len circle_4points: {len(obj['points'])}')
  359. points = sort_points_clockwise(obj['points'])
  360. circle_4points.append(points)
  361. xmin = max(obj['xmin'] - 40, 0)
  362. xmax = min(obj['xmax'] + 40, w)
  363. ymin = max(obj['ymin'] - 40, 0)
  364. ymax = min(obj['ymax'] + 40, h)
  365. boxes.append([xmin, ymin, xmax, ymax])
  366. labels.append(torch.tensor(4))
  367. boxes = torch.tensor(boxes, dtype=torch.float32)
  368. print(f'boxes:{boxes.shape}')
  369. labels = torch.tensor(labels)
  370. if len(points) == 0:
  371. points = None
  372. else:
  373. points = torch.tensor(points, dtype=torch.float32)
  374. print(f'read labels:{labels}')
  375. # print(f'read points:{points}')
  376. if len(line_point_pairs) == 0:
  377. line_point_pairs = None
  378. else:
  379. line_point_pairs = torch.tensor(line_point_pairs)
  380. # print(f'line_point_pairs:{line_point_pairs.shape},{line_point_pairs.dtype}')
  381. # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}')
  382. if len(arc_mask) == 0:
  383. arc_mask = None
  384. else:
  385. arc_mask = torch.tensor(arc_mask, dtype=torch.float32)
  386. print(f'arc_mask shape :{arc_mask.shape},{arc_mask.dtype}')
  387. if len(arc_ends) == 0:
  388. arc_ends = None
  389. else:
  390. arc_ends = torch.tensor(arc_ends, dtype=torch.float32)
  391. if len(circle_4points) == 0:
  392. circle_4points = None
  393. else:
  394. # for circle_4point in circle_4points:
  395. # print(f'circle_4point len111:{len(circle_4point)}')
  396. circle_4points = torch.tensor(circle_4points, dtype=torch.float32)
  397. # print(f'circle_4points shape:{circle_4points.shape}')
  398. return boxes, line_point_pairs, points, arc_mask, circle_4points, labels, arc_ends, arc_params
  399. if __name__ == '__main__':
  400. path = r'\\192.168.50.222/share/lm/1112/a_dataset'
  401. dataset = LineDataset(dataset_path=path, dataset_type='train', augmentation=False, data_type='jpg')
  402. dataset.show(9, show_type='arc_masks')