wirepoint_dataset.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from torch.utils.data.dataset import T_co
  2. from models.base.base_dataset import BaseDataset
  3. import glob
  4. import json
  5. import math
  6. import os
  7. import random
  8. import cv2
  9. import PIL
  10. import numpy as np
  11. import numpy.linalg as LA
  12. import torch
  13. from skimage import io
  14. from torch.utils.data import Dataset
  15. from torch.utils.data.dataloader import default_collate
  16. import matplotlib.pyplot as plt
  17. from models.dataset_tool import masks_to_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
  18. class WirePointDataset(BaseDataset):
  19. def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
  20. super().__init__(dataset_path)
  21. self.data_path = dataset_path
  22. print(f'data_path:{dataset_path}')
  23. self.transforms = transforms
  24. self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
  25. self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
  26. self.imgs = os.listdir(self.img_path)
  27. self.lbls = os.listdir(self.lbl_path)
  28. self.target_type = target_type
  29. # self.default_transform = DefaultTransform()
  30. def __getitem__(self, index) -> T_co:
  31. img_path = os.path.join(self.img_path, self.imgs[index])
  32. lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
  33. img = PIL.Image.open(img_path).convert('RGB')
  34. w, h = img.size
  35. # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  36. target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
  37. if self.transforms:
  38. img, target = self.transforms(img, target)
  39. else:
  40. img = self.default_transform(img)
  41. # print(f'img:{img}')
  42. return img, target
  43. def __len__(self):
  44. return len(self.imgs)
  45. def read_target(self, item, lbl_path, shape, extra=None):
  46. # print(f'lbl_path:{lbl_path}')
  47. with open(lbl_path, 'r') as file:
  48. lable_all = json.load(file)
  49. n_stc_posl = 300
  50. n_stc_negl = 40
  51. use_cood = 0
  52. use_slop = 0
  53. wire = lable_all["wires"][0] # 字典
  54. line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少
  55. line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
  56. npos, nneg = len(line_pos_coords), len(line_neg_coords)
  57. lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起
  58. for i in range(len(lpre)):
  59. if random.random() > 0.5:
  60. lpre[i] = lpre[i, ::-1]
  61. ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
  62. ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
  63. feat = [
  64. lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
  65. ldir * use_slop,
  66. lpre[:, :, 2],
  67. ]
  68. feat = np.concatenate(feat, 1)
  69. wire_labels = {
  70. "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
  71. "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
  72. "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
  73. # 真实存在线条的邻接矩阵
  74. "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
  75. # 不存在线条的临界矩阵
  76. "lpre": torch.tensor(lpre)[:, :, :2],
  77. "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0
  78. "lpre_feat": torch.from_numpy(feat),
  79. "junc_map": torch.tensor(wire['junc_map']["content"]),
  80. "junc_offset": torch.tensor(wire['junc_offset']["content"]),
  81. "line_map": torch.tensor(wire['line_map']["content"]),
  82. }
  83. h, w = shape
  84. labels = []
  85. masks = []
  86. if self.target_type == 'polygon':
  87. labels, masks = read_masks_from_txt_wire(lbl_path, shape)
  88. elif self.target_type == 'pixel':
  89. labels, masks = read_masks_from_pixels_wire(lbl_path, shape)
  90. target = {}
  91. target["boxes"] = masks_to_boxes(torch.stack(masks))
  92. target["labels"] = torch.stack(labels)
  93. target["masks"] = torch.stack(masks)
  94. target["image_id"] = torch.tensor(item)
  95. # return wire_labels, target
  96. target["wires"] = wire_labels
  97. return target
  98. def show(self, idx):
  99. img_path = os.path.join(self.img_path, self.imgs[idx])
  100. lbl_path = os.path.join(self.lbl_path, self.imgs[idx][:-3] + 'json')
  101. with open(lbl_path, 'r') as file:
  102. lable_all = json.load(file)
  103. # 可视化图像和标注
  104. image = cv2.imread(img_path) # [H,W,3] # 默认为BGR格式
  105. # print(image.shape)
  106. # 绘制每个标注的多边形
  107. # for ann in lable_all["segmentations"]:
  108. # segmentation = [[x * 512 for x in ann['data']]]
  109. # # segmentation = [ann['data']]
  110. # # for i in range(len(ann['data'])):
  111. # # if i % 2 == 0:
  112. # # segmentation[0][i] *= image.shape[0]
  113. # # else:
  114. # # segmentation[0][i] *= image.shape[0]
  115. #
  116. # # if isinstance(segmentation, list):
  117. # # for seg in segmentation:
  118. # # poly = np.array(seg).reshape((-1, 2)).astype(int)
  119. # # cv2.polylines(image, [poly], isClosed=True, color=(0, 255, 0), thickness=2)
  120. # # cv2.fillPoly(image, [poly], color=(0, 255, 0))
  121. #
  122. # # 显示图像
  123. # cv2.namedWindow('Image with Segmentations', cv2.WINDOW_NORMAL)
  124. # cv2.imshow('Image with Segmentations', image)
  125. # cv2.waitKey(0)
  126. # cv2.destroyAllWindows()
  127. def show_img(self,img_path):
  128. pass