keypoint_dataset.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. <<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
  2. ========
  3. # import glob
  4. # import json
  5. # import math
  6. # import os
  7. # import random
  8. #
  9. # import numpy as np
  10. # import numpy.linalg as LA
  11. # import torch
  12. # from skimage import io
  13. # from torch.utils.data import Dataset
  14. # from torch.utils.data.dataloader import default_collate
  15. #
  16. # from lcnn.config import M
  17. #
  18. # from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
  19. #
  20. #
  21. # class WireframeDataset(Dataset):
  22. # def __init__(self, rootdir, split):
  23. # self.rootdir = rootdir
  24. # filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
  25. # filelist.sort()
  26. #
  27. # # print(f"n{split}:", len(filelist))
  28. # self.split = split
  29. # self.filelist = filelist
  30. #
  31. # def __len__(self):
  32. # return len(self.filelist)
  33. #
  34. # def __getitem__(self, idx):
  35. # iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
  36. # image = io.imread(iname).astype(float)[:, :, :3]
  37. # if "a1" in self.filelist[idx]:
  38. # image = image[:, ::-1, :]
  39. # image = (image - M.image.mean) / M.image.stddev
  40. # image = np.rollaxis(image, 2).copy()
  41. #
  42. # with np.load(self.filelist[idx]) as npz:
  43. # target = {
  44. # name: torch.from_numpy(npz[name]).float()
  45. # for name in ["jmap", "joff", "lmap"]
  46. # }
  47. # lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
  48. # lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
  49. # npos, nneg = len(lpos), len(lneg)
  50. # lpre = np.concatenate([lpos, lneg], 0)
  51. # for i in range(len(lpre)):
  52. # if random.random() > 0.5:
  53. # lpre[i] = lpre[i, ::-1]
  54. # ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  55. # ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  56. # feat = [
  57. # lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
  58. # ldir * M.use_slop,
  59. # lpre[:, :, 2],
  60. # ]
  61. # feat = np.concatenate(feat, 1)
  62. # meta = {
  63. # "junc": torch.from_numpy(npz["junc"][:, :2]),
  64. # "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
  65. # "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
  66. # "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
  67. # "lpre": torch.from_numpy(lpre[:, :, :2]),
  68. # "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
  69. # "lpre_feat": torch.from_numpy(feat),
  70. # }
  71. #
  72. # labels = []
  73. # labels = read_masks_from_pixels_wire(iname, (512, 512))
  74. # # if self.target_type == 'polygon':
  75. # # labels, masks = read_masks_from_txt_wire(iname, (512, 512))
  76. # # elif self.target_type == 'pixel':
  77. # # labels = read_masks_from_pixels_wire(iname, (512, 512))
  78. #
  79. # target["labels"] = torch.stack(labels)
  80. # target["boxes"] = line_boxes_faster(meta)
  81. #
  82. #
  83. # return torch.from_numpy(image).float(), meta, target
  84. #
  85. # def adjacency_matrix(self, n, link):
  86. # mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
  87. # link = torch.from_numpy(link)
  88. # if len(link) > 0:
  89. # mat[link[:, 0], link[:, 1]] = 1
  90. # mat[link[:, 1], link[:, 0]] = 1
  91. # return mat
  92. #
  93. #
  94. # def collate(batch):
  95. # return (
  96. # default_collate([b[0] for b in batch]),
  97. # [b[1] for b in batch],
  98. # default_collate([b[2] for b in batch]),
  99. # )
  100. # 原LCNN数据格式,改了属性名,加了box相关
  101. >>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
  102. from torch.utils.data.dataset import T_co
  103. from models.base.base_dataset import BaseDataset
  104. import glob
  105. import json
  106. import math
  107. import os
  108. import random
  109. import cv2
  110. import PIL
  111. import matplotlib.pyplot as plt
  112. import matplotlib as mpl
  113. from torchvision.utils import draw_bounding_boxes
  114. import numpy as np
  115. import numpy.linalg as LA
  116. import torch
  117. from skimage import io
  118. from torch.utils.data import Dataset
  119. from torch.utils.data.dataloader import default_collate
  120. import matplotlib.pyplot as plt
  121. from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  122. def validate_keypoints(keypoints, image_width, image_height):
  123. for kp in keypoints:
  124. x, y, v = kp
  125. if not (0 <= x < image_width and 0 <= y < image_height):
  126. raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
  127. <<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
  128. class KeypointDataset(BaseDataset):
  129. ========
  130. class WireframeDataset(BaseDataset):
  131. >>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
  132. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  133. super().__init__(dataset_path)
  134. self.data_path = dataset_path
  135. print(f'data_path:{dataset_path}')
  136. self.transforms = transforms
  137. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  138. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  139. self.imgs = os.listdir(self.img_path)
  140. self.lbls = os.listdir(self.lbl_path)
  141. self.target_type = target_type
  142. # self.default_transform = DefaultTransform()
  143. def __getitem__(self, index) -> T_co:
  144. img_path = os.path.join(self.img_path, self.imgs[index])
  145. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  146. img = PIL.Image.open(img_path).convert('RGB')
  147. w, h = img.size
  148. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  149. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  150. if self.transforms:
  151. img, target = self.transforms(img, target)
  152. else:
  153. img = self.default_transform(img)
  154. # print(f'img:{img}')
  155. return img, target
  156. def __len__(self):
  157. return len(self.imgs)
  158. def read_target(self, item, lbl_path, shape, extra=None):
  159. # print(f'shape:{shape}')
  160. # print(f'lbl_path:{lbl_path}')
  161. with open(lbl_path, 'r') as file:
  162. lable_all = json.load(file)
  163. n_stc_posl = 300
  164. n_stc_negl = 40
  165. use_cood = 0
  166. use_slop = 0
  167. wire = lable_all["wires"][0] # 字典
  168. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少
  169. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  170. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  171. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起
  172. for i in range(len(lpre)):
  173. if random.random() > 0.5:
  174. lpre[i] = lpre[i, ::-1]
  175. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  176. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  177. feat = [
  178. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  179. ldir * use_slop,
  180. lpre[:, :, 2],
  181. ]
  182. feat = np.concatenate(feat, 1)
  183. wire_labels = {
  184. "junc_coords": torch.tensor(wire["junc_coords"]["content"]),
  185. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  186. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  187. # 真实存在线条的邻接矩阵
  188. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  189. "lpre": torch.tensor(lpre)[:, :, :2],
  190. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0
  191. "lpre_feat": torch.from_numpy(feat),
  192. "junc_map": torch.tensor(wire['junc_map']["content"]),
  193. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  194. "line_map": torch.tensor(wire['line_map']["content"]),
  195. }
  196. labels = []
  197. if self.target_type == 'polygon':
  198. labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  199. elif self.target_type == 'pixel':
  200. labels = read_masks_from_pixels_wire(lbl_path, shape)
  201. # print(torch.stack(masks).shape) # [线段数, 512, 512]
  202. target = {}
  203. target["image_id"] = torch.tensor(item)
  204. # return wire_labels, target
  205. target["wires"] = wire_labels
  206. target["labels"] = torch.stack(labels)
  207. # print(f'labels:{target["labels"]}')
  208. # target["boxes"] = line_boxes(target)
  209. target["boxes"], keypoints = line_boxes(target)
  210. # keypoints=keypoints/512
  211. # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
  212. # keypoints= wire_labels["junc_coords"]
  213. a = torch.full((keypoints.shape[0],), 2).unsqueeze(1)
  214. keypoints = torch.cat((keypoints, a), dim=1)
  215. target["keypoints"] = keypoints.to(torch.float32).view(-1,2,3)
  216. # print(f'boxes:{target["boxes"].shape}')
  217. # 在 __getitem__ 方法中调用此函数
  218. validate_keypoints(keypoints, shape[0], shape[1])
  219. # print(f'keypoints:{target["keypoints"].shape}')
  220. return target
  221. def show(self, idx):
  222. image, target = self.__getitem__(idx)
  223. cmap = plt.get_cmap("jet")
  224. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  225. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  226. sm.set_array([])
  227. def imshow(im):
  228. plt.close()
  229. plt.tight_layout()
  230. plt.imshow(im)
  231. plt.colorbar(sm, fraction=0.046)
  232. plt.xlim([0, im.shape[0]])
  233. plt.ylim([im.shape[0], 0])
  234. def draw_vecl(lines, sline, juncs, junts, fn=None):
  235. img_path = os.path.join(self.img_path, self.imgs[idx])
  236. imshow(io.imread(img_path))
  237. if len(lines) > 0 and not (lines[0] == 0).all():
  238. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  239. if i > 0 and (lines[i] == lines[0]).all():
  240. break
  241. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  242. if not (juncs[0] == 0).all():
  243. for i, j in enumerate(juncs):
  244. if i > 0 and (i == juncs[0]).all():
  245. break
  246. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
  247. img_path = os.path.join(self.img_path, self.imgs[idx])
  248. img = PIL.Image.open(img_path).convert('RGB')
  249. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  250. colors="yellow", width=1)
  251. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  252. plt.show()
  253. plt.show()
  254. if fn != None:
  255. plt.savefig(fn)
  256. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  257. jtyp = target['wires']['jtyp'].cpu().numpy()
  258. juncs = junc[jtyp == 0]
  259. junts = junc[jtyp == 1]
  260. lpre = target['wires']["lpre"].cpu().numpy() * 4
  261. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  262. lpre = lpre[vecl_target == 1]
  263. # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
  264. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
  265. def show_img(self, img_path):
  266. pass
  267. <<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
  268. if __name__ == '__main__':
  269. path=r"I:\datasets\wirenet_1000"
  270. dataset= KeypointDataset(dataset_path=path, dataset_type='train')
  271. dataset.show(7)
  272. ========
  273. '''
  274. # 使用roi_head数据格式有要求,更改数据格式
  275. from torch.utils.data.dataset import T_co
  276. from models.base.base_dataset import BaseDataset
  277. import glob
  278. import json
  279. import math
  280. import os
  281. import random
  282. import cv2
  283. import PIL
  284. import matplotlib.pyplot as plt
  285. import matplotlib as mpl
  286. from torchvision.utils import draw_bounding_boxes
  287. import numpy as np
  288. import numpy.linalg as LA
  289. import torch
  290. from skimage import io
  291. from torch.utils.data import Dataset
  292. from torch.utils.data.dataloader import default_collate
  293. import matplotlib.pyplot as plt
  294. from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  295. from tools.presets import DetectionPresetTrain
  296. class WireframeDataset(BaseDataset):
  297. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  298. super().__init__(dataset_path)
  299. self.data_path = dataset_path
  300. print(f'data_path:{dataset_path}')
  301. self.transforms = transforms
  302. self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
  303. self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
  304. self.imgs = os.listdir(self.img_path)
  305. self.lbls = os.listdir(self.lbl_path)
  306. self.target_type = target_type
  307. # self.default_transform = DefaultTransform()
  308. self.data_augmentation = DetectionPresetTrain(data_augmentation="hflip") # multiscale会改变图像大小
  309. def __getitem__(self, index) -> T_co:
  310. img_path = os.path.join(self.img_path, self.imgs[index])
  311. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  312. img = PIL.Image.open(img_path).convert('RGB')
  313. w, h = img.size
  314. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  315. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  316. # if self.transforms:
  317. # img, target = self.transforms(img, target)
  318. # else:
  319. # img = self.default_transform(img)
  320. img, target = self.data_augmentation(img, target)
  321. print(f'img:{img.shape}')
  322. return img, target
  323. def __len__(self):
  324. return len(self.imgs)
  325. def read_target(self, item, lbl_path, shape, extra=None):
  326. # print(f'lbl_path:{lbl_path}')
  327. with open(lbl_path, 'r') as file:
  328. lable_all = json.load(file)
  329. n_stc_posl = 300
  330. n_stc_negl = 40
  331. use_cood = 0
  332. use_slop = 0
  333. wire = lable_all["wires"][0] # 字典
  334. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少
  335. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  336. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  337. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起
  338. for i in range(len(lpre)):
  339. if random.random() > 0.5:
  340. lpre[i] = lpre[i, ::-1]
  341. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  342. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  343. feat = [
  344. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  345. ldir * use_slop,
  346. lpre[:, :, 2],
  347. ]
  348. feat = np.concatenate(feat, 1)
  349. wire_labels = {
  350. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  351. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  352. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  353. # 真实存在线条的邻接矩阵
  354. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  355. # 不存在线条的临界矩阵
  356. "lpre": torch.tensor(lpre)[:, :, :2],
  357. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0
  358. "lpre_feat": torch.from_numpy(feat),
  359. "junc_map": torch.tensor(wire['junc_map']["content"]),
  360. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  361. "line_map": torch.tensor(wire['line_map']["content"]),
  362. }
  363. labels = []
  364. # if self.target_type == 'polygon':
  365. # labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  366. # elif self.target_type == 'pixel':
  367. # labels = read_masks_from_pixels_wire(lbl_path, shape)
  368. # print(torch.stack(masks).shape) # [线段数, 512, 512]
  369. target = {}
  370. # target["labels"] = torch.stack(labels)
  371. target["image_id"] = torch.tensor(item)
  372. # return wire_labels, target
  373. target["wires"] = wire_labels
  374. target["boxes"] = line_boxes(target)
  375. return target
  376. def show(self, idx):
  377. image, target = self.__getitem__(idx)
  378. img_path = os.path.join(self.img_path, self.imgs[idx])
  379. self._draw_vecl(img_path, target)
  380. def show_img(self, img_path):
  381. """根据给定的图片路径展示图像及其标注信息"""
  382. # 获取对应的标签文件路径
  383. img_name = os.path.basename(img_path)
  384. img_path = os.path.join(self.img_path, img_name)
  385. print(img_path)
  386. lbl_name = img_name[:-3] + 'json'
  387. lbl_path = os.path.join(self.lbl_path, lbl_name)
  388. print(lbl_path)
  389. if not os.path.exists(lbl_path):
  390. raise FileNotFoundError(f"Label file {lbl_path} does not exist.")
  391. img = PIL.Image.open(img_path).convert('RGB')
  392. w, h = img.size
  393. target = self.read_target(0, lbl_path, shape=(h, w))
  394. # 调用绘图函数
  395. self._draw_vecl(img_path, target)
  396. def _draw_vecl(self, img_path, target, fn=None):
  397. cmap = plt.get_cmap("jet")
  398. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  399. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  400. sm.set_array([])
  401. def imshow(im):
  402. plt.close()
  403. plt.tight_layout()
  404. plt.imshow(im)
  405. plt.colorbar(sm, fraction=0.046)
  406. plt.xlim([0, im.shape[0]])
  407. plt.ylim([im.shape[0], 0])
  408. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  409. jtyp = target['wires']['jtyp'].cpu().numpy()
  410. juncs = junc[jtyp == 0]
  411. junts = junc[jtyp == 1]
  412. lpre = target['wires']["lpre"].cpu().numpy() * 4
  413. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  414. lpre = lpre[vecl_target == 1]
  415. lines = lpre
  416. sline = np.ones(lpre.shape[0])
  417. imshow(io.imread(img_path))
  418. if len(lines) > 0 and not (lines[0] == 0).all():
  419. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  420. if i > 0 and (lines[i] == lines[0]).all():
  421. break
  422. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  423. if not (juncs[0] == 0).all():
  424. for i, j in enumerate(juncs):
  425. if i > 0 and (i == juncs[0]).all():
  426. break
  427. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
  428. img = PIL.Image.open(img_path).convert('RGB')
  429. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  430. colors="yellow", width=1)
  431. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  432. plt.show()
  433. if fn != None:
  434. plt.savefig(fn)
  435. '''
  436. >>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py