line_dataset.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from torch.utils.data.dataset import T_co
  2. from libs.vision_libs.utils import draw_keypoints
  3. from models.base.base_dataset import BaseDataset
  4. import json
  5. import os
  6. import PIL
  7. import matplotlib as mpl
  8. from torchvision.utils import draw_bounding_boxes
  9. import torchvision.transforms.v2 as transforms
  10. import torch
  11. import matplotlib.pyplot as plt
  12. from models.base.transforms import get_transforms
  13. def validate_keypoints(keypoints, image_width, image_height):
  14. for kp in keypoints:
  15. x, y, v = kp
  16. if not (0 <= x < image_width and 0 <= y < image_height):
  17. raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
  18. """
  19. 直接读取xanlabel标注的数据集json格式
  20. """
  21. class LineDataset(BaseDataset):
  22. def __init__(self, dataset_path, data_type, transforms=None,augmentation=False, dataset_type=None,img_type='rgb', target_type='pixel'):
  23. super().__init__(dataset_path)
  24. self.data_path = dataset_path
  25. self.data_type = data_type
  26. print(f'data_path:{dataset_path}')
  27. self.transforms = transforms
  28. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  29. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  30. self.imgs = os.listdir(self.img_path)
  31. self.lbls = os.listdir(self.lbl_path)
  32. self.target_type = target_type
  33. self.img_type=img_type
  34. self.augmentation=augmentation
  35. # self.default_transform = DefaultTransform()
  36. def __getitem__(self, index) -> T_co:
  37. img_path = os.path.join(self.img_path, self.imgs[index])
  38. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  39. img = PIL.Image.open(img_path).convert('RGB')
  40. w, h = img.size
  41. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  42. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  43. self.transforms=get_transforms(augmention=self.augmentation)
  44. img, target = self.transforms(img, target)
  45. return img, target
  46. def __len__(self):
  47. return len(self.imgs)
  48. def read_target(self, item, lbl_path, shape, extra=None):
  49. # print(f'shape:{shape}')
  50. # print(f'lbl_path:{lbl_path}')
  51. with open(lbl_path, 'r') as file:
  52. lable_all = json.load(file)
  53. objs = lable_all["shapes"]
  54. point_pairs=objs[0]['points']
  55. # print(f'point_pairs:{point_pairs}')
  56. target = {}
  57. target["image_id"] = torch.tensor(item)
  58. target["boxes"], lines = get_boxes_lines(objs,shape)
  59. # print(f'lines:{lines}')
  60. target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
  61. a = torch.full((lines.shape[0],), 2).unsqueeze(1)
  62. lines = torch.cat((lines, a), dim=1)
  63. target["lines"] = lines.to(torch.float32).view(-1,2,3)
  64. print(f'lines:{target["lines"].shape}')
  65. target["img_size"]=shape
  66. validate_keypoints(lines, shape[0], shape[1])
  67. return target
  68. def show(self, idx,show_type='all'):
  69. image, target = self.__getitem__(idx)
  70. cmap = plt.get_cmap("jet")
  71. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  72. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  73. sm.set_array([])
  74. # img_path = os.path.join(self.img_path, self.imgs[idx])
  75. img = image
  76. if show_type=='all':
  77. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
  78. colors="yellow", width=1)
  79. keypoint_img=draw_keypoints(boxed_image,target['lines'],colors='red',width=3)
  80. plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  81. plt.show()
  82. if show_type=='lines':
  83. keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['lines'],colors='red',width=3)
  84. plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  85. plt.show()
  86. if show_type=='boxes':
  87. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
  88. colors="yellow", width=1)
  89. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  90. plt.show()
  91. def show_img(self, img_path):
  92. pass
  93. def get_boxes_lines(objs,shape):
  94. boxes = []
  95. h,w=shape
  96. line_point_pairs = []
  97. for obj in objs:
  98. # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  99. # print(f"points:{obj['points']}")
  100. a,b=obj['points'][0],obj['points'][1]
  101. line_point_pairs.append(a)
  102. line_point_pairs.append(b)
  103. xmin = max(0, (min(a[0], b[0]) - 6))
  104. xmax = min(w, (max(a[0], b[0]) + 6))
  105. ymin = max(0, (min(a[1], b[1]) - 6))
  106. ymax = min(h, (max(a[1], b[1]) + 6))
  107. boxes.append([ xmin,ymin, xmax,ymax])
  108. boxes=torch.tensor(boxes)
  109. line_point_pairs=torch.tensor(line_point_pairs)
  110. # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}')
  111. return boxes,line_point_pairs
  112. if __name__ == '__main__':
  113. path=r"\\192.168.50.222/share/rlq/datasets/0706_"
  114. dataset= LineDataset(dataset_path=path, dataset_type='train',augmentation=True, data_type='jpg')
  115. dataset.show(1,show_type='all')