dataset_LD.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # ??roi_head??????????????
  2. from torch import dtype
  3. from torch.utils.data.dataset import T_co
  4. from models.base.base_dataset import BaseDataset
  5. import glob
  6. import json
  7. import math
  8. import os
  9. import random
  10. import cv2
  11. import PIL
  12. import matplotlib.pyplot as plt
  13. import matplotlib as mpl
  14. from torchvision.utils import draw_bounding_boxes
  15. import numpy as np
  16. import numpy.linalg as LA
  17. import torch
  18. from skimage import io
  19. from torch.utils.data import Dataset
  20. from torch.utils.data.dataloader import default_collate
  21. import matplotlib.pyplot as plt
  22. from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  23. from tools.presets import DetectionPresetTrain
  24. def line_boxes1(target):
  25. boxs = []
  26. lines = target.cpu().numpy() * 4
  27. if len(lines) > 0 and not (lines[0] == 0).all():
  28. for i, ((a, b)) in enumerate(lines):
  29. if i > 0 and (lines[i] == lines[0]).all():
  30. break
  31. if a[1] > b[1]:
  32. ymax = a[1] + 10
  33. ymin = b[1] - 10
  34. else:
  35. ymin = a[1] - 10
  36. ymax = b[1] + 10
  37. if a[0] > b[0]:
  38. xmax = a[0] + 10
  39. xmin = b[0] - 10
  40. else:
  41. xmin = a[0] - 10
  42. xmax = b[0] + 10
  43. boxs.append([ymin, xmin, ymax, xmax])
  44. # if boxs == []:
  45. # print(target)
  46. return torch.tensor(boxs)
  47. class WirePointDataset(BaseDataset):
  48. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  49. super().__init__(dataset_path)
  50. self.data_path = dataset_path
  51. print(f'data_path:{dataset_path}')
  52. self.transforms = transforms
  53. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  54. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  55. self.imgs = os.listdir(self.img_path)
  56. self.lbls = os.listdir(self.lbl_path)
  57. self.target_type = target_type
  58. # self.default_transform = DefaultTransform()
  59. def __getitem__(self, index) -> T_co:
  60. img_path = os.path.join(self.img_path, self.imgs[index])
  61. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  62. # img = PIL.Image.open(img_path).convert('RGB')
  63. # w, h = img.size
  64. img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
  65. print(img.shape)
  66. w, h = img.shape[0:2]
  67. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  68. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  69. # if self.transforms:
  70. # img, target = self.transforms(img, target)
  71. # else:
  72. # img = self.default_transform(img)
  73. # 分离RGB和深度通道
  74. rgb_channels = img[:, :, :3]
  75. depth_channel = img[:, :, 3]
  76. # rgb_normalized = rgb_channels.astype(np.float32) / 255.0
  77. rgb_normalized = rgb_channels
  78. depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())*255
  79. # 将归一化后的RGB和深度通道重新组合
  80. normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized)) # 或者使用depth_normalized_fixed_range
  81. print("Normalized RGBA image shape:", normalized_rgba_image.shape)
  82. img = torch.tensor(normalized_rgba_image,dtype=torch.float32).permute(2,1,0)
  83. # new_channel = torch.zeros(1, 512, 512)
  84. # img=torch.cat((img,new_channel),dim=0)
  85. print(f'img:{img.shape}')
  86. # print(f'img dtype:{img.dtype}')
  87. return img, target
  88. def __len__(self):
  89. return len(self.imgs)
  90. def read_target(self, item, lbl_path, shape, extra=None):
  91. # print(f'lbl_path:{lbl_path}')
  92. with open(lbl_path, 'r') as file:
  93. lable_all = json.load(file)
  94. n_stc_posl = 300
  95. n_stc_negl = 40
  96. use_cood = 0
  97. use_slop = 0
  98. wire = lable_all["wires"][0] # ??
  99. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # ?????????
  100. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  101. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  102. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # ??????????
  103. for i in range(len(lpre)):
  104. if random.random() > 0.5:
  105. lpre[i] = lpre[i, ::-1]
  106. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  107. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  108. feat = [
  109. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  110. ldir * use_slop,
  111. lpre[:, :, 2],
  112. ]
  113. feat = np.concatenate(feat, 1)
  114. wire_labels = {
  115. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  116. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  117. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  118. # ???????????
  119. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  120. # ??????????
  121. "lpre": torch.tensor(lpre)[:, :, :2],
  122. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # ?????? 1?0
  123. "lpre_feat": torch.from_numpy(feat),
  124. "junc_map": torch.tensor(wire['junc_map']["content"]),
  125. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  126. "line_map": torch.tensor(wire['line_map']["content"]),
  127. }
  128. labels = []
  129. #
  130. # if self.target_type == 'polygon':
  131. # labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  132. # elif self.target_type == 'pixel':
  133. # labels = read_masks_from_pixels_wire(lbl_path, shape)
  134. # print(torch.stack(masks).shape) # [???, 512, 512]
  135. target = {}
  136. # target["labels"] = torch.stack(labels)
  137. target["image_id"] = torch.tensor(item)
  138. # return wire_labels, target
  139. target["wires"] = wire_labels
  140. # target["boxes"] = line_boxes(target)
  141. target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
  142. target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
  143. # print(f'target["labels"]:{ target["labels"]}')
  144. # print(f'boxes:{target["boxes"].shape}')
  145. if target["boxes"].numel() == 0:
  146. print("Tensor is empty")
  147. print(f'path:{lbl_path}')
  148. return target
  149. def show(self, idx):
  150. image, target = self.__getitem__(idx)
  151. cmap = plt.get_cmap("jet")
  152. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  153. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  154. sm.set_array([])
  155. def imshow(im):
  156. plt.close()
  157. plt.tight_layout()
  158. plt.imshow(im)
  159. plt.colorbar(sm, fraction=0.046)
  160. plt.xlim([0, im.shape[0]])
  161. plt.ylim([im.shape[0], 0])
  162. def draw_vecl(lines, sline, juncs, junts, fn=None):
  163. img_path = os.path.join(self.img_path, self.imgs[idx])
  164. imshow(io.imread(img_path))
  165. if len(lines) > 0 and not (lines[0] == 0).all():
  166. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  167. if i > 0 and (lines[i] == lines[0]).all():
  168. break
  169. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]?????
  170. if not (juncs[0] == 0).all():
  171. for i, j in enumerate(juncs):
  172. if i > 0 and (i == juncs[0]).all():
  173. break
  174. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # ? s=64
  175. img_path = os.path.join(self.img_path, self.imgs[idx])
  176. img = PIL.Image.open(img_path).convert('RGB')
  177. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  178. colors="yellow", width=1)
  179. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  180. plt.show()
  181. plt.show()
  182. if fn != None:
  183. plt.savefig(fn)
  184. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  185. jtyp = target['wires']['jtyp'].cpu().numpy()
  186. juncs = junc[jtyp == 0]
  187. junts = junc[jtyp == 1]
  188. lpre = target['wires']["lpre"].cpu().numpy() * 4
  189. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  190. lpre = lpre[vecl_target == 1]
  191. # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
  192. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
  193. def show_img(self, img_path):
  194. pass
  195. # dataset_train = WirePointDataset("/data/lm/dataset/0424_", dataset_type='val')
  196. # for i in dataset_train:
  197. # a = 1