| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- from torch.utils.data.dataset import T_co
- from libs.vision_libs.utils import draw_keypoints
- from models.base.base_dataset import BaseDataset
- import glob
- import json
- import math
- import os
- import random
- import cv2
- import PIL
- import matplotlib.pyplot as plt
- import matplotlib as mpl
- from torchvision.utils import draw_bounding_boxes
- import numpy as np
- import numpy.linalg as LA
- import torch
- from skimage import io
- from torch.utils.data import Dataset
- from torch.utils.data.dataloader import default_collate
- import matplotlib.pyplot as plt
- from models.dataset_tool import read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
- def validate_keypoints(keypoints, image_width, image_height):
- for kp in keypoints:
- x, y, v = kp
- if not (0 <= x < image_width and 0 <= y < image_height):
- raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
- class KeypointDataset(BaseDataset):
- def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
- super().__init__(dataset_path)
- self.data_path = dataset_path
- print(f'data_path:{dataset_path}')
- self.transforms = transforms
- self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
- self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
- self.imgs = os.listdir(self.img_path)
- self.lbls = os.listdir(self.lbl_path)
- self.target_type = target_type
- # self.default_transform = DefaultTransform()
- def __getitem__(self, index) -> T_co:
- img_path = os.path.join(self.img_path, self.imgs[index])
- lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
- img = PIL.Image.open(img_path).convert('RGB')
- w, h = img.size
- # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
- target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
- if self.transforms:
- img, target = self.transforms(img, target)
- else:
- img = self.default_transform(img)
- # print(f'img:{img}')
- return img, target
- def __len__(self):
- return len(self.imgs)
- def read_target(self, item, lbl_path, shape, extra=None):
- # print(f'shape:{shape}')
- # print(f'lbl_path:{lbl_path}')
- with open(lbl_path, 'r') as file:
- lable_all = json.load(file)
- n_stc_posl = 300
- n_stc_negl = 40
- use_cood = 0
- use_slop = 0
- wire = lable_all["wires"][0] # 字典
- line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少
- line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
- npos, nneg = len(line_pos_coords), len(line_neg_coords)
- lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起
- for i in range(len(lpre)):
- if random.random() > 0.5:
- lpre[i] = lpre[i, ::-1]
- ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
- ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
- feat = [
- lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
- ldir * use_slop,
- lpre[:, :, 2],
- ]
- feat = np.concatenate(feat, 1)
- wire_labels = {
- "junc_coords": torch.tensor(wire["junc_coords"]["content"]),
- "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
- "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
- # 真实存在线条的邻接矩阵
- "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
- "lpre": torch.tensor(lpre)[:, :, :2],
- "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0
- "lpre_feat": torch.from_numpy(feat),
- "junc_map": torch.tensor(wire['junc_map']["content"]),
- "junc_offset": torch.tensor(wire['junc_offset']["content"]),
- "line_map": torch.tensor(wire['line_map']["content"]),
- }
- labels = []
- if self.target_type == 'polygon':
- labels, masks = read_masks_from_txt_wire(lbl_path, shape)
- elif self.target_type == 'pixel':
- labels = read_masks_from_pixels_wire(lbl_path, shape)
- # print(torch.stack(masks).shape) # [线段数, 512, 512]
- target = {}
- target["image_id"] = torch.tensor(item)
- # return wire_labels, target
- target["wires"] = wire_labels
- # target["labels"] = torch.stack(labels)
- # print(f'labels:{target["labels"]}')
- # target["boxes"] = line_boxes(target)
- target["boxes"], keypoints = line_boxes(target)
- target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
- # keypoints=keypoints/512
- # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
- # keypoints= wire_labels["junc_coords"]
- a = torch.full((keypoints.shape[0],), 2).unsqueeze(1)
- keypoints = torch.cat((keypoints, a), dim=1)
- target["keypoints"] = keypoints.to(torch.float32).view(-1,2,3)
- # print(f'boxes:{target["boxes"].shape}')
- # 在 __getitem__ 方法中调用此函数
- validate_keypoints(keypoints, shape[0], shape[1])
- # print(f'keypoints:{target["keypoints"].shape}')
- return target
- def show(self, idx):
- image, target = self.__getitem__(idx)
- cmap = plt.get_cmap("jet")
- norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
- sm.set_array([])
- img_path = os.path.join(self.img_path, self.imgs[idx])
- img = PIL.Image.open(img_path).convert('RGB')
- boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
- colors="yellow", width=1)
- keypoint_img=draw_keypoints(boxed_image,target['keypoints'],colors='red',width=3)
- plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
- plt.show()
- def show_img(self, img_path):
- pass
- def line_boxes(target):
- boxs = []
- lpre = target['wires']["lpre"].cpu().numpy()
- vecl_target = target['wires']["lpre_label"].cpu().numpy()
- lpre = lpre[vecl_target == 1]
- lines = lpre
- sline = np.ones(lpre.shape[0])
- keypoints = []
- if len(lines) > 0 and not (lines[0] == 0).all():
- for i, ((a, b), s) in enumerate(zip(lines, sline)):
- if i > 0 and (lines[i] == lines[0]).all():
- break
- # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
- keypoints.append([a[1], a[0]])
- keypoints.append([b[1], b[0]])
- if a[1] > b[1]:
- ymax = a[1] + 1
- ymin = b[1] - 1
- else:
- ymin = a[1] - 1
- ymax = b[1] + 1
- if a[0] > b[0]:
- xmax = a[0] + 1
- xmin = b[0] - 1
- else:
- xmin = a[0] - 1
- xmax = b[0] + 1
- boxs.append([ymin, xmin, ymax, xmax])
- return torch.tensor(boxs), torch.tensor(keypoints)
- if __name__ == '__main__':
- path=r"\\192.168.50.222/share/lm/Dataset_all"
- dataset= KeypointDataset(dataset_path=path, dataset_type='train')
- dataset.show(10)
|