Browse Source

添加动态数据增强功能

RenLiqiang 5 months ago
parent
commit
6cfebd9e97

+ 9 - 3
libs/vision_libs/transforms/_functional_pil.py

@@ -109,10 +109,16 @@ def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
     h, s, v = img.convert("HSV").split()
     h, s, v = img.convert("HSV").split()
 
 
     np_h = np.array(h, dtype=np.uint8)
     np_h = np.array(h, dtype=np.uint8)
-    # uint8 addition take cares of rotation across boundaries
+
+    # # uint8 addition take cares of rotation across boundaries
+    # with np.errstate(over="ignore"):
+    #     np_h += np.uint8(hue_factor * 255)
+    # h = Image.fromarray(np_h, "L")
+
+    # 使用 int16 防止溢出,然后转换回 uint8
     with np.errstate(over="ignore"):
     with np.errstate(over="ignore"):
-        np_h += np.uint8(hue_factor * 255)
-    h = Image.fromarray(np_h, "L")
+        np_h = (np_h.astype(np.int16) + int(hue_factor * 255)) % 256
+    h = Image.fromarray(np_h.astype(np.uint8), "L")
 
 
     img = Image.merge("HSV", (h, s, v)).convert(input_mode)
     img = Image.merge("HSV", (h, s, v)).convert(input_mode)
     return img
     return img

+ 3 - 1
libs/vision_libs/transforms/functional.py

@@ -8,6 +8,7 @@ from typing import Any, List, Optional, Tuple, Union
 import numpy as np
 import numpy as np
 import torch
 import torch
 from PIL import Image
 from PIL import Image
+# from PIL.Image import Image
 from torch import Tensor
 from torch import Tensor
 
 
 try:
 try:
@@ -916,7 +917,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
     return F_t.adjust_saturation(img, saturation_factor)
     return F_t.adjust_saturation(img, saturation_factor)
 
 
 
 
-def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
+def adjust_hue(img: Tensor, hue_factor: float) ->  Tensor:
     """Adjust hue of an image.
     """Adjust hue of an image.
 
 
     The image hue is adjusted by converting the image to HSV and
     The image hue is adjusted by converting the image to HSV and
@@ -947,6 +948,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
     Returns:
     Returns:
         PIL Image or Tensor: Hue adjusted image.
         PIL Image or Tensor: Hue adjusted image.
     """
     """
+    print(f'hue_factor:{hue_factor}')
     if not torch.jit.is_scripting() and not torch.jit.is_tracing():
     if not torch.jit.is_scripting() and not torch.jit.is_tracing():
         _log_api_usage_once(adjust_hue)
         _log_api_usage_once(adjust_hue)
     if not isinstance(img, torch.Tensor):
     if not isinstance(img, torch.Tensor):

+ 1 - 0
libs/vision_libs/transforms/transforms.py

@@ -1197,6 +1197,7 @@ class ColorJitter(torch.nn.Module):
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         _log_api_usage_once(self)
         _log_api_usage_once(self)
+        # print(f'hue:{hue}')
         self.brightness = self._check_input(brightness, "brightness")
         self.brightness = self._check_input(brightness, "brightness")
         self.contrast = self._check_input(contrast, "contrast")
         self.contrast = self._check_input(contrast, "contrast")
         self.saturation = self._check_input(saturation, "saturation")
         self.saturation = self._check_input(saturation, "saturation")

+ 0 - 16
models/base/high_reso_resnet.py

@@ -150,22 +150,6 @@ class ResNet(nn.Module):
         self.encoder1 = self._make_layer(block, 64, layers[0],stride=2)
         self.encoder1 = self._make_layer(block, 64, layers[0],stride=2)
         self.encoder2 = self._make_layer(block, 128, layers[1], stride=2)
         self.encoder2 = self._make_layer(block, 128, layers[1], stride=2)
         self.encoder3 = self._make_layer(block, 256, layers[2], stride=2)
         self.encoder3 = self._make_layer(block, 256, layers[2], stride=2)
-        # self.encoder4 = self._make_layer(block, 512, 3, stride=2)
-        # self.encoder5 = self._make_layer(block, 512, 3, stride=2)
-        # self.body = nn.ModuleDict({
-        #     'encoder0': self.encoder0,
-        #     'encoder1': self.encoder1,
-        #     'encoder2': self.encoder2,
-        #     'encoder3': self.encoder3,
-        #     'encoder4': self.encoder4
-        # })
-        # self.fpn = self.get_convnext_fpn(
-        #     backbone=self.body,
-        #     trainable_layers=5,
-        #     returned_layers=[0, 1, 2, 3, 4],
-        #     extra_blocks=None,
-        #     norm_layer=None
-        # )
 
 
 
 
 
 

