dataset_LD.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # 使用roi_head数据格式有要求,更改数据格式
  2. from torch.utils.data.dataset import T_co
  3. from models.base.base_dataset import BaseDataset
  4. import glob
  5. import json
  6. import math
  7. import os
  8. import random
  9. import cv2
  10. import PIL
  11. import matplotlib.pyplot as plt
  12. import matplotlib as mpl
  13. from torchvision.utils import draw_bounding_boxes
  14. import numpy as np
  15. import numpy.linalg as LA
  16. import torch
  17. from skimage import io
  18. from torch.utils.data import Dataset
  19. from torch.utils.data.dataloader import default_collate
  20. import matplotlib.pyplot as plt
  21. from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  22. from tools.presets import DetectionPresetTrain
  23. def line_boxes1(target):
  24. boxs = []
  25. lines = target.cpu().numpy() * 4
  26. if len(lines) > 0 :
  27. for i, ((a, b)) in enumerate(lines):
  28. if i > 0 and (lines[i] == lines[0]).all():
  29. break
  30. # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]ÎÞÃ÷È·´óС
  31. # if a[-1]==0. and b[-1]==0.:
  32. # continue
  33. # if a[:2].tolist() == [0., 0.] and b[:2].tolist() == [0., 0.]:
  34. # continue
  35. if a[1] > b[1]:
  36. ymax = a[1] + 10
  37. ymin = b[1] - 10
  38. else:
  39. ymin = a[1] - 10
  40. ymax = b[1] + 10
  41. if a[0] > b[0]:
  42. xmax = a[0] + 10
  43. xmin = b[0] - 10
  44. else:
  45. xmin = a[0] - 10
  46. xmax = b[0] + 10
  47. boxs.append([min(ymin,0), min(xmin,0), max(ymax,512), max(xmax, 0)])
  48. # print(f'box:{boxs}')
  49. # if boxs == []:
  50. # print(f'box:{boxs}')
  51. # print(f'target:{target}')
  52. return torch.tensor(boxs)
  53. class WirePointDataset(BaseDataset):
  54. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  55. super().__init__(dataset_path)
  56. self.data_path = dataset_path
  57. print(f'data_path:{dataset_path}')
  58. self.transforms = transforms
  59. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  60. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  61. self.imgs = os.listdir(self.img_path)
  62. self.lbls = os.listdir(self.lbl_path)
  63. self.target_type = target_type
  64. # self.default_transform = DefaultTransform()
  65. def __getitem__(self, index) -> T_co:
  66. img_path = os.path.join(self.img_path, self.imgs[index])
  67. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  68. img = PIL.Image.open(img_path).convert('RGB')
  69. w, h = img.size
  70. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  71. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  72. if self.transforms:
  73. img, target = self.transforms(img, target)
  74. else:
  75. img = self.default_transform(img)
  76. # print(f'img:{img.shape}')
  77. return img, target
  78. def __len__(self):
  79. return len(self.imgs)
  80. def read_target(self, item, lbl_path, shape, extra=None):
  81. # print(f'lbl_path:{lbl_path}')
  82. with open(lbl_path, 'r') as file:
  83. lable_all = json.load(file)
  84. n_stc_posl = 300
  85. n_stc_negl = 40
  86. use_cood = 0
  87. use_slop = 0
  88. wire = lable_all["wires"][0] # 字典
  89. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少
  90. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  91. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  92. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起
  93. for i in range(len(lpre)):
  94. if random.random() > 0.5:
  95. lpre[i] = lpre[i, ::-1]
  96. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  97. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  98. feat = [
  99. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  100. ldir * use_slop,
  101. lpre[:, :, 2],
  102. ]
  103. feat = np.concatenate(feat, 1)
  104. wire_labels = {
  105. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  106. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  107. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  108. # 真实存在线条的邻接矩阵
  109. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  110. # 不存在线条的临界矩阵
  111. "lpre": torch.tensor(lpre)[:, :, :2],
  112. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0
  113. "lpre_feat": torch.from_numpy(feat),
  114. "junc_map": torch.tensor(wire['junc_map']["content"]),
  115. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  116. "line_map": torch.tensor(wire['line_map']["content"]),
  117. }
  118. labels = []
  119. #
  120. # if self.target_type == 'polygon':
  121. # labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  122. # elif self.target_type == 'pixel':
  123. # labels = read_masks_from_pixels_wire(lbl_path, shape)
  124. # print(torch.stack(masks).shape) # [线段数, 512, 512]
  125. target = {}
  126. # target["labels"] = torch.stack(labels)
  127. target["image_id"] = torch.tensor(item)
  128. # return wire_labels, target
  129. target["wires"] = wire_labels
  130. # target["boxes"] = line_boxes(target)
  131. target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
  132. target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
  133. # print(f'target["labels"]:{ target["labels"]}')
  134. # if target["boxes"].shape == [0]:
  135. # print(f'box is null:{lbl_path}')
  136. # print(f'boxes:{target["boxes"].shape}')
  137. return target
  138. def show(self, idx):
  139. image, target = self.__getitem__(idx)
  140. cmap = plt.get_cmap("jet")
  141. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  142. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  143. sm.set_array([])
  144. def imshow(im):
  145. plt.close()
  146. plt.tight_layout()
  147. plt.imshow(im)
  148. plt.colorbar(sm, fraction=0.046)
  149. plt.xlim([0, im.shape[0]])
  150. plt.ylim([im.shape[0], 0])
  151. def draw_vecl(lines, sline, juncs, junts, fn=None):
  152. img_path = os.path.join(self.img_path, self.imgs[idx])
  153. imshow(io.imread(img_path))
  154. if len(lines) > 0 and not (lines[0] == 0).all():
  155. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  156. if i > 0 and (lines[i] == lines[0]).all():
  157. break
  158. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
  159. if not (juncs[0] == 0).all():
  160. for i, j in enumerate(juncs):
  161. if i > 0 and (i == juncs[0]).all():
  162. break
  163. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
  164. img_path = os.path.join(self.img_path, self.imgs[idx])
  165. img = PIL.Image.open(img_path).convert('RGB')
  166. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  167. colors="yellow", width=1)
  168. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  169. plt.show()
  170. plt.show()
  171. if fn != None:
  172. plt.savefig(fn)
  173. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  174. jtyp = target['wires']['jtyp'].cpu().numpy()
  175. juncs = junc[jtyp == 0]
  176. junts = junc[jtyp == 1]
  177. lpre = target['wires']["lpre"].cpu().numpy() * 4
  178. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  179. lpre = lpre[vecl_target == 1]
  180. # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
  181. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
  182. def show_img(self, img_path):
  183. pass
  184. # dataset_train = WirePointDataset(r"\\192.168.50.222\share\lm\04\424-转分好的zjf", dataset_type='train')
  185. # for i in dataset_train:
  186. # a = 1