from torch.utils.data.dataset import T_co from models.base.base_dataset import BaseDataset import glob import json import math import os import random import cv2 import PIL import numpy as np import numpy.linalg as LA import torch from skimage import io from torch.utils.data import Dataset from torch.utils.data.dataloader import default_collate import matplotlib.pyplot as plt from models.dataset_tool import masks_to_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix class WirePointDataset(BaseDataset): def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'): super().__init__(dataset_path) self.data_path = dataset_path print(f'data_path:{dataset_path}') self.transforms = transforms self.img_path = os.path.join(dataset_path, "images\\" + dataset_type) self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type) self.imgs = os.listdir(self.img_path) self.lbls = os.listdir(self.lbl_path) self.target_type = target_type # self.default_transform = DefaultTransform() def __getitem__(self, index) -> T_co: img_path = os.path.join(self.img_path, self.imgs[index]) lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json') img = PIL.Image.open(img_path).convert('RGB') w, h = img.size # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) if self.transforms: img, target = self.transforms(img, target) else: img = self.default_transform(img) # print(f'img:{img}') return img, target def __len__(self): return len(self.imgs) def read_target(self, item, lbl_path, shape, extra=None): # print(f'lbl_path:{lbl_path}') with open(lbl_path, 'r') as file: lable_all = json.load(file) n_stc_posl = 300 n_stc_negl = 40 use_cood = 0 use_slop = 0 wire = lable_all["wires"][0] # 字典 line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少 line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl] npos, nneg = len(line_pos_coords), len(line_neg_coords) lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起 for i in range(len(lpre)): if random.random() > 0.5: lpre[i] = lpre[i, ::-1] ldir = lpre[:, 0, :2] - lpre[:, 1, :2] ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None) feat = [ lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood, ldir * use_slop, lpre[:, :, 2], ] feat = np.concatenate(feat, 1) wire_labels = { "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2], "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(), "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]), # 真实存在线条的邻接矩阵 "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]), # 不存在线条的临界矩阵 "lpre": torch.tensor(lpre)[:, :, :2], "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0 "lpre_feat": torch.from_numpy(feat), "junc_map": torch.tensor(wire['junc_map']["content"]), "junc_offset": torch.tensor(wire['junc_offset']["content"]), "line_map": torch.tensor(wire['line_map']["content"]), } h, w = shape labels = [] masks = [] if self.target_type == 'polygon': labels, masks = read_masks_from_txt_wire(lbl_path, shape) elif self.target_type == 'pixel': labels, masks = read_masks_from_pixels_wire(lbl_path, shape) target = {} target["boxes"] = masks_to_boxes(torch.stack(masks)) target["labels"] = torch.stack(labels) target["masks"] = torch.stack(masks) target["image_id"] = torch.tensor(item) # return wire_labels, target target["wires"] = wire_labels return target def show(self, idx): img_path = os.path.join(self.img_path, self.imgs[idx]) lbl_path = os.path.join(self.lbl_path, self.imgs[idx][:-3] + 'json') with open(lbl_path, 'r') as file: lable_all = json.load(file) # 可视化图像和标注 image = cv2.imread(img_path) # [H,W,3] # 默认为BGR格式 # print(image.shape) # 绘制每个标注的多边形 # for ann in lable_all["segmentations"]: # segmentation = [[x * 512 for x in ann['data']]] # # segmentation = [ann['data']] # # for i in range(len(ann['data'])): # # if i % 2 == 0: # # segmentation[0][i] *= image.shape[0] # # else: # # segmentation[0][i] *= image.shape[0] # # # if isinstance(segmentation, list): # # for seg in segmentation: # # poly = np.array(seg).reshape((-1, 2)).astype(int) # # cv2.polylines(image, [poly], isClosed=True, color=(0, 255, 0), thickness=2) # # cv2.fillPoly(image, [poly], color=(0, 255, 0)) # # # 显示图像 # cv2.namedWindow('Image with Segmentations', cv2.WINDOW_NORMAL) # cv2.imshow('Image with Segmentations', image) # cv2.waitKey(0) # cv2.destroyAllWindows() def show_img(self,img_path): pass