dataset_LD.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # ??roi_head??????????????
  2. import imageio
  3. from torch.utils.data.dataset import T_co
  4. from models.base.base_dataset import BaseDataset
  5. import glob
  6. import json
  7. import math
  8. import os
  9. import random
  10. import cv2
  11. import PIL
  12. import matplotlib.pyplot as plt
  13. import matplotlib as mpl
  14. from torchvision.utils import draw_bounding_boxes
  15. import numpy as np
  16. import numpy.linalg as LA
  17. import torch
  18. from skimage import io
  19. from torch.utils.data import Dataset
  20. from torch.utils.data.dataloader import default_collate
  21. import matplotlib.pyplot as plt
  22. from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  23. from tools.presets import DetectionPresetTrain
  24. def line_boxes1(target):
  25. boxs = []
  26. lines = target.cpu().numpy() * 4
  27. if len(lines) > 0 and not (lines[0] == 0).all():
  28. for i, ((a, b)) in enumerate(lines):
  29. if i > 0 and (lines[i] == lines[0]).all():
  30. break
  31. if a[1] > b[1]:
  32. ymax = a[1] + 10
  33. ymin = b[1] - 10
  34. else:
  35. ymin = a[1] - 10
  36. ymax = b[1] + 10
  37. if a[0] > b[0]:
  38. xmax = a[0] + 10
  39. xmin = b[0] - 10
  40. else:
  41. xmin = a[0] - 10
  42. xmax = b[0] + 10
  43. boxs.append([ymin, xmin, ymax, xmax])
  44. # if boxs == []:
  45. # print(target)
  46. return torch.tensor(boxs)
  47. class WirePointDataset(BaseDataset):
  48. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  49. super().__init__(dataset_path)
  50. self.data_path = dataset_path
  51. print(f'data_path:{dataset_path}')
  52. self.transforms = transforms
  53. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  54. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  55. self.imgs = os.listdir(self.img_path)
  56. self.lbls = os.listdir(self.lbl_path)
  57. self.target_type = target_type
  58. # self.default_transform = DefaultTransform()
  59. def __getitem__(self, index) -> T_co:
  60. img_path = os.path.join(self.img_path, self.imgs[index])
  61. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
  62. img=imageio.v3.imread(img_path).reshape(512,512,1)
  63. img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
  64. # 将原始通道复制到第一个通道
  65. img_3channel[:, :, 2] = img[:, :, 0]
  66. # print(f'dataset img shape:{img.shape}')
  67. # img = PIL.Image.open(img_path).convert('RGB')
  68. w, h = img.shape[:2]
  69. img=torch.from_numpy(img_3channel).permute(2, 0, 1)
  70. img=self.zscore_normalize_depth(img)
  71. # img=img.transpose(2,0,1)
  72. # print(f'dataset img shape2:{img.shape}')
  73. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  74. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  75. if self.transforms:
  76. img, target = self.transforms(img, target)
  77. else:
  78. img = self.default_transform(img)
  79. # print(f'img:{img.shape}')
  80. return img, target
  81. def __len__(self):
  82. return len(self.imgs)
  83. def zscore_normalize_depth(self,img):
  84. depth = img[2]
  85. mean = depth.mean()
  86. std = depth.std()
  87. depth_normalized = (depth - mean) / (std + 1e-8)
  88. img_normalized = img.clone()
  89. img_normalized[2] = depth_normalized
  90. return img_normalized
  91. def read_target(self, item, lbl_path, shape, extra=None):
  92. # print(f'lbl_path:{lbl_path}')
  93. with open(lbl_path, 'r') as file:
  94. lable_all = json.load(file)
  95. n_stc_posl = 300
  96. n_stc_negl = 40
  97. use_cood = 0
  98. use_slop = 0
  99. wire = lable_all["wires"][0] # ??
  100. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # ?????????
  101. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  102. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  103. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # ??????????
  104. for i in range(len(lpre)):
  105. if random.random() > 0.5:
  106. lpre[i] = lpre[i, ::-1]
  107. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  108. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  109. feat = [
  110. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  111. ldir * use_slop,
  112. lpre[:, :, 2],
  113. ]
  114. feat = np.concatenate(feat, 1)
  115. wire_labels = {
  116. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  117. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  118. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  119. # ???????????
  120. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  121. # ??????????
  122. "lpre": torch.tensor(lpre)[:, :, :2],
  123. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # ?????? 1?0
  124. "lpre_feat": torch.from_numpy(feat),
  125. "junc_map": torch.tensor(wire['junc_map']["content"]),
  126. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  127. "line_map": torch.tensor(wire['line_map']["content"]),
  128. }
  129. labels = []
  130. #
  131. # if self.target_type == 'polygon':
  132. # labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  133. # elif self.target_type == 'pixel':
  134. # labels = read_masks_from_pixels_wire(lbl_path, shape)
  135. # print(torch.stack(masks).shape) # [???, 512, 512]
  136. target = {}
  137. # target["labels"] = torch.stack(labels)
  138. target["image_id"] = torch.tensor(item)
  139. # return wire_labels, target
  140. target["wires"] = wire_labels
  141. # target["boxes"] = line_boxes(target)
  142. target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
  143. target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
  144. # print(f'target["labels"]:{ target["labels"]}')
  145. # print(f'boxes:{target["boxes"].shape}')
  146. if target["boxes"].numel() == 0:
  147. print("Tensor is empty")
  148. print(f'path:{lbl_path}')
  149. return target
  150. def show(self, idx):
  151. image, target = self.__getitem__(idx)
  152. cmap = plt.get_cmap("jet")
  153. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  154. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  155. sm.set_array([])
  156. def imshow(im):
  157. plt.close()
  158. plt.tight_layout()
  159. plt.imshow(im)
  160. plt.colorbar(sm, fraction=0.046)
  161. plt.xlim([0, im.shape[0]])
  162. plt.ylim([im.shape[0], 0])
  163. def draw_vecl(lines, sline, juncs, junts, fn=None):
  164. img_path = os.path.join(self.img_path, self.imgs[idx])
  165. imshow(io.imread(img_path))
  166. if len(lines) > 0 and not (lines[0] == 0).all():
  167. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  168. if i > 0 and (lines[i] == lines[0]).all():
  169. break
  170. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]?????
  171. if not (juncs[0] == 0).all():
  172. for i, j in enumerate(juncs):
  173. if i > 0 and (i == juncs[0]).all():
  174. break
  175. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # ? s=64
  176. img_path = os.path.join(self.img_path, self.imgs[idx])
  177. img = PIL.Image.open(img_path).convert('RGB')
  178. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  179. colors="yellow", width=1)
  180. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  181. plt.show()
  182. plt.show()
  183. if fn != None:
  184. plt.savefig(fn)
  185. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  186. jtyp = target['wires']['jtyp'].cpu().numpy()
  187. juncs = junc[jtyp == 0]
  188. junts = junc[jtyp == 1]
  189. lpre = target['wires']["lpre"].cpu().numpy() * 4
  190. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  191. lpre = lpre[vecl_target == 1]
  192. # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
  193. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
  194. def show_img(self, img_path):
  195. pass
  196. # dataset_train = WirePointDataset("/data/lm/dataset/0424_", dataset_type='val')
  197. # for i in dataset_train:
  198. # a = 1