|
@@ -1,107 +1,3 @@
|
|
|
-<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
|
|
|
-========
|
|
|
-# import glob
|
|
|
-# import json
|
|
|
-# import math
|
|
|
-# import os
|
|
|
-# import random
|
|
|
-#
|
|
|
-# import numpy as np
|
|
|
-# import numpy.linalg as LA
|
|
|
-# import torch
|
|
|
-# from skimage import io
|
|
|
-# from torch.utils.data import Dataset
|
|
|
-# from torch.utils.data.dataloader import default_collate
|
|
|
-#
|
|
|
-# from lcnn.config import M
|
|
|
-#
|
|
|
-# from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
|
|
|
-#
|
|
|
-#
|
|
|
-# class WireframeDataset(Dataset):
|
|
|
-# def __init__(self, rootdir, split):
|
|
|
-# self.rootdir = rootdir
|
|
|
-# filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
|
|
|
-# filelist.sort()
|
|
|
-#
|
|
|
-# # print(f"n{split}:", len(filelist))
|
|
|
-# self.split = split
|
|
|
-# self.filelist = filelist
|
|
|
-#
|
|
|
-# def __len__(self):
|
|
|
-# return len(self.filelist)
|
|
|
-#
|
|
|
-# def __getitem__(self, idx):
|
|
|
-# iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
|
|
|
-# image = io.imread(iname).astype(float)[:, :, :3]
|
|
|
-# if "a1" in self.filelist[idx]:
|
|
|
-# image = image[:, ::-1, :]
|
|
|
-# image = (image - M.image.mean) / M.image.stddev
|
|
|
-# image = np.rollaxis(image, 2).copy()
|
|
|
-#
|
|
|
-# with np.load(self.filelist[idx]) as npz:
|
|
|
-# target = {
|
|
|
-# name: torch.from_numpy(npz[name]).float()
|
|
|
-# for name in ["jmap", "joff", "lmap"]
|
|
|
-# }
|
|
|
-# lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
|
|
|
-# lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
|
|
|
-# npos, nneg = len(lpos), len(lneg)
|
|
|
-# lpre = np.concatenate([lpos, lneg], 0)
|
|
|
-# for i in range(len(lpre)):
|
|
|
-# if random.random() > 0.5:
|
|
|
-# lpre[i] = lpre[i, ::-1]
|
|
|
-# ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
|
|
|
-# ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
|
|
|
-# feat = [
|
|
|
-# lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
|
|
|
-# ldir * M.use_slop,
|
|
|
-# lpre[:, :, 2],
|
|
|
-# ]
|
|
|
-# feat = np.concatenate(feat, 1)
|
|
|
-# meta = {
|
|
|
-# "junc": torch.from_numpy(npz["junc"][:, :2]),
|
|
|
-# "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
|
|
|
-# "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
|
|
|
-# "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
|
|
|
-# "lpre": torch.from_numpy(lpre[:, :, :2]),
|
|
|
-# "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
|
|
|
-# "lpre_feat": torch.from_numpy(feat),
|
|
|
-# }
|
|
|
-#
|
|
|
-# labels = []
|
|
|
-# labels = read_masks_from_pixels_wire(iname, (512, 512))
|
|
|
-# # if self.target_type == 'polygon':
|
|
|
-# # labels, masks = read_masks_from_txt_wire(iname, (512, 512))
|
|
|
-# # elif self.target_type == 'pixel':
|
|
|
-# # labels = read_masks_from_pixels_wire(iname, (512, 512))
|
|
|
-#
|
|
|
-# target["labels"] = torch.stack(labels)
|
|
|
-# target["boxes"] = line_boxes_faster(meta)
|
|
|
-#
|
|
|
-#
|
|
|
-# return torch.from_numpy(image).float(), meta, target
|
|
|
-#
|
|
|
-# def adjacency_matrix(self, n, link):
|
|
|
-# mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
|
|
|
-# link = torch.from_numpy(link)
|
|
|
-# if len(link) > 0:
|
|
|
-# mat[link[:, 0], link[:, 1]] = 1
|
|
|
-# mat[link[:, 1], link[:, 0]] = 1
|
|
|
-# return mat
|
|
|
-#
|
|
|
-#
|
|
|
-# def collate(batch):
|
|
|
-# return (
|
|
|
-# default_collate([b[0] for b in batch]),
|
|
|
-# [b[1] for b in batch],
|
|
|
-# default_collate([b[2] for b in batch]),
|
|
|
-# )
|
|
|
-
|
|
|
-
|
|
|
-# 原LCNN数据格式,改了属性名,加了box相关
|
|
|
-
|
|
|
->>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
|
|
|
from torch.utils.data.dataset import T_co
|
|
|
|
|
|
from models.base.base_dataset import BaseDataset
|
|
@@ -134,12 +30,8 @@ def validate_keypoints(keypoints, image_width, image_height):
|
|
|
if not (0 <= x < image_width and 0 <= y < image_height):
|
|
|
raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
|
|
|
|
|
|
-<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
|
|
|
|
|
|
class KeypointDataset(BaseDataset):
|
|
|
-========
|
|
|
-class WireframeDataset(BaseDataset):
|
|
|
->>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
|
|
|
def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
|
|
|
super().__init__(dataset_path)
|
|
|
|
|
@@ -307,211 +199,7 @@ class WireframeDataset(BaseDataset):
|
|
|
|
|
|
|
|
|
|
|
|
-<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
|
|
|
if __name__ == '__main__':
|
|
|
path=r"I:\datasets\wirenet_1000"
|
|
|
dataset= KeypointDataset(dataset_path=path, dataset_type='train')
|
|
|
- dataset.show(7)
|
|
|
-========
|
|
|
-
|
|
|
-'''
|
|
|
-# 使用roi_head数据格式有要求,更改数据格式
|
|
|
-from torch.utils.data.dataset import T_co
|
|
|
-
|
|
|
-from models.base.base_dataset import BaseDataset
|
|
|
-
|
|
|
-import glob
|
|
|
-import json
|
|
|
-import math
|
|
|
-import os
|
|
|
-import random
|
|
|
-import cv2
|
|
|
-import PIL
|
|
|
-
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-import matplotlib as mpl
|
|
|
-from torchvision.utils import draw_bounding_boxes
|
|
|
-
|
|
|
-import numpy as np
|
|
|
-import numpy.linalg as LA
|
|
|
-import torch
|
|
|
-from skimage import io
|
|
|
-from torch.utils.data import Dataset
|
|
|
-from torch.utils.data.dataloader import default_collate
|
|
|
-
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
|
|
|
-
|
|
|
-from tools.presets import DetectionPresetTrain
|
|
|
-
|
|
|
-
|
|
|
-class WireframeDataset(BaseDataset):
|
|
|
- def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
|
|
|
- super().__init__(dataset_path)
|
|
|
-
|
|
|
- self.data_path = dataset_path
|
|
|
- print(f'data_path:{dataset_path}')
|
|
|
- self.transforms = transforms
|
|
|
- self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
|
|
|
- self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
|
|
|
- self.imgs = os.listdir(self.img_path)
|
|
|
- self.lbls = os.listdir(self.lbl_path)
|
|
|
- self.target_type = target_type
|
|
|
- # self.default_transform = DefaultTransform()
|
|
|
- self.data_augmentation = DetectionPresetTrain(data_augmentation="hflip") # multiscale会改变图像大小
|
|
|
-
|
|
|
- def __getitem__(self, index) -> T_co:
|
|
|
- img_path = os.path.join(self.img_path, self.imgs[index])
|
|
|
- lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
|
|
|
-
|
|
|
- img = PIL.Image.open(img_path).convert('RGB')
|
|
|
- w, h = img.size
|
|
|
-
|
|
|
- # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
- target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
- # if self.transforms:
|
|
|
- # img, target = self.transforms(img, target)
|
|
|
- # else:
|
|
|
- # img = self.default_transform(img)
|
|
|
-
|
|
|
- img, target = self.data_augmentation(img, target)
|
|
|
-
|
|
|
- print(f'img:{img.shape}')
|
|
|
- return img, target
|
|
|
-
|
|
|
- def __len__(self):
|
|
|
- return len(self.imgs)
|
|
|
-
|
|
|
- def read_target(self, item, lbl_path, shape, extra=None):
|
|
|
- # print(f'lbl_path:{lbl_path}')
|
|
|
- with open(lbl_path, 'r') as file:
|
|
|
- lable_all = json.load(file)
|
|
|
-
|
|
|
- n_stc_posl = 300
|
|
|
- n_stc_negl = 40
|
|
|
- use_cood = 0
|
|
|
- use_slop = 0
|
|
|
-
|
|
|
- wire = lable_all["wires"][0] # 字典
|
|
|
- line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少
|
|
|
- line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
|
|
|
- npos, nneg = len(line_pos_coords), len(line_neg_coords)
|
|
|
- lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起
|
|
|
- for i in range(len(lpre)):
|
|
|
- if random.random() > 0.5:
|
|
|
- lpre[i] = lpre[i, ::-1]
|
|
|
- ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
|
|
|
- ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
|
|
|
- feat = [
|
|
|
- lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
|
|
|
- ldir * use_slop,
|
|
|
- lpre[:, :, 2],
|
|
|
- ]
|
|
|
- feat = np.concatenate(feat, 1)
|
|
|
-
|
|
|
- wire_labels = {
|
|
|
- "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
|
|
|
- "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
|
|
|
- "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
|
|
|
- # 真实存在线条的邻接矩阵
|
|
|
- "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
|
|
|
- # 不存在线条的临界矩阵
|
|
|
- "lpre": torch.tensor(lpre)[:, :, :2],
|
|
|
- "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0
|
|
|
- "lpre_feat": torch.from_numpy(feat),
|
|
|
- "junc_map": torch.tensor(wire['junc_map']["content"]),
|
|
|
- "junc_offset": torch.tensor(wire['junc_offset']["content"]),
|
|
|
- "line_map": torch.tensor(wire['line_map']["content"]),
|
|
|
- }
|
|
|
-
|
|
|
- labels = []
|
|
|
- # if self.target_type == 'polygon':
|
|
|
- # labels, masks = read_masks_from_txt_wire(lbl_path, shape)
|
|
|
- # elif self.target_type == 'pixel':
|
|
|
- # labels = read_masks_from_pixels_wire(lbl_path, shape)
|
|
|
-
|
|
|
- # print(torch.stack(masks).shape) # [线段数, 512, 512]
|
|
|
- target = {}
|
|
|
- # target["labels"] = torch.stack(labels)
|
|
|
- target["image_id"] = torch.tensor(item)
|
|
|
- # return wire_labels, target
|
|
|
- target["wires"] = wire_labels
|
|
|
- target["boxes"] = line_boxes(target)
|
|
|
- return target
|
|
|
-
|
|
|
- def show(self, idx):
|
|
|
- image, target = self.__getitem__(idx)
|
|
|
- img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
- self._draw_vecl(img_path, target)
|
|
|
-
|
|
|
- def show_img(self, img_path):
|
|
|
-
|
|
|
- """根据给定的图片路径展示图像及其标注信息"""
|
|
|
- # 获取对应的标签文件路径
|
|
|
- img_name = os.path.basename(img_path)
|
|
|
- img_path = os.path.join(self.img_path, img_name)
|
|
|
- print(img_path)
|
|
|
- lbl_name = img_name[:-3] + 'json'
|
|
|
- lbl_path = os.path.join(self.lbl_path, lbl_name)
|
|
|
- print(lbl_path)
|
|
|
-
|
|
|
- if not os.path.exists(lbl_path):
|
|
|
- raise FileNotFoundError(f"Label file {lbl_path} does not exist.")
|
|
|
-
|
|
|
- img = PIL.Image.open(img_path).convert('RGB')
|
|
|
- w, h = img.size
|
|
|
-
|
|
|
- target = self.read_target(0, lbl_path, shape=(h, w))
|
|
|
-
|
|
|
- # 调用绘图函数
|
|
|
- self._draw_vecl(img_path, target)
|
|
|
-
|
|
|
-
|
|
|
- def _draw_vecl(self, img_path, target, fn=None):
|
|
|
- cmap = plt.get_cmap("jet")
|
|
|
- norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
|
|
|
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
|
- sm.set_array([])
|
|
|
-
|
|
|
- def imshow(im):
|
|
|
- plt.close()
|
|
|
- plt.tight_layout()
|
|
|
- plt.imshow(im)
|
|
|
- plt.colorbar(sm, fraction=0.046)
|
|
|
- plt.xlim([0, im.shape[0]])
|
|
|
- plt.ylim([im.shape[0], 0])
|
|
|
-
|
|
|
- junc = target['wires']['junc_coords'].cpu().numpy() * 4
|
|
|
- jtyp = target['wires']['jtyp'].cpu().numpy()
|
|
|
- juncs = junc[jtyp == 0]
|
|
|
- junts = junc[jtyp == 1]
|
|
|
-
|
|
|
- lpre = target['wires']["lpre"].cpu().numpy() * 4
|
|
|
- vecl_target = target['wires']["lpre_label"].cpu().numpy()
|
|
|
- lpre = lpre[vecl_target == 1]
|
|
|
-
|
|
|
- lines = lpre
|
|
|
- sline = np.ones(lpre.shape[0])
|
|
|
- imshow(io.imread(img_path))
|
|
|
- if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
- for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
- if i > 0 and (lines[i] == lines[0]).all():
|
|
|
- break
|
|
|
- plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
|
|
|
- if not (juncs[0] == 0).all():
|
|
|
- for i, j in enumerate(juncs):
|
|
|
- if i > 0 and (i == juncs[0]).all():
|
|
|
- break
|
|
|
- plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
|
|
|
-
|
|
|
- img = PIL.Image.open(img_path).convert('RGB')
|
|
|
- boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
|
- colors="yellow", width=1)
|
|
|
- plt.imshow(boxed_image.permute(1, 2, 0).numpy())
|
|
|
- plt.show()
|
|
|
-
|
|
|
- if fn != None:
|
|
|
- plt.savefig(fn)
|
|
|
-
|
|
|
-'''
|
|
|
->>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
|
|
|
+ dataset.show(7)
|