wirepoint_dataset.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from torch.utils.data.dataset import T_co
  2. from models.base.base_dataset import BaseDataset
  3. import glob
  4. import json
  5. import math
  6. import os
  7. import random
  8. import cv2
  9. import PIL
  10. import matplotlib.pyplot as plt
  11. import matplotlib as mpl
  12. from torchvision.utils import draw_bounding_boxes
  13. import numpy as np
  14. import numpy.linalg as LA
  15. import torch
  16. from skimage import io
  17. from torch.utils.data import Dataset
  18. from torch.utils.data.dataloader import default_collate
  19. import matplotlib.pyplot as plt
  20. from models.dataset_tool import masks_to_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  21. class WirePointDataset(BaseDataset):
  22. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  23. super().__init__(dataset_path)
  24. self.data_path = dataset_path
  25. print(f'data_path:{dataset_path}')
  26. self.transforms = transforms
  27. self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
  28. self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
  29. self.imgs = os.listdir(self.img_path)
  30. self.lbls = os.listdir(self.lbl_path)
  31. self.target_type = target_type
  32. # self.default_transform = DefaultTransform()
  33. def __getitem__(self, index) -> T_co:
  34. img_path = os.path.join(self.img_path, self.imgs[index])
  35. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  36. img = PIL.Image.open(img_path).convert('RGB')
  37. w, h = img.size
  38. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  39. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  40. if self.transforms:
  41. img, target = self.transforms(img, target)
  42. else:
  43. img = self.default_transform(img)
  44. # print(f'img:{img}')
  45. return img, target
  46. def __len__(self):
  47. return len(self.imgs)
  48. def read_target(self, item, lbl_path, shape, extra=None):
  49. # print(f'lbl_path:{lbl_path}')
  50. with open(lbl_path, 'r') as file:
  51. lable_all = json.load(file)
  52. n_stc_posl = 300
  53. n_stc_negl = 40
  54. use_cood = 0
  55. use_slop = 0
  56. wire = lable_all["wires"][0] # 字典
  57. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少
  58. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  59. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  60. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起
  61. for i in range(len(lpre)):
  62. if random.random() > 0.5:
  63. lpre[i] = lpre[i, ::-1]
  64. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  65. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  66. feat = [
  67. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  68. ldir * use_slop,
  69. lpre[:, :, 2],
  70. ]
  71. feat = np.concatenate(feat, 1)
  72. wire_labels = {
  73. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  74. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  75. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  76. # 真实存在线条的邻接矩阵
  77. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  78. # 不存在线条的临界矩阵
  79. "lpre": torch.tensor(lpre)[:, :, :2],
  80. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0
  81. "lpre_feat": torch.from_numpy(feat),
  82. "junc_map": torch.tensor(wire['junc_map']["content"]),
  83. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  84. "line_map": torch.tensor(wire['line_map']["content"]),
  85. }
  86. h, w = shape
  87. labels = []
  88. masks = []
  89. if self.target_type == 'polygon':
  90. labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  91. elif self.target_type == 'pixel':
  92. labels, masks = read_masks_from_pixels_wire(lbl_path, shape)
  93. # print(torch.stack(masks).shape) # [线段数, 512, 512]
  94. target = {}
  95. # target["boxes"] = masks_to_boxes(torch.stack(masks))
  96. # print(target["boxes"])
  97. target["labels"] = torch.stack(labels)
  98. target["masks"] = torch.stack(masks)
  99. target["image_id"] = torch.tensor(item)
  100. # return wire_labels, target
  101. target["wires"] = wire_labels
  102. boxs = []
  103. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  104. lpre = target['wires']["lpre"].cpu().numpy() * 4
  105. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  106. lpre = lpre[vecl_target == 1]
  107. lines = lpre
  108. sline = np.ones(lpre.shape[0])
  109. if len(lines) > 0 and not (lines[0] == 0).all():
  110. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  111. if i > 0 and (lines[i] == lines[0]).all():
  112. break
  113. # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  114. if a[1] > b[1]:
  115. ymax = a[1] + 1
  116. ymin = b[1] - 1
  117. else:
  118. ymin = a[1] - 1
  119. ymax = b[1] + 1
  120. if a[0] > b[0]:
  121. xmax = a[0] + 1
  122. xmin = b[0] - 1
  123. else:
  124. xmin = a[0] - 1
  125. xmax = b[0] + 1
  126. boxs.append([ymin, xmin, ymax, xmax])
  127. # plt.Rectangle([a[1] - 1, b[1] + 1], [a[0] + 1, b[0] - 1], c="g", linewidth=1)
  128. target["line_boxes"] = torch.tensor(boxs)
  129. return target
  130. def show(self, idx):
  131. image, target = self.__getitem__(idx)
  132. cmap = plt.get_cmap("jet")
  133. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  134. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  135. sm.set_array([])
  136. def imshow(im):
  137. plt.close()
  138. plt.tight_layout()
  139. plt.imshow(im)
  140. plt.colorbar(sm, fraction=0.046)
  141. plt.xlim([0, im.shape[0]])
  142. plt.ylim([im.shape[0], 0])
  143. def draw_vecl(lines, sline, juncs, junts, fn=None):
  144. img_path = os.path.join(self.img_path, self.imgs[idx])
  145. imshow(io.imread(img_path))
  146. if len(lines) > 0 and not (lines[0] == 0).all():
  147. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  148. if i > 0 and (lines[i] == lines[0]).all():
  149. break
  150. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  151. if not (juncs[0] == 0).all():
  152. for i, j in enumerate(juncs):
  153. if i > 0 and (i == juncs[0]).all():
  154. break
  155. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
  156. img_path = os.path.join(self.img_path, self.imgs[idx])
  157. img = PIL.Image.open(img_path).convert('RGB')
  158. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["line_boxes"],
  159. colors="yellow", width=1)
  160. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  161. plt.show()
  162. plt.show()
  163. if fn != None:
  164. plt.savefig(fn)
  165. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  166. jtyp = target['wires']['jtyp'].cpu().numpy()
  167. juncs = junc[jtyp == 0]
  168. junts = junc[jtyp == 1]
  169. lpre = target['wires']["lpre"].cpu().numpy() * 4
  170. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  171. lpre = lpre[vecl_target == 1]
  172. # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
  173. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
  174. def show_img(self, img_path):
  175. pass