<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py ======== # import glob # import json # import math # import os # import random # # import numpy as np # import numpy.linalg as LA # import torch # from skimage import io # from torch.utils.data import Dataset # from torch.utils.data.dataloader import default_collate # # from lcnn.config import M # # from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire # # # class WireframeDataset(Dataset): # def __init__(self, rootdir, split): # self.rootdir = rootdir # filelist = glob.glob(f"{rootdir}/{split}/*_label.npz") # filelist.sort() # # # print(f"n{split}:", len(filelist)) # self.split = split # self.filelist = filelist # # def __len__(self): # return len(self.filelist) # # def __getitem__(self, idx): # iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png" # image = io.imread(iname).astype(float)[:, :, :3] # if "a1" in self.filelist[idx]: # image = image[:, ::-1, :] # image = (image - M.image.mean) / M.image.stddev # image = np.rollaxis(image, 2).copy() # # with np.load(self.filelist[idx]) as npz: # target = { # name: torch.from_numpy(npz[name]).float() # for name in ["jmap", "joff", "lmap"] # } # lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl] # lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl] # npos, nneg = len(lpos), len(lneg) # lpre = np.concatenate([lpos, lneg], 0) # for i in range(len(lpre)): # if random.random() > 0.5: # lpre[i] = lpre[i, ::-1] # ldir = lpre[:, 0, :2] - lpre[:, 1, :2] # ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None) # feat = [ # lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood, # ldir * M.use_slop, # lpre[:, :, 2], # ] # feat = np.concatenate(feat, 1) # meta = { # "junc": torch.from_numpy(npz["junc"][:, :2]), # "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(), # "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]), # "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]), # "lpre": torch.from_numpy(lpre[:, :, :2]), # "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # "lpre_feat": torch.from_numpy(feat), # } # # labels = [] # labels = read_masks_from_pixels_wire(iname, (512, 512)) # # if self.target_type == 'polygon': # # labels, masks = read_masks_from_txt_wire(iname, (512, 512)) # # elif self.target_type == 'pixel': # # labels = read_masks_from_pixels_wire(iname, (512, 512)) # # target["labels"] = torch.stack(labels) # target["boxes"] = line_boxes_faster(meta) # # # return torch.from_numpy(image).float(), meta, target # # def adjacency_matrix(self, n, link): # mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8) # link = torch.from_numpy(link) # if len(link) > 0: # mat[link[:, 0], link[:, 1]] = 1 # mat[link[:, 1], link[:, 0]] = 1 # return mat # # # def collate(batch): # return ( # default_collate([b[0] for b in batch]), # [b[1] for b in batch], # default_collate([b[2] for b in batch]), # ) # 原LCNN数据格式,改了属性名,加了box相关 >>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py from torch.utils.data.dataset import T_co from models.base.base_dataset import BaseDataset import glob import json import math import os import random import cv2 import PIL import matplotlib.pyplot as plt import matplotlib as mpl from torchvision.utils import draw_bounding_boxes import numpy as np import numpy.linalg as LA import torch from skimage import io from torch.utils.data import Dataset from torch.utils.data.dataloader import default_collate import matplotlib.pyplot as plt from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix def validate_keypoints(keypoints, image_width, image_height): for kp in keypoints: x, y, v = kp if not (0 <= x < image_width and 0 <= y < image_height): raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})") <<<<<<<< HEAD:models/keypoint/keypoint_dataset.py class KeypointDataset(BaseDataset): ======== class WireframeDataset(BaseDataset): >>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'): super().__init__(dataset_path) self.data_path = dataset_path print(f'data_path:{dataset_path}') self.transforms = transforms self.img_path = os.path.join(dataset_path, "images/" + dataset_type) self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type) self.imgs = os.listdir(self.img_path) self.lbls = os.listdir(self.lbl_path) self.target_type = target_type # self.default_transform = DefaultTransform() def __getitem__(self, index) -> T_co: img_path = os.path.join(self.img_path, self.imgs[index]) lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json') img = PIL.Image.open(img_path).convert('RGB') w, h = img.size # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) if self.transforms: img, target = self.transforms(img, target) else: img = self.default_transform(img) # print(f'img:{img}') return img, target def __len__(self): return len(self.imgs) def read_target(self, item, lbl_path, shape, extra=None): # print(f'shape:{shape}') # print(f'lbl_path:{lbl_path}') with open(lbl_path, 'r') as file: lable_all = json.load(file) n_stc_posl = 300 n_stc_negl = 40 use_cood = 0 use_slop = 0 wire = lable_all["wires"][0] # 字典 line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少 line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl] npos, nneg = len(line_pos_coords), len(line_neg_coords) lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起 for i in range(len(lpre)): if random.random() > 0.5: lpre[i] = lpre[i, ::-1] ldir = lpre[:, 0, :2] - lpre[:, 1, :2] ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None) feat = [ lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood, ldir * use_slop, lpre[:, :, 2], ] feat = np.concatenate(feat, 1) wire_labels = { "junc_coords": torch.tensor(wire["junc_coords"]["content"]), "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(), "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]), # 真实存在线条的邻接矩阵 "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]), "lpre": torch.tensor(lpre)[:, :, :2], "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0 "lpre_feat": torch.from_numpy(feat), "junc_map": torch.tensor(wire['junc_map']["content"]), "junc_offset": torch.tensor(wire['junc_offset']["content"]), "line_map": torch.tensor(wire['line_map']["content"]), } labels = [] if self.target_type == 'polygon': labels, masks = read_masks_from_txt_wire(lbl_path, shape) elif self.target_type == 'pixel': labels = read_masks_from_pixels_wire(lbl_path, shape) # print(torch.stack(masks).shape) # [线段数, 512, 512] target = {} target["image_id"] = torch.tensor(item) # return wire_labels, target target["wires"] = wire_labels target["labels"] = torch.stack(labels) # print(f'labels:{target["labels"]}') # target["boxes"] = line_boxes(target) target["boxes"], keypoints = line_boxes(target) # keypoints=keypoints/512 # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1)) # keypoints= wire_labels["junc_coords"] a = torch.full((keypoints.shape[0],), 2).unsqueeze(1) keypoints = torch.cat((keypoints, a), dim=1) target["keypoints"] = keypoints.to(torch.float32).view(-1,2,3) # print(f'boxes:{target["boxes"].shape}') # 在 __getitem__ 方法中调用此函数 validate_keypoints(keypoints, shape[0], shape[1]) # print(f'keypoints:{target["keypoints"].shape}') return target def show(self, idx): image, target = self.__getitem__(idx) cmap = plt.get_cmap("jet") norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) def imshow(im): plt.close() plt.tight_layout() plt.imshow(im) plt.colorbar(sm, fraction=0.046) plt.xlim([0, im.shape[0]]) plt.ylim([im.shape[0], 0]) def draw_vecl(lines, sline, juncs, junts, fn=None): img_path = os.path.join(self.img_path, self.imgs[idx]) imshow(io.imread(img_path)) if len(lines) > 0 and not (lines[0] == 0).all(): for i, ((a, b), s) in enumerate(zip(lines, sline)): if i > 0 and (lines[i] == lines[0]).all(): break plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小 if not (juncs[0] == 0).all(): for i, j in enumerate(juncs): if i > 0 and (i == juncs[0]).all(): break plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64 img_path = os.path.join(self.img_path, self.imgs[idx]) img = PIL.Image.open(img_path).convert('RGB') boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"], colors="yellow", width=1) plt.imshow(boxed_image.permute(1, 2, 0).numpy()) plt.show() plt.show() if fn != None: plt.savefig(fn) junc = target['wires']['junc_coords'].cpu().numpy() * 4 jtyp = target['wires']['jtyp'].cpu().numpy() juncs = junc[jtyp == 0] junts = junc[jtyp == 1] lpre = target['wires']["lpre"].cpu().numpy() * 4 vecl_target = target['wires']["lpre_label"].cpu().numpy() lpre = lpre[vecl_target == 1] # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path) draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts) def show_img(self, img_path): pass <<<<<<<< HEAD:models/keypoint/keypoint_dataset.py if __name__ == '__main__': path=r"I:\datasets\wirenet_1000" dataset= KeypointDataset(dataset_path=path, dataset_type='train') dataset.show(7) ======== ''' # 使用roi_head数据格式有要求,更改数据格式 from torch.utils.data.dataset import T_co from models.base.base_dataset import BaseDataset import glob import json import math import os import random import cv2 import PIL import matplotlib.pyplot as plt import matplotlib as mpl from torchvision.utils import draw_bounding_boxes import numpy as np import numpy.linalg as LA import torch from skimage import io from torch.utils.data import Dataset from torch.utils.data.dataloader import default_collate import matplotlib.pyplot as plt from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix from tools.presets import DetectionPresetTrain class WireframeDataset(BaseDataset): def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'): super().__init__(dataset_path) self.data_path = dataset_path print(f'data_path:{dataset_path}') self.transforms = transforms self.img_path = os.path.join(dataset_path, "images\\" + dataset_type) self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type) self.imgs = os.listdir(self.img_path) self.lbls = os.listdir(self.lbl_path) self.target_type = target_type # self.default_transform = DefaultTransform() self.data_augmentation = DetectionPresetTrain(data_augmentation="hflip") # multiscale会改变图像大小 def __getitem__(self, index) -> T_co: img_path = os.path.join(self.img_path, self.imgs[index]) lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json') img = PIL.Image.open(img_path).convert('RGB') w, h = img.size # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) # if self.transforms: # img, target = self.transforms(img, target) # else: # img = self.default_transform(img) img, target = self.data_augmentation(img, target) print(f'img:{img.shape}') return img, target def __len__(self): return len(self.imgs) def read_target(self, item, lbl_path, shape, extra=None): # print(f'lbl_path:{lbl_path}') with open(lbl_path, 'r') as file: lable_all = json.load(file) n_stc_posl = 300 n_stc_negl = 40 use_cood = 0 use_slop = 0 wire = lable_all["wires"][0] # 字典 line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少 line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl] npos, nneg = len(line_pos_coords), len(line_neg_coords) lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起 for i in range(len(lpre)): if random.random() > 0.5: lpre[i] = lpre[i, ::-1] ldir = lpre[:, 0, :2] - lpre[:, 1, :2] ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None) feat = [ lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood, ldir * use_slop, lpre[:, :, 2], ] feat = np.concatenate(feat, 1) wire_labels = { "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2], "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(), "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]), # 真实存在线条的邻接矩阵 "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]), # 不存在线条的临界矩阵 "lpre": torch.tensor(lpre)[:, :, :2], "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0 "lpre_feat": torch.from_numpy(feat), "junc_map": torch.tensor(wire['junc_map']["content"]), "junc_offset": torch.tensor(wire['junc_offset']["content"]), "line_map": torch.tensor(wire['line_map']["content"]), } labels = [] # if self.target_type == 'polygon': # labels, masks = read_masks_from_txt_wire(lbl_path, shape) # elif self.target_type == 'pixel': # labels = read_masks_from_pixels_wire(lbl_path, shape) # print(torch.stack(masks).shape) # [线段数, 512, 512] target = {} # target["labels"] = torch.stack(labels) target["image_id"] = torch.tensor(item) # return wire_labels, target target["wires"] = wire_labels target["boxes"] = line_boxes(target) return target def show(self, idx): image, target = self.__getitem__(idx) img_path = os.path.join(self.img_path, self.imgs[idx]) self._draw_vecl(img_path, target) def show_img(self, img_path): """根据给定的图片路径展示图像及其标注信息""" # 获取对应的标签文件路径 img_name = os.path.basename(img_path) img_path = os.path.join(self.img_path, img_name) print(img_path) lbl_name = img_name[:-3] + 'json' lbl_path = os.path.join(self.lbl_path, lbl_name) print(lbl_path) if not os.path.exists(lbl_path): raise FileNotFoundError(f"Label file {lbl_path} does not exist.") img = PIL.Image.open(img_path).convert('RGB') w, h = img.size target = self.read_target(0, lbl_path, shape=(h, w)) # 调用绘图函数 self._draw_vecl(img_path, target) def _draw_vecl(self, img_path, target, fn=None): cmap = plt.get_cmap("jet") norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) def imshow(im): plt.close() plt.tight_layout() plt.imshow(im) plt.colorbar(sm, fraction=0.046) plt.xlim([0, im.shape[0]]) plt.ylim([im.shape[0], 0]) junc = target['wires']['junc_coords'].cpu().numpy() * 4 jtyp = target['wires']['jtyp'].cpu().numpy() juncs = junc[jtyp == 0] junts = junc[jtyp == 1] lpre = target['wires']["lpre"].cpu().numpy() * 4 vecl_target = target['wires']["lpre_label"].cpu().numpy() lpre = lpre[vecl_target == 1] lines = lpre sline = np.ones(lpre.shape[0]) imshow(io.imread(img_path)) if len(lines) > 0 and not (lines[0] == 0).all(): for i, ((a, b), s) in enumerate(zip(lines, sline)): if i > 0 and (lines[i] == lines[0]).all(): break plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小 if not (juncs[0] == 0).all(): for i, j in enumerate(juncs): if i > 0 and (i == juncs[0]).all(): break plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64 img = PIL.Image.open(img_path).convert('RGB') boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"], colors="yellow", width=1) plt.imshow(boxed_image.permute(1, 2, 0).numpy()) plt.show() if fn != None: plt.savefig(fn) ''' >>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py