dataset_LD.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # ??roi_head??????????????
  2. from torch.utils.data.dataset import T_co
  3. from models.base.base_dataset import BaseDataset
  4. import glob
  5. import json
  6. import math
  7. import os
  8. import random
  9. import cv2
  10. import PIL
  11. import matplotlib.pyplot as plt
  12. import matplotlib as mpl
  13. from torchvision.utils import draw_bounding_boxes
  14. import numpy as np
  15. import numpy.linalg as LA
  16. import torch
  17. from skimage import io
  18. from torch.utils.data import Dataset
  19. from torch.utils.data.dataloader import default_collate
  20. import matplotlib.pyplot as plt
  21. from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  22. from tools.presets import DetectionPresetTrain
  23. def line_boxes1(target):
  24. boxs = []
  25. lines = target.cpu().numpy() * 4
  26. if len(lines) > 0 and not (lines[0] == 0).all():
  27. for i, ((a, b)) in enumerate(lines):
  28. if i > 0 and (lines[i] == lines[0]).all():
  29. break
  30. if a[1] > b[1]:
  31. ymax = a[1] + 10
  32. ymin = b[1] - 10
  33. else:
  34. ymin = a[1] - 10
  35. ymax = b[1] + 10
  36. if a[0] > b[0]:
  37. xmax = a[0] + 10
  38. xmin = b[0] - 10
  39. else:
  40. xmin = a[0] - 10
  41. xmax = b[0] + 10
  42. boxs.append([ymin, xmin, ymax, xmax])
  43. # if boxs == []:
  44. # print(target)
  45. return torch.tensor(boxs)
  46. class WirePointDataset(BaseDataset):
  47. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  48. super().__init__(dataset_path)
  49. self.data_path = dataset_path
  50. print(f'data_path:{dataset_path}')
  51. self.transforms = transforms
  52. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  53. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  54. self.imgs = os.listdir(self.img_path)
  55. self.lbls = os.listdir(self.lbl_path)
  56. self.target_type = target_type
  57. # self.default_transform = DefaultTransform()
  58. def __getitem__(self, index) -> T_co:
  59. img_path = os.path.join(self.img_path, self.imgs[index])
  60. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  61. # img = PIL.Image.open(img_path).convert('RGB')
  62. # w, h = img.size
  63. img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
  64. print(img.shape)
  65. w, h = img.shape[0:2]
  66. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  67. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  68. # if self.transforms:
  69. # img, target = self.transforms(img, target)
  70. # else:
  71. # img = self.default_transform(img)
  72. # 分离RGB和深度通道
  73. rgb_channels = img[:, :, :3]
  74. depth_channel = img[:, :, 3]
  75. # rgb_normalized = rgb_channels.astype(np.float32) / 255.0
  76. rgb_normalized = rgb_channels
  77. depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())*255
  78. # 将归一化后的RGB和深度通道重新组合
  79. normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized)) # 或者使用depth_normalized_fixed_range
  80. print("Normalized RGBA image shape:", normalized_rgba_image.shape)
  81. img = torch.tensor(normalized_rgba_image,dtype=torch.float32).permute(2,1,0)
  82. # new_channel = torch.zeros(1, 512, 512)
  83. # img=torch.cat((img,new_channel),dim=0)
  84. print(f'img:{img.shape}')
  85. # print(f'img dtype:{img.dtype}')
  86. return img, target
  87. def __len__(self):
  88. return len(self.imgs)
  89. def read_target(self, item, lbl_path, shape, extra=None):
  90. # print(f'lbl_path:{lbl_path}')
  91. with open(lbl_path, 'r') as file:
  92. lable_all = json.load(file)
  93. n_stc_posl = 300
  94. n_stc_negl = 40
  95. use_cood = 0
  96. use_slop = 0
  97. wire = lable_all["wires"][0] # ??
  98. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # ?????????
  99. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  100. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  101. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # ??????????
  102. for i in range(len(lpre)):
  103. if random.random() > 0.5:
  104. lpre[i] = lpre[i, ::-1]
  105. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  106. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  107. feat = [
  108. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  109. ldir * use_slop,
  110. lpre[:, :, 2],
  111. ]
  112. feat = np.concatenate(feat, 1)
  113. wire_labels = {
  114. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  115. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  116. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  117. # ???????????
  118. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  119. # ??????????
  120. "lpre": torch.tensor(lpre)[:, :, :2],
  121. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # ?????? 1?0
  122. "lpre_feat": torch.from_numpy(feat),
  123. "junc_map": torch.tensor(wire['junc_map']["content"]),
  124. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  125. "line_map": torch.tensor(wire['line_map']["content"]),
  126. }
  127. labels = []
  128. #
  129. # if self.target_type == 'polygon':
  130. # labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  131. # elif self.target_type == 'pixel':
  132. # labels = read_masks_from_pixels_wire(lbl_path, shape)
  133. # print(torch.stack(masks).shape) # [???, 512, 512]
  134. target = {}
  135. # target["labels"] = torch.stack(labels)
  136. target["image_id"] = torch.tensor(item)
  137. # return wire_labels, target
  138. target["wires"] = wire_labels
  139. # target["boxes"] = line_boxes(target)
  140. target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
  141. target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
  142. # print(f'target["labels"]:{ target["labels"]}')
  143. # print(f'boxes:{target["boxes"].shape}')
  144. if target["boxes"].numel() == 0:
  145. print("Tensor is empty")
  146. print(f'path:{lbl_path}')
  147. return target
  148. def show(self, idx):
  149. image, target = self.__getitem__(idx)
  150. cmap = plt.get_cmap("jet")
  151. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  152. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  153. sm.set_array([])
  154. def imshow(im):
  155. plt.close()
  156. plt.tight_layout()
  157. plt.imshow(im)
  158. plt.colorbar(sm, fraction=0.046)
  159. plt.xlim([0, im.shape[0]])
  160. plt.ylim([im.shape[0], 0])
  161. def draw_vecl(lines, sline, juncs, junts, fn=None):
  162. img_path = os.path.join(self.img_path, self.imgs[idx])
  163. imshow(io.imread(img_path))
  164. if len(lines) > 0 and not (lines[0] == 0).all():
  165. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  166. if i > 0 and (lines[i] == lines[0]).all():
  167. break
  168. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]?????
  169. if not (juncs[0] == 0).all():
  170. for i, j in enumerate(juncs):
  171. if i > 0 and (i == juncs[0]).all():
  172. break
  173. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # ? s=64
  174. img_path = os.path.join(self.img_path, self.imgs[idx])
  175. img = PIL.Image.open(img_path).convert('RGB')
  176. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  177. colors="yellow", width=1)
  178. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  179. plt.show()
  180. plt.show()
  181. if fn != None:
  182. plt.savefig(fn)
  183. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  184. jtyp = target['wires']['jtyp'].cpu().numpy()
  185. juncs = junc[jtyp == 0]
  186. junts = junc[jtyp == 1]
  187. lpre = target['wires']["lpre"].cpu().numpy() * 4
  188. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  189. lpre = lpre[vecl_target == 1]
  190. # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
  191. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
  192. def show_img(self, img_path):
  193. pass
  194. # dataset_train = WirePointDataset("/data/lm/dataset/0424_", dataset_type='val')
  195. # for i in dataset_train:
  196. # a = 1