+ 487 - 0
models/base/transforms.py

@@ -0,0 +1,487 @@
+import logging
+import random
+from typing import Any
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from torch import nn, Tensor
+
+from libs.vision_libs .transforms import functional as F
+
+from libs.vision_libs import transforms
+
+
+class Compose:
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, img, target):
+        for t in self.transforms:
+            img, target = t(img, target)
+
+
+        return img, target
+
+
+class RandomHorizontalFlip:
+    def __init__(self, prob=0.5):
+        self.prob = prob
+
+    def __call__(self, img, target):
+        if random.random() < self.prob:
+            width = img.width if isinstance(img, Image.Image) else img.shape[-1]
+
+            # Flip image
+            img = F.hflip(img)
+
+            # Flip boxes
+            boxes = target["boxes"]
+            x1, y1, x2, y2 = boxes.unbind(dim=1)
+            boxes_flipped = torch.stack((width - x2, y1, width - x1, y2), dim=1)
+            target["boxes"] = boxes_flipped
+
+            # Flip lines
+            if "lines" in target:
+                lines = target["lines"].clone()
+                # 只翻转 x 坐标,y 和 visibility 不变
+                lines[..., 0] = width - lines[..., 0]
+                target["lines"] = lines
+
+        return img, target
+
+class RandomVerticalFlip:
+    def __init__(self, prob=0.5):
+        self.prob = prob
+
+    def __call__(self, img, target):
+        if random.random() < self.prob:
+            height = img.height if isinstance(img, Image.Image) else img.shape[-2]
+
+            # Flip image
+            img = F.vflip(img)
+
+            # Flip boxes
+            boxes = target["boxes"]
+            x1, y1, x2, y2 = boxes.unbind(dim=1)
+            boxes_flipped = torch.stack((x1, height - y2, x2, height - y1), dim=1)
+            target["boxes"] = boxes_flipped
+
+            # Flip lines
+            if "lines" in target:
+                lines = target["lines"].clone()
+                lines[..., 1] = height - lines[..., 1]
+                target["lines"] = lines
+
+        return img, target
+
+
+class ColorJitter:
+    def __init__(self, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2):
+        if not (0 <= hue <= 0.5):
+            raise ValueError(f"Hue jitter value should be in [0, 0.5], but got {hue}")
+
+        self.color_jitter = transforms.ColorJitter(
+            brightness=brightness,
+            contrast=contrast,
+            saturation=saturation,
+            hue=hue
+        )
+
+    def __call__(self, img, target):
+        print(f"Original image type: {type(img)}")
+        img = self.color_jitter(img)
+        print("Color jitter applied successfully.")
+        return img, target
+
+
+class RandomGrayscale:
+    def __init__(self, p=0.1):
+        self.p = p
+
+    def __call__(self, img, target):
+        print(f"RandomGrayscale Original image type: {type(img)}")
+        if random.random() < self.p:
+            img = F.to_grayscale(img, num_output_channels=3)
+        return img, target
+
+
+class RandomResize:
+    def __init__(self, min_size, max_size=None):
+        self.min_size = min_size
+        self.max_size = max_size
+
+    def __call__(self, img, target):
+        size = random.randint(self.min_size, self.max_size) if self.max_size else self.min_size
+        w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
+        scale = size / min(h, w)
+        new_h, new_w = int(scale * h), int(scale * w)
+        img = F.resize(img, (new_h, new_w))
+
+        # Update boxes
+        boxes = target["boxes"]
+        boxes = boxes * scale
+        target["boxes"] = boxes
+
+        # Update lines
+        if "lines" in target:
+            target["lines"] = target["lines"] * torch.tensor([scale, scale, 1], device=target["lines"].device)
+
+        return img, target
+
+
+class RandomCrop:
+    def __init__(self, size):
+        self.size = size
+
+    def __call__(self, img, target):
+        w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
+        th, tw = self.size
+
+        if h <= th or w <= tw:
+            return img, target
+
+        i = random.randint(0, h - th)
+        j = random.randint(0, w - tw)
+
+        img = F.crop(img, i, j, th, tw)
+
+        # Adjust boxes
+        boxes = target["boxes"]
+        boxes = boxes - torch.tensor([j, i, j, i], device=boxes.device)
+        boxes = torch.clamp(boxes, min=0)
+        xmax, ymax = tw, th
+        boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(max=xmax)
+        boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(max=ymax)
+        target["boxes"] = boxes
+
+        # Adjust lines
+        if "lines" in target:
+            lines = target["lines"].clone()
+            lines[..., 0] -= j
+            lines[..., 1] -= i
+            lines = torch.clamp(lines, min=0)
+            lines[..., 0] = torch.clamp(lines[..., 0], max=tw)
+            lines[..., 1] = torch.clamp(lines[..., 1], max=th)
+            target["lines"] = lines
+
+        return img, target
+
+
+class GaussianBlur:
+    def __init__(self, kernel_size=5, sigma=(0.1, 2.0), prob=0.2):
+        self.kernel_size = kernel_size if kernel_size % 2 == 1 else kernel_size + 1  # Ensure kernel size is odd
+        self.sigma = sigma
+        self.prob = prob
+
+    def __call__(self, img, target):
+        if random.random() < self.prob:
+            # Convert PIL Image to Tensor if necessary
+            if isinstance(img, Image.Image):
+                img = transforms.ToTensor()(img)
+
+            # Apply Gaussian blur using PyTorch's functional interface
+            img = transforms.GaussianBlur(kernel_size=self.kernel_size, sigma=random.uniform(*self.sigma))(img)
+
+            # If the original image was a PIL Image, convert it back
+            if isinstance(img, Tensor) and not isinstance(target.get('original_image_format', None), Tensor):
+                img = transforms.ToPILImage()(img)
+
+        return img, target
+class RandomRotation:
+    def __init__(self, degrees=15, prob=0.5):
+        self.degrees = degrees
+        self.prob = prob
+
+    def rotate_boxes(self, boxes, angle, center):
+        # Convert to numpy for easier rotation math
+        boxes_np = boxes.cpu().numpy()
+        center_np = np.array(center)
+
+        corners = np.array([
+            [boxes_np[:, 0], boxes_np[:, 1]],  # top-left
+            [boxes_np[:, 2], boxes_np[:, 1]],  # top-right
+            [boxes_np[:, 2], boxes_np[:, 3]],  # bottom-right
+            [boxes_np[:, 0], boxes_np[:, 3]]  # bottom-left
+        ]).transpose(2, 0, 1)  # shape: (N, 4, 2)
+
+        # Translate to origin
+        corners -= center_np
+
+        # Rotate points
+        theta = np.radians(angle)
+        c, s = np.cos(theta), np.sin(theta)
+        R = np.array([[c, -s], [s, c]])
+        rotated_corners = corners @ R
+
+        # Translate back
+        rotated_corners += center_np
+
+        # Get new bounding box coordinates
+        x_min = np.min(rotated_corners[:, :, 0], axis=1)
+        y_min = np.min(rotated_corners[:, :, 1], axis=1)
+        x_max = np.max(rotated_corners[:, :, 0], axis=1)
+        y_max = np.max(rotated_corners[:, :, 1], axis=1)
+
+        # Convert back to tensor and move to the same device
+        device = boxes.device
+        return torch.tensor(np.stack([x_min, y_min, x_max, y_max], axis=1), dtype=boxes.dtype, device=device)
+
+    def rotate_lines(self, lines, angle, center):
+        coords = lines[..., :2]  # shape: (..., 2)
+        visibility = lines[..., 2:]  # shape: (..., N)
+
+        # Translate to origin
+        coords = coords - torch.tensor(center, dtype=coords.dtype, device=coords.device)
+
+        # Rotation matrix
+        theta = torch.deg2rad(torch.tensor(angle))
+        cos_t = torch.cos(theta)
+        sin_t = torch.sin(theta)
+        R = torch.tensor([[cos_t, -sin_t], [sin_t, cos_t]], dtype=coords.dtype, device=coords.device)
+
+        # Apply rotation using torch.matmul
+        rotated_coords = torch.matmul(coords, R)
+
+        # Translate back
+        rotated_coords = rotated_coords + torch.tensor(center, dtype=coords.dtype, device=coords.device)
+
+        # Concatenate with visibility
+        rotated_lines = torch.cat([rotated_coords, visibility], dim=-1)
+        return rotated_lines
+
+    def __call__(self, img, target):
+        if random.random() < self.prob:
+            angle = random.uniform(-self.degrees, self.degrees)
+            w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
+            center = (w / 2, h / 2)
+
+            # Rotate image
+            img = F.rotate(img, angle, center=center)
+
+            # Rotate boxes
+            if "boxes" in target:
+                target["boxes"] = self.rotate_boxes(target["boxes"], angle, center)
+
+            # Rotate lines
+            if "lines" in target:
+                target["lines"] = self.rotate_lines(target["lines"], angle, center)
+
+        return img, target
+
+
+class RandomErasing:
+    def __init__(self, prob=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.485, 0.456, 0.406]):
+        """
+        :param prob: 应用擦除的概率
+        :param sl: 擦除面积比例的下界
+        :param sh: 擦除面积比例的上界
+        :param r1: 长宽比的下界
+        :param mean: 用于填充擦除区域的像素值
+        """
+        self.prob = prob
+        self.sl = sl
+        self.sh = sh
+        self.r1 = r1
+        self.mean = mean
+
+    def __call__(self, img, target):
+        if random.random() < self.prob:
+            # 如果是Tensor,则直接处理
+            if isinstance(img, torch.Tensor):
+                img = self._erase_tensor(img)
+            # 如果是PIL Image,则转换为Tensor处理后再转回PIL Image
+            elif isinstance(img, Image.Image):
+                img_tensor = transforms.ToTensor()(img)
+                img_tensor = self._erase_tensor(img_tensor)
+                img = transforms.ToPILImage()(img_tensor)
+
+        return img, target
+
+    def _erase_tensor(self, img_tensor):
+        """
+        对Tensor类型的图像执行随机擦除
+        """
+        img_c, img_h, img_w = img_tensor.shape
+        area = img_h * img_w
+
+        # 计算擦除区域的大小
+        erase_area = random.uniform(self.sl, self.sh) * area
+        aspect_ratio = random.uniform(self.r1, 1 / self.r1)
+
+        h = int(round((erase_area * aspect_ratio) ** 0.5))
+        w = int(round((erase_area / aspect_ratio) ** 0.5))
+
+        # 确保不会超出图像边界
+        if h < img_h and w < img_w:
+            x = random.randint(0, img_w - w)
+            y = random.randint(0, img_h - h)
+
+            # 创建一个与擦除区域相同大小且填充指定均值的区域
+            mean_tensor = torch.tensor(self.mean).view(img_c, 1, 1).expand(img_c, h, w)
+
+            # 将该区域应用到原始图像上
+            img_tensor[:, y:y + h, x:x + w] = mean_tensor
+
+        return img_tensor
+
+"""
+有Bugs
+"""
+class RandomPerspective:
+    def __init__(self, distortion_scale=0.5, p=0.5):
+        self.distortion_scale = distortion_scale
+        self.p = p
+
+    def _get_perspective_params(self, width, height, distortion_scale):
+        half_w = width // 2
+        half_h = height // 2
+        w = int(width * distortion_scale)
+        h = int(height * distortion_scale)
+
+        startpoints = [
+            [0, 0],
+            [width - 1, 0],
+            [width - 1, height - 1],
+            [0, height - 1]
+        ]
+        endpoints = [
+            [random.randint(0, w), random.randint(0, h)],
+            [width - 1 - random.randint(0, w), random.randint(0, h)],
+            [width - 1 - random.randint(0, w), height - 1 - random.randint(0, h)],
+            [random.randint(0, w), height - 1 - random.randint(0, h)]
+        ]
+        return startpoints, endpoints
+
+    def perspective_boxes(self, boxes, M, width, height):
+        # 将boxes转换为角点形式
+        corners = np.array([
+            [boxes[:, 0], boxes[:, 1]],  # top-left
+            [boxes[:, 2], boxes[:, 1]],  # top-right
+            [boxes[:, 2], boxes[:, 3]],  # bottom-right
+            [boxes[:, 0], boxes[:, 3]]   # bottom-left
+        ]).transpose(2, 0, 1).reshape(-1, 2)  # shape: (N*4, 2)
+
+        # 应用透视变换
+        ones = np.ones((corners.shape[0], 1))
+        coords_homogeneous = np.hstack([corners, ones])
+        transformed_coords = (M @ coords_homogeneous.T).T
+        transformed_coords /= transformed_coords[:, 2].reshape(-1, 1)  # 齐次除法
+        transformed_coords = transformed_coords[:, :2]
+
+        # 重新组合成bounding box
+        transformed_coords = transformed_coords.reshape(-1, 4, 2)
+        x_min = np.min(transformed_coords[:, :, 0], axis=1)
+        y_min = np.min(transformed_coords[:, :, 1], axis=1)
+        x_max = np.max(transformed_coords[:, :, 0], axis=1)
+        y_max = np.max(transformed_coords[:, :, 1], axis=1)
+
+        # 裁剪到图像范围内
+        x_min = np.clip(x_min, 0, width)
+        y_min = np.clip(y_min, 0, height)
+        x_max = np.clip(x_max, 0, width)
+        y_max = np.clip(y_max, 0, height)
+
+        return torch.tensor(np.stack([x_min, y_min, x_max, y_max], axis=1), dtype=boxes.dtype, device=boxes.device)
+
+    def perspective_lines(self, lines, M, width, height):
+        # 提取坐标和可见性标志
+        coords = lines[..., :2].cpu().numpy()  # Shape: (N, L, 2)
+        visibility = lines[..., 2:]
+
+        # 确保coords是二维数组,如果它是三维的,则将其重塑为二维
+        original_shape = coords.shape
+        coords_reshaped = coords.reshape(-1, 2)  # Reshape to (N*L, 2)
+
+        # 添加齐次坐标
+        ones = np.ones((coords_reshaped.shape[0], 1))
+        coords_homogeneous = np.hstack([coords_reshaped, ones])  # Shape: (N*L, 3)
+
+        # 应用透视变换矩阵
+        transformed_coords_homogeneous = np.dot(M, coords_homogeneous.T).T
+        transformed_coords = transformed_coords_homogeneous[:, :2] / transformed_coords_homogeneous[:, 2:]  # 归一化
+
+        # 将变换后的坐标恢复到原始形状
+        transformed_coords = transformed_coords.reshape(original_shape)  # Reshape back to (N, L, 2)
+
+        # 裁剪到图像范围内
+        transformed_coords = np.clip(transformed_coords, [0, 0], [width, height])
+
+        # 转换回tensor
+        transformed_coords = torch.tensor(transformed_coords, dtype=lines.dtype, device=lines.device)
+        return torch.cat([transformed_coords, visibility], dim=-1)
+
+    def __call__(self, img, target):
+        if random.random() < self.p:
+            width, height = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
+            startpoints, endpoints = self._get_perspective_params(width, height, self.distortion_scale)
+
+            # 使用 OpenCV 计算透视变换矩阵
+            M = cv2.getPerspectiveTransform(
+                np.float32(startpoints),
+                np.float32(endpoints)
+            )
+
+            # 对图像应用透视变换
+            if isinstance(img, Image.Image):
+                img = img.transform((width, height), Image.PERSPECTIVE, M.flatten(), resample=Image.BILINEAR)
+            elif isinstance(img, torch.Tensor):
+                # 如果你需要用 TorchVision 实现,可以考虑使用 F.perspective,但更推荐配合PIL操作
+                pil_img = F.to_pil_image(img)
+                pil_img = pil_img.transform((width, height), Image.PERSPECTIVE, M.flatten(), resample=Image.BILINEAR)
+                img = F.to_tensor(pil_img)
+
+            # 对 boxes 变换
+            if "boxes" in target:
+                target["boxes"] = self.perspective_boxes(target["boxes"], M, width, height)
+
+            # 对 lines 变换
+            if "lines" in target:
+                target["lines"] = self.perspective_lines(target["lines"], M, width, height)
+
+        return img, target
+
+class DefaultTransform(nn.Module):
+    def forward(self, img: Tensor,target) -> tuple[Tensor, Any]:
+        if not isinstance(img, Tensor):
+            img = F.pil_to_tensor(img)
+        return F.convert_image_dtype(img, torch.float),target
+
+    def __repr__(self) -> str:
+        return self.__class__.__name__ + "()"
+
+    def describe(self) -> str:
+        return (
+            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
+            "The images are rescaled to ``[0.0, 1.0]``."
+        )
+
+
+class ToTensor:
+    def __call__(self, img, target):
+        img = F.to_tensor(img)
+        return img, target
+
+
+def get_transforms(augmention=True):
+    transforms_list = []
+
+    if augmention:
+
+        transforms_list.append(ColorJitter())
+        transforms_list.append(RandomGrayscale(0.1))
+
+        transforms_list.append(GaussianBlur())
+        transforms_list.append(RandomErasing())
+        transforms_list.append(RandomHorizontalFlip(0.5))
+        transforms_list.append(RandomVerticalFlip(0.2))
+        # transforms_list.append(RandomPerspective())
+        transforms_list.append(RandomRotation(degrees=15))
+        transforms_list.append(RandomResize(512, 2048))
+
+        transforms_list.append(RandomCrop((512,512)))
+
+    transforms_list.append(DefaultTransform())
+
+    return Compose(transforms_list)

