line_dataset.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from torch.utils.data.dataset import T_co
  2. from libs.vision_libs.utils import draw_keypoints
  3. from models.base.base_dataset import BaseDataset
  4. import glob
  5. import json
  6. import math
  7. import os
  8. import random
  9. import cv2
  10. import PIL
  11. import imageio
  12. import matplotlib.pyplot as plt
  13. import matplotlib as mpl
  14. from torchvision.utils import draw_bounding_boxes
  15. import numpy as np
  16. import numpy.linalg as LA
  17. import torch
  18. from skimage import io
  19. from torch.utils.data import Dataset
  20. from torch.utils.data.dataloader import default_collate
  21. import matplotlib.pyplot as plt
  22. from models.dataset_tool import read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  23. def validate_keypoints(keypoints, image_width, image_height):
  24. for kp in keypoints:
  25. x, y, v = kp
  26. if not (0 <= x < image_width and 0 <= y < image_height):
  27. raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
  28. """
  29. 直接读取xanlabel标注的数据集json格式
  30. """
  31. class LineDataset(BaseDataset):
  32. def __init__(self, dataset_path, data_type, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
  33. super().__init__(dataset_path)
  34. self.data_path = dataset_path
  35. self.data_type = data_type
  36. print(f'data_path:{dataset_path}')
  37. self.transforms = transforms
  38. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  39. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  40. self.imgs = os.listdir(self.img_path)
  41. self.lbls = os.listdir(self.lbl_path)
  42. self.target_type = target_type
  43. self.img_type=img_type
  44. # self.default_transform = DefaultTransform()
  45. def __getitem__(self, index) -> T_co:
  46. img_path = os.path.join(self.img_path, self.imgs[index])
  47. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  48. img = PIL.Image.open(img_path).convert('RGB')
  49. w, h = img.size
  50. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  51. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  52. if self.transforms:
  53. img, target = self.transforms(img, target)
  54. else:
  55. img = self.default_transform(img)
  56. # print(f'img:{img}')
  57. # print(f'img shape:{img.shape}')
  58. return img, target
  59. def __len__(self):
  60. return len(self.imgs)
  61. def read_target(self, item, lbl_path, shape, extra=None):
  62. # print(f'shape:{shape}')
  63. # print(f'lbl_path:{lbl_path}')
  64. with open(lbl_path, 'r') as file:
  65. lable_all = json.load(file)
  66. objs = lable_all["shapes"]
  67. point_pairs=objs[0]['points']
  68. # print(f'point_pairs:{point_pairs}')
  69. target = {}
  70. target["image_id"] = torch.tensor(item)
  71. target["boxes"], lines = get_boxes_lines(objs,shape)
  72. # print(f'lines:{lines}')
  73. target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
  74. a = torch.full((lines.shape[0],), 2).unsqueeze(1)
  75. lines = torch.cat((lines, a), dim=1)
  76. target["lines"] = lines.to(torch.float32).view(-1,2,3)
  77. target["img_size"]=shape
  78. validate_keypoints(lines, shape[0], shape[1])
  79. return target
  80. def show(self, idx,show_type='all'):
  81. image, target = self.__getitem__(idx)
  82. cmap = plt.get_cmap("jet")
  83. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  84. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  85. sm.set_array([])
  86. img_path = os.path.join(self.img_path, self.imgs[idx])
  87. img = PIL.Image.open(img_path).convert('RGB')
  88. if show_type=='all':
  89. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  90. colors="yellow", width=1)
  91. keypoint_img=draw_keypoints(boxed_image,target['lines'],colors='red',width=3)
  92. plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  93. plt.show()
  94. if show_type=='lines':
  95. keypoint_img=draw_keypoints((self.default_transform(img) * 255).to(torch.uint8),target['lines'],colors='red',width=3)
  96. plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  97. plt.show()
  98. if show_type=='boxes':
  99. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  100. colors="yellow", width=1)
  101. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  102. plt.show()
  103. def show_img(self, img_path):
  104. pass
  105. def get_boxes_lines(objs,shape):
  106. boxes = []
  107. h,w=shape
  108. line_point_pairs = []
  109. for obj in objs:
  110. # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  111. # print(f"points:{obj['points']}")
  112. a,b=obj['points'][0],obj['points'][1]
  113. line_point_pairs.append(a)
  114. line_point_pairs.append(b)
  115. xmin = max(0, (min(a[0], b[0]) - 6))
  116. xmax = min(w, (max(a[0], b[0]) + 6))
  117. ymin = max(0, (min(a[1], b[1]) - 6))
  118. ymax = min(h, (max(a[1], b[1]) + 6))
  119. boxes.append([ xmin,ymin, xmax,ymax])
  120. boxes=torch.tensor(boxes)
  121. line_point_pairs=torch.tensor(line_point_pairs)
  122. # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}')
  123. return boxes,line_point_pairs
  124. if __name__ == '__main__':
  125. path=r"\\192.168.50.222/share/rlq/datasets/0706_"
  126. dataset= LineDataset(dataset_path=path, dataset_type='train',data_type='jpg')
  127. dataset.show(1,show_type='lines')