dataset_LD.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # ??roi_head??????????????
  2. import imageio
  3. from torch.utils.data.dataset import T_co
  4. from models.base.base_dataset import BaseDataset
  5. import glob
  6. import json
  7. import math
  8. import os
  9. import random
  10. import cv2
  11. import PIL
  12. import matplotlib.pyplot as plt
  13. import matplotlib as mpl
  14. from torchvision.utils import draw_bounding_boxes
  15. import numpy as np
  16. import numpy.linalg as LA
  17. import torch
  18. from skimage import io
  19. from torch.utils.data import Dataset
  20. from torch.utils.data.dataloader import default_collate
  21. import matplotlib.pyplot as plt
  22. from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  23. from tools.presets import DetectionPresetTrain
  24. def line_boxes1(target):
  25. boxs = []
  26. lines = target.cpu().numpy() * 4
  27. if len(lines) > 0 and not (lines[0] == 0).all():
  28. for i, ((a, b)) in enumerate(lines):
  29. if i > 0 and (lines[i] == lines[0]).all():
  30. break
  31. if a[1] > b[1]:
  32. ymax = a[1] + 10
  33. ymin = b[1] - 10
  34. else:
  35. ymin = a[1] - 10
  36. ymax = b[1] + 10
  37. if a[0] > b[0]:
  38. xmax = a[0] + 10
  39. xmin = b[0] - 10
  40. else:
  41. xmin = a[0] - 10
  42. xmax = b[0] + 10
  43. boxs.append([ymin, xmin, ymax, xmax])
  44. # if boxs == []:
  45. # print(target)
  46. return torch.tensor(boxs)
  47. class WirePointDataset(BaseDataset):
  48. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  49. super().__init__(dataset_path)
  50. self.data_path = dataset_path
  51. print(f'data_path:{dataset_path}')
  52. self.transforms = transforms
  53. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  54. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  55. self.imgs = os.listdir(self.img_path)
  56. self.lbls = os.listdir(self.lbl_path)
  57. self.target_type = target_type
  58. # self.default_transform = DefaultTransform()
  59. def __getitem__(self, index) -> T_co:
  60. img_path = os.path.join(self.img_path, self.imgs[index])
  61. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
  62. img=imageio.v3.imread(img_path).reshape(512,512,1)
  63. img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
  64. # 将原始通道复制到第一个通道
  65. img_3channel[:, :, 2] = img[:, :, 0]
  66. # print(f'dataset img shape:{img.shape}')
  67. # img = PIL.Image.open(img_path).convert('RGB')
  68. w, h = img.shape[:2]
  69. img=torch.from_numpy(img_3channel).permute(2, 0, 1)
  70. # img=img.transpose(2,0,1)
  71. # print(f'dataset img shape2:{img.shape}')
  72. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  73. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  74. if self.transforms:
  75. img, target = self.transforms(img, target)
  76. else:
  77. img = self.default_transform(img)
  78. # print(f'img:{img.shape}')
  79. return img, target
  80. def __len__(self):
  81. return len(self.imgs)
  82. def read_target(self, item, lbl_path, shape, extra=None):
  83. # print(f'lbl_path:{lbl_path}')
  84. with open(lbl_path, 'r') as file:
  85. lable_all = json.load(file)
  86. n_stc_posl = 300
  87. n_stc_negl = 40
  88. use_cood = 0
  89. use_slop = 0
  90. wire = lable_all["wires"][0] # ??
  91. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # ?????????
  92. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  93. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  94. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # ??????????
  95. for i in range(len(lpre)):
  96. if random.random() > 0.5:
  97. lpre[i] = lpre[i, ::-1]
  98. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  99. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  100. feat = [
  101. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  102. ldir * use_slop,
  103. lpre[:, :, 2],
  104. ]
  105. feat = np.concatenate(feat, 1)
  106. wire_labels = {
  107. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  108. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  109. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  110. # ???????????
  111. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  112. # ??????????
  113. "lpre": torch.tensor(lpre)[:, :, :2],
  114. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # ?????? 1?0
  115. "lpre_feat": torch.from_numpy(feat),
  116. "junc_map": torch.tensor(wire['junc_map']["content"]),
  117. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  118. "line_map": torch.tensor(wire['line_map']["content"]),
  119. }
  120. labels = []
  121. #
  122. # if self.target_type == 'polygon':
  123. # labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  124. # elif self.target_type == 'pixel':
  125. # labels = read_masks_from_pixels_wire(lbl_path, shape)
  126. # print(torch.stack(masks).shape) # [???, 512, 512]
  127. target = {}
  128. # target["labels"] = torch.stack(labels)
  129. target["image_id"] = torch.tensor(item)
  130. # return wire_labels, target
  131. target["wires"] = wire_labels
  132. # target["boxes"] = line_boxes(target)
  133. target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
  134. target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
  135. # print(f'target["labels"]:{ target["labels"]}')
  136. # print(f'boxes:{target["boxes"].shape}')
  137. if target["boxes"].numel() == 0:
  138. print("Tensor is empty")
  139. print(f'path:{lbl_path}')
  140. return target
  141. def show(self, idx):
  142. image, target = self.__getitem__(idx)
  143. cmap = plt.get_cmap("jet")
  144. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  145. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  146. sm.set_array([])
  147. def imshow(im):
  148. plt.close()
  149. plt.tight_layout()
  150. plt.imshow(im)
  151. plt.colorbar(sm, fraction=0.046)
  152. plt.xlim([0, im.shape[0]])
  153. plt.ylim([im.shape[0], 0])
  154. def draw_vecl(lines, sline, juncs, junts, fn=None):
  155. img_path = os.path.join(self.img_path, self.imgs[idx])
  156. imshow(io.imread(img_path))
  157. if len(lines) > 0 and not (lines[0] == 0).all():
  158. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  159. if i > 0 and (lines[i] == lines[0]).all():
  160. break
  161. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]?????
  162. if not (juncs[0] == 0).all():
  163. for i, j in enumerate(juncs):
  164. if i > 0 and (i == juncs[0]).all():
  165. break
  166. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # ? s=64
  167. img_path = os.path.join(self.img_path, self.imgs[idx])
  168. img = PIL.Image.open(img_path).convert('RGB')
  169. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  170. colors="yellow", width=1)
  171. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  172. plt.show()
  173. plt.show()
  174. if fn != None:
  175. plt.savefig(fn)
  176. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  177. jtyp = target['wires']['jtyp'].cpu().numpy()
  178. juncs = junc[jtyp == 0]
  179. junts = junc[jtyp == 1]
  180. lpre = target['wires']["lpre"].cpu().numpy() * 4
  181. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  182. lpre = lpre[vecl_target == 1]
  183. # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
  184. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
  185. def show_img(self, img_path):
  186. pass
  187. # dataset_train = WirePointDataset("/data/lm/dataset/0424_", dataset_type='val')
  188. # for i in dataset_train:
  189. # a = 1