import imageio import numpy as np from torch.utils.data.dataset import T_co from libs.vision_libs.utils import draw_keypoints from models.base.base_dataset import BaseDataset import json import os import PIL import matplotlib as mpl from torchvision.utils import draw_bounding_boxes import torchvision.transforms.v2 as transforms import torch import matplotlib.pyplot as plt from models.base.transforms import get_transforms def validate_keypoints(keypoints, image_width, image_height): for kp in keypoints: x, y, v = kp if not (0 <= x < image_width and 0 <= y < image_height): raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})") """ 直接读取xanlabel标注的数据集json格式 """ class LineDataset(BaseDataset): def __init__(self, dataset_path, data_type, transforms=None,augmentation=False, dataset_type=None,img_type='rgb', target_type='pixel'): super().__init__(dataset_path) self.data_path = dataset_path self.data_type = data_type print(f'data_path:{dataset_path}') self.transforms = transforms self.img_path = os.path.join(dataset_path, "images/" + dataset_type) self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type) self.imgs = os.listdir(self.img_path) self.lbls = os.listdir(self.lbl_path) self.target_type = target_type self.img_type=img_type self.augmentation=augmentation print(f'augmentation:{augmentation}') # self.default_transform = DefaultTransform() def __getitem__(self, index) -> T_co: img_path = os.path.join(self.img_path, self.imgs[index]) if self.data_type == 'tiff': lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json') img = imageio.v3.imread(img_path)[:,:,0] print(f'img shape:{img.shape}') w, h = img.shape[:2] img=img.reshape(w,h,1) img_3channel = np.zeros((w, h, 3), dtype=img.dtype) img_3channel[:, :, 2] = img[:, :, 0] img = torch.from_numpy(img_3channel).permute(2, 1, 0) else: lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json') img = PIL.Image.open(img_path).convert('RGB') w, h = img.size # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) self.transforms=get_transforms(augmention=self.augmentation) img, target = self.transforms(img, target) return img, target def __len__(self): return len(self.imgs) def read_target(self, item, lbl_path, shape, extra=None): # print(f'shape:{shape}') # print(f'lbl_path:{lbl_path}') with open(lbl_path, 'r') as file: lable_all = json.load(file) objs = lable_all["shapes"] point_pairs=objs[0]['points'] # print(f'point_pairs:{point_pairs}') target = {} target["image_id"] = torch.tensor(item) boxes, lines, points, arc_mask,circle_4points,labels = get_boxes_lines(objs, shape) if points is not None: target["points"]=points if lines is not None: a = torch.full((lines.shape[0],), 2).unsqueeze(1) lines = torch.cat((lines, a), dim=1) target["lines"] = lines.to(torch.float32).view(-1, 2, 3) # print(f'lines shape:{ target["lines"].shape}') if arc_mask is not None: target['arc_mask']=arc_mask # print(f'arc_mask dataset') # else: # print(f'not arc_mask dataset') if circle_4points is not None: target['circle']=circle_4points target["boxes"]=boxes target["labels"]=labels # target["boxes"], lines,target["points"], target["labels"] = get_boxes_lines(objs,shape) # print(f'lines:{lines}') # target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64) # print(f'target points:{target["points"]}') # target["lines"] = lines.to(torch.float32).view(-1,2,3) # print(f'') # print(f'lines:{target["lines"].shape}') target["img_size"]=shape # validate_keypoints(lines, shape[0], shape[1]) return target def show(self, idx,show_type='all'): image, target = self.__getitem__(idx) cmap = plt.get_cmap("jet") norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) # img_path = os.path.join(self.img_path, self.imgs[idx]) # print(f'boxes:{target["boxes"]}') img = image if show_type=='all': boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"], colors="yellow", width=1) circle=target['circle'] print(f'taget circle:{circle.shape}') keypoint_img=draw_keypoints(boxed_image,circle,colors='red',width=3) plt.imshow(keypoint_img.permute(1, 2, 0).numpy()) plt.show() # if show_type=='lines': # keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['lines'],colors='red',width=3) # plt.imshow(keypoint_img.permute(1, 2, 0).numpy()) # plt.show() if show_type=='points': # print(f'points:{target['points'].shape}') keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['points'].unsqueeze(1),colors='red',width=3) plt.imshow(keypoint_img.permute(1, 2, 0).numpy()) plt.show() if show_type=='boxes': boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"], colors="yellow", width=1) plt.imshow(boxed_image.permute(1, 2, 0).numpy()) plt.show() def show_img(self, img_path): pass def get_boxes_lines(objs,shape): boxes = [] labels=[] h,w=shape line_point_pairs = [] points=[] line_mask=[] circle_4points=[] for obj in objs: # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小 # print(f"points:{obj['points']}") label=obj['label'] if label =='line' or label=='dseam1': a,b=obj['points'][0],obj['points'][1] line_point_pairs.append(a) line_point_pairs.append(b) xmin = max(0, (min(a[0], b[0]) - 6)) xmax = min(w, (max(a[0], b[0]) + 6)) ymin = max(0, (min(a[1], b[1]) - 6)) ymax = min(h, (max(a[1], b[1]) + 6)) boxes.append([ xmin,ymin, xmax,ymax]) labels.append(torch.tensor(2)) elif label =='point': p= obj['points'][0] xmin=max(0,p[0]-12) xmax = min(w, p[0] +12) ymin=max(0,p[1]-12) ymax = min(h, p[1] + 12) points.append(p) labels.append(torch.tensor(1)) boxes.append([xmin, ymin, xmax, ymax]) elif label == 'arc' : line_mask.append(obj['points']) xmin = obj['xmin'] xmax = obj['xmax'] ymin = obj['ymin'] ymax = obj['ymax'] boxes.append([xmin, ymin, xmax, ymax]) labels.append(torch.tensor(3)) elif label == 'circle' : circle_4points.append(obj['points']) xmin = max(obj['xmin'] - 6, 0) xmax = min(obj['xmax'] + 6, w) ymin = max(obj['ymin'] - 6, 0) ymax = min(obj['ymax'] + 6, h) boxes.append([xmin, ymin, xmax, ymax]) labels.append(torch.tensor(3)) boxes=torch.tensor(boxes) print(f'boxes:{boxes.shape}') labels=torch.tensor(labels) if len(points)==0: points=None else: points=torch.tensor(points,dtype=torch.float32) print(f'read labels:{labels}') # print(f'read points:{points}') if len(line_point_pairs)==0: line_point_pairs=None else: line_point_pairs=torch.tensor(line_point_pairs) # print(f'line_point_pairs:{line_point_pairs.shape},{line_point_pairs.dtype}') # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}') if len(line_mask)==0: line_mask=None else: line_mask=torch.tensor(line_mask,dtype=torch.float32) print(f'arc_mask shape :{line_mask.shape},{line_mask.dtype}') if len(circle_4points)==0: circle_4points=None else: circle_4points=torch.tensor(circle_4points,dtype=torch.float32) return boxes,line_point_pairs,points,line_mask,circle_4points, labels if __name__ == '__main__': path=r"\\192.168.50.222/share/zyh/data/rgb_4point/a_dataset" dataset= LineDataset(dataset_path=path, dataset_type='train',augmentation=False, data_type='jpg') dataset.show(1,show_type='all')