line_dataset_old.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from torch.utils.data.dataset import T_co
  2. from libs.vision_libs.utils import draw_keypoints
  3. from models.base.base_dataset import BaseDataset
  4. import glob
  5. import json
  6. import math
  7. import os
  8. import random
  9. import cv2
  10. import PIL
  11. import imageio
  12. import matplotlib.pyplot as plt
  13. import matplotlib as mpl
  14. from torchvision.utils import draw_bounding_boxes
  15. import numpy as np
  16. import numpy.linalg as LA
  17. import torch
  18. from skimage import io
  19. from torch.utils.data import Dataset
  20. from torch.utils.data.dataloader import default_collate
  21. import matplotlib.pyplot as plt
  22. from models.dataset_tool import read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  23. def validate_keypoints(keypoints, image_width, image_height):
  24. for kp in keypoints:
  25. x, y, v = kp
  26. if not (0 <= x < image_width and 0 <= y < image_height):
  27. raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
  28. class LineDataset(BaseDataset):
  29. def __init__(self, dataset_path, data_type, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
  30. super().__init__(dataset_path)
  31. self.data_path = dataset_path
  32. self.data_type = data_type
  33. print(f'data_path:{dataset_path}')
  34. self.transforms = transforms
  35. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  36. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  37. self.imgs = os.listdir(self.img_path)
  38. self.lbls = os.listdir(self.lbl_path)
  39. self.target_type = target_type
  40. self.img_type=img_type
  41. # self.default_transform = DefaultTransform()
  42. def __getitem__(self, index) -> T_co:
  43. img_path = os.path.join(self.img_path, self.imgs[index])
  44. if self.data_type == 'tiff':
  45. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
  46. # img = imageio.v3.imread(img_path).reshape(512, 512, 1)
  47. img = imageio.v3.imread(img_path)[:, :, :3]
  48. # img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
  49. # img_3channel[:, :, 2] = img[:, :, 0]
  50. img_3channel=img
  51. w, h = img.shape[:2]
  52. img = torch.from_numpy(img_3channel).permute(2, 0, 1)
  53. else:
  54. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  55. img = PIL.Image.open(img_path).convert('RGB')
  56. w, h = img.size
  57. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  58. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  59. if self.transforms:
  60. img, target = self.transforms(img, target)
  61. else:
  62. img = self.default_transform(img)
  63. # print(f'img:{img}')
  64. return img, target
  65. def __len__(self):
  66. return len(self.imgs)
  67. def read_target(self, item, lbl_path, shape, extra=None):
  68. # print(f'shape:{shape}')
  69. # print(f'lbl_path:{lbl_path}')
  70. with open(lbl_path, 'r') as file:
  71. lable_all = json.load(file)
  72. n_stc_posl = 300
  73. n_stc_negl = 40
  74. use_cood = 0
  75. use_slop = 0
  76. wire = lable_all["wires"][0] # 字典
  77. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少
  78. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  79. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  80. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起
  81. for i in range(len(lpre)):
  82. if random.random() > 0.5:
  83. lpre[i] = lpre[i, ::-1]
  84. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  85. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  86. feat = [
  87. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  88. ldir * use_slop,
  89. lpre[:, :, 2],
  90. ]
  91. feat = np.concatenate(feat, 1)
  92. wire_labels = {
  93. "junc_coords": torch.tensor(wire["junc_coords"]["content"]),
  94. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  95. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  96. # 真实存在线条的邻接矩阵
  97. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  98. "lpre": torch.tensor(lpre)[:, :, :2],
  99. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0
  100. "lpre_feat": torch.from_numpy(feat),
  101. "junc_map": torch.tensor(wire['junc_map']["content"]),
  102. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  103. "line_map": torch.tensor(wire['line_map']["content"]),
  104. }
  105. labels = []
  106. if self.target_type == 'polygon':
  107. labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  108. elif self.target_type == 'pixel':
  109. labels = read_masks_from_pixels_wire(lbl_path, shape)
  110. # print(torch.stack(masks).shape) # [线段数, 512, 512]
  111. target = {}
  112. target["image_id"] = torch.tensor(item)
  113. # return wire_labels, target
  114. target["wires"] = wire_labels
  115. # target["labels"] = torch.stack(labels)
  116. # print(f'labels:{target["labels"]}')
  117. # target["boxes"] = line_boxes(target)
  118. target["boxes"], lines = get_boxes_lines(target)
  119. target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
  120. # keypoints=keypoints/512
  121. # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
  122. # keypoints= wire_labels["junc_coords"]
  123. a = torch.full((lines.shape[0],), 2).unsqueeze(1)
  124. lines = torch.cat((lines, a), dim=1)
  125. target["lines"] = lines.to(torch.float32).view(-1,2,3)
  126. target["img_size"] = shape
  127. # print(f'boxes:{target["boxes"].shape}')
  128. # 在 __getitem__ 方法中调用此函数
  129. validate_keypoints(lines, shape[0], shape[1])
  130. # print(f'keypoints:{target["keypoints"].shape}')
  131. # print(f'target:{target}')
  132. return target
  133. def show(self, idx):
  134. image, target = self.__getitem__(idx)
  135. cmap = plt.get_cmap("jet")
  136. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  137. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  138. sm.set_array([])
  139. img_path = os.path.join(self.img_path, self.imgs[idx])
  140. img = PIL.Image.open(img_path).convert('RGB')
  141. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  142. colors="yellow", width=1)
  143. keypoint_img=draw_keypoints(boxed_image,target['keypoints'],colors='red',width=3)
  144. plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
  145. plt.show()
  146. def show_img(self, img_path):
  147. pass
  148. def get_boxes_lines(target):
  149. boxs = []
  150. lpre = target['wires']["lpre"].cpu().numpy()
  151. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  152. lpre = lpre[vecl_target == 1]
  153. lines = lpre
  154. sline = np.ones(lpre.shape[0])
  155. line_point_pairs = []
  156. if len(lines) > 0 and not (lines[0] == 0).all():
  157. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  158. if i > 0 and (lines[i] == lines[0]).all():
  159. break
  160. # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  161. line_point_pairs.append([a[1], a[0]])
  162. line_point_pairs.append([b[1], b[0]])
  163. xmin = max(0, (min(a[0], b[0]) - 6))
  164. xmax = min(511, (max(a[0], b[0]) + 6))
  165. ymin = max(0, (min(a[1], b[1]) - 6))
  166. ymax = min(511, (max(a[1], b[1]) + 6))
  167. boxs.append([ymin, xmin, ymax, xmax])
  168. return torch.tensor(boxs), torch.tensor(line_point_pairs)
  169. if __name__ == '__main__':
  170. path=r"\\192.168.50.222/share/lm/Dataset_all"
  171. dataset= LineDataset(dataset_path=path, dataset_type='train')
  172. dataset.show(10)