Explorar o código

linedataset 添加数据增强

RenLiqiang hai 5 meses
pai
achega
6806148273

+ 59 - 2
models/line_detect/line_dataset.py

@@ -14,7 +14,7 @@ import imageio
 import matplotlib.pyplot as plt
 import matplotlib as mpl
 from torchvision.utils import draw_bounding_boxes
-
+import torchvision.transforms.v2 as transforms
 import numpy as np
 import numpy.linalg as LA
 import torch
@@ -31,12 +31,66 @@ def validate_keypoints(keypoints, image_width, image_height):
         if not (0 <= x < image_width and 0 <= y < image_height):
             raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
 
+
+def apply_transform_with_boxes_and_keypoints(img,target):
+    """
+    对图像、边界框和关键点应用相同的变换。
+
+    :param img_path: 图像文件路径
+    :param boxes: 形状为 (N, 4) 的 Tensor,表示 N 个边界框的坐标 [x_min, y_min, x_max, y_max]
+    :param keypoints: 形状为 (N, K, 3) 的 Tensor,表示 N 个实例的 K 个关键点的坐标和可见性 [x, y, visibility]
+    :return: 变换后的图像、边界框和关键点
+    """
+
+
+    # 定义一系列用于数据增强的变换
+    data_transforms = transforms.Compose([
+        # 随机调整大小和随机裁剪
+        transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0), antialias=True),
+
+        # 随机水平翻转
+        transforms.RandomHorizontalFlip(p=0.5),
+
+        # 颜色抖动: 改变亮度、对比度、饱和度和色调
+        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
+
+        # # 转换为张量
+        # transforms.ToTensor(),
+        #
+        # # 标准化
+        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
+        #                      std=[0.229, 0.224, 0.225])
+    ])
+
+    boxes=target['boxes']
+    keypoints=target['lines']
+    # 将边界框转换为适合传递给 transforms 的格式
+    boxes_format = [(box[0].item(), box[1].item(), box[2].item(), box[3].item()) for box in boxes]
+
+    # 将关键点转换为适合传递给 transforms 的格式
+    keypoints_format = [[(kp[0].item(), kp[1].item(), bool(kp[2].item())) for kp in keypoint] for keypoint in keypoints]
+
+    # 应用变换
+    transformed = data_transforms(img, {"boxes": boxes_format, "keypoints": keypoints_format})
+
+    # 获取变换后的图像、边界框和关键点
+    img_transformed = transformed[0]
+    boxes_transformed = torch.tensor([(box[0], box[1], box[2], box[3]) for box in transformed[1]['boxes']],
+                                     dtype=torch.float32)
+    keypoints_transformed = torch.tensor(
+        [[(kp[0], kp[1], int(kp[2])) for kp in keypoint] for keypoint in transformed[1]['keypoints']],
+        dtype=torch.float32)
+
+    target['boxes']=boxes_transformed
+    target['lines']=keypoints_transformed
+    return img_transformed, target
+
 """
 直接读取xanlabel标注的数据集json格式
 
 """
 class LineDataset(BaseDataset):
-    def __init__(self, dataset_path, data_type, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
+    def __init__(self, dataset_path, data_type, transforms=None,augmentation=False, dataset_type=None,img_type='rgb', target_type='pixel'):
         super().__init__(dataset_path)
 
         self.data_path = dataset_path
@@ -49,6 +103,7 @@ class LineDataset(BaseDataset):
         self.lbls = os.listdir(self.lbl_path)
         self.target_type = target_type
         self.img_type=img_type
+        self.augmentation=augmentation
         # self.default_transform = DefaultTransform()
 
     def __getitem__(self, index) -> T_co:
@@ -67,6 +122,8 @@ class LineDataset(BaseDataset):
 
         # print(f'img:{img}')
         # print(f'img shape:{img.shape}')
+        if self.augmentation:
+            img, target=apply_transform_with_boxes_and_keypoints(img, target)
         return img, target
 
     def __len__(self):

+ 1 - 1
models/line_detect/loi_heads.py

@@ -624,7 +624,7 @@ def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511):
 
     if not losses:  # 如果损失列表为空,则返回默认值或抛出自定义异常
         print("Warning: No valid losses were computed.")
-        return torch.tensor(1, requires_grad=True).to(x.device)  # 返回一个标量张量
+        return torch.tensor(1.0, requires_grad=True).to(x.device)  # 返回一个标量张量
 
     total_loss = torch.mean(torch.cat(losses))
     return total_loss

+ 1 - 0
models/line_detect/train.yaml

@@ -13,6 +13,7 @@ train_params:
   num_workers: 8
   batch_size: 4
   max_epoch: 80000
+  augmentation: True
   optim:
     name: Adam
     lr: 4.0e-4

+ 4 - 3
models/line_detect/trainer.py

@@ -105,6 +105,7 @@ class Trainer(BaseTrainer):
             self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth')
             self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth')
             self.max_epoch = kwargs['train_params']['max_epoch']
+            self.augmentation= kwargs['train_params']["augmentation"]
 
     def move_to_device(self, data, device):
         if isinstance(data, (list, tuple)):
@@ -243,8 +244,8 @@ class Trainer(BaseTrainer):
 
         self.init_params(**kwargs)
 
-        dataset_train = LineDataset(dataset_path=self.dataset_path, data_type=self.data_type, dataset_type='train')
-        dataset_val = LineDataset(dataset_path=self.dataset_path, data_type=self.data_type, dataset_type='val')
+        dataset_train = LineDataset(dataset_path=self.dataset_path,augmentation=self.augmentation, data_type=self.data_type, dataset_type='train')
+        dataset_val = LineDataset(dataset_path=self.dataset_path, augmentation=self.augmentation, data_type=self.data_type, dataset_type='val')
 
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
         val_sampler = torch.utils.data.RandomSampler(dataset_val)
@@ -254,7 +255,7 @@ class Trainer(BaseTrainer):
         val_collate_fn = utils.collate_fn
 
         data_loader_train = torch.utils.data.DataLoader(
-            dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
+            dataset_train, batch_sampler=train_batch_sampler,  num_workers=self.num_workers, collate_fn=train_collate_fn
         )
         data_loader_val = torch.utils.data.DataLoader(
             dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn