dataset_tool.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import cv2
  2. import numpy as np
  3. import torch
  4. import torchvision
  5. from matplotlib import pyplot as plt
  6. import tools.transforms as reference_transforms
  7. from collections import defaultdict
  8. from tools import presets
  9. import json
  10. def get_modules(use_v2):
  11. # We need a protected import to avoid the V2 warning in case just V1 is used
  12. if use_v2:
  13. import torchvision.transforms.v2
  14. import torchvision.tv_tensors
  15. return torchvision.transforms.v2, torchvision.tv_tensors
  16. else:
  17. return reference_transforms, None
  18. class Augmentation:
  19. # Note: this transform assumes that the input to forward() are always PIL
  20. # images, regardless of the backend parameter.
  21. def __init__(
  22. self,
  23. *,
  24. data_augmentation,
  25. hflip_prob=0.5,
  26. mean=(123.0, 117.0, 104.0),
  27. backend="pil",
  28. use_v2=False,
  29. ):
  30. T, tv_tensors = get_modules(use_v2)
  31. transforms = []
  32. backend = backend.lower()
  33. if backend == "tv_tensor":
  34. transforms.append(T.ToImage())
  35. elif backend == "tensor":
  36. transforms.append(T.PILToTensor())
  37. elif backend != "pil":
  38. raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
  39. if data_augmentation == "hflip":
  40. transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
  41. elif data_augmentation == "lsj":
  42. transforms += [
  43. T.ScaleJitter(target_size=(1024, 1024), antialias=True),
  44. # TODO: FixedSizeCrop below doesn't work on tensors!
  45. reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
  46. T.RandomHorizontalFlip(p=hflip_prob),
  47. ]
  48. elif data_augmentation == "multiscale":
  49. transforms += [
  50. T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
  51. T.RandomHorizontalFlip(p=hflip_prob),
  52. ]
  53. elif data_augmentation == "ssd":
  54. fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
  55. transforms += [
  56. T.RandomPhotometricDistort(),
  57. T.RandomZoomOut(fill=fill),
  58. T.RandomIoUCrop(),
  59. T.RandomHorizontalFlip(p=hflip_prob),
  60. ]
  61. elif data_augmentation == "ssdlite":
  62. transforms += [
  63. T.RandomIoUCrop(),
  64. T.RandomHorizontalFlip(p=hflip_prob),
  65. ]
  66. else:
  67. raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
  68. if backend == "pil":
  69. # Note: we could just convert to pure tensors even in v2.
  70. transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
  71. transforms += [T.ToDtype(torch.float, scale=True)]
  72. if use_v2:
  73. transforms += [
  74. T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
  75. T.SanitizeBoundingBoxes(),
  76. T.ToPureTensor(),
  77. ]
  78. self.transforms = T.Compose(transforms)
  79. def __call__(self, img, target):
  80. return self.transforms(img, target)
  81. def read_polygon_points(lbl_path, shape):
  82. """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
  83. polygon_points = []
  84. w, h = shape[:2]
  85. with open(lbl_path, 'r') as f:
  86. lines = f.readlines()
  87. for line in lines:
  88. parts = line.strip().split()
  89. class_id = int(parts[0])
  90. points = np.array(parts[1:], dtype=np.float32).reshape(-1, 2) # 读取点坐标
  91. points[:, 0] *= h
  92. points[:, 1] *= w
  93. polygon_points.append((class_id, points))
  94. return polygon_points
  95. def read_masks_from_pixels(lbl_path, shape):
  96. """读取纯像素点格式的文件,不是轮廓像素点"""
  97. h, w = shape
  98. masks = []
  99. labels = []
  100. with open(lbl_path, 'r') as reader:
  101. lines = reader.readlines()
  102. mask_points = []
  103. for line in lines:
  104. mask = torch.zeros((h, w), dtype=torch.uint8)
  105. parts = line.strip().split()
  106. # print(f'parts:{parts}')
  107. cls = torch.tensor(int(parts[0]), dtype=torch.int64)
  108. labels.append(cls)
  109. x_array = parts[1::2]
  110. y_array = parts[2::2]
  111. for x, y in zip(x_array, y_array):
  112. x = float(x)
  113. y = float(y)
  114. mask_points.append((int(y * h), int(x * w)))
  115. for p in mask_points:
  116. mask[p] = 1
  117. masks.append(mask)
  118. reader.close()
  119. return labels, masks
  120. def create_masks_from_polygons(polygons, image_shape):
  121. """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
  122. colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
  123. masks = []
  124. for polygon_data, col in zip(polygons, colors):
  125. mask = np.zeros(image_shape[:2], dtype=np.uint8)
  126. # 将多边形顶点转换为 NumPy 数组
  127. _, polygon = polygon_data
  128. pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
  129. # 使用 OpenCV 的 fillPoly 函数填充多边形
  130. # print(f'color:{col[:3]}')
  131. cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
  132. mask = torch.from_numpy(mask)
  133. mask[mask != 0] = 1
  134. masks.append(mask)
  135. return masks
  136. def read_masks_from_txt(label_path, shape):
  137. polygon_points = read_polygon_points(label_path, shape)
  138. masks = create_masks_from_polygons(polygon_points, shape)
  139. labels = [torch.tensor(item[0]) for item in polygon_points]
  140. return labels, masks
  141. def masks_to_boxes(masks: torch.Tensor, ) -> torch.Tensor:
  142. """
  143. Compute the bounding boxes around the provided masks.
  144. Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
  145. ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  146. Args:
  147. masks (Tensor[N, H, W]): masks to transform where N is the number of masks
  148. and (H, W) are the spatial dimensions.
  149. Returns:
  150. Tensor[N, 4]: bounding boxes
  151. """
  152. # if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  153. # _log_api_usage_once(masks_to_boxes)
  154. if masks.numel() == 0:
  155. return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
  156. n = masks.shape[0]
  157. bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
  158. for index, mask in enumerate(masks):
  159. y, x = torch.where(mask != 0)
  160. bounding_boxes[index, 0] = torch.min(x)
  161. bounding_boxes[index, 1] = torch.min(y)
  162. bounding_boxes[index, 2] = torch.max(x)
  163. bounding_boxes[index, 3] = torch.max(y)
  164. # debug to pixel datasets
  165. if bounding_boxes[index, 0] == bounding_boxes[index, 2]:
  166. bounding_boxes[index, 2] = bounding_boxes[index, 2] + 1
  167. bounding_boxes[index, 0] = bounding_boxes[index, 0] - 1
  168. if bounding_boxes[index, 1] == bounding_boxes[index, 3]:
  169. bounding_boxes[index, 3] = bounding_boxes[index, 3] + 1
  170. bounding_boxes[index, 1] = bounding_boxes[index, 1] - 1
  171. return bounding_boxes
  172. def line_boxes_faster(target):
  173. boxs = []
  174. lpre = target["lpre"].cpu().numpy() * 4
  175. vecl_target = target["lpre_label"].cpu().numpy()
  176. lpre = lpre[vecl_target == 1]
  177. lines = lpre
  178. sline = np.ones(lpre.shape[0])
  179. if len(lines) > 0 and not (lines[0] == 0).all():
  180. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  181. if i > 0 and (lines[i] == lines[0]).all():
  182. break
  183. # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  184. if a[1] > b[1]:
  185. ymax = a[1] + 10
  186. ymin = b[1] - 10
  187. else:
  188. ymin = a[1] - 10
  189. ymax = b[1] + 10
  190. if a[0] > b[0]:
  191. xmax = a[0] + 10
  192. xmin = b[0] - 10
  193. else:
  194. xmin = a[0] - 10
  195. xmax = b[0] + 10
  196. boxs.append([max(0, ymin), max(0, xmin), min(512, ymax), min(512, xmax)])
  197. return torch.tensor(boxs)
  198. # 将线段变为 [起点,终点]形式 传入[num,2,2]
  199. def a_to_b(line):
  200. result_pairs = []
  201. for a, b in line:
  202. min_x = min(a[0], b[0])
  203. min_y = min(a[1], b[1])
  204. new_top_left_x = max((min_x - 10), 0)
  205. new_top_left_y = max((min_y - 10), 0)
  206. dist_a = (a[0] - new_top_left_x) ** 2 + (a[1] - new_top_left_y) ** 2
  207. dist_b = (b[0] - new_top_left_x) ** 2 + (b[1] - new_top_left_y) ** 2
  208. # 根据距离选择起点并设置标签
  209. if dist_a <= dist_b: # 如果a点离新左上角更近或两者距离相等
  210. result_pairs.append([a, b]) # 将a设为起点,b为终点
  211. else: # 如果b点离新左上角更近
  212. result_pairs.append([b, a]) # 将b设为起点,a为终点
  213. result_tensor = torch.stack([torch.stack(row) for row in result_pairs])
  214. return result_tensor
  215. def read_polygon_points_wire(lbl_path, shape):
  216. """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
  217. polygon_points = []
  218. w, h = shape[:2]
  219. with open(lbl_path, 'r') as f:
  220. lines = json.load(f)
  221. for line in lines["segmentations"]:
  222. parts = line["data"]
  223. class_id = int(line["cls_id"])
  224. points = np.array(parts, dtype=np.float32).reshape(-1, 2) # 读取点坐标
  225. points[:, 0] *= h
  226. points[:, 1] *= w
  227. polygon_points.append((class_id, points))
  228. return polygon_points
  229. def read_masks_from_txt_wire(label_path, shape):
  230. polygon_points = read_polygon_points_wire(label_path, shape)
  231. masks = create_masks_from_polygons(polygon_points, shape)
  232. labels = [torch.tensor(item[0]) for item in polygon_points]
  233. return labels, masks
  234. def read_masks_from_pixels_wire(lbl_path, shape):
  235. """读取纯像素点格式的文件,不是轮廓像素点"""
  236. h, w = shape
  237. masks = []
  238. labels = []
  239. with open(lbl_path, 'r') as reader:
  240. lines = json.load(reader)
  241. mask_points = []
  242. for line in lines["segmentations"]:
  243. # mask = torch.zeros((h, w), dtype=torch.uint8)
  244. # parts = line["data"]
  245. # print(f'parts:{parts}')
  246. cls = torch.tensor(int(line["cls_id"]), dtype=torch.int64)
  247. labels.append(cls)
  248. # x_array = parts[0::2]
  249. # y_array = parts[1::2]
  250. #
  251. # for x, y in zip(x_array, y_array):
  252. # x = float(x)
  253. # y = float(y)
  254. # mask_points.append((int(y * h), int(x * w)))
  255. # for p in mask_points:
  256. # mask[p] = 1
  257. # masks.append(mask)
  258. reader.close()
  259. return labels
  260. def read_lable_keypoint(lbl_path):
  261. """判断线段的起点终点, 起点 lable=0, 终点 lable=1"""
  262. labels = []
  263. with open(lbl_path, 'r') as reader:
  264. lines = json.load(reader)
  265. aa = lines["wires"][0]["line_pos_coords"]["content"]
  266. result_pairs = []
  267. for a, b in aa:
  268. min_x = min(a[0], b[0])
  269. min_y = min(a[1], b[1])
  270. # 定义新的左上角位置
  271. new_top_left_x = max((min_x - 10), 0)
  272. new_top_left_y = max((min_y - 10), 0)
  273. # Step 2: 计算各点到新左上角的距离平方(避免浮点运算误差)
  274. dist_a = (a[0] - new_top_left_x) ** 2 + (a[1] - new_top_left_y) ** 2
  275. dist_b = (b[0] - new_top_left_x) ** 2 + (b[1] - new_top_left_y) ** 2
  276. # Step 3 & 4: 根据距离选择起点并设置标签
  277. if dist_a <= dist_b: # 如果a点离新左上角更近或两者距离相等
  278. result_pairs.append([a, b]) # 将a设为起点,b为终点
  279. else: # 如果b点离新左上角更近
  280. result_pairs.append([b, a]) # 将b设为起点,a为终点
  281. # x_ = abs(a[0] - b[0])
  282. # y_ = abs(a[1] - b[1])
  283. # if x_ > y_: # x小的离左上角近
  284. # if a[0] < b[0]: # 视为起点,lable=0
  285. # label = [0, 1]
  286. # else:
  287. # label = [1, 0]
  288. # else: # x大的是起点
  289. # if a[0] > b[0]: # 视为起点,lable=0
  290. # label = [0, 1]
  291. # else:
  292. # label = [1, 0]
  293. # labels.append(label)
  294. # print(result_pairs )
  295. reader.close()
  296. return labels
  297. def adjacency_matrix(n, link): # 邻接矩阵
  298. mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
  299. link = torch.tensor(link)
  300. if len(link) > 0:
  301. mat[link[:, 0], link[:, 1]] = 1
  302. mat[link[:, 1], link[:, 0]] = 1
  303. return mat