|
@@ -14,7 +14,7 @@ import imageio
|
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.pyplot as plt
|
|
|
import matplotlib as mpl
|
|
import matplotlib as mpl
|
|
|
from torchvision.utils import draw_bounding_boxes
|
|
from torchvision.utils import draw_bounding_boxes
|
|
|
-
|
|
|
|
|
|
|
+import torchvision.transforms.v2 as transforms
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import numpy.linalg as LA
|
|
import numpy.linalg as LA
|
|
|
import torch
|
|
import torch
|
|
@@ -31,12 +31,66 @@ def validate_keypoints(keypoints, image_width, image_height):
|
|
|
if not (0 <= x < image_width and 0 <= y < image_height):
|
|
if not (0 <= x < image_width and 0 <= y < image_height):
|
|
|
raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
|
|
raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
|
|
|
|
|
|
|
|
|
|
+
|
|
|
|
|
+def apply_transform_with_boxes_and_keypoints(img,target):
|
|
|
|
|
+ """
|
|
|
|
|
+ 对图像、边界框和关键点应用相同的变换。
|
|
|
|
|
+
|
|
|
|
|
+ :param img_path: 图像文件路径
|
|
|
|
|
+ :param boxes: 形状为 (N, 4) 的 Tensor,表示 N 个边界框的坐标 [x_min, y_min, x_max, y_max]
|
|
|
|
|
+ :param keypoints: 形状为 (N, K, 3) 的 Tensor,表示 N 个实例的 K 个关键点的坐标和可见性 [x, y, visibility]
|
|
|
|
|
+ :return: 变换后的图像、边界框和关键点
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ # 定义一系列用于数据增强的变换
|
|
|
|
|
+ data_transforms = transforms.Compose([
|
|
|
|
|
+ # 随机调整大小和随机裁剪
|
|
|
|
|
+ transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0), antialias=True),
|
|
|
|
|
+
|
|
|
|
|
+ # 随机水平翻转
|
|
|
|
|
+ transforms.RandomHorizontalFlip(p=0.5),
|
|
|
|
|
+
|
|
|
|
|
+ # 颜色抖动: 改变亮度、对比度、饱和度和色调
|
|
|
|
|
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
|
|
|
|
+
|
|
|
|
|
+ # # 转换为张量
|
|
|
|
|
+ # transforms.ToTensor(),
|
|
|
|
|
+ #
|
|
|
|
|
+ # # 标准化
|
|
|
|
|
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
|
|
+ # std=[0.229, 0.224, 0.225])
|
|
|
|
|
+ ])
|
|
|
|
|
+
|
|
|
|
|
+ boxes=target['boxes']
|
|
|
|
|
+ keypoints=target['lines']
|
|
|
|
|
+ # 将边界框转换为适合传递给 transforms 的格式
|
|
|
|
|
+ boxes_format = [(box[0].item(), box[1].item(), box[2].item(), box[3].item()) for box in boxes]
|
|
|
|
|
+
|
|
|
|
|
+ # 将关键点转换为适合传递给 transforms 的格式
|
|
|
|
|
+ keypoints_format = [[(kp[0].item(), kp[1].item(), bool(kp[2].item())) for kp in keypoint] for keypoint in keypoints]
|
|
|
|
|
+
|
|
|
|
|
+ # 应用变换
|
|
|
|
|
+ transformed = data_transforms(img, {"boxes": boxes_format, "keypoints": keypoints_format})
|
|
|
|
|
+
|
|
|
|
|
+ # 获取变换后的图像、边界框和关键点
|
|
|
|
|
+ img_transformed = transformed[0]
|
|
|
|
|
+ boxes_transformed = torch.tensor([(box[0], box[1], box[2], box[3]) for box in transformed[1]['boxes']],
|
|
|
|
|
+ dtype=torch.float32)
|
|
|
|
|
+ keypoints_transformed = torch.tensor(
|
|
|
|
|
+ [[(kp[0], kp[1], int(kp[2])) for kp in keypoint] for keypoint in transformed[1]['keypoints']],
|
|
|
|
|
+ dtype=torch.float32)
|
|
|
|
|
+
|
|
|
|
|
+ target['boxes']=boxes_transformed
|
|
|
|
|
+ target['lines']=keypoints_transformed
|
|
|
|
|
+ return img_transformed, target
|
|
|
|
|
+
|
|
|
"""
|
|
"""
|
|
|
直接读取xanlabel标注的数据集json格式
|
|
直接读取xanlabel标注的数据集json格式
|
|
|
|
|
|
|
|
"""
|
|
"""
|
|
|
class LineDataset(BaseDataset):
|
|
class LineDataset(BaseDataset):
|
|
|
- def __init__(self, dataset_path, data_type, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
|
|
|
|
|
|
|
+ def __init__(self, dataset_path, data_type, transforms=None,augmentation=False, dataset_type=None,img_type='rgb', target_type='pixel'):
|
|
|
super().__init__(dataset_path)
|
|
super().__init__(dataset_path)
|
|
|
|
|
|
|
|
self.data_path = dataset_path
|
|
self.data_path = dataset_path
|
|
@@ -49,6 +103,7 @@ class LineDataset(BaseDataset):
|
|
|
self.lbls = os.listdir(self.lbl_path)
|
|
self.lbls = os.listdir(self.lbl_path)
|
|
|
self.target_type = target_type
|
|
self.target_type = target_type
|
|
|
self.img_type=img_type
|
|
self.img_type=img_type
|
|
|
|
|
+ self.augmentation=augmentation
|
|
|
# self.default_transform = DefaultTransform()
|
|
# self.default_transform = DefaultTransform()
|
|
|
|
|
|
|
|
def __getitem__(self, index) -> T_co:
|
|
def __getitem__(self, index) -> T_co:
|
|
@@ -67,6 +122,8 @@ class LineDataset(BaseDataset):
|
|
|
|
|
|
|
|
# print(f'img:{img}')
|
|
# print(f'img:{img}')
|
|
|
# print(f'img shape:{img.shape}')
|
|
# print(f'img shape:{img.shape}')
|
|
|
|
|
+ if self.augmentation:
|
|
|
|
|
+ img, target=apply_transform_with_boxes_and_keypoints(img, target)
|
|
|
return img, target
|
|
return img, target
|
|
|
|
|
|
|
|
def __len__(self):
|
|
def __len__(self):
|