+ 8 - 72
models/line_detect/line_dataset.py

@@ -3,27 +3,17 @@ from torch.utils.data.dataset import T_co
 from libs.vision_libs.utils import draw_keypoints
 from libs.vision_libs.utils import draw_keypoints
 from models.base.base_dataset import BaseDataset
 from models.base.base_dataset import BaseDataset
 
 
-import glob
 import json
 import json
-import math
 import os
 import os
-import random
-import cv2
 import PIL
 import PIL
-import imageio
-import matplotlib.pyplot as plt
 import matplotlib as mpl
 import matplotlib as mpl
 from torchvision.utils import draw_bounding_boxes
 from torchvision.utils import draw_bounding_boxes
 import torchvision.transforms.v2 as transforms
 import torchvision.transforms.v2 as transforms
-import numpy as np
-import numpy.linalg as LA
 import torch
 import torch
-from skimage import io
-from torch.utils.data import Dataset
-from torch.utils.data.dataloader import default_collate
 
 
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
-from models.dataset_tool import read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+from models.base.transforms import get_transforms
+
 
 
 def validate_keypoints(keypoints, image_width, image_height):
 def validate_keypoints(keypoints, image_width, image_height):
     for kp in keypoints:
     for kp in keypoints:
@@ -32,58 +22,6 @@ def validate_keypoints(keypoints, image_width, image_height):
             raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
             raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
 
 
 
 
