datasets.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. # import glob
  2. # import json
  3. # import math
  4. # import os
  5. # import random
  6. #
  7. # import numpy as np
  8. # import numpy.linalg as LA
  9. # import torch
  10. # from skimage import io
  11. # from torch.utils.data import Dataset
  12. # from torch.utils.data.dataloader import default_collate
  13. #
  14. # from lcnn.config import M
  15. #
  16. # from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
  17. #
  18. #
  19. # class WireframeDataset(Dataset):
  20. # def __init__(self, rootdir, split):
  21. # self.rootdir = rootdir
  22. # filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
  23. # filelist.sort()
  24. #
  25. # # print(f"n{split}:", len(filelist))
  26. # self.split = split
  27. # self.filelist = filelist
  28. #
  29. # def __len__(self):
  30. # return len(self.filelist)
  31. #
  32. # def __getitem__(self, idx):
  33. # iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
  34. # image = io.imread(iname).astype(float)[:, :, :3]
  35. # if "a1" in self.filelist[idx]:
  36. # image = image[:, ::-1, :]
  37. # image = (image - M.image.mean) / M.image.stddev
  38. # image = np.rollaxis(image, 2).copy()
  39. #
  40. # with np.load(self.filelist[idx]) as npz:
  41. # target = {
  42. # name: torch.from_numpy(npz[name]).float()
  43. # for name in ["jmap", "joff", "lmap"]
  44. # }
  45. # lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
  46. # lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
  47. # npos, nneg = len(lpos), len(lneg)
  48. # lpre = np.concatenate([lpos, lneg], 0)
  49. # for i in range(len(lpre)):
  50. # if random.random() > 0.5:
  51. # lpre[i] = lpre[i, ::-1]
  52. # ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  53. # ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  54. # feat = [
  55. # lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
  56. # ldir * M.use_slop,
  57. # lpre[:, :, 2],
  58. # ]
  59. # feat = np.concatenate(feat, 1)
  60. # meta = {
  61. # "junc": torch.from_numpy(npz["junc"][:, :2]),
  62. # "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
  63. # "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
  64. # "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
  65. # "lpre": torch.from_numpy(lpre[:, :, :2]),
  66. # "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
  67. # "lpre_feat": torch.from_numpy(feat),
  68. # }
  69. #
  70. # labels = []
  71. # labels = read_masks_from_pixels_wire(iname, (512, 512))
  72. # # if self.target_type == 'polygon':
  73. # # labels, masks = read_masks_from_txt_wire(iname, (512, 512))
  74. # # elif self.target_type == 'pixel':
  75. # # labels = read_masks_from_pixels_wire(iname, (512, 512))
  76. #
  77. # target["labels"] = torch.stack(labels)
  78. # target["boxes"] = line_boxes_faster(meta)
  79. #
  80. #
  81. # return torch.from_numpy(image).float(), meta, target
  82. #
  83. # def adjacency_matrix(self, n, link):
  84. # mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
  85. # link = torch.from_numpy(link)
  86. # if len(link) > 0:
  87. # mat[link[:, 0], link[:, 1]] = 1
  88. # mat[link[:, 1], link[:, 0]] = 1
  89. # return mat
  90. #
  91. #
  92. # def collate(batch):
  93. # return (
  94. # default_collate([b[0] for b in batch]),
  95. # [b[1] for b in batch],
  96. # default_collate([b[2] for b in batch]),
  97. # )
  98. # 原LCNN数据格式,改了属性名,加了box相关
  99. from torch.utils.data.dataset import T_co
  100. from .models.base.base_dataset import BaseDataset
  101. import glob
  102. import json
  103. import math
  104. import os
  105. import random
  106. import cv2
  107. import PIL
  108. import matplotlib.pyplot as plt
  109. import matplotlib as mpl
  110. from torchvision.utils import draw_bounding_boxes
  111. import numpy as np
  112. import numpy.linalg as LA
  113. import torch
  114. from skimage import io
  115. from torch.utils.data import Dataset
  116. from torch.utils.data.dataloader import default_collate
  117. import matplotlib.pyplot as plt
  118. from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  119. class WireframeDataset(BaseDataset):
  120. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  121. super().__init__(dataset_path)
  122. self.data_path = dataset_path
  123. print(f'data_path:{dataset_path}')
  124. self.transforms = transforms
  125. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  126. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  127. self.imgs = os.listdir(self.img_path)
  128. self.lbls = os.listdir(self.lbl_path)
  129. self.target_type = target_type
  130. # self.default_transform = DefaultTransform()
  131. def __getitem__(self, index) -> T_co:
  132. img_path = os.path.join(self.img_path, self.imgs[index])
  133. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  134. img = PIL.Image.open(img_path).convert('RGB')
  135. w, h = img.size
  136. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  137. meta, target, target_b = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  138. img = self.default_transform(img)
  139. # print(f'img:{img}')
  140. return img, meta, target, target_b
  141. def __len__(self):
  142. return len(self.imgs)
  143. def read_target(self, item, lbl_path, shape, extra=None):
  144. # print(f'shape:{shape}')
  145. # print(f'lbl_path:{lbl_path}')
  146. with open(lbl_path, 'r') as file:
  147. lable_all = json.load(file)
  148. n_stc_posl = 300
  149. n_stc_negl = 40
  150. use_cood = 0
  151. use_slop = 0
  152. wire = lable_all["wires"][0] # 字典
  153. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少
  154. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  155. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  156. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起
  157. for i in range(len(lpre)):
  158. if random.random() > 0.5:
  159. lpre[i] = lpre[i, ::-1]
  160. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  161. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  162. feat = [
  163. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  164. ldir * use_slop,
  165. lpre[:, :, 2],
  166. ]
  167. feat = np.concatenate(feat, 1)
  168. meta = {
  169. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  170. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  171. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  172. # 真实存在线条的邻接矩阵
  173. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  174. "lpre": torch.tensor(lpre)[:, :, :2],
  175. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0
  176. "lpre_feat": torch.from_numpy(feat),
  177. }
  178. target = {
  179. "junc_map": torch.tensor(wire['junc_map']["content"]),
  180. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  181. "line_map": torch.tensor(wire['line_map']["content"]),
  182. }
  183. labels = []
  184. if self.target_type == 'polygon':
  185. labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  186. elif self.target_type == 'pixel':
  187. labels = read_masks_from_pixels_wire(lbl_path, shape)
  188. # print(torch.stack(masks).shape) # [线段数, 512, 512]
  189. target_b = {}
  190. # target_b["image_id"] = torch.tensor(item)
  191. target_b["labels"] = torch.stack(labels)
  192. target_b["boxes"] = line_boxes_faster(meta)
  193. return meta, target, target_b
  194. def show(self, idx):
  195. image, target = self.__getitem__(idx)
  196. cmap = plt.get_cmap("jet")
  197. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  198. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  199. sm.set_array([])
  200. def imshow(im):
  201. plt.close()
  202. plt.tight_layout()
  203. plt.imshow(im)
  204. plt.colorbar(sm, fraction=0.046)
  205. plt.xlim([0, im.shape[0]])
  206. plt.ylim([im.shape[0], 0])
  207. def draw_vecl(lines, sline, juncs, junts, fn=None):
  208. img_path = os.path.join(self.img_path, self.imgs[idx])
  209. imshow(io.imread(img_path))
  210. if len(lines) > 0 and not (lines[0] == 0).all():
  211. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  212. if i > 0 and (lines[i] == lines[0]).all():
  213. break
  214. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  215. if not (juncs[0] == 0).all():
  216. for i, j in enumerate(juncs):
  217. if i > 0 and (i == juncs[0]).all():
  218. break
  219. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
  220. img_path = os.path.join(self.img_path, self.imgs[idx])
  221. img = PIL.Image.open(img_path).convert('RGB')
  222. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  223. colors="yellow", width=1)
  224. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  225. plt.show()
  226. plt.show()
  227. if fn != None:
  228. plt.savefig(fn)
  229. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  230. jtyp = target['wires']['jtyp'].cpu().numpy()
  231. juncs = junc[jtyp == 0]
  232. junts = junc[jtyp == 1]
  233. lpre = target['wires']["lpre"].cpu().numpy() * 4
  234. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  235. lpre = lpre[vecl_target == 1]
  236. # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
  237. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
  238. def show_img(self, img_path):
  239. pass
  240. def collate(batch):
  241. return (
  242. default_collate([b[0] for b in batch]),
  243. [b[1] for b in batch],
  244. default_collate([b[2] for b in batch]),
  245. [b[3] for b in batch],
  246. )
  247. # if __name__ == '__main__':
  248. # path = r"D:\python\PycharmProjects\data"
  249. # dataset = WireframeDataset(dataset_path=path, dataset_type='train')
  250. # dataset.show(0)
  251. '''
  252. # 使用roi_head数据格式有要求,更改数据格式
  253. from torch.utils.data.dataset import T_co
  254. from models.base.base_dataset import BaseDataset
  255. import glob
  256. import json
  257. import math
  258. import os
  259. import random
  260. import cv2
  261. import PIL
  262. import matplotlib.pyplot as plt
  263. import matplotlib as mpl
  264. from torchvision.utils import draw_bounding_boxes
  265. import numpy as np
  266. import numpy.linalg as LA
  267. import torch
  268. from skimage import io
  269. from torch.utils.data import Dataset
  270. from torch.utils.data.dataloader import default_collate
  271. import matplotlib.pyplot as plt
  272. from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  273. from tools.presets import DetectionPresetTrain
  274. class WireframeDataset(BaseDataset):
  275. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  276. super().__init__(dataset_path)
  277. self.data_path = dataset_path
  278. print(f'data_path:{dataset_path}')
  279. self.transforms = transforms
  280. self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
  281. self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
  282. self.imgs = os.listdir(self.img_path)
  283. self.lbls = os.listdir(self.lbl_path)
  284. self.target_type = target_type
  285. # self.default_transform = DefaultTransform()
  286. self.data_augmentation = DetectionPresetTrain(data_augmentation="hflip") # multiscale会改变图像大小
  287. def __getitem__(self, index) -> T_co:
  288. img_path = os.path.join(self.img_path, self.imgs[index])
  289. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  290. img = PIL.Image.open(img_path).convert('RGB')
  291. w, h = img.size
  292. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  293. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  294. # if self.transforms:
  295. # img, target = self.transforms(img, target)
  296. # else:
  297. # img = self.default_transform(img)
  298. img, target = self.data_augmentation(img, target)
  299. print(f'img:{img.shape}')
  300. return img, target
  301. def __len__(self):
  302. return len(self.imgs)
  303. def read_target(self, item, lbl_path, shape, extra=None):
  304. # print(f'lbl_path:{lbl_path}')
  305. with open(lbl_path, 'r') as file:
  306. lable_all = json.load(file)
  307. n_stc_posl = 300
  308. n_stc_negl = 40
  309. use_cood = 0
  310. use_slop = 0
  311. wire = lable_all["wires"][0] # 字典
  312. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少
  313. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  314. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  315. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起
  316. for i in range(len(lpre)):
  317. if random.random() > 0.5:
  318. lpre[i] = lpre[i, ::-1]
  319. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  320. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  321. feat = [
  322. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  323. ldir * use_slop,
  324. lpre[:, :, 2],
  325. ]
  326. feat = np.concatenate(feat, 1)
  327. wire_labels = {
  328. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  329. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  330. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  331. # 真实存在线条的邻接矩阵
  332. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  333. # 不存在线条的临界矩阵
  334. "lpre": torch.tensor(lpre)[:, :, :2],
  335. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0
  336. "lpre_feat": torch.from_numpy(feat),
  337. "junc_map": torch.tensor(wire['junc_map']["content"]),
  338. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  339. "line_map": torch.tensor(wire['line_map']["content"]),
  340. }
  341. labels = []
  342. # if self.target_type == 'polygon':
  343. # labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  344. # elif self.target_type == 'pixel':
  345. # labels = read_masks_from_pixels_wire(lbl_path, shape)
  346. # print(torch.stack(masks).shape) # [线段数, 512, 512]
  347. target = {}
  348. # target["labels"] = torch.stack(labels)
  349. target["image_id"] = torch.tensor(item)
  350. # return wire_labels, target
  351. target["wires"] = wire_labels
  352. target["boxes"] = line_boxes(target)
  353. return target
  354. def show(self, idx):
  355. image, target = self.__getitem__(idx)
  356. img_path = os.path.join(self.img_path, self.imgs[idx])
  357. self._draw_vecl(img_path, target)
  358. def show_img(self, img_path):
  359. """根据给定的图片路径展示图像及其标注信息"""
  360. # 获取对应的标签文件路径
  361. img_name = os.path.basename(img_path)
  362. img_path = os.path.join(self.img_path, img_name)
  363. print(img_path)
  364. lbl_name = img_name[:-3] + 'json'
  365. lbl_path = os.path.join(self.lbl_path, lbl_name)
  366. print(lbl_path)
  367. if not os.path.exists(lbl_path):
  368. raise FileNotFoundError(f"Label file {lbl_path} does not exist.")
  369. img = PIL.Image.open(img_path).convert('RGB')
  370. w, h = img.size
  371. target = self.read_target(0, lbl_path, shape=(h, w))
  372. # 调用绘图函数
  373. self._draw_vecl(img_path, target)
  374. def _draw_vecl(self, img_path, target, fn=None):
  375. cmap = plt.get_cmap("jet")
  376. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  377. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  378. sm.set_array([])
  379. def imshow(im):
  380. plt.close()
  381. plt.tight_layout()
  382. plt.imshow(im)
  383. plt.colorbar(sm, fraction=0.046)
  384. plt.xlim([0, im.shape[0]])
  385. plt.ylim([im.shape[0], 0])
  386. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  387. jtyp = target['wires']['jtyp'].cpu().numpy()
  388. juncs = junc[jtyp == 0]
  389. junts = junc[jtyp == 1]
  390. lpre = target['wires']["lpre"].cpu().numpy() * 4
  391. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  392. lpre = lpre[vecl_target == 1]
  393. lines = lpre
  394. sline = np.ones(lpre.shape[0])
  395. imshow(io.imread(img_path))
  396. if len(lines) > 0 and not (lines[0] == 0).all():
  397. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  398. if i > 0 and (lines[i] == lines[0]).all():
  399. break
  400. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  401. if not (juncs[0] == 0).all():
  402. for i, j in enumerate(juncs):
  403. if i > 0 and (i == juncs[0]).all():
  404. break
  405. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
  406. img = PIL.Image.open(img_path).convert('RGB')
  407. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  408. colors="yellow", width=1)
  409. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  410. plt.show()
  411. if fn != None:
  412. plt.savefig(fn)
  413. '''