| 
					
				 | 
			
			
				@@ -1,107 +1,3 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-======== 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# import glob 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# import math 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# import os 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# import random 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# import numpy as np 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# import numpy.linalg as LA 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# import torch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# from skimage import io 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# from torch.utils.data import Dataset 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# from torch.utils.data.dataloader import default_collate 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# from lcnn.config import M 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# from .dataset_tool import line_boxes_faster, read_masks_from_txt_wire, read_masks_from_pixels_wire 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# class WireframeDataset(Dataset): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     def __init__(self, rootdir, split): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         self.rootdir = rootdir 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         filelist = glob.glob(f"{rootdir}/{split}/*_label.npz") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         filelist.sort() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         # print(f"n{split}:", len(filelist)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         self.split = split 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         self.filelist = filelist 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     def __len__(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         return len(self.filelist) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     def __getitem__(self, idx): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         iname = self.filelist[idx][:-10].replace("_a0", "").replace("_a1", "") + ".png" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         image = io.imread(iname).astype(float)[:, :, :3] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         if "a1" in self.filelist[idx]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             image = image[:, ::-1, :] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         image = (image - M.image.mean) / M.image.stddev 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         image = np.rollaxis(image, 2).copy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         with np.load(self.filelist[idx]) as npz: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             target = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 name: torch.from_numpy(npz[name]).float() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 for name in ["jmap", "joff", "lmap"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             lpos = np.random.permutation(npz["lpos"])[: M.n_stc_posl] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             lneg = np.random.permutation(npz["lneg"])[: M.n_stc_negl] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             npos, nneg = len(lpos), len(lneg) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             lpre = np.concatenate([lpos, lneg], 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             for i in range(len(lpre)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 if random.random() > 0.5: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                     lpre[i] = lpre[i, ::-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             ldir = lpre[:, 0, :2] - lpre[:, 1, :2] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             feat = [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 lpre[:, :, :2].reshape(-1, 4) / 128 * M.use_cood, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 ldir * M.use_slop, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 lpre[:, :, 2], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             feat = np.concatenate(feat, 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             meta = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 "junc": torch.from_numpy(npz["junc"][:, :2]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 "jtyp": torch.from_numpy(npz["junc"][:, 2]).byte(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 "Lpos": self.adjacency_matrix(len(npz["junc"]), npz["Lpos"]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 "Lneg": self.adjacency_matrix(len(npz["junc"]), npz["Lneg"]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 "lpre": torch.from_numpy(lpre[:, :, :2]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#                 "lpre_feat": torch.from_numpy(feat), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         labels = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         labels = read_masks_from_pixels_wire(iname, (512, 512)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         # if self.target_type == 'polygon': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         #     labels, masks = read_masks_from_txt_wire(iname, (512, 512)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         # elif self.target_type == 'pixel': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         #     labels = read_masks_from_pixels_wire(iname, (512, 512)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         target["labels"] = torch.stack(labels) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         target["boxes"] = line_boxes_faster(meta) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         return torch.from_numpy(image).float(), meta, target 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     def adjacency_matrix(self, n, link): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         link = torch.from_numpy(link) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         if len(link) > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             mat[link[:, 0], link[:, 1]] = 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#             mat[link[:, 1], link[:, 0]] = 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         return mat 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# def collate(batch): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     return ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         default_collate([b[0] for b in batch]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         [b[1] for b in batch], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#         default_collate([b[2] for b in batch]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 原LCNN数据格式,改了属性名,加了box相关 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				->>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from torch.utils.data.dataset import T_co 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from models.base.base_dataset import BaseDataset 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -134,12 +30,8 @@ def validate_keypoints(keypoints, image_width, image_height): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if not (0 <= x < image_width and 0 <= y < image_height): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class KeypointDataset(BaseDataset): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-======== 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-class  WireframeDataset(BaseDataset): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				->>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         super().__init__(dataset_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -307,211 +199,7 @@ class  WireframeDataset(BaseDataset): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-<<<<<<<< HEAD:models/keypoint/keypoint_dataset.py 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     path=r"I:\datasets\wirenet_1000" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     dataset= KeypointDataset(dataset_path=path, dataset_type='train') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    dataset.show(7) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-======== 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-''' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 使用roi_head数据格式有要求,更改数据格式 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from torch.utils.data.dataset import T_co 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from models.base.base_dataset import BaseDataset 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import glob 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import math 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import os 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import random 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import cv2 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import PIL 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import matplotlib.pyplot as plt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import matplotlib as mpl 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from torchvision.utils import draw_bounding_boxes 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import numpy as np 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import numpy.linalg as LA 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import torch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from skimage import io 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from torch.utils.data import Dataset 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from torch.utils.data.dataloader import default_collate 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import matplotlib.pyplot as plt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from tools.presets import DetectionPresetTrain 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-class WireframeDataset(BaseDataset): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        super().__init__(dataset_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.data_path = dataset_path 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print(f'data_path:{dataset_path}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.transforms = transforms 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.img_path = os.path.join(dataset_path, "images\\" + dataset_type) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.lbl_path = os.path.join(dataset_path, "labels\\" + dataset_type) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.imgs = os.listdir(self.img_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.lbls = os.listdir(self.lbl_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.target_type = target_type 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # self.default_transform = DefaultTransform() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.data_augmentation = DetectionPresetTrain(data_augmentation="hflip")  # multiscale会改变图像大小 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def __getitem__(self, index) -> T_co: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        img_path = os.path.join(self.img_path, self.imgs[index]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        img = PIL.Image.open(img_path).convert('RGB') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        w, h = img.size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # if self.transforms: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #     img, target = self.transforms(img, target) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #     img = self.default_transform(img) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        img, target = self.data_augmentation(img, target) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print(f'img:{img.shape}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return img, target 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def __len__(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return len(self.imgs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def read_target(self, item, lbl_path, shape, extra=None): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # print(f'lbl_path:{lbl_path}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        with open(lbl_path, 'r') as file: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            lable_all = json.load(file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        n_stc_posl = 300 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        n_stc_negl = 40 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        use_cood = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        use_slop = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        wire = lable_all["wires"][0]  # 字典 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # 不足,有多少取多少 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        npos, nneg = len(line_pos_coords), len(line_neg_coords) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # 正负样本坐标合在一起 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for i in range(len(lpre)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if random.random() > 0.5: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                lpre[i] = lpre[i, ::-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ldir = lpre[:, 0, :2] - lpre[:, 1, :2] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        feat = [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            ldir * use_slop, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            lpre[:, :, 2], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        feat = np.concatenate(feat, 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        wire_labels = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            # 真实存在线条的邻接矩阵 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            # 不存在线条的临界矩阵 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "lpre": torch.tensor(lpre)[:, :, :2], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # 样本对应标签 1,0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "lpre_feat": torch.from_numpy(feat), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "junc_map": torch.tensor(wire['junc_map']["content"]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "junc_offset": torch.tensor(wire['junc_offset']["content"]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "line_map": torch.tensor(wire['line_map']["content"]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        labels = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # if self.target_type == 'polygon': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #     labels, masks = read_masks_from_txt_wire(lbl_path, shape) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # elif self.target_type == 'pixel': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #     labels = read_masks_from_pixels_wire(lbl_path, shape) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # print(torch.stack(masks).shape)    # [线段数, 512, 512] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        target = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # target["labels"] = torch.stack(labels) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        target["image_id"] = torch.tensor(item) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # return wire_labels, target 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        target["wires"] = wire_labels 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        target["boxes"] = line_boxes(target) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return target 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def show(self, idx): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        image, target = self.__getitem__(idx) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        img_path = os.path.join(self.img_path, self.imgs[idx]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self._draw_vecl(img_path, target) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def show_img(self, img_path): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        """根据给定的图片路径展示图像及其标注信息""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # 获取对应的标签文件路径 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        img_name = os.path.basename(img_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        img_path = os.path.join(self.img_path, img_name) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print(img_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        lbl_name = img_name[:-3] + 'json' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        lbl_path = os.path.join(self.lbl_path, lbl_name) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print(lbl_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if not os.path.exists(lbl_path): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            raise FileNotFoundError(f"Label file {lbl_path} does not exist.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        img = PIL.Image.open(img_path).convert('RGB') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        w, h = img.size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        target = self.read_target(0, lbl_path, shape=(h, w)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # 调用绘图函数 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self._draw_vecl(img_path, target) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def _draw_vecl(self, img_path, target, fn=None): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        cmap = plt.get_cmap("jet") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        sm.set_array([]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        def imshow(im): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            plt.tight_layout() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            plt.imshow(im) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            plt.colorbar(sm, fraction=0.046) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            plt.xlim([0, im.shape[0]]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            plt.ylim([im.shape[0], 0]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        junc = target['wires']['junc_coords'].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        jtyp = target['wires']['jtyp'].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        juncs = junc[jtyp == 0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        junts = junc[jtyp == 1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        lpre = target['wires']["lpre"].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        vecl_target = target['wires']["lpre_label"].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        lpre = lpre[vecl_target == 1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        lines = lpre 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        sline = np.ones(lpre.shape[0]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        imshow(io.imread(img_path)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if len(lines) > 0 and not (lines[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for i, ((a, b), s) in enumerate(zip(lines, sline)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if i > 0 and (lines[i] == lines[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if not (juncs[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for i, j in enumerate(juncs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if i > 0 and (i == juncs[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        img = PIL.Image.open(img_path).convert('RGB') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                          colors="yellow", width=1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        plt.imshow(boxed_image.permute(1, 2, 0).numpy()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        plt.show() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if fn != None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            plt.savefig(fn) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-''' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				->>>>>>>> 8c208a87b75e9fde2fe6dcf23f2e589aeb87f172:lcnn/datasets.py 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    dataset.show(7) 
			 |