datasets.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # import glob
  2. # import json
  3. # import math
  4. # import os
  5. # import random
  6. #
  7. # import numpy as np
  8. # import numpy.linalg as LA
  9. # import torch
  10. # from skimage import io
  11. # from torch.utils.data import Dataset
  12. # from torch.utils.data.dataloader import default_collate
  13. #
  14. # from lcnn.config import M
  15. #
  16. # from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
  17. #
  18. #
  19. # class WireframeDataset(Dataset):
  20. # def __init__(self, rootdir, split):
  21. # self.rootdir = rootdir
  22. # filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
  23. # filelist.sort()
  24. #
  25. # # print(f"n{split}:", len(filelist))
  26. # self.split = split
  27. # self.filelist = filelist
  28. #
  29. # def __len__(self):
  30. # return len(self.filelist)
  31. #
  32. # def __getitem__(self, idx):
  33. # iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
  34. # image = io.imread(iname).astype(float)[:, :, :3]
  35. # if "a1" in self.filelist[idx]:
  36. # image = image[:, ::-1, :]
  37. # image = (image - M.image.mean) / M.image.stddev
  38. # image = np.rollaxis(image, 2).copy()
  39. #
  40. # with np.load(self.filelist[idx]) as npz:
  41. # target = {
  42. # name: torch.from_numpy(npz[name]).float()
  43. # for name in ["jmap", "joff", "lmap"]
  44. # }
  45. # lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
  46. # lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
  47. # npos, nneg = len(lpos), len(lneg)
  48. # lpre = np.concatenate([lpos, lneg], 0)
  49. # for i in range(len(lpre)):
  50. # if random.random() > 0.5:
  51. # lpre[i] = lpre[i, ::-1]
  52. # ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  53. # ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  54. # feat = [
  55. # lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
  56. # ldir * M.use_slop,
  57. # lpre[:, :, 2],
  58. # ]
  59. # feat = np.concatenate(feat, 1)
  60. # meta = {
  61. # "junc": torch.from_numpy(npz["junc"][:, :2]),
  62. # "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
  63. # "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
  64. # "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
  65. # "lpre": torch.from_numpy(lpre[:, :, :2]),
  66. # "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
  67. # "lpre_feat": torch.from_numpy(feat),
  68. # }
  69. #
  70. # labels = []
  71. # labels = read_masks_from_pixels_wire(iname, (512, 512))
  72. # # if self.target_type == 'polygon':
  73. # # labels, masks = read_masks_from_txt_wire(iname, (512, 512))
  74. # # elif self.target_type == 'pixel':
  75. # # labels = read_masks_from_pixels_wire(iname, (512, 512))
  76. #
  77. # target["labels"] = torch.stack(labels)
  78. # target["boxes"] = line_boxes_faster(meta)
  79. #
  80. #
  81. # return torch.from_numpy(image).float(), meta, target
  82. #
  83. # def adjacency_matrix(self, n, link):
  84. # mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
  85. # link = torch.from_numpy(link)
  86. # if len(link) > 0:
  87. # mat[link[:, 0], link[:, 1]] = 1
  88. # mat[link[:, 1], link[:, 0]] = 1
  89. # return mat
  90. #
  91. #
  92. # def collate(batch):
  93. # return (
  94. # default_collate([b[0] for b in batch]),
  95. # [b[1] for b in batch],
  96. # default_collate([b[2] for b in batch]),
  97. # )
  98. from torch.utils.data.dataset import T_co
  99. from .models.base.base_dataset import BaseDataset
  100. import glob
  101. import json
  102. import math
  103. import os
  104. import random
  105. import cv2
  106. import PIL
  107. import matplotlib.pyplot as plt
  108. import matplotlib as mpl
  109. from torchvision.utils import draw_bounding_boxes
  110. import numpy as np
  111. import numpy.linalg as LA
  112. import torch
  113. from skimage import io
  114. from torch.utils.data import Dataset
  115. from torch.utils.data.dataloader import default_collate
  116. import matplotlib.pyplot as plt
  117. from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  118. class WireframeDataset(BaseDataset):
  119. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  120. super().__init__(dataset_path)
  121. self.data_path = dataset_path
  122. print(f'data_path:{dataset_path}')
  123. self.transforms = transforms
  124. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  125. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  126. self.imgs = os.listdir(self.img_path)
  127. self.lbls = os.listdir(self.lbl_path)
  128. self.target_type = target_type
  129. # self.default_transform = DefaultTransform()
  130. def __getitem__(self, index) -> T_co:
  131. img_path = os.path.join(self.img_path, self.imgs[index])
  132. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  133. img = PIL.Image.open(img_path).convert('RGB')
  134. w, h = img.size
  135. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  136. meta, target, target_b = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  137. img = self.default_transform(img)
  138. # print(f'img:{img}')
  139. return img, meta, target, target_b
  140. def __len__(self):
  141. return len(self.imgs)
  142. def read_target(self, item, lbl_path, shape, extra=None):
  143. # print(f'shape:{shape}')
  144. # print(f'lbl_path:{lbl_path}')
  145. with open(lbl_path, 'r') as file:
  146. lable_all = json.load(file)
  147. n_stc_posl = 300
  148. n_stc_negl = 40
  149. use_cood = 0
  150. use_slop = 0
  151. wire = lable_all["wires"][0] # 字典
  152. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少
  153. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  154. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  155. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起
  156. for i in range(len(lpre)):
  157. if random.random() > 0.5:
  158. lpre[i] = lpre[i, ::-1]
  159. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  160. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  161. feat = [
  162. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  163. ldir * use_slop,
  164. lpre[:, :, 2],
  165. ]
  166. feat = np.concatenate(feat, 1)
  167. meta = {
  168. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  169. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  170. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  171. # 真实存在线条的邻接矩阵
  172. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  173. "lpre": torch.tensor(lpre)[:, :, :2],
  174. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0
  175. "lpre_feat": torch.from_numpy(feat),
  176. }
  177. target = {
  178. "junc_map": torch.tensor(wire['junc_map']["content"]),
  179. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  180. "line_map": torch.tensor(wire['line_map']["content"]),
  181. }
  182. labels = []
  183. if self.target_type == 'polygon':
  184. labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  185. elif self.target_type == 'pixel':
  186. labels = read_masks_from_pixels_wire(lbl_path, shape)
  187. # print(torch.stack(masks).shape) # [线段数, 512, 512]
  188. target_b = {}
  189. # target_b["image_id"] = torch.tensor(item)
  190. target_b["labels"] = torch.stack(labels)
  191. target_b["boxes"] = line_boxes_faster(meta)
  192. return meta, target, target_b
  193. def show(self, idx):
  194. image, target = self.__getitem__(idx)
  195. cmap = plt.get_cmap("jet")
  196. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  197. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  198. sm.set_array([])
  199. def imshow(im):
  200. plt.close()
  201. plt.tight_layout()
  202. plt.imshow(im)
  203. plt.colorbar(sm, fraction=0.046)
  204. plt.xlim([0, im.shape[0]])
  205. plt.ylim([im.shape[0], 0])
  206. def draw_vecl(lines, sline, juncs, junts, fn=None):
  207. img_path = os.path.join(self.img_path, self.imgs[idx])
  208. imshow(io.imread(img_path))
  209. if len(lines) > 0 and not (lines[0] == 0).all():
  210. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  211. if i > 0 and (lines[i] == lines[0]).all():
  212. break
  213. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  214. if not (juncs[0] == 0).all():
  215. for i, j in enumerate(juncs):
  216. if i > 0 and (i == juncs[0]).all():
  217. break
  218. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
  219. img_path = os.path.join(self.img_path, self.imgs[idx])
  220. img = PIL.Image.open(img_path).convert('RGB')
  221. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  222. colors="yellow", width=1)
  223. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  224. plt.show()
  225. plt.show()
  226. if fn != None:
  227. plt.savefig(fn)
  228. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  229. jtyp = target['wires']['jtyp'].cpu().numpy()
  230. juncs = junc[jtyp == 0]
  231. junts = junc[jtyp == 1]
  232. lpre = target['wires']["lpre"].cpu().numpy() * 4
  233. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  234. lpre = lpre[vecl_target == 1]
  235. # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
  236. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
  237. def show_img(self, img_path):
  238. pass
  239. def collate(batch):
  240. return (
  241. default_collate([b[0] for b in batch]),
  242. [b[1] for b in batch],
  243. default_collate([b[2] for b in batch]),
  244. [b[3] for b in batch],
  245. )
  246. # if __name__ == '__main__':
  247. # path = r"D:\python\PycharmProjects\data"
  248. # dataset = WireframeDataset(dataset_path=path, dataset_type='train')
  249. # dataset.show(0)