dataset_tool.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import cv2
  2. import numpy as np
  3. import torch
  4. import torchvision
  5. from matplotlib import pyplot as plt
  6. import tools.transforms as reference_transforms
  7. from collections import defaultdict
  8. from tools import presets
  9. import json
  10. def get_modules(use_v2):
  11. # We need a protected import to avoid the V2 warning in case just V1 is used
  12. if use_v2:
  13. import torchvision.transforms.v2
  14. import torchvision.tv_tensors
  15. return torchvision.transforms.v2, torchvision.tv_tensors
  16. else:
  17. return reference_transforms, None
  18. class Augmentation:
  19. # Note: this transform assumes that the input to forward() are always PIL
  20. # images, regardless of the backend parameter.
  21. def __init__(
  22. self,
  23. *,
  24. data_augmentation,
  25. hflip_prob=0.5,
  26. mean=(123.0, 117.0, 104.0),
  27. backend="pil",
  28. use_v2=False,
  29. ):
  30. T, tv_tensors = get_modules(use_v2)
  31. transforms = []
  32. backend = backend.lower()
  33. if backend == "tv_tensor":
  34. transforms.append(T.ToImage())
  35. elif backend == "tensor":
  36. transforms.append(T.PILToTensor())
  37. elif backend != "pil":
  38. raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
  39. if data_augmentation == "hflip":
  40. transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
  41. elif data_augmentation == "lsj":
  42. transforms += [
  43. T.ScaleJitter(target_size=(1024, 1024), antialias=True),
  44. # TODO: FixedSizeCrop below doesn't work on tensors!
  45. reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
  46. T.RandomHorizontalFlip(p=hflip_prob),
  47. ]
  48. elif data_augmentation == "multiscale":
  49. transforms += [
  50. T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
  51. T.RandomHorizontalFlip(p=hflip_prob),
  52. ]
  53. elif data_augmentation == "ssd":
  54. fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
  55. transforms += [
  56. T.RandomPhotometricDistort(),
  57. T.RandomZoomOut(fill=fill),
  58. T.RandomIoUCrop(),
  59. T.RandomHorizontalFlip(p=hflip_prob),
  60. ]
  61. elif data_augmentation == "ssdlite":
  62. transforms += [
  63. T.RandomIoUCrop(),
  64. T.RandomHorizontalFlip(p=hflip_prob),
  65. ]
  66. else:
  67. raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
  68. if backend == "pil":
  69. # Note: we could just convert to pure tensors even in v2.
  70. transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
  71. transforms += [T.ToDtype(torch.float, scale=True)]
  72. if use_v2:
  73. transforms += [
  74. T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
  75. T.SanitizeBoundingBoxes(),
  76. T.ToPureTensor(),
  77. ]
  78. self.transforms = T.Compose(transforms)
  79. def __call__(self, img, target):
  80. return self.transforms(img, target)
  81. def read_polygon_points(lbl_path, shape):
  82. """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
  83. polygon_points = []
  84. w, h = shape[:2]
  85. with open(lbl_path, 'r') as f:
  86. lines = f.readlines()
  87. for line in lines:
  88. parts = line.strip().split()
  89. class_id = int(parts[0])
  90. points = np.array(parts[1:], dtype=np.float32).reshape(-1, 2) # 读取点坐标
  91. points[:, 0] *= h
  92. points[:, 1] *= w
  93. polygon_points.append((class_id, points))
  94. return polygon_points
  95. def read_masks_from_pixels(lbl_path, shape):
  96. """读取纯像素点格式的文件,不是轮廓像素点"""
  97. h, w = shape
  98. masks = []
  99. labels = []
  100. with open(lbl_path, 'r') as reader:
  101. lines = reader.readlines()
  102. mask_points = []
  103. for line in lines:
  104. mask = torch.zeros((h, w), dtype=torch.uint8)
  105. parts = line.strip().split()
  106. # print(f'parts:{parts}')
  107. cls = torch.tensor(int(parts[0]), dtype=torch.int64)
  108. labels.append(cls)
  109. x_array = parts[1::2]
  110. y_array = parts[2::2]
  111. for x, y in zip(x_array, y_array):
  112. x = float(x)
  113. y = float(y)
  114. mask_points.append((int(y * h), int(x * w)))
  115. for p in mask_points:
  116. mask[p] = 1
  117. masks.append(mask)
  118. reader.close()
  119. return labels, masks
  120. def create_masks_from_polygons(polygons, image_shape):
  121. """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
  122. colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
  123. masks = []
  124. for polygon_data, col in zip(polygons, colors):
  125. mask = np.zeros(image_shape[:2], dtype=np.uint8)
  126. # 将多边形顶点转换为 NumPy 数组
  127. _, polygon = polygon_data
  128. pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
  129. # 使用 OpenCV 的 fillPoly 函数填充多边形
  130. # print(f'color:{col[:3]}')
  131. cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
  132. mask = torch.from_numpy(mask)
  133. mask[mask != 0] = 1
  134. masks.append(mask)
  135. return masks
  136. def read_masks_from_txt(label_path, shape):
  137. polygon_points = read_polygon_points(label_path, shape)
  138. masks = create_masks_from_polygons(polygon_points, shape)
  139. labels = [torch.tensor(item[0]) for item in polygon_points]
  140. return labels, masks
  141. def masks_to_boxes(masks: torch.Tensor, ) -> torch.Tensor:
  142. """
  143. Compute the bounding boxes around the provided masks.
  144. Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
  145. ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  146. Args:
  147. masks (Tensor[N, H, W]): masks to transform where N is the number of masks
  148. and (H, W) are the spatial dimensions.
  149. Returns:
  150. Tensor[N, 4]: bounding boxes
  151. """
  152. # if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  153. # _log_api_usage_once(masks_to_boxes)
  154. if masks.numel() == 0:
  155. return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
  156. n = masks.shape[0]
  157. bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
  158. for index, mask in enumerate(masks):
  159. y, x = torch.where(mask != 0)
  160. bounding_boxes[index, 0] = torch.min(x)
  161. bounding_boxes[index, 1] = torch.min(y)
  162. bounding_boxes[index, 2] = torch.max(x)
  163. bounding_boxes[index, 3] = torch.max(y)
  164. # debug to pixel datasets
  165. if bounding_boxes[index, 0] == bounding_boxes[index, 2]:
  166. bounding_boxes[index, 2] = bounding_boxes[index, 2] + 1
  167. bounding_boxes[index, 0] = bounding_boxes[index, 0] - 1
  168. if bounding_boxes[index, 1] == bounding_boxes[index, 3]:
  169. bounding_boxes[index, 3] = bounding_boxes[index, 3] + 1
  170. bounding_boxes[index, 1] = bounding_boxes[index, 1] - 1
  171. return bounding_boxes
  172. def line_boxes(target):
  173. boxs = []
  174. lpre = target['wires']["lpre"].cpu().numpy() * 4
  175. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  176. lpre = lpre[vecl_target == 1]
  177. lines = lpre
  178. sline = np.ones(lpre.shape[0])
  179. if len(lines) > 0 and not (lines[0] == 0).all():
  180. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  181. if i > 0 and (lines[i] == lines[0]).all():
  182. break
  183. # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  184. if a[1] > b[1]:
  185. ymax = a[1] + 10
  186. ymin = b[1] - 10
  187. else:
  188. ymin = a[1] - 10
  189. ymax = b[1] + 10
  190. if a[0] > b[0]:
  191. xmax = a[0] + 10
  192. xmin = b[0] - 10
  193. else:
  194. xmin = a[0] - 10
  195. xmax = b[0] + 10
  196. boxs.append([ymin, xmin, ymax, xmax])
  197. return torch.tensor(boxs)
  198. def read_polygon_points_wire(lbl_path, shape):
  199. """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
  200. polygon_points = []
  201. w, h = shape[:2]
  202. with open(lbl_path, 'r') as f:
  203. lines = json.load(f)
  204. for line in lines["segmentations"]:
  205. parts = line["data"]
  206. class_id = int(line["cls_id"])
  207. points = np.array(parts, dtype=np.float32).reshape(-1, 2) # 读取点坐标
  208. points[:, 0] *= h
  209. points[:, 1] *= w
  210. polygon_points.append((class_id, points))
  211. return polygon_points
  212. def read_masks_from_txt_wire(label_path, shape):
  213. polygon_points = read_polygon_points_wire(label_path, shape)
  214. masks = create_masks_from_polygons(polygon_points, shape)
  215. labels = [torch.tensor(item[0]) for item in polygon_points]
  216. return labels, masks
  217. def read_masks_from_pixels_wire(lbl_path, shape):
  218. """读取纯像素点格式的文件,不是轮廓像素点"""
  219. h, w = shape
  220. masks = []
  221. labels = []
  222. with open(lbl_path, 'r') as reader:
  223. lines = json.load(reader)
  224. mask_points = []
  225. for line in lines["segmentations"]:
  226. # mask = torch.zeros((h, w), dtype=torch.uint8)
  227. # parts = line["data"]
  228. # print(f'parts:{parts}')
  229. cls = torch.tensor(int(line["cls_id"]), dtype=torch.int64)
  230. labels.append(cls)
  231. # x_array = parts[0::2]
  232. # y_array = parts[1::2]
  233. #
  234. # for x, y in zip(x_array, y_array):
  235. # x = float(x)
  236. # y = float(y)
  237. # mask_points.append((int(y * h), int(x * w)))
  238. # for p in mask_points:
  239. # mask[p] = 1
  240. # masks.append(mask)
  241. reader.close()
  242. return labels
  243. def adjacency_matrix(n, link): # 邻接矩阵
  244. mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
  245. link = torch.tensor(link)
  246. if len(link) > 0:
  247. mat[link[:, 0], link[:, 1]] = 1
  248. mat[link[:, 1], link[:, 0]] = 1
  249. return mat