# 使用roi_head数据格式有要求,更改数据格式 from torch.utils.data.dataset import T_co from models.base.base_dataset import BaseDataset import glob import json import math import os import random import cv2 import PIL import matplotlib.pyplot as plt import matplotlib as mpl from torchvision.utils import draw_bounding_boxes import numpy as np import numpy.linalg as LA import torch from skimage import io from torch.utils.data import Dataset from torch.utils.data.dataloader import default_collate import matplotlib.pyplot as plt from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix from tools.presets import DetectionPresetTrain class WirePointDataset(BaseDataset): def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'): super().__init__(dataset_path) self.data_path = dataset_path print(f'data_path:{dataset_path}') self.transforms = transforms self.img_path = os.path.join(dataset_path, "images/" + dataset_type) self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type) self.imgs = os.listdir(self.img_path) self.lbls = os.listdir(self.lbl_path) self.target_type = target_type # self.default_transform = DefaultTransform() def __getitem__(self, index) -> T_co: img_path = os.path.join(self.img_path, self.imgs[index]) lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json') img = PIL.Image.open(img_path).convert('RGB') w, h = img.size # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) if self.transforms: img, target = self.transforms(img, target) else: img = self.default_transform(img) # print(f'img:{img}') return img, target def __len__(self): return len(self.imgs) def read_target(self, item, lbl_path, shape, extra=None): # print(f'lbl_path:{lbl_path}') with open(lbl_path, 'r') as file: lable_all = json.load(file) n_stc_posl = 300 n_stc_negl = 40 use_cood = 0 use_slop = 0 wire = lable_all["wires"][0] # 字典 line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少 line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl] npos, nneg = len(line_pos_coords), len(line_neg_coords) lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起 for i in range(len(lpre)): if random.random() > 0.5: lpre[i] = lpre[i, ::-1] ldir = lpre[:, 0, :2] - lpre[:, 1, :2] ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None) feat = [ lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood, ldir * use_slop, lpre[:, :, 2], ] feat = np.concatenate(feat, 1) wire_labels = { "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2], "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(), "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]), # 真实存在线条的邻接矩阵 "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]), # 不存在线条的临界矩阵 "lpre": torch.tensor(lpre)[:, :, :2], "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0 "lpre_feat": torch.from_numpy(feat), "junc_map": torch.tensor(wire['junc_map']["content"]), "junc_offset": torch.tensor(wire['junc_offset']["content"]), "line_map": torch.tensor(wire['line_map']["content"]), } labels = [] # # if self.target_type == 'polygon': # labels, masks = read_masks_from_txt_wire(lbl_path, shape) # elif self.target_type == 'pixel': # labels = read_masks_from_pixels_wire(lbl_path, shape) # print(torch.stack(masks).shape) # [线段数, 512, 512] target = {} # target["labels"] = torch.stack(labels) target["image_id"] = torch.tensor(item) # return wire_labels, target target["wires"] = wire_labels target["boxes"] = line_boxes(target) target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64) # print(f'target["labels"]:{ target["labels"]}') # print(f'boxes:{target["boxes"].shape}') return target def show(self, idx): image, target = self.__getitem__(idx) cmap = plt.get_cmap("jet") norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) def imshow(im): plt.close() plt.tight_layout() plt.imshow(im) plt.colorbar(sm, fraction=0.046) plt.xlim([0, im.shape[0]]) plt.ylim([im.shape[0], 0]) def draw_vecl(lines, sline, juncs, junts, fn=None): img_path = os.path.join(self.img_path, self.imgs[idx]) imshow(io.imread(img_path)) if len(lines) > 0 and not (lines[0] == 0).all(): for i, ((a, b), s) in enumerate(zip(lines, sline)): if i > 0 and (lines[i] == lines[0]).all(): break plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小 if not (juncs[0] == 0).all(): for i, j in enumerate(juncs): if i > 0 and (i == juncs[0]).all(): break plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64 img_path = os.path.join(self.img_path, self.imgs[idx]) img = PIL.Image.open(img_path).convert('RGB') boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"], colors="yellow", width=1) plt.imshow(boxed_image.permute(1, 2, 0).numpy()) plt.show() plt.show() if fn != None: plt.savefig(fn) junc = target['wires']['junc_coords'].cpu().numpy() * 4 jtyp = target['wires']['jtyp'].cpu().numpy() juncs = junc[jtyp == 0] junts = junc[jtyp == 1] lpre = target['wires']["lpre"].cpu().numpy() * 4 vecl_target = target['wires']["lpre_label"].cpu().numpy() lpre = lpre[vecl_target == 1] # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path) draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts) def show_img(self, img_path): pass