-def apply_transform_with_boxes_and_keypoints(img,target):
-    """
-    对图像、边界框和关键点应用相同的变换。
-
-    :param img_path: 图像文件路径
-    :param boxes: 形状为 (N, 4) 的 Tensor,表示 N 个边界框的坐标 [x_min, y_min, x_max, y_max]
-    :param keypoints: 形状为 (N, K, 3) 的 Tensor,表示 N 个实例的 K 个关键点的坐标和可见性 [x, y, visibility]
-    :return: 变换后的图像、边界框和关键点
-    """
-
-
-    # 定义一系列用于数据增强的变换
-    data_transforms = transforms.Compose([
-        # 随机调整大小和随机裁剪
-        # transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0), antialias=True),
-
-        # 随机水平翻转
-        transforms.RandomHorizontalFlip(p=0.5),
-
-        # 颜色抖动: 改变亮度、对比度、饱和度和色调
-        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
-
-        # # 转换为张量
-        # transforms.ToTensor(),
-        #
-        # # 标准化
-        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
-        #                      std=[0.229, 0.224, 0.225])
-    ])
-
-    boxes=target['boxes']
-    keypoints=target['lines']
-    # 将边界框转换为适合传递给 transforms 的格式
-    boxes_format = [(box[0].item(), box[1].item(), box[2].item(), box[3].item()) for box in boxes]
-
-    # 将关键点转换为适合传递给 transforms 的格式
-    keypoints_format = [[(kp[0].item(), kp[1].item(), bool(kp[2].item())) for kp in keypoint] for keypoint in keypoints]
-
-    # 应用变换
-    transformed = data_transforms(img, {"boxes": boxes_format, "keypoints": keypoints_format})
-
-    # 获取变换后的图像、边界框和关键点
-    img_transformed = transformed[0]
-    boxes_transformed = torch.tensor([(box[0], box[1], box[2], box[3]) for box in transformed[1]['boxes']],
-                                     dtype=torch.float32)
-    keypoints_transformed = torch.tensor(
-        [[(kp[0], kp[1], int(kp[2])) for kp in keypoint] for keypoint in transformed[1]['keypoints']],
-        dtype=torch.float32)
-
-    target['boxes']=boxes_transformed
-    target['lines']=keypoints_transformed
-    return img_transformed, target
 
 
 """
 """
 直接读取xanlabel标注的数据集json格式
 直接读取xanlabel标注的数据集json格式
