line_dataset.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import imageio
  2. import numpy as np
  3. from torch.utils.data.dataset import T_co
  4. from libs.vision_libs.utils import draw_keypoints
  5. from models.base.base_dataset import BaseDataset
  6. import json
  7. import os
  8. import PIL
  9. import matplotlib as mpl
  10. from torchvision.utils import draw_bounding_boxes
  11. import torchvision.transforms.v2 as transforms
  12. import torch
  13. import matplotlib.pyplot as plt
  14. from models.base.transforms import get_transforms
  15. def validate_keypoints(keypoints, image_width, image_height):
  16. for kp in keypoints:
  17. x, y, v = kp
  18. if not (0 <= x < image_width and 0 <= y < image_height):
  19. raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
  20. """
  21. 直接读取xanlabel标注的数据集json格式
  22. """
  23. class LineDataset(BaseDataset):
  24. def __init__(self, dataset_path, data_type, transforms=None,augmentation=False, dataset_type=None,img_type='rgb', target_type='pixel'):
  25. super().__init__(dataset_path)
  26. self.data_path = dataset_path
  27. self.data_type = data_type
  28. print(f'data_path:{dataset_path}')
  29. self.transforms = transforms
  30. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  31. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  32. self.imgs = os.listdir(self.img_path)
  33. self.lbls = os.listdir(self.lbl_path)
  34. self.target_type = target_type
  35. self.img_type=img_type
  36. self.augmentation=augmentation
  37. print(f'augmentation:{augmentation}')
  38. # self.default_transform = DefaultTransform()
  39. def __getitem__(self, index) -> T_co:
  40. img_path = os.path.join(self.img_path, self.imgs[index])
  41. if self.data_type == 'tiff':
  42. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
  43. img = imageio.v3.imread(img_path)[:,:,0]
  44. print(f'img shape:{img.shape}')
  45. w, h = img.shape[:2]
  46. img=img.reshape(w,h,1)
  47. img_3channel = np.zeros((w, h, 3), dtype=img.dtype)
  48. img_3channel[:, :, 2] = img[:, :, 0]
  49. img = torch.from_numpy(img_3channel).permute(2, 1, 0)
  50. else:
  51. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  52. img = PIL.Image.open(img_path).convert('RGB')
  53. w, h = img.size
  54. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  55. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  56. self.transforms=get_transforms(augmention=self.augmentation)
  57. img, target = self.transforms(img, target)
  58. return img, target
  59. def __len__(self):
  60. return len(self.imgs)
  61. def read_target(self, item, lbl_path, shape, extra=None):
  62. # print(f'shape:{shape}')
  63. # print(f'lbl_path:{lbl_path}')
  64. with open(lbl_path, 'r') as file:
  65. lable_all = json.load(file)
  66. objs = lable_all["shapes"]
  67. point_pairs=objs[0]['points']
  68. # print(f'point_pairs:{point_pairs}')
  69. target = {}
  70. target["image_id"] = torch.tensor(item)
  71. boxes, lines, points, arc_mask,labels = get_boxes_lines(objs, shape)
  72. if points is not None:
  73. target["points"]=points
  74. if lines is not None:
  75. a = torch.full((lines.shape[0],), 2).unsqueeze(1)
  76. lines = torch.cat((lines, a), dim=1)
  77. target["lines"] = lines.to(torch.float32).view(-1, 2, 3)
  78. # print(f'lines shape:{ target["lines"].shape}')
  79. if arc_mask is not None:
  80. target['arc_mask']=arc_mask
  81. # print(f'arc_mask dataset')
  82. # else:
  83. # print(f'not arc_mask dataset')
  84. target["boxes"]=boxes
  85. target["labels"]=labels
  86. # target["boxes"], lines,target["points"], target["labels"] = get_boxes_lines(objs,shape)
  87. # print(f'lines:{lines}')
  88. # target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
  89. # print(f'target points:{target["points"]}')
  90. # target["lines"] = lines.to(torch.float32).view(-1,2,3)
  91. # print(f'')
  92. # print(f'lines:{target["lines"].shape}')
  93. target["img_size"]=shape
  94. # validate_keypoints(lines, shape[0], shape[1])
  95. return target
  96. def show(self, idx,show_type='all'):
  97. image, target = self.__getitem__(idx)
  98. cmap = plt.get_cmap("jet")
  99. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  100. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  101. sm.set_array([])
  102. # img_path = os.path.join(self.img_path, self.imgs[idx])
  103. # print(f'boxes:{target["boxes"]}')
  104. img = image
  105. if show_type=='all':
  106. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
  107. colors="yellow", width=1)
  108. keypoint_img=draw_keypoints(boxed_image,target['points'].unsqueeze(1),colors='red',width=3)
  109. plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  110. plt.show()
  111. # if show_type=='lines':
  112. # keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['lines'],colors='red',width=3)
  113. # plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  114. # plt.show()
  115. if show_type=='points':
  116. # print(f'points:{target['points'].shape}')
  117. keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['points'].unsqueeze(1),colors='red',width=3)
  118. plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  119. plt.show()
  120. if show_type=='boxes':
  121. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
  122. colors="yellow", width=1)
  123. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  124. plt.show()
  125. def show_img(self, img_path):
  126. pass
  127. def get_boxes_lines(objs,shape):
  128. boxes = []
  129. labels=[]
  130. h,w=shape
  131. line_point_pairs = []
  132. points=[]
  133. line_mask=[]
  134. for obj in objs:
  135. # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  136. # print(f"points:{obj['points']}")
  137. label=obj['label']
  138. if label =='line' or label=='dseam1':
  139. a,b=obj['points'][0],obj['points'][1]
  140. line_point_pairs.append(a)
  141. line_point_pairs.append(b)
  142. xmin = max(0, (min(a[0], b[0]) - 6))
  143. xmax = min(w, (max(a[0], b[0]) + 6))
  144. ymin = max(0, (min(a[1], b[1]) - 6))
  145. ymax = min(h, (max(a[1], b[1]) + 6))
  146. boxes.append([ xmin,ymin, xmax,ymax])
  147. labels.append(torch.tensor(2))
  148. elif label =='point':
  149. p= obj['points'][0]
  150. xmin=max(0,p[0]-12)
  151. xmax = min(w, p[0] +12)
  152. ymin=max(0,p[1]-12)
  153. ymax = min(h, p[1] + 12)
  154. points.append(p)
  155. labels.append(torch.tensor(1))
  156. boxes.append([xmin, ymin, xmax, ymax])
  157. elif label == 'arc' :
  158. line_mask.append(obj['points'])
  159. xmin = obj['xmin']
  160. xmax = obj['xmax']
  161. ymin = obj['ymin']
  162. ymax = obj['ymax']
  163. boxes.append([xmin, ymin, xmax, ymax])
  164. labels.append(torch.tensor(3))
  165. boxes=torch.tensor(boxes)
  166. print(f'boxes:{boxes.shape}')
  167. labels=torch.tensor(labels)
  168. if len(points)==0:
  169. points=None
  170. else:
  171. points=torch.tensor(points,dtype=torch.float32)
  172. print(f'read labels:{labels}')
  173. # print(f'read points:{points}')
  174. if len(line_point_pairs)==0:
  175. line_point_pairs=None
  176. else:
  177. line_point_pairs=torch.tensor(line_point_pairs)
  178. # print(f'line_point_pairs:{line_point_pairs.shape},{line_point_pairs.dtype}')
  179. # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}')
  180. if len(line_mask)==0:
  181. line_mask=None
  182. else:
  183. line_mask=torch.tensor(line_mask,dtype=torch.float32)
  184. print(f'arc_mask shape :{line_mask.shape},{line_mask.dtype}')
  185. return boxes,line_point_pairs,points,line_mask, labels
  186. if __name__ == '__main__':
  187. path=r"\\192.168.50.222\share\rlq\datasets\Dataset0709_"
  188. dataset= LineDataset(dataset_path=path, dataset_type='train',augmentation=False, data_type='jpg')
  189. dataset.show(1,show_type='all')