from torch.utils.data.dataset import T_co from libs.vision_libs.utils import draw_keypoints from models.base.base_dataset import BaseDataset import glob import json import math import os import random import cv2 import PIL import imageio import matplotlib.pyplot as plt import matplotlib as mpl from torchvision.utils import draw_bounding_boxes import numpy as np import numpy.linalg as LA import torch from skimage import io from torch.utils.data import Dataset from torch.utils.data.dataloader import default_collate import matplotlib.pyplot as plt from models.dataset_tool import read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix def validate_keypoints(keypoints, image_width, image_height): for kp in keypoints: x, y, v = kp if not (0 <= x < image_width and 0 <= y < image_height): raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})") class LineDataset(BaseDataset): def __init__(self, dataset_path, data_type, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'): super().__init__(dataset_path) self.data_path = dataset_path self.data_type = data_type print(f'data_path:{dataset_path}') self.transforms = transforms self.img_path = os.path.join(dataset_path, "images/" + dataset_type) self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type) self.imgs = os.listdir(self.img_path) self.lbls = os.listdir(self.lbl_path) self.target_type = target_type self.img_type=img_type # self.default_transform = DefaultTransform() def __getitem__(self, index) -> T_co: img_path = os.path.join(self.img_path, self.imgs[index]) if self.data_type == 'tiff': lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json') # img = imageio.v3.imread(img_path).reshape(512, 512, 1) img = imageio.v3.imread(img_path)[:, :, :3] # img_3channel = np.zeros((512, 512, 3), dtype=img.dtype) # img_3channel[:, :, 2] = img[:, :, 0] img_3channel=img w, h = img.shape[:2] img = torch.from_numpy(img_3channel).permute(2, 0, 1) else: lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json') img = PIL.Image.open(img_path).convert('RGB') w, h = img.size # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) if self.transforms: img, target = self.transforms(img, target) else: img = self.default_transform(img) # print(f'img:{img}') return img, target def __len__(self): return len(self.imgs) def read_target(self, item, lbl_path, shape, extra=None): # print(f'shape:{shape}') # print(f'lbl_path:{lbl_path}') with open(lbl_path, 'r') as file: lable_all = json.load(file) n_stc_posl = 300 n_stc_negl = 40 use_cood = 0 use_slop = 0 wire = lable_all["wires"][0] # 字典 line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少 line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl] npos, nneg = len(line_pos_coords), len(line_neg_coords) lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起 for i in range(len(lpre)): if random.random() > 0.5: lpre[i] = lpre[i, ::-1] ldir = lpre[:, 0, :2] - lpre[:, 1, :2] ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None) feat = [ lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood, ldir * use_slop, lpre[:, :, 2], ] feat = np.concatenate(feat, 1) wire_labels = { "junc_coords": torch.tensor(wire["junc_coords"]["content"]), "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(), "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]), # 真实存在线条的邻接矩阵 "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]), "lpre": torch.tensor(lpre)[:, :, :2], "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0 "lpre_feat": torch.from_numpy(feat), "junc_map": torch.tensor(wire['junc_map']["content"]), "junc_offset": torch.tensor(wire['junc_offset']["content"]), "line_map": torch.tensor(wire['line_map']["content"]), } labels = [] if self.target_type == 'polygon': labels, masks = read_masks_from_txt_wire(lbl_path, shape) elif self.target_type == 'pixel': labels = read_masks_from_pixels_wire(lbl_path, shape) # print(torch.stack(masks).shape) # [线段数, 512, 512] target = {} target["image_id"] = torch.tensor(item) # return wire_labels, target target["wires"] = wire_labels # target["labels"] = torch.stack(labels) # print(f'labels:{target["labels"]}') # target["boxes"] = line_boxes(target) target["boxes"], lines = get_boxes_lines(target) target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64) # keypoints=keypoints/512 # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1)) # keypoints= wire_labels["junc_coords"] a = torch.full((lines.shape[0],), 2).unsqueeze(1) lines = torch.cat((lines, a), dim=1) target["lines"] = lines.to(torch.float32).view(-1,2,3) # print(f'boxes:{target["boxes"].shape}') # 在 __getitem__ 方法中调用此函数 validate_keypoints(lines, shape[0], shape[1]) # print(f'keypoints:{target["keypoints"].shape}') # print(f'target:{target}') return target def show(self, idx): image, target = self.__getitem__(idx) cmap = plt.get_cmap("jet") norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) img_path = os.path.join(self.img_path, self.imgs[idx]) img = PIL.Image.open(img_path).convert('RGB') boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"], colors="yellow", width=1) keypoint_img=draw_keypoints(boxed_image,target['keypoints'],colors='red',width=3) plt.imshow(keypoint_img.permute(1, 2, 0).numpy()) plt.show() def show_img(self, img_path): pass def get_boxes_lines(target): boxs = [] lpre = target['wires']["lpre"].cpu().numpy() vecl_target = target['wires']["lpre_label"].cpu().numpy() lpre = lpre[vecl_target == 1] lines = lpre sline = np.ones(lpre.shape[0]) line_point_pairs = [] if len(lines) > 0 and not (lines[0] == 0).all(): for i, ((a, b), s) in enumerate(zip(lines, sline)): if i > 0 and (lines[i] == lines[0]).all(): break # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小 line_point_pairs.append([a[1], a[0]]) line_point_pairs.append([b[1], b[0]]) xmin = max(0, (min(a[0], b[0]) - 6)) xmax = min(511, (max(a[0], b[0]) + 6)) ymin = max(0, (min(a[1], b[1]) - 6)) ymax = min(511, (max(a[1], b[1]) + 6)) boxs.append([ymin, xmin, ymax, xmax]) return torch.tensor(boxs), torch.tensor(line_point_pairs) if __name__ == '__main__': path=r"\\192.168.50.222/share/lm/Dataset_all" dataset= LineDataset(dataset_path=path, dataset_type='train') dataset.show(10)