@@ -114,16 +52,13 @@ class LineDataset(BaseDataset):
         w, h = img.size
         w, h = img.size
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        if self.transforms:
-            img, target = self.transforms(img, target)
 
 
-        else:
-            img = self.default_transform(img)
 
 
-        # print(f'img:{img}')
-        # print(f'img shape:{img.shape}')
-        if self.augmentation:
-            img, target=apply_transform_with_boxes_and_keypoints(img, target)
+        self.transforms=get_transforms(augmention=self.augmentation)
+
+        img, target = self.transforms(img, target)
+
+
         return img, target
         return img, target
 
 
     def __len__(self):
     def __len__(self):
@@ -154,6 +89,7 @@ class LineDataset(BaseDataset):
         lines = torch.cat((lines, a), dim=1)
         lines = torch.cat((lines, a), dim=1)
 
 
         target["lines"] = lines.to(torch.float32).view(-1,2,3)
         target["lines"] = lines.to(torch.float32).view(-1,2,3)
+        print(f'lines:{target["lines"].shape}')
         target["img_size"]=shape
         target["img_size"]=shape
 
 
         validate_keypoints(lines, shape[0], shape[1])
         validate_keypoints(lines, shape[0], shape[1])

+ 14 - 54
models/line_detect/line_detect.py

