line_dataset.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. import imageio
  2. import numpy as np
  3. from skimage.draw import ellipse
  4. from torch.utils.data.dataset import T_co
  5. from libs.vision_libs.utils import draw_keypoints
  6. from models.base.base_dataset import BaseDataset
  7. import json
  8. import os
  9. import PIL
  10. import matplotlib as mpl
  11. from torchvision.utils import draw_bounding_boxes
  12. import torchvision.transforms.v2 as transforms
  13. import torch
  14. import matplotlib.pyplot as plt
  15. from models.base.transforms import get_transforms
  16. def validate_keypoints(keypoints, image_width, image_height):
  17. for kp in keypoints:
  18. x, y, v = kp
  19. if not (0 <= x < image_width and 0 <= y < image_height):
  20. raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
  21. """
  22. 直接读取xanlabel标注的数据集json格式
  23. """
  24. class LineDataset(BaseDataset):
  25. def __init__(self, dataset_path, data_type, transforms=None,augmentation=False, dataset_type=None,img_type='rgb', target_type='pixel'):
  26. super().__init__(dataset_path)
  27. self.data_path = dataset_path
  28. self.data_type = data_type
  29. print(f'data_path:{dataset_path}')
  30. self.transforms = transforms
  31. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  32. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  33. self.imgs = os.listdir(self.img_path)
  34. self.lbls = os.listdir(self.lbl_path)
  35. self.target_type = target_type
  36. self.img_type=img_type
  37. self.augmentation=augmentation
  38. print(f'augmentation:{augmentation}')
  39. # self.default_transform = DefaultTransform()
  40. def __getitem__(self, index) -> T_co:
  41. img_path = os.path.join(self.img_path, self.imgs[index])
  42. if self.data_type == 'tiff':
  43. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
  44. img = imageio.v3.imread(img_path)[:,:,0]
  45. print(f'img shape:{img.shape}')
  46. w, h = img.shape[:2]
  47. img=img.reshape(w,h,1)
  48. img_3channel = np.zeros((w, h, 3), dtype=img.dtype)
  49. img_3channel[:, :, 2] = img[:, :, 0]
  50. img = torch.from_numpy(img_3channel).permute(2, 1, 0)
  51. else:
  52. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  53. img = PIL.Image.open(img_path).convert('RGB')
  54. w, h = img.size
  55. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  56. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  57. self.transforms=get_transforms(augmention=self.augmentation)
  58. img, target = self.transforms(img, target)
  59. return img, target
  60. def __len__(self):
  61. return len(self.imgs)
  62. def read_target(self, item, lbl_path, shape, extra=None):
  63. # print(f'shape:{shape}')
  64. # print(f'lbl_path:{lbl_path}')
  65. with open(lbl_path, 'r') as file:
  66. lable_all = json.load(file)
  67. objs = lable_all["shapes"]
  68. point_pairs=objs[0]['points']
  69. # print(f'point_pairs:{point_pairs}')
  70. target = {}
  71. target["image_id"] = torch.tensor(item)
  72. boxes, lines, points, arc_mask,circle_4points,labels = get_boxes_lines(objs, shape)
  73. if points is not None:
  74. target["points"]=points
  75. if lines is not None:
  76. a = torch.full((lines.shape[0],), 2).unsqueeze(1)
  77. lines = torch.cat((lines, a), dim=1)
  78. target["lines"] = lines.to(torch.float32).view(-1, 2, 3)
  79. # print(f'lines shape:{ target["lines"].shape}')
  80. if arc_mask is not None:
  81. target['arc_mask']=arc_mask
  82. # print(f'arc_mask dataset')
  83. # else:
  84. # print(f'not arc_mask dataset')
  85. if circle_4points is not None:
  86. target['circles']=circle_4points
  87. circle_masks=generate_ellipse_mask(shape,points_to_ellipse(circle_4points))
  88. target['circle_masks'] =torch.tensor(circle_masks,dtype=torch.float32).unsqueeze(0)
  89. target["boxes"]=boxes
  90. target["labels"]=labels
  91. # target["boxes"], lines,target["points"], target["labels"] = get_boxes_lines(objs,shape)
  92. # print(f'lines:{lines}')
  93. # target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
  94. # print(f'target points:{target["points"]}')
  95. # target["lines"] = lines.to(torch.float32).view(-1,2,3)
  96. # print(f'')
  97. # print(f'lines:{target["lines"].shape}')
  98. target["img_size"]=shape
  99. # validate_keypoints(lines, shape[0], shape[1])
  100. return target
  101. def show(self, idx,show_type='all'):
  102. image, target = self.__getitem__(idx)
  103. cmap = plt.get_cmap("jet")
  104. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  105. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  106. sm.set_array([])
  107. # img_path = os.path.join(self.img_path, self.imgs[idx])
  108. # print(f'boxes:{target["boxes"]}')
  109. img = image
  110. if show_type=='all':
  111. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
  112. colors="yellow", width=1)
  113. circle=target['circles']
  114. circle_mask=target['circle_masks']
  115. print(f'taget circle:{circle.shape}')
  116. print(f'target circle_masks:{circle_mask.shape}')
  117. plt.imshow(circle_mask.squeeze(0))
  118. keypoint_img=draw_keypoints(boxed_image,circle,colors='red',width=3)
  119. # plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  120. plt.show()
  121. # if show_type=='lines':
  122. # keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['lines'],colors='red',width=3)
  123. # plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  124. # plt.show()
  125. if show_type=='points':
  126. # print(f'points:{target['points'].shape}')
  127. keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['points'].unsqueeze(1),colors='red',width=3)
  128. plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  129. plt.show()
  130. if show_type=='boxes':
  131. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
  132. colors="yellow", width=1)
  133. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  134. plt.show()
  135. def show_img(self, img_path):
  136. pass
  137. def points_to_ellipse(points):
  138. """
  139. 根据提供的四个点估计椭圆参数。
  140. :param points: Tensor of shape (4, 2) 表示椭圆上的四个点
  141. :return: 返回 (cx, cy, r1, r2, orientation) 其中 cx, cy 是中心坐标,r1, r2 分别是长轴和短轴半径,orientation 是椭圆的方向(弧度)
  142. """
  143. # 转换为numpy数组进行计算
  144. pts = points.numpy()
  145. pts = pts.reshape(-1, 2)
  146. # 计算中心点
  147. center = np.mean(pts, axis=0)
  148. # 使用最小二乘法拟合椭圆
  149. A = np.hstack(
  150. [pts[:, 0:1] ** 2, pts[:, 0:1] * pts[:, 1:2], pts[:, 1:2] ** 2, pts[:, :2], np.ones((pts.shape[0], 1))])
  151. b = np.ones(pts.shape[0])
  152. x = np.linalg.lstsq(A, b, rcond=None)[0]
  153. # 解析解参见 https://en.wikipedia.org/wiki/Ellipse#General_ellipse
  154. a, b, c, d, f, g = x.ravel()
  155. numerator = 2 * (a * f * f + c * d * d + g * b * b - 2 * b * d * f - a * c * g)
  156. denominator1 = (b * b - a * c) * ((c - a) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a))
  157. denominator2 = (b * b - a * c) * ((a - c) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a))
  158. major_axis = np.sqrt(numerator / denominator1)
  159. minor_axis = np.sqrt(numerator / denominator2)
  160. # 简化处理:直接用点间距离估算长短轴
  161. distances = np.linalg.norm(pts - center, axis=1)
  162. long_axis_length = np.max(distances) * 2
  163. short_axis_length = np.min(distances) * 2
  164. # 方向可以通过两点之间的连线斜率来近似计算
  165. orientation = np.arctan2(pts[1, 1] - pts[0, 1], pts[1, 0] - pts[0, 0])
  166. return center[0], center[1], long_axis_length / 2, short_axis_length / 2, orientation
  167. def generate_ellipse_mask(shape, ellipse_params):
  168. """
  169. 在指定形状的图像上生成椭圆mask。
  170. :param shape: 输出mask的形状 (HxW)
  171. :param ellipse_params: 椭圆参数 (cx, cy, rx, ry, orientation)
  172. :return: 椭圆mask
  173. """
  174. cx, cy, rx, ry, orientation = ellipse_params
  175. img = np.zeros(shape, dtype=np.uint8)
  176. cx, cy, rx, ry = int(cx), int(cy), int(rx), int(ry)
  177. # 注意skimage的ellipse函数不直接支持旋转,所以这里简化处理
  178. rr, cc = ellipse(cy, cx, ry, rx, shape)
  179. img[rr, cc] = 1
  180. return img
  181. def sort_points_clockwise(points):
  182. points = np.array(points)
  183. top_left_idx = np.lexsort((points[:, 0], points[:, 1]))[0]
  184. reference_point = points[top_left_idx]
  185. def angle_to_reference(point):
  186. return np.arctan2(point[1] - reference_point[1], point[0] - reference_point[0])
  187. angles = np.apply_along_axis(angle_to_reference, 1, points)
  188. angles[angles < 0] += 2 * np.pi
  189. sorted_indices = np.argsort(angles)
  190. sorted_points = points[sorted_indices]
  191. return sorted_points.tolist()
  192. def get_boxes_lines(objs,shape):
  193. boxes = []
  194. labels=[]
  195. h,w=shape
  196. line_point_pairs = []
  197. points=[]
  198. line_mask=[]
  199. circle_4points=[]
  200. for obj in objs:
  201. # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  202. # print(f"points:{obj['points']}")
  203. label=obj['label']
  204. if label =='line' or label=='dseam1':
  205. a,b=obj['points'][0],obj['points'][1]
  206. line_point_pairs.append(a)
  207. line_point_pairs.append(b)
  208. xmin = max(0, (min(a[0], b[0]) - 6))
  209. xmax = min(w, (max(a[0], b[0]) + 6))
  210. ymin = max(0, (min(a[1], b[1]) - 6))
  211. ymax = min(h, (max(a[1], b[1]) + 6))
  212. boxes.append([ xmin,ymin, xmax,ymax])
  213. labels.append(torch.tensor(2))
  214. elif label =='point':
  215. p= obj['points'][0]
  216. xmin=max(0,p[0]-12)
  217. xmax = min(w, p[0] +12)
  218. ymin=max(0,p[1]-12)
  219. ymax = min(h, p[1] + 12)
  220. points.append(p)
  221. labels.append(torch.tensor(1))
  222. boxes.append([xmin, ymin, xmax, ymax])
  223. elif label == 'arc' :
  224. line_mask.append(obj['points'])
  225. xmin = obj['xmin']
  226. xmax = obj['xmax']
  227. ymin = obj['ymin']
  228. ymax = obj['ymax']
  229. boxes.append([xmin, ymin, xmax, ymax])
  230. labels.append(torch.tensor(3))
  231. elif label == 'circle' :
  232. # print(f'len circle_4points: {len(obj['points'])}')
  233. points=sort_points_clockwise(obj['points'])
  234. circle_4points.append(points)
  235. xmin = max(obj['xmin'] - 40, 0)
  236. xmax = min(obj['xmax'] + 40, w)
  237. ymin = max(obj['ymin'] - 40, 0)
  238. ymax = min(obj['ymax'] + 40, h)
  239. boxes.append([xmin, ymin, xmax, ymax])
  240. labels.append(torch.tensor(4))
  241. boxes=torch.tensor(boxes,dtype=torch.float32)
  242. print(f'boxes:{boxes.shape}')
  243. labels=torch.tensor(labels)
  244. if len(points)==0:
  245. points=None
  246. else:
  247. points=torch.tensor(points,dtype=torch.float32)
  248. print(f'read labels:{labels}')
  249. # print(f'read points:{points}')
  250. if len(line_point_pairs)==0:
  251. line_point_pairs=None
  252. else:
  253. line_point_pairs=torch.tensor(line_point_pairs)
  254. # print(f'line_point_pairs:{line_point_pairs.shape},{line_point_pairs.dtype}')
  255. # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}')
  256. if len(line_mask)==0:
  257. line_mask=None
  258. else:
  259. line_mask=torch.tensor(line_mask,dtype=torch.float32)
  260. print(f'arc_mask shape :{line_mask.shape},{line_mask.dtype}')
  261. if len(circle_4points)==0:
  262. circle_4points=None
  263. else:
  264. # for circle_4point in circle_4points:
  265. # print(f'circle_4point len111:{len(circle_4point)}')
  266. circle_4points=torch.tensor(circle_4points,dtype=torch.float32)
  267. # print(f'circle_4points shape:{circle_4points.shape}')
  268. return boxes,line_point_pairs,points,line_mask,circle_4points, labels
  269. if __name__ == '__main__':
  270. path=r'/data/share/zyh/master_dataset/circle/huyan_eclipse/a_dataset'
  271. dataset= LineDataset(dataset_path=path, dataset_type='train',augmentation=False, data_type='jpg')
  272. dataset.show(9,show_type='all')