| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227 |
- from torch.utils.data.dataset import T_co
- from libs.vision_libs.utils import draw_keypoints
- from models.base.base_dataset import BaseDataset
- import glob
- import json
- import math
- import os
- import random
- import cv2
- import PIL
- import imageio
- import matplotlib.pyplot as plt
- import matplotlib as mpl
- from torchvision.utils import draw_bounding_boxes
- import torchvision.transforms.v2 as transforms
- import numpy as np
- import numpy.linalg as LA
- import torch
- from skimage import io
- from torch.utils.data import Dataset
- from torch.utils.data.dataloader import default_collate
- import matplotlib.pyplot as plt
- from models.dataset_tool import read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
- def validate_keypoints(keypoints, image_width, image_height):
- for kp in keypoints:
- x, y, v = kp
- if not (0 <= x < image_width and 0 <= y < image_height):
- raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
- def apply_transform_with_boxes_and_keypoints(img,target):
- """
- 对图像、边界框和关键点应用相同的变换。
- :param img_path: 图像文件路径
- :param boxes: 形状为 (N, 4) 的 Tensor,表示 N 个边界框的坐标 [x_min, y_min, x_max, y_max]
- :param keypoints: 形状为 (N, K, 3) 的 Tensor,表示 N 个实例的 K 个关键点的坐标和可见性 [x, y, visibility]
- :return: 变换后的图像、边界框和关键点
- """
- # 定义一系列用于数据增强的变换
- data_transforms = transforms.Compose([
- # 随机调整大小和随机裁剪
- # transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0), antialias=True),
- # 随机水平翻转
- transforms.RandomHorizontalFlip(p=0.5),
- # 颜色抖动: 改变亮度、对比度、饱和度和色调
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
- # # 转换为张量
- # transforms.ToTensor(),
- #
- # # 标准化
- # transforms.Normalize(mean=[0.485, 0.456, 0.406],
- # std=[0.229, 0.224, 0.225])
- ])
- boxes=target['boxes']
- keypoints=target['lines']
- # 将边界框转换为适合传递给 transforms 的格式
- boxes_format = [(box[0].item(), box[1].item(), box[2].item(), box[3].item()) for box in boxes]
- # 将关键点转换为适合传递给 transforms 的格式
- keypoints_format = [[(kp[0].item(), kp[1].item(), bool(kp[2].item())) for kp in keypoint] for keypoint in keypoints]
- # 应用变换
- transformed = data_transforms(img, {"boxes": boxes_format, "keypoints": keypoints_format})
- # 获取变换后的图像、边界框和关键点
- img_transformed = transformed[0]
- boxes_transformed = torch.tensor([(box[0], box[1], box[2], box[3]) for box in transformed[1]['boxes']],
- dtype=torch.float32)
- keypoints_transformed = torch.tensor(
- [[(kp[0], kp[1], int(kp[2])) for kp in keypoint] for keypoint in transformed[1]['keypoints']],
- dtype=torch.float32)
- target['boxes']=boxes_transformed
- target['lines']=keypoints_transformed
- return img_transformed, target
- """
- 直接读取xanlabel标注的数据集json格式
- """
- class LineDataset(BaseDataset):
- def __init__(self, dataset_path, data_type, transforms=None,augmentation=False, dataset_type=None,img_type='rgb', target_type='pixel'):
- super().__init__(dataset_path)
- self.data_path = dataset_path
- self.data_type = data_type
- print(f'data_path:{dataset_path}')
- self.transforms = transforms
- self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
- self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
- self.imgs = os.listdir(self.img_path)
- self.lbls = os.listdir(self.lbl_path)
- self.target_type = target_type
- self.img_type=img_type
- self.augmentation=augmentation
- # self.default_transform = DefaultTransform()
- def __getitem__(self, index) -> T_co:
- img_path = os.path.join(self.img_path, self.imgs[index])
- lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
- img = PIL.Image.open(img_path).convert('RGB')
- w, h = img.size
- # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
- target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
- if self.transforms:
- img, target = self.transforms(img, target)
- else:
- img = self.default_transform(img)
- # print(f'img:{img}')
- # print(f'img shape:{img.shape}')
- if self.augmentation:
- img, target=apply_transform_with_boxes_and_keypoints(img, target)
- return img, target
- def __len__(self):
- return len(self.imgs)
- def read_target(self, item, lbl_path, shape, extra=None):
- # print(f'shape:{shape}')
- # print(f'lbl_path:{lbl_path}')
- with open(lbl_path, 'r') as file:
- lable_all = json.load(file)
- objs = lable_all["shapes"]
- point_pairs=objs[0]['points']
- # print(f'point_pairs:{point_pairs}')
- target = {}
- target["image_id"] = torch.tensor(item)
- target["boxes"], lines = get_boxes_lines(objs,shape)
- # print(f'lines:{lines}')
- target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
- a = torch.full((lines.shape[0],), 2).unsqueeze(1)
- lines = torch.cat((lines, a), dim=1)
- target["lines"] = lines.to(torch.float32).view(-1,2,3)
- target["img_size"]=shape
- validate_keypoints(lines, shape[0], shape[1])
- return target
- def show(self, idx,show_type='all'):
- image, target = self.__getitem__(idx)
- cmap = plt.get_cmap("jet")
- norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
- sm.set_array([])
- # img_path = os.path.join(self.img_path, self.imgs[idx])
- img = image
- if show_type=='all':
- boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
- colors="yellow", width=1)
- keypoint_img=draw_keypoints(boxed_image,target['lines'],colors='red',width=3)
- plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
- plt.show()
- if show_type=='lines':
- keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['lines'],colors='red',width=3)
- plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
- plt.show()
- if show_type=='boxes':
- boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
- colors="yellow", width=1)
- plt.imshow(boxed_image.permute(1, 2, 0).numpy())
- plt.show()
- def show_img(self, img_path):
- pass
- def get_boxes_lines(objs,shape):
- boxes = []
- h,w=shape
- line_point_pairs = []
- for obj in objs:
- # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
- # print(f"points:{obj['points']}")
- a,b=obj['points'][0],obj['points'][1]
- line_point_pairs.append(a)
- line_point_pairs.append(b)
- xmin = max(0, (min(a[0], b[0]) - 6))
- xmax = min(w, (max(a[0], b[0]) + 6))
- ymin = max(0, (min(a[1], b[1]) - 6))
- ymax = min(h, (max(a[1], b[1]) + 6))
- boxes.append([ xmin,ymin, xmax,ymax])
- boxes=torch.tensor(boxes)
- line_point_pairs=torch.tensor(line_point_pairs)
- # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}')
- return boxes,line_point_pairs
- if __name__ == '__main__':
- path=r"\\192.168.50.222/share/rlq/datasets/0706_"
- dataset= LineDataset(dataset_path=path, dataset_type='train',augmentation=True, data_type='jpg')
- dataset.show(1,show_type='all')
|