@@ -50,7 +50,7 @@ class LineDetect(BaseDetectionNet):
     def __init__(
     def __init__(
             self,
             self,
             backbone,
             backbone,
-            num_classes=None,
+            num_classes=2,
             # transform parameters
             # transform parameters
             min_size=512,
             min_size=512,
             max_size=2048,
             max_size=2048,
@@ -85,7 +85,7 @@ class LineDetect(BaseDetectionNet):
             line_roi_pool=None,
             line_roi_pool=None,
             line_head=None,
             line_head=None,
             line_predictor=None,
             line_predictor=None,
-            num_keypoints=None,
+            num_points=3,
             **kwargs,
             **kwargs,
     ):
     ):
 
 
@@ -149,30 +149,13 @@ class LineDetect(BaseDetectionNet):
 
 
 
 
 
 
-        if not isinstance(line_roi_pool, (MultiScaleRoIAlign, type(None))):
-            raise TypeError(
-                "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
-            )
-        if min_size is None:
-            min_size = (640, 672, 704, 736, 768, 800)
-
-        if num_keypoints is not None:
-            if line_predictor is not None:
-                raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
-        else:
-            num_keypoints = 2
-
-
-        if line_roi_pool is None:
-            line_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
-
         if line_head is None:
         if line_head is None:
             keypoint_layers = tuple(1 for _ in range(8))
             keypoint_layers = tuple(1 for _ in range(8))
             line_head = LineHeads(8, keypoint_layers)
             line_head = LineHeads(8, keypoint_layers)
 
 
