瀏覽代碼

fixed keypoint_dataset

RenLiqiang 3 月之前
父節點
當前提交
47c02eb7c5
共有 1 個文件被更改,包括 1 次插入313 次删除
  1. 1 313
      models/keypoint/keypoint_dataset.py

+ 1 - 313
models/keypoint/keypoint_dataset.py

@@ -1,107 +1,3 @@
-<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
-========
-# import glob
-# import json
-# import math
-# import os
-# import random
-#
-# import numpy as np
-# import numpy.linalg as LA
-# import torch
-# from skimage import io
-# from torch.utils.data import Dataset
-# from torch.utils.data.dataloader import default_collate
-#
-# from lcnn.config import M
-#
-# from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire
-#
-#
-# class WireframeDataset(Dataset):
-#     def __init__(self, rootdir, split):
-#         self.rootdir = rootdir
-#         filelist = glob.glob(f"{rootdir}/{split}/*_label.npz")
-#         filelist.sort()
-#
-#         # print(f"n{split}:", len(filelist))
-#         self.split = split
-#         self.filelist = filelist
-#
-#     def __len__(self):
-#         return len(self.filelist)
-#
-#     def __getitem__(self, idx):
-#         iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png"
-#         image = io.imread(iname).astype(float)[:, :, :3]
-#         if "a1" in self.filelist[idx]:
-#             image = image[:, ::-1, :]
-#         image = (image - M.image.mean) / M.image.stddev
-#         image = np.rollaxis(image, 2).copy()
-#
-#         with np.load(self.filelist[idx]) as npz:
-#             target = {
-#                 name: torch.from_numpy(npz[name]).float()
-#                 for name in ["jmap", "joff", "lmap"]
-#             }
-#             lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl]
-#             lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl]
-#             npos, nneg = len(lpos), len(lneg)
-#             lpre = np.concatenate([lpos, lneg], 0)
-#             for i in range(len(lpre)):
-#                 if random.random() > 0.5:
-#                     lpre[i] = lpre[i, ::-1]
-#             ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
-#             ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
-#             feat = [
-#                 lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood,
-#                 ldir * M.use_slop,
-#                 lpre[:, :, 2],
-#             ]
-#             feat = np.concatenate(feat, 1)
-#             meta = {
-#                 "junc": torch.from_numpy(npz["junc"][:, :2]),
-#                 "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(),
-#                 "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]),
-#                 "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]),
-#                 "lpre": torch.from_numpy(lpre[:, :, :2]),
-#                 "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),
-#                 "lpre_feat": torch.from_numpy(feat),
-#             }
-#
-#         labels = []
-#         labels = read_masks_from_pixels_wire(iname, (512, 512))
-#         # if self.target_type == 'polygon':
-#         #     labels, masks = read_masks_from_txt_wire(iname, (512, 512))
-#         # elif self.target_type == 'pixel':
-#         #     labels = read_masks_from_pixels_wire(iname, (512, 512))
-#
-#         target["labels"] = torch.stack(labels)
-#         target["boxes"] = line_boxes_faster(meta)
-#
-#
-#         return torch.from_numpy(image).float(), meta, target
-#
-#     def adjacency_matrix(self, n, link):
-#         mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
-#         link = torch.from_numpy(link)
-#         if len(link) > 0:
-#             mat[link[:, 0], link[:, 1]] = 1
-#             mat[link[:, 1], link[:, 0]] = 1
-#         return mat
-#
-#
-# def collate(batch):
-#     return (
-#         default_collate([b[0] for b in batch]),
-#         [b[1] for b in batch],
-#         default_collate([b[2] for b in batch]),
-#     )
-
-
-# 原LCNN数据格式,改了属性名,加了box相关
-
->>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
 from torch.utils.data.dataset import T_co
 
 from models.base.base_dataset import BaseDataset
@@ -134,12 +30,8 @@ def validate_keypoints(keypoints, image_width, image_height):
         if not (0 <= x < image_width and 0 <= y < image_height):
             raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
 
-<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
 
 class KeypointDataset(BaseDataset):
-========
-class  WireframeDataset(BaseDataset):
->>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
     def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
         super().__init__(dataset_path)
 
@@ -307,211 +199,7 @@ class  WireframeDataset(BaseDataset):
 
 
 
-<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py
 if __name__ == '__main__':
     path=r"I:\datasets\wirenet_1000"
     dataset= KeypointDataset(dataset_path=path, dataset_type='train')
