dataset_tool.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # ??roi_head??????????????
  2. from torch.utils.data.dataset import T_co
  3. from models.base.base_dataset import BaseDataset
  4. import glob
  5. import json
  6. import math
  7. import os
  8. import random
  9. import cv2
  10. import PIL
  11. import matplotlib.pyplot as plt
  12. import matplotlib as mpl
  13. from torchvision.utils import draw_bounding_boxes
  14. import numpy as np
  15. import numpy.linalg as LA
  16. import torch
  17. from skimage import io
  18. from torch.utils.data import Dataset
  19. from torch.utils.data.dataloader import default_collate
  20. import matplotlib.pyplot as plt
  21. from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  22. from tools.presets import DetectionPresetTrain
  23. def line_boxes1(target):
  24. boxs = []
  25. lines = target.cpu().numpy() * 4
  26. if len(lines) > 0 and not (lines[0] == 0).all():
  27. for i, ((a, b)) in enumerate(lines):
  28. if i > 0 and (lines[i] == lines[0]).all():
  29. break
  30. if a[1] > b[1]:
  31. ymax = a[1] + 10
  32. ymin = b[1] - 10
  33. else:
  34. ymin = a[1] - 10
  35. ymax = b[1] + 10
  36. if a[0] > b[0]:
  37. xmax = a[0] + 10
  38. xmin = b[0] - 10
  39. else:
  40. xmin = a[0] - 10
  41. xmax = b[0] + 10
  42. boxs.append([ymin, xmin, ymax, xmax])
  43. # if boxs == []:
  44. # print(target)
  45. return torch.tensor(boxs)
  46. class WirePointDataset(BaseDataset):
  47. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  48. super().__init__(dataset_path)
  49. self.data_path = dataset_path
  50. print(f'data_path:{dataset_path}')
  51. self.transforms = transforms
  52. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  53. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  54. self.imgs = os.listdir(self.img_path)
  55. self.lbls = os.listdir(self.lbl_path)
  56. self.target_type = target_type
  57. # self.default_transform = DefaultTransform()
  58. def __getitem__(self, index) -> T_co:
  59. img_path = os.path.join(self.img_path, self.imgs[index])
  60. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  61. img = PIL.Image.open(img_path).convert('RGB')
  62. w, h = img.size
  63. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  64. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  65. if self.transforms:
  66. img, target = self.transforms(img, target)
  67. else:
  68. img = self.default_transform(img)
  69. # print(f'img:{img.shape}')
  70. return img, target
  71. def __len__(self):
  72. return len(self.imgs)
  73. def read_target(self, item, lbl_path, shape, extra=None):
  74. # print(f'lbl_path:{lbl_path}')
  75. with open(lbl_path, 'r') as file:
  76. lable_all = json.load(file)
  77. n_stc_posl = 300
  78. n_stc_negl = 40
  79. use_cood = 0
  80. use_slop = 0
  81. wire = lable_all["wires"][0] # ??
  82. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # ?????????
  83. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  84. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  85. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # ??????????
  86. for i in range(len(lpre)):
  87. if random.random() > 0.5:
  88. lpre[i] = lpre[i, ::-1]
  89. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  90. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  91. feat = [
  92. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  93. ldir * use_slop,
  94. lpre[:, :, 2],
  95. ]
  96. feat = np.concatenate(feat, 1)
  97. wire_labels = {
  98. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  99. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  100. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  101. # ???????????
  102. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  103. # ??????????
  104. "lpre": torch.tensor(lpre)[:, :, :2],
  105. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # ?????? 1?0
  106. "lpre_feat": torch.from_numpy(feat),
  107. "junc_map": torch.tensor(wire['junc_map']["content"]),
  108. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  109. "line_map": torch.tensor(wire['line_map']["content"]),
  110. }
  111. labels = []
  112. #
  113. # if self.target_type == 'polygon':
  114. # labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  115. # elif self.target_type == 'pixel':
  116. # labels = read_masks_from_pixels_wire(lbl_path, shape)
  117. # print(torch.stack(masks).shape) # [???, 512, 512]
  118. target = {}
  119. # target["labels"] = torch.stack(labels)
  120. target["image_id"] = torch.tensor(item)
  121. # return wire_labels, target
  122. target["wires"] = wire_labels
  123. # target["boxes"] = line_boxes(target)
  124. target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
  125. target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
  126. # print(f'target["labels"]:{ target["labels"]}')
  127. # print(f'boxes:{target["boxes"].shape}')
  128. if target["boxes"].numel() == 0:
  129. print("Tensor is empty")
  130. print(f'path:{lbl_path}')
  131. return target
  132. def show(self, idx):
  133. image, target = self.__getitem__(idx)
  134. cmap = plt.get_cmap("jet")
  135. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  136. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  137. sm.set_array([])
  138. def imshow(im):
  139. plt.close()
  140. plt.tight_layout()
  141. plt.imshow(im)
  142. plt.colorbar(sm, fraction=0.046)
  143. plt.xlim([0, im.shape[0]])
  144. plt.ylim([im.shape[0], 0])
  145. def draw_vecl(lines, sline, juncs, junts, fn=None):
  146. img_path = os.path.join(self.img_path, self.imgs[idx])
  147. imshow(io.imread(img_path))
  148. if len(lines) > 0 and not (lines[0] == 0).all():
  149. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  150. if i > 0 and (lines[i] == lines[0]).all():
  151. break
  152. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]?????
  153. if not (juncs[0] == 0).all():
  154. for i, j in enumerate(juncs):
  155. if i > 0 and (i == juncs[0]).all():
  156. break
  157. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # ? s=64
  158. img_path = os.path.join(self.img_path, self.imgs[idx])
  159. img = PIL.Image.open(img_path).convert('RGB')
  160. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  161. colors="yellow", width=1)
  162. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  163. plt.show()
  164. plt.show()
  165. if fn != None:
  166. plt.savefig(fn)
  167. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  168. jtyp = target['wires']['jtyp'].cpu().numpy()
  169. juncs = junc[jtyp == 0]
  170. junts = junc[jtyp == 1]
  171. lpre = target['wires']["lpre"].cpu().numpy() * 4
  172. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  173. lpre = lpre[vecl_target == 1]
  174. # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
  175. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
  176. def show_img(self, img_path):
  177. pass
  178. # dataset_train = WirePointDataset("/data/lm/dataset/0424_", dataset_type='val')
  179. # for i in dataset_train:
  180. # a = 1