-        if line_predictor is None:
-            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
-            line_predictor = LinePredictor(keypoint_dim_reduced)
+        # if line_predictor is None:
+        #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
+        #     line_predictor = LinePredictor(keypoint_dim_reduced)
 
 
 
 
         self.roi_heads.line_roi_pool = line_roi_pool
         self.roi_heads.line_roi_pool = line_roi_pool
@@ -303,30 +286,7 @@ class LineHeads(nn.Sequential):
                 nn.init.constant_(m.bias, 0)
                 nn.init.constant_(m.bias, 0)
 
 
 
 
-class LinePredictor(nn.Module):
-    def __init__(self, in_channels, out_channels=1 ):
-        super().__init__()
-        input_features = in_channels
-        deconv_kernel = 4
-        self.kps_score_lowres = nn.ConvTranspose2d(
-            input_features,
-            out_channels,
-            deconv_kernel,
-            stride=2,
-            padding=deconv_kernel // 2 - 1,
-        )
-        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
-        nn.init.constant_(self.kps_score_lowres.bias, 0)
-        self.up_scale = 2
-        self.out_channels = out_channels
 
 
-    def forward(self, x):
-        print(f'before kps_score_lowres x:{x.shape}')
-        x = self.kps_score_lowres(x)
-        print(f'kps_score_lowres x:{x.shape}')
-        return torch.nn.functional.interpolate(
-            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
-        )
 
 
 def linedetect_newresnet18fpn(
 def linedetect_newresnet18fpn(
         *,
         *,
@@ -339,9 +299,9 @@ def linedetect_newresnet18fpn(
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     if num_classes is None:
     if num_classes is None:
-        num_classes = 2
+        num_classes = 3
     if num_points is None:
     if num_points is None:
-        num_points = 2
+        num_points = 3
 
 
 
 
     backbone =resnet18fpn()
     backbone =resnet18fpn()
@@ -361,7 +321,7 @@ def linedetect_newresnet18fpn(
 
 
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
 
 
-    model = LineDetect(backbone, num_classes, num_keypoints=num_points,rpn_anchor_generator=anchor_generator,box_roi_pool=roi_pooler, **kwargs)
+    model = LineDetect(backbone, num_classes, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
 
 
     return model
     return model
 
 
@@ -375,12 +335,12 @@ def linedetect_resnet18_fpn(
 ) -> LineDetect:
 ) -> LineDetect:
 
 
     if num_classes is None:
     if num_classes is None:
-        num_classes = 2
+        num_classes = 3
     if num_points is None:
     if num_points is None:
-        num_points = 2
+        num_points = 3
 
 
     backbone = resnet_fpn_backbone(backbone_name='resnet18',weights=None)
     backbone = resnet_fpn_backbone(backbone_name='resnet18',weights=None)
-    model = LineDetect(backbone, num_classes, num_keypoints=num_points, **kwargs)
+    model = LineDetect(backbone, num_classes, num_points=num_points, **kwargs)
 
 
     return model
     return model
 
 
@@ -391,12 +351,12 @@ def linedetect_resnet50_fpn(
         **kwargs: Any,
         **kwargs: Any,
 ) -> LineDetect:
 ) -> LineDetect:
     if num_classes is None:
     if num_classes is None:
-        num_classes = 2
+        num_classes = 3
     if num_points is None:
     if num_points is None:
-        num_points = 2
+        num_points = 3
 
 
     backbone = resnet_fpn_backbone(backbone_name='resnet18', weights=None)
     backbone = resnet_fpn_backbone(backbone_name='resnet18', weights=None)
-    model = LineDetect(backbone, num_classes, num_keypoints=num_points, **kwargs)
+    model = LineDetect(backbone, num_classes, num_points=num_points, **kwargs)
 
 
 
 
     return model
     return model

+ 8 - 14
models/line_detect/loi_heads.py

@@ -191,15 +191,7 @@ def line_points_to_heatmap(keypoints, rois, heatmap_size):
     # type: (Tensor, Tensor, int) -> Tensor
     # type: (Tensor, Tensor, int) -> Tensor
     print(f'rois:{rois.shape}')
     print(f'rois:{rois.shape}')
     print(f'heatmap_size:{heatmap_size}')
     print(f'heatmap_size:{heatmap_size}')
-    # offset_x = rois[:, 0]
-    # offset_y = rois[:, 1]
-    # scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
-    # scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
-    #
-    # offset_x = offset_x[:, None]
-    # offset_y = offset_y[:, None]
-    # scale_x = scale_x[:, None]
-    # scale_y = scale_y[:, None]
+
 
 
     print(f'keypoints.shape:{keypoints.shape}')
     print(f'keypoints.shape:{keypoints.shape}')
     # batch_size, num_keypoints, _ = keypoints.shape
     # batch_size, num_keypoints, _ = keypoints.shape
@@ -1070,12 +1062,12 @@ class RoIHeads(nn.Module):
         return True
         return True
 
 
     def has_line(self):
     def has_line(self):
-        if self.line_roi_pool is None:
-            return False
+        # if self.line_roi_pool is None:
+        #     return False
         if self.line_head is None:
         if self.line_head is None:
             return False
             return False
-        if self.line_predictor is None:
-            return False
+        # if self.line_predictor is None:
+        #     return False
         return True
         return True
 
 
     def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
     def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
@@ -1351,12 +1343,14 @@ class RoIHeads(nn.Module):
             # line_features = self.line_roi_pool(features, line_proposals, image_shapes)
             # line_features = self.line_roi_pool(features, line_proposals, image_shapes)
 
 
             # print(f'line_features from line_roi_pool:{line_features.shape}')
             # print(f'line_features from line_roi_pool:{line_features.shape}')
-
+            #(b,256,512,512)
             line_features = self.channel_compress(features['0'])
             line_features = self.channel_compress(features['0'])
+            #(b.8,512,512)
 
 
             line_features = lines_features_align(line_features, line_proposals, image_shapes)
             line_features = lines_features_align(line_features, line_proposals, image_shapes)
 
 
             line_features = self.line_head(line_features)
             line_features = self.line_head(line_features)
+            #(N,1,512,512)
             print(f'line_features from line_head:{line_features.shape}')
             print(f'line_features from line_head:{line_features.shape}')
             # line_logits = self.line_predictor(line_features)
             # line_logits = self.line_predictor(line_features)
 
 

+ 2 - 2
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
 io:
   logdir: train_results
   logdir: train_results
-  datadir: /data/share/zyh/202507/a_dataset
+  datadir: \\192.168.50.222/share/rlq/datasets/0706_
   data_type: rgb
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 #  datadir: I:\datasets\wirenet_1000
@@ -11,7 +11,7 @@ io:
 train_params:
 train_params:
   resume_from:
   resume_from:
   num_workers: 8
   num_workers: 8
-  batch_size: 2
+  batch_size: 1
   max_epoch: 80000
   max_epoch: 80000
   augmentation: True
   augmentation: True
   optim:
   optim:

+ 1 - 1
models/line_detect/train_demo.py

@@ -16,6 +16,6 @@ if __name__ == '__main__':
     # model = lineDetect_resnet18_fpn()
     # model = lineDetect_resnet18_fpn()
 
 
     # model=linedetect_resnet18_fpn()
     # model=linedetect_resnet18_fpn()
-    model=linedetect_newresnet18fpn()
+    model=linedetect_newresnet18fpn(num_points=2)
 
 
     model.start_train(cfg='train.yaml')
     model.start_train(cfg='train.yaml')

+ 1 - 1
models/line_detect/trainer.py

@@ -12,7 +12,7 @@ from libs.vision_libs.utils import draw_bounding_boxes, draw_keypoints
 from models.base.base_model import BaseModel
 from models.base.base_model import BaseModel
 from models.base.base_trainer import BaseTrainer
 from models.base.base_trainer import BaseTrainer
 from models.config.config_tool import read_yaml
 from models.config.config_tool import read_yaml
-from models.line_detect.line_dataset_old import LineDataset
+from models.line_detect.line_dataset import LineDataset
 
 
 from models.line_net.dataset_LD import WirePointDataset
 from models.line_net.dataset_LD import WirePointDataset
 from models.wirenet.postprocess import postprocess
 from models.wirenet.postprocess import postprocess