-    dataset.show(7)
-========
-
-'''
-# 使用roi_head数据格式有要求,更改数据格式
-from torch.utils.data.dataset import T_co
-
-from models.base.base_dataset import BaseDataset
-
-import glob
-import json
-import math
-import os
-import random
-import cv2
-import PIL
-
-import matplotlib.pyplot as plt
-import matplotlib as mpl
-from torchvision.utils import draw_bounding_boxes
-
-import numpy as np
-import numpy.linalg as LA
-import torch
-from skimage import io
-from torch.utils.data import Dataset
-from torch.utils.data.dataloader import default_collate
-
-import matplotlib.pyplot as plt
-from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
-
-from tools.presets import DetectionPresetTrain
-
-
-class WireframeDataset(BaseDataset):
-    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
-        super().__init__(dataset_path)
-
-        self.data_path = dataset_path
-        print(f'data_path:{dataset_path}')
-        self.transforms = transforms
-        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type)
-        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type)
-        self.imgs = os.listdir(self.img_path)
-        self.lbls = os.listdir(self.lbl_path)
-        self.target_type = target_type
-        # self.default_transform = DefaultTransform()
-        self.data_augmentation = DetectionPresetTrain(data_augmentation="hflip")  # multiscale会改变图像大小
-
-    def __getitem__(self, index) -> T_co:
-        img_path = os.path.join(self.img_path, self.imgs[index])
-        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
-
-        img = PIL.Image.open(img_path).convert('RGB')
-        w, h = img.size
-
-        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        # if self.transforms:
-        #     img, target = self.transforms(img, target)
-        # else:
-        #     img = self.default_transform(img)
-
-        img, target = self.data_augmentation(img, target)
-
-        print(f'img:{img.shape}')
-        return img, target
-
-    def __len__(self):
-        return len(self.imgs)
-
-    def read_target(self, item, lbl_path, shape, extra=None):
-        # print(f'lbl_path:{lbl_path}')
-        with open(lbl_path, 'r') as file:
-            lable_all = json.load(file)
-
-        n_stc_posl = 300
-        n_stc_negl = 40
-        use_cood = 0
-        use_slop = 0
-
-        wire = lable_all["wires"][0]  # 字典
-        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少
-        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
-        npos, nneg = len(line_pos_coords), len(line_neg_coords)
-        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起
-        for i in range(len(lpre)):
-            if random.random() > 0.5:
-                lpre[i] = lpre[i, ::-1]
-        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
-        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
-        feat = [
-            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
-            ldir * use_slop,
-            lpre[:, :, 2],
-        ]
-        feat = np.concatenate(feat, 1)
-
-        wire_labels = {
-            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
-            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
-            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
-            # 真实存在线条的邻接矩阵
-            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
-            # 不存在线条的临界矩阵
-            "lpre": torch.tensor(lpre)[:, :, :2],
-            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0
-            "lpre_feat": torch.from_numpy(feat),
-            "junc_map": torch.tensor(wire['junc_map']["content"]),
-            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
-            "line_map": torch.tensor(wire['line_map']["content"]),
-        }
-
-        labels = []
-        # if self.target_type == 'polygon':
-        #     labels, masks = read_masks_from_txt_wire(lbl_path, shape)
-        # elif self.target_type == 'pixel':
-        #     labels = read_masks_from_pixels_wire(lbl_path, shape)
-
-        # print(torch.stack(masks).shape)    # [线段数, 512, 512]
-        target = {}
-        # target["labels"] = torch.stack(labels)
-        target["image_id"] = torch.tensor(item)
-        # return wire_labels, target
-        target["wires"] = wire_labels
-        target["boxes"] = line_boxes(target)
-        return target
-
-    def show(self, idx):
-        image, target = self.__getitem__(idx)
-        img_path = os.path.join(self.img_path, self.imgs[idx])
-        self._draw_vecl(img_path, target)
-
-    def show_img(self, img_path):
-
-        """根据给定的图片路径展示图像及其标注信息"""
-        # 获取对应的标签文件路径
-        img_name = os.path.basename(img_path)
-        img_path = os.path.join(self.img_path, img_name)
-        print(img_path)
-        lbl_name = img_name[:-3] + 'json'
-        lbl_path = os.path.join(self.lbl_path, lbl_name)
-        print(lbl_path)
-
-        if not os.path.exists(lbl_path):
-            raise FileNotFoundError(f"Label file {lbl_path} does not exist.")
-
-        img = PIL.Image.open(img_path).convert('RGB')
-        w, h = img.size
-
-        target = self.read_target(0, lbl_path, shape=(h, w))
-
-        # 调用绘图函数
-        self._draw_vecl(img_path, target)
-
-
-    def _draw_vecl(self, img_path, target, fn=None):
-        cmap = plt.get_cmap("jet")
-        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
-        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
-        sm.set_array([])
-
-        def imshow(im):
-            plt.close()
-            plt.tight_layout()
-            plt.imshow(im)
-            plt.colorbar(sm, fraction=0.046)
-            plt.xlim([0, im.shape[0]])
-            plt.ylim([im.shape[0], 0])
-
-        junc = target['wires']['junc_coords'].cpu().numpy() * 4
-        jtyp = target['wires']['jtyp'].cpu().numpy()
-        juncs = junc[jtyp == 0]
-        junts = junc[jtyp == 1]
-
-        lpre = target['wires']["lpre"].cpu().numpy() * 4
-        vecl_target = target['wires']["lpre_label"].cpu().numpy()
-        lpre = lpre[vecl_target == 1]
-
-        lines = lpre
-        sline = np.ones(lpre.shape[0])
-        imshow(io.imread(img_path))
-        if len(lines) > 0 and not (lines[0] == 0).all():
-            for i, ((a, b), s) in enumerate(zip(lines, sline)):
-                if i > 0 and (lines[i] == lines[0]).all():
-                    break
-                plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
-        if not (juncs[0] == 0).all():
-            for i, j in enumerate(juncs):
-                if i > 0 and (i == juncs[0]).all():
-                    break
-                plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
-
-        img = PIL.Image.open(img_path).convert('RGB')
-        boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
-                                          colors="yellow", width=1)
-        plt.imshow(boxed_image.permute(1, 2, 0).numpy())
-        plt.show()
-
-        if fn != None:
-            plt.savefig(fn)
-
-'''
->>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py
+    dataset.show(7)