line_dataset.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. from torch.utils.data.dataset import T_co
  2. from libs.vision_libs.utils import draw_keypoints
  3. from models.base.base_dataset import BaseDataset
  4. import glob
  5. import json
  6. import math
  7. import os
  8. import random
  9. import cv2
  10. import PIL
  11. import imageio
  12. import matplotlib.pyplot as plt
  13. import matplotlib as mpl
  14. from torchvision.utils import draw_bounding_boxes
  15. import torchvision.transforms.v2 as transforms
  16. import numpy as np
  17. import numpy.linalg as LA
  18. import torch
  19. from skimage import io
  20. from torch.utils.data import Dataset
  21. from torch.utils.data.dataloader import default_collate
  22. import matplotlib.pyplot as plt
  23. from models.dataset_tool import read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  24. def validate_keypoints(keypoints, image_width, image_height):
  25. for kp in keypoints:
  26. x, y, v = kp
  27. if not (0 <= x < image_width and 0 <= y < image_height):
  28. raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
  29. def apply_transform_with_boxes_and_keypoints(img,target):
  30. """
  31. 对图像、边界框和关键点应用相同的变换。
  32. :param img_path: 图像文件路径
  33. :param boxes: 形状为 (N, 4) 的 Tensor,表示 N 个边界框的坐标 [x_min, y_min, x_max, y_max]
  34. :param keypoints: 形状为 (N, K, 3) 的 Tensor,表示 N 个实例的 K 个关键点的坐标和可见性 [x, y, visibility]
  35. :return: 变换后的图像、边界框和关键点
  36. """
  37. # 定义一系列用于数据增强的变换
  38. data_transforms = transforms.Compose([
  39. # 随机调整大小和随机裁剪
  40. # transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0), antialias=True),
  41. # 随机水平翻转
  42. transforms.RandomHorizontalFlip(p=0.5),
  43. # 颜色抖动: 改变亮度、对比度、饱和度和色调
  44. transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
  45. # # 转换为张量
  46. # transforms.ToTensor(),
  47. #
  48. # # 标准化
  49. # transforms.Normalize(mean=[0.485, 0.456, 0.406],
  50. # std=[0.229, 0.224, 0.225])
  51. ])
  52. boxes=target['boxes']
  53. keypoints=target['lines']
  54. # 将边界框转换为适合传递给 transforms 的格式
  55. boxes_format = [(box[0].item(), box[1].item(), box[2].item(), box[3].item()) for box in boxes]
  56. # 将关键点转换为适合传递给 transforms 的格式
  57. keypoints_format = [[(kp[0].item(), kp[1].item(), bool(kp[2].item())) for kp in keypoint] for keypoint in keypoints]
  58. # 应用变换
  59. transformed = data_transforms(img, {"boxes": boxes_format, "keypoints": keypoints_format})
  60. # 获取变换后的图像、边界框和关键点
  61. img_transformed = transformed[0]
  62. boxes_transformed = torch.tensor([(box[0], box[1], box[2], box[3]) for box in transformed[1]['boxes']],
  63. dtype=torch.float32)
  64. keypoints_transformed = torch.tensor(
  65. [[(kp[0], kp[1], int(kp[2])) for kp in keypoint] for keypoint in transformed[1]['keypoints']],
  66. dtype=torch.float32)
  67. target['boxes']=boxes_transformed
  68. target['lines']=keypoints_transformed
  69. return img_transformed, target
  70. """
  71. 直接读取xanlabel标注的数据集json格式
  72. """
  73. class LineDataset(BaseDataset):
  74. def __init__(self, dataset_path, data_type, transforms=None,augmentation=False, dataset_type=None,img_type='rgb', target_type='pixel'):
  75. super().__init__(dataset_path)
  76. self.data_path = dataset_path
  77. self.data_type = data_type
  78. print(f'data_path:{dataset_path}')
  79. self.transforms = transforms
  80. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  81. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  82. self.imgs = os.listdir(self.img_path)
  83. self.lbls = os.listdir(self.lbl_path)
  84. self.target_type = target_type
  85. self.img_type=img_type
  86. self.augmentation=augmentation
  87. # self.default_transform = DefaultTransform()
  88. def __getitem__(self, index) -> T_co:
  89. img_path = os.path.join(self.img_path, self.imgs[index])
  90. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  91. img = PIL.Image.open(img_path).convert('RGB')
  92. w, h = img.size
  93. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  94. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  95. if self.transforms:
  96. img, target = self.transforms(img, target)
  97. else:
  98. img = self.default_transform(img)
  99. # print(f'img:{img}')
  100. # print(f'img shape:{img.shape}')
  101. if self.augmentation:
  102. img, target=apply_transform_with_boxes_and_keypoints(img, target)
  103. return img, target
  104. def __len__(self):
  105. return len(self.imgs)
  106. def read_target(self, item, lbl_path, shape, extra=None):
  107. # print(f'shape:{shape}')
  108. # print(f'lbl_path:{lbl_path}')
  109. with open(lbl_path, 'r') as file:
  110. lable_all = json.load(file)
  111. objs = lable_all["shapes"]
  112. point_pairs=objs[0]['points']
  113. # print(f'point_pairs:{point_pairs}')
  114. target = {}
  115. target["image_id"] = torch.tensor(item)
  116. target["boxes"], lines = get_boxes_lines(objs,shape)
  117. # print(f'lines:{lines}')
  118. target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
  119. a = torch.full((lines.shape[0],), 2).unsqueeze(1)
  120. lines = torch.cat((lines, a), dim=1)
  121. target["lines"] = lines.to(torch.float32).view(-1,2,3)
  122. target["img_size"]=shape
  123. validate_keypoints(lines, shape[0], shape[1])
  124. return target
  125. def show(self, idx,show_type='all'):
  126. image, target = self.__getitem__(idx)
  127. cmap = plt.get_cmap("jet")
  128. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  129. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  130. sm.set_array([])
  131. # img_path = os.path.join(self.img_path, self.imgs[idx])
  132. img = image
  133. if show_type=='all':
  134. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
  135. colors="yellow", width=1)
  136. keypoint_img=draw_keypoints(boxed_image,target['lines'],colors='red',width=3)
  137. plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  138. plt.show()
  139. if show_type=='lines':
  140. keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['lines'],colors='red',width=3)
  141. plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  142. plt.show()
  143. if show_type=='boxes':
  144. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
  145. colors="yellow", width=1)
  146. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  147. plt.show()
  148. def show_img(self, img_path):
  149. pass
  150. def get_boxes_lines(objs,shape):
  151. boxes = []
  152. h,w=shape
  153. line_point_pairs = []
  154. for obj in objs:
  155. # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  156. # print(f"points:{obj['points']}")
  157. a,b=obj['points'][0],obj['points'][1]
  158. line_point_pairs.append(a)
  159. line_point_pairs.append(b)
  160. xmin = max(0, (min(a[0], b[0]) - 6))
  161. xmax = min(w, (max(a[0], b[0]) + 6))
  162. ymin = max(0, (min(a[1], b[1]) - 6))
  163. ymax = min(h, (max(a[1], b[1]) + 6))
  164. boxes.append([ xmin,ymin, xmax,ymax])
  165. boxes=torch.tensor(boxes)
  166. line_point_pairs=torch.tensor(line_point_pairs)
  167. # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}')
  168. return boxes,line_point_pairs
  169. if __name__ == '__main__':
  170. path=r"\\192.168.50.222/share/rlq/datasets/0706_"
  171. dataset= LineDataset(dataset_path=path, dataset_type='train',augmentation=True, data_type='jpg')
  172. dataset.show(1,show_type='all')