dataset_LD.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # ??roi_head??????????????
  2. from torch.utils.data.dataset import T_co
  3. from libs.vision_libs.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
  4. from models.base.base_dataset import BaseDataset
  5. from torchvision.transforms import functional as F
  6. import glob
  7. import json
  8. import math
  9. import os
  10. import random
  11. import cv2
  12. import PIL
  13. import matplotlib.pyplot as plt
  14. import matplotlib as mpl
  15. from torchvision.utils import draw_bounding_boxes
  16. import numpy as np
  17. import numpy.linalg as LA
  18. import torch
  19. from skimage import io
  20. from torch.utils.data import Dataset
  21. from torch.utils.data.dataloader import default_collate
  22. import matplotlib.pyplot as plt
  23. from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  24. from tools.presets import DetectionPresetTrain
  25. def line_boxes1(target):
  26. boxs = []
  27. lines = target.cpu().numpy() * 4
  28. if len(lines) > 0 and not (lines[0] == 0).all():
  29. for i, ((a, b)) in enumerate(lines):
  30. if i > 0 and (lines[i] == lines[0]).all():
  31. break
  32. if a[1] > b[1]:
  33. ymax = a[1] + 10
  34. ymin = b[1] - 10
  35. else:
  36. ymin = a[1] - 10
  37. ymax = b[1] + 10
  38. if a[0] > b[0]:
  39. xmax = a[0] + 10
  40. xmin = b[0] - 10
  41. else:
  42. xmin = a[0] - 10
  43. xmax = b[0] + 10
  44. boxs.append([ymin, xmin, ymax, xmax])
  45. # if boxs == []:
  46. # print(target)
  47. return torch.tensor(boxs)
  48. class WirePointDataset(BaseDataset):
  49. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  50. super().__init__(dataset_path)
  51. self.data_path = dataset_path
  52. print(f'data_path:{dataset_path}')
  53. self.transforms = transforms
  54. self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
  55. self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
  56. self.imgs = os.listdir(self.img_path)
  57. self.lbls = os.listdir(self.lbl_path)
  58. self.target_type = target_type
  59. self.transform = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms()
  60. # self.default_transform = DefaultTransform()
  61. def __getitem__(self, index) -> T_co:
  62. img_path = os.path.join(self.img_path, self.imgs[index])
  63. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  64. # img = PIL.Image.open(img_path).convert('RGB')
  65. img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
  66. img_rgb=img[:,:,:3]
  67. print(f'img shape:{img.shape}')
  68. img_rgb=cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB)
  69. # img=np.array(img,copy=True)
  70. # img = self.default_transform(img)
  71. # print(f'pil img:{img.dtype}')
  72. # w, h = img.size
  73. # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
  74. # cv2.imshow('img',img)
  75. # cv2.waitKey(1000000)
  76. # print(img.shape)
  77. w, h = img.shape[0:2]
  78. # w, h = img.size
  79. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  80. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  81. print(f'self.default_transform:{self.default_transform}')
  82. # if self.transforms:
  83. # img, target = self.transforms(img, target)
  84. # else:
  85. # img = self.default_transform(img)
  86. # # 分离RGB和深度通道
  87. # rgb_channels = img[:, :, :3]
  88. # depth_channel = img[:, :, 3]
  89. #
  90. # rgb_normalized = rgb_channels/255
  91. # depth_normalized = (depth_channel - depth_channel.min()) / (depth_channel.max() - depth_channel.min())*255
  92. #
  93. # # 将归一化后的RGB和深度通道重新组合
  94. # normalized_rgba_image = np.dstack((rgb_normalized, depth_normalized)) # 或者使用depth_normalized_fixed_range
  95. #
  96. # print("Normalized RGBA image shape:", normalized_rgba_image.shape)
  97. #
  98. # img = torch.tensor(normalized_rgba_image,dtype=torch.uint8).permute(2,1,0)
  99. # cv2.imshow('img',img[:3].permute(1,2,0).numpy().astype(np.uint8))
  100. # cv2.waitKey(10000000)
  101. # plt.imshow(img[:3].permute(1,2,0).numpy())
  102. # plt.show()
  103. # new_channel = torch.zeros(1, 512, 512)
  104. # img=torch.cat((img,new_channel),dim=0)
  105. img=np.dstack((img_rgb,img[:,:,3]))
  106. img=torch.as_tensor(img).permute(2,0,1)
  107. img=self.default_transform(img)
  108. # img=F.convert_image_dtype(img, torch.float)
  109. print(f'img:{img.shape}')
  110. # print(f'img dtype:{img.dtype}')
  111. return img, target
  112. def __len__(self):
  113. return len(self.imgs)
  114. def read_target(self, item, lbl_path, shape, extra=None):
  115. # print(f'lbl_path:{lbl_path}')
  116. with open(lbl_path, 'r') as file:
  117. lable_all = json.load(file)
  118. n_stc_posl = 300
  119. n_stc_negl = 40
  120. use_cood = 0
  121. use_slop = 0
  122. wire = lable_all["wires"][0] # ??
  123. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # ?????????
  124. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  125. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  126. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # ??????????
  127. for i in range(len(lpre)):
  128. if random.random() > 0.5:
  129. lpre[i] = lpre[i, ::-1]
  130. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  131. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  132. feat = [
  133. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  134. ldir * use_slop,
  135. lpre[:, :, 2],
  136. ]
  137. feat = np.concatenate(feat, 1)
  138. wire_labels = {
  139. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  140. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  141. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  142. # ???????????
  143. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  144. # ??????????
  145. "lpre": torch.tensor(lpre)[:, :, :2],
  146. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # ?????? 1?0
  147. "lpre_feat": torch.from_numpy(feat),
  148. "junc_map": torch.tensor(wire['junc_map']["content"]),
  149. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  150. "line_map": torch.tensor(wire['line_map']["content"]),
  151. }
  152. labels = []
  153. #
  154. # if self.target_type == 'polygon':
  155. # labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  156. # elif self.target_type == 'pixel':
  157. # labels = read_masks_from_pixels_wire(lbl_path, shape)
  158. # print(torch.stack(masks).shape) # [???, 512, 512]
  159. target = {}
  160. # target["labels"] = torch.stack(labels)
  161. target["image_id"] = torch.tensor(item)
  162. # return wire_labels, target
  163. target["wires"] = wire_labels
  164. # target["boxes"] = line_boxes(target)
  165. target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
  166. target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
  167. # print(f'target["labels"]:{ target["labels"]}')
  168. # print(f'boxes:{target["boxes"].shape}')
  169. if target["boxes"].numel() == 0:
  170. print("Tensor is empty")
  171. print(f'path:{lbl_path}')
  172. return target
  173. def show(self, idx):
  174. image, target = self.__getitem__(idx)
  175. cmap = plt.get_cmap("jet")
  176. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  177. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  178. sm.set_array([])
  179. def imshow(im):
  180. plt.close()
  181. plt.tight_layout()
  182. plt.imshow(im)
  183. plt.colorbar(sm, fraction=0.046)
  184. plt.xlim([0, im.shape[0]])
  185. plt.ylim([im.shape[0], 0])
  186. def draw_vecl(lines, sline, juncs, junts, fn=None):
  187. img_path = os.path.join(self.img_path, self.imgs[idx])
  188. imshow(io.imread(img_path))
  189. if len(lines) > 0 and not (lines[0] == 0).all():
  190. for i, ((a, b), s) in enumerate(zip(lines, sline)):
  191. if i > 0 and (lines[i] == lines[0]).all():
  192. break
  193. plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]?????
  194. if not (juncs[0] == 0).all():
  195. for i, j in enumerate(juncs):
  196. if i > 0 and (i == juncs[0]).all():
  197. break
  198. plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # ? s=64
  199. img_path = os.path.join(self.img_path, self.imgs[idx])
  200. img = PIL.Image.open(img_path).convert('RGB')
  201. boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
  202. colors="yellow", width=1)
  203. plt.imshow(boxed_image.permute(1, 2, 0).numpy())
  204. plt.show()
  205. plt.show()
  206. if fn != None:
  207. plt.savefig(fn)
  208. junc = target['wires']['junc_coords'].cpu().numpy() * 4
  209. jtyp = target['wires']['jtyp'].cpu().numpy()
  210. juncs = junc[jtyp == 0]
  211. junts = junc[jtyp == 1]
  212. lpre = target['wires']["lpre"].cpu().numpy() * 4
  213. vecl_target = target['wires']["lpre_label"].cpu().numpy()
  214. lpre = lpre[vecl_target == 1]
  215. # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
  216. draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
  217. def show_img(self, img_path):
  218. pass
  219. # dataset_train = WirePointDataset("/data/lm/dataset/0424_", dataset_type='val')
  220. # for i in dataset_train:
  221. # a = 1