瀏覽代碼

添加动态数据增强功能

RenLiqiang 5 月之前
父節點
當前提交
6cfebd9e97

+ 9 - 3
libs/vision_libs/transforms/_functional_pil.py

@@ -109,10 +109,16 @@ def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
     h, s, v = img.convert("HSV").split()
 
     np_h = np.array(h, dtype=np.uint8)
-    # uint8 addition take cares of rotation across boundaries
+
+    # # uint8 addition take cares of rotation across boundaries
+    # with np.errstate(over="ignore"):
+    #     np_h += np.uint8(hue_factor * 255)
+    # h = Image.fromarray(np_h, "L")
+
+    # 使用 int16 防止溢出,然后转换回 uint8
     with np.errstate(over="ignore"):
-        np_h += np.uint8(hue_factor * 255)
-    h = Image.fromarray(np_h, "L")
+        np_h = (np_h.astype(np.int16) + int(hue_factor * 255)) % 256
+    h = Image.fromarray(np_h.astype(np.uint8), "L")
 
     img = Image.merge("HSV", (h, s, v)).convert(input_mode)
     return img

+ 3 - 1
libs/vision_libs/transforms/functional.py

@@ -8,6 +8,7 @@ from typing import Any, List, Optional, Tuple, Union
 import numpy as np
 import torch
 from PIL import Image
+# from PIL.Image import Image
 from torch import Tensor
 
 try:
@@ -916,7 +917,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
     return F_t.adjust_saturation(img, saturation_factor)
 
 
-def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
+def adjust_hue(img: Tensor, hue_factor: float) ->  Tensor:
     """Adjust hue of an image.
 
     The image hue is adjusted by converting the image to HSV and
@@ -947,6 +948,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
     Returns:
         PIL Image or Tensor: Hue adjusted image.
     """
+    print(f'hue_factor:{hue_factor}')
     if not torch.jit.is_scripting() and not torch.jit.is_tracing():
         _log_api_usage_once(adjust_hue)
     if not isinstance(img, torch.Tensor):

+ 1 - 0
libs/vision_libs/transforms/transforms.py

@@ -1197,6 +1197,7 @@ class ColorJitter(torch.nn.Module):
     ) -> None:
         super().__init__()
         _log_api_usage_once(self)
+        # print(f'hue:{hue}')
         self.brightness = self._check_input(brightness, "brightness")
         self.contrast = self._check_input(contrast, "contrast")
         self.saturation = self._check_input(saturation, "saturation")

+ 0 - 16
models/base/high_reso_resnet.py

@@ -150,22 +150,6 @@ class ResNet(nn.Module):
         self.encoder1 = self._make_layer(block, 64, layers[0],stride=2)
         self.encoder2 = self._make_layer(block, 128, layers[1], stride=2)
         self.encoder3 = self._make_layer(block, 256, layers[2], stride=2)
-        # self.encoder4 = self._make_layer(block, 512, 3, stride=2)
-        # self.encoder5 = self._make_layer(block, 512, 3, stride=2)
-        # self.body = nn.ModuleDict({
-        #     'encoder0': self.encoder0,
-        #     'encoder1': self.encoder1,
-        #     'encoder2': self.encoder2,
-        #     'encoder3': self.encoder3,
-        #     'encoder4': self.encoder4
-        # })
-        # self.fpn = self.get_convnext_fpn(
-        #     backbone=self.body,
-        #     trainable_layers=5,
-        #     returned_layers=[0, 1, 2, 3, 4],
-        #     extra_blocks=None,
-        #     norm_layer=None
-        # )
 
 
 

+ 487 - 0
models/base/transforms.py

@@ -0,0 +1,487 @@
+import logging
+import random
+from typing import Any
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from torch import nn, Tensor
+
+from libs.vision_libs .transforms import functional as F
+
+from libs.vision_libs import transforms
+
+
+class Compose:
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, img, target):
+        for t in self.transforms:
+            img, target = t(img, target)
+
+
+        return img, target
+
+
+class RandomHorizontalFlip:
+    def __init__(self, prob=0.5):
+        self.prob = prob
+
+    def __call__(self, img, target):
+        if random.random() < self.prob:
+            width = img.width if isinstance(img, Image.Image) else img.shape[-1]
+
+            # Flip image
+            img = F.hflip(img)
+
+            # Flip boxes
+            boxes = target["boxes"]
+            x1, y1, x2, y2 = boxes.unbind(dim=1)
+            boxes_flipped = torch.stack((width - x2, y1, width - x1, y2), dim=1)
+            target["boxes"] = boxes_flipped
+
+            # Flip lines
+            if "lines" in target:
+                lines = target["lines"].clone()
+                # 只翻转 x 坐标,y 和 visibility 不变
+                lines[..., 0] = width - lines[..., 0]
+                target["lines"] = lines
+
+        return img, target
+
+class RandomVerticalFlip:
+    def __init__(self, prob=0.5):
+        self.prob = prob
+
+    def __call__(self, img, target):
+        if random.random() < self.prob:
+            height = img.height if isinstance(img, Image.Image) else img.shape[-2]
+
+            # Flip image
+            img = F.vflip(img)
+
+            # Flip boxes
+            boxes = target["boxes"]
+            x1, y1, x2, y2 = boxes.unbind(dim=1)
+            boxes_flipped = torch.stack((x1, height - y2, x2, height - y1), dim=1)
+            target["boxes"] = boxes_flipped
+
+            # Flip lines
+            if "lines" in target:
+                lines = target["lines"].clone()
+                lines[..., 1] = height - lines[..., 1]
+                target["lines"] = lines
+
+        return img, target
+
+
+class ColorJitter:
+    def __init__(self, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2):
+        if not (0 <= hue <= 0.5):
+            raise ValueError(f"Hue jitter value should be in [0, 0.5], but got {hue}")
+
+        self.color_jitter = transforms.ColorJitter(
+            brightness=brightness,
+            contrast=contrast,
+            saturation=saturation,
+            hue=hue
+        )
+
+    def __call__(self, img, target):
+        print(f"Original image type: {type(img)}")
+        img = self.color_jitter(img)
+        print("Color jitter applied successfully.")
+        return img, target
+
+
+class RandomGrayscale:
+    def __init__(self, p=0.1):
+        self.p = p
+
+    def __call__(self, img, target):
+        print(f"RandomGrayscale Original image type: {type(img)}")
+        if random.random() < self.p:
+            img = F.to_grayscale(img, num_output_channels=3)
+        return img, target
+
+
+class RandomResize:
+    def __init__(self, min_size, max_size=None):
+        self.min_size = min_size
+        self.max_size = max_size
+
+    def __call__(self, img, target):
+        size = random.randint(self.min_size, self.max_size) if self.max_size else self.min_size
+        w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
+        scale = size / min(h, w)
+        new_h, new_w = int(scale * h), int(scale * w)
+        img = F.resize(img, (new_h, new_w))
+
+        # Update boxes
+        boxes = target["boxes"]
+        boxes = boxes * scale
+        target["boxes"] = boxes
+
+        # Update lines
+        if "lines" in target:
+            target["lines"] = target["lines"] * torch.tensor([scale, scale, 1], device=target["lines"].device)
+
+        return img, target
+
+
+class RandomCrop:
+    def __init__(self, size):
+        self.size = size
+
+    def __call__(self, img, target):
+        w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
+        th, tw = self.size
+
+        if h <= th or w <= tw:
+            return img, target
+
+        i = random.randint(0, h - th)
+        j = random.randint(0, w - tw)
+
+        img = F.crop(img, i, j, th, tw)
+
+        # Adjust boxes
+        boxes = target["boxes"]
+        boxes = boxes - torch.tensor([j, i, j, i], device=boxes.device)
+        boxes = torch.clamp(boxes, min=0)
+        xmax, ymax = tw, th
+        boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(max=xmax)
+        boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(max=ymax)
+        target["boxes"] = boxes
+
+        # Adjust lines
+        if "lines" in target:
+            lines = target["lines"].clone()
+            lines[..., 0] -= j
+            lines[..., 1] -= i
+            lines = torch.clamp(lines, min=0)
+            lines[..., 0] = torch.clamp(lines[..., 0], max=tw)
+            lines[..., 1] = torch.clamp(lines[..., 1], max=th)
+            target["lines"] = lines
+
+        return img, target
+
+
+class GaussianBlur:
+    def __init__(self, kernel_size=5, sigma=(0.1, 2.0), prob=0.2):
+        self.kernel_size = kernel_size if kernel_size % 2 == 1 else kernel_size + 1  # Ensure kernel size is odd
+        self.sigma = sigma
+        self.prob = prob
+
+    def __call__(self, img, target):
+        if random.random() < self.prob:
+            # Convert PIL Image to Tensor if necessary
+            if isinstance(img, Image.Image):
+                img = transforms.ToTensor()(img)
+
+            # Apply Gaussian blur using PyTorch's functional interface
+            img = transforms.GaussianBlur(kernel_size=self.kernel_size, sigma=random.uniform(*self.sigma))(img)
+
+            # If the original image was a PIL Image, convert it back
+            if isinstance(img, Tensor) and not isinstance(target.get('original_image_format', None), Tensor):
+                img = transforms.ToPILImage()(img)
+
+        return img, target
+class RandomRotation:
+    def __init__(self, degrees=15, prob=0.5):
+        self.degrees = degrees
+        self.prob = prob
+
+    def rotate_boxes(self, boxes, angle, center):
+        # Convert to numpy for easier rotation math
+        boxes_np = boxes.cpu().numpy()
+        center_np = np.array(center)
+
+        corners = np.array([
+            [boxes_np[:, 0], boxes_np[:, 1]],  # top-left
+            [boxes_np[:, 2], boxes_np[:, 1]],  # top-right
+            [boxes_np[:, 2], boxes_np[:, 3]],  # bottom-right
+            [boxes_np[:, 0], boxes_np[:, 3]]  # bottom-left
+        ]).transpose(2, 0, 1)  # shape: (N, 4, 2)
+
+        # Translate to origin
+        corners -= center_np
+
+        # Rotate points
+        theta = np.radians(angle)
+        c, s = np.cos(theta), np.sin(theta)
+        R = np.array([[c, -s], [s, c]])
+        rotated_corners = corners @ R
+
+        # Translate back
+        rotated_corners += center_np
+
+        # Get new bounding box coordinates
+        x_min = np.min(rotated_corners[:, :, 0], axis=1)
+        y_min = np.min(rotated_corners[:, :, 1], axis=1)
+        x_max = np.max(rotated_corners[:, :, 0], axis=1)
+        y_max = np.max(rotated_corners[:, :, 1], axis=1)
+
+        # Convert back to tensor and move to the same device
+        device = boxes.device
+        return torch.tensor(np.stack([x_min, y_min, x_max, y_max], axis=1), dtype=boxes.dtype, device=device)
+
+    def rotate_lines(self, lines, angle, center):
+        coords = lines[..., :2]  # shape: (..., 2)
+        visibility = lines[..., 2:]  # shape: (..., N)
+
+        # Translate to origin
+        coords = coords - torch.tensor(center, dtype=coords.dtype, device=coords.device)
+
+        # Rotation matrix
+        theta = torch.deg2rad(torch.tensor(angle))
+        cos_t = torch.cos(theta)
+        sin_t = torch.sin(theta)
+        R = torch.tensor([[cos_t, -sin_t], [sin_t, cos_t]], dtype=coords.dtype, device=coords.device)
+
+        # Apply rotation using torch.matmul
+        rotated_coords = torch.matmul(coords, R)
+
+        # Translate back
+        rotated_coords = rotated_coords + torch.tensor(center, dtype=coords.dtype, device=coords.device)
+
+        # Concatenate with visibility
+        rotated_lines = torch.cat([rotated_coords, visibility], dim=-1)
+        return rotated_lines
+
+    def __call__(self, img, target):
+        if random.random() < self.prob:
+            angle = random.uniform(-self.degrees, self.degrees)
+            w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
+            center = (w / 2, h / 2)
+
+            # Rotate image
+            img = F.rotate(img, angle, center=center)
+
+            # Rotate boxes
+            if "boxes" in target:
+                target["boxes"] = self.rotate_boxes(target["boxes"], angle, center)
+
+            # Rotate lines
+            if "lines" in target:
+                target["lines"] = self.rotate_lines(target["lines"], angle, center)
+
+        return img, target
+
+
+class RandomErasing:
+    def __init__(self, prob=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.485, 0.456, 0.406]):
+        """
+        :param prob: 应用擦除的概率
+        :param sl: 擦除面积比例的下界
+        :param sh: 擦除面积比例的上界
+        :param r1: 长宽比的下界
+        :param mean: 用于填充擦除区域的像素值
+        """
+        self.prob = prob
+        self.sl = sl
+        self.sh = sh
+        self.r1 = r1
+        self.mean = mean
+
+    def __call__(self, img, target):
+        if random.random() < self.prob:
+            # 如果是Tensor,则直接处理
+            if isinstance(img, torch.Tensor):
+                img = self._erase_tensor(img)
+            # 如果是PIL Image,则转换为Tensor处理后再转回PIL Image
+            elif isinstance(img, Image.Image):
+                img_tensor = transforms.ToTensor()(img)
+                img_tensor = self._erase_tensor(img_tensor)
+                img = transforms.ToPILImage()(img_tensor)
+
+        return img, target
+
+    def _erase_tensor(self, img_tensor):
+        """
+        对Tensor类型的图像执行随机擦除
+        """
+        img_c, img_h, img_w = img_tensor.shape
+        area = img_h * img_w
+
+        # 计算擦除区域的大小
+        erase_area = random.uniform(self.sl, self.sh) * area
+        aspect_ratio = random.uniform(self.r1, 1 / self.r1)
+
+        h = int(round((erase_area * aspect_ratio) ** 0.5))
+        w = int(round((erase_area / aspect_ratio) ** 0.5))
+
+        # 确保不会超出图像边界
+        if h < img_h and w < img_w:
+            x = random.randint(0, img_w - w)
+            y = random.randint(0, img_h - h)
+
+            # 创建一个与擦除区域相同大小且填充指定均值的区域
+            mean_tensor = torch.tensor(self.mean).view(img_c, 1, 1).expand(img_c, h, w)
+
+            # 将该区域应用到原始图像上
+            img_tensor[:, y:y + h, x:x + w] = mean_tensor
+
+        return img_tensor
+
+"""
+有Bugs
+"""
+class RandomPerspective:
+    def __init__(self, distortion_scale=0.5, p=0.5):
+        self.distortion_scale = distortion_scale
+        self.p = p
+
+    def _get_perspective_params(self, width, height, distortion_scale):
+        half_w = width // 2
+        half_h = height // 2
+        w = int(width * distortion_scale)
+        h = int(height * distortion_scale)
+
+        startpoints = [
+            [0, 0],
+            [width - 1, 0],
+            [width - 1, height - 1],
+            [0, height - 1]
+        ]
+        endpoints = [
+            [random.randint(0, w), random.randint(0, h)],
+            [width - 1 - random.randint(0, w), random.randint(0, h)],
+            [width - 1 - random.randint(0, w), height - 1 - random.randint(0, h)],
+            [random.randint(0, w), height - 1 - random.randint(0, h)]
+        ]
+        return startpoints, endpoints
+
+    def perspective_boxes(self, boxes, M, width, height):
+        # 将boxes转换为角点形式
+        corners = np.array([
+            [boxes[:, 0], boxes[:, 1]],  # top-left
+            [boxes[:, 2], boxes[:, 1]],  # top-right
+            [boxes[:, 2], boxes[:, 3]],  # bottom-right
+            [boxes[:, 0], boxes[:, 3]]   # bottom-left
+        ]).transpose(2, 0, 1).reshape(-1, 2)  # shape: (N*4, 2)
+
+        # 应用透视变换
+        ones = np.ones((corners.shape[0], 1))
+        coords_homogeneous = np.hstack([corners, ones])
+        transformed_coords = (M @ coords_homogeneous.T).T
+        transformed_coords /= transformed_coords[:, 2].reshape(-1, 1)  # 齐次除法
+        transformed_coords = transformed_coords[:, :2]
+
+        # 重新组合成bounding box
+        transformed_coords = transformed_coords.reshape(-1, 4, 2)
+        x_min = np.min(transformed_coords[:, :, 0], axis=1)
+        y_min = np.min(transformed_coords[:, :, 1], axis=1)
+        x_max = np.max(transformed_coords[:, :, 0], axis=1)
+        y_max = np.max(transformed_coords[:, :, 1], axis=1)
+
+        # 裁剪到图像范围内
+        x_min = np.clip(x_min, 0, width)
+        y_min = np.clip(y_min, 0, height)
+        x_max = np.clip(x_max, 0, width)
+        y_max = np.clip(y_max, 0, height)
+
+        return torch.tensor(np.stack([x_min, y_min, x_max, y_max], axis=1), dtype=boxes.dtype, device=boxes.device)
+
+    def perspective_lines(self, lines, M, width, height):
+        # 提取坐标和可见性标志
+        coords = lines[..., :2].cpu().numpy()  # Shape: (N, L, 2)
+        visibility = lines[..., 2:]
+
+        # 确保coords是二维数组,如果它是三维的,则将其重塑为二维
+        original_shape = coords.shape
+        coords_reshaped = coords.reshape(-1, 2)  # Reshape to (N*L, 2)
+
+        # 添加齐次坐标
+        ones = np.ones((coords_reshaped.shape[0], 1))
+        coords_homogeneous = np.hstack([coords_reshaped, ones])  # Shape: (N*L, 3)
+
+        # 应用透视变换矩阵
+        transformed_coords_homogeneous = np.dot(M, coords_homogeneous.T).T
+        transformed_coords = transformed_coords_homogeneous[:, :2] / transformed_coords_homogeneous[:, 2:]  # 归一化
+
+        # 将变换后的坐标恢复到原始形状
+        transformed_coords = transformed_coords.reshape(original_shape)  # Reshape back to (N, L, 2)
+
+        # 裁剪到图像范围内
+        transformed_coords = np.clip(transformed_coords, [0, 0], [width, height])
+
+        # 转换回tensor
+        transformed_coords = torch.tensor(transformed_coords, dtype=lines.dtype, device=lines.device)
+        return torch.cat([transformed_coords, visibility], dim=-1)
+
+    def __call__(self, img, target):
+        if random.random() < self.p:
+            width, height = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
+            startpoints, endpoints = self._get_perspective_params(width, height, self.distortion_scale)
+
+            # 使用 OpenCV 计算透视变换矩阵
+            M = cv2.getPerspectiveTransform(
+                np.float32(startpoints),
+                np.float32(endpoints)
+            )
+
+            # 对图像应用透视变换
+            if isinstance(img, Image.Image):
+                img = img.transform((width, height), Image.PERSPECTIVE, M.flatten(), resample=Image.BILINEAR)
+            elif isinstance(img, torch.Tensor):
+                # 如果你需要用 TorchVision 实现,可以考虑使用 F.perspective,但更推荐配合PIL操作
+                pil_img = F.to_pil_image(img)
+                pil_img = pil_img.transform((width, height), Image.PERSPECTIVE, M.flatten(), resample=Image.BILINEAR)
+                img = F.to_tensor(pil_img)
+
+            # 对 boxes 变换
+            if "boxes" in target:
+                target["boxes"] = self.perspective_boxes(target["boxes"], M, width, height)
+
+            # 对 lines 变换
+            if "lines" in target:
+                target["lines"] = self.perspective_lines(target["lines"], M, width, height)
+
+        return img, target
+
+class DefaultTransform(nn.Module):
+    def forward(self, img: Tensor,target) -> tuple[Tensor, Any]:
+        if not isinstance(img, Tensor):
+            img = F.pil_to_tensor(img)
+        return F.convert_image_dtype(img, torch.float),target
+
+    def __repr__(self) -> str:
+        return self.__class__.__name__ + "()"
+
+    def describe(self) -> str:
+        return (
+            "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
+            "The images are rescaled to ``[0.0, 1.0]``."
+        )
+
+
+class ToTensor:
+    def __call__(self, img, target):
+        img = F.to_tensor(img)
+        return img, target
+
+
+def get_transforms(augmention=True):
+    transforms_list = []
+
+    if augmention:
+
+        transforms_list.append(ColorJitter())
+        transforms_list.append(RandomGrayscale(0.1))
+
+        transforms_list.append(GaussianBlur())
+        transforms_list.append(RandomErasing())
+        transforms_list.append(RandomHorizontalFlip(0.5))
+        transforms_list.append(RandomVerticalFlip(0.2))
+        # transforms_list.append(RandomPerspective())
+        transforms_list.append(RandomRotation(degrees=15))
+        transforms_list.append(RandomResize(512, 2048))
+
+        transforms_list.append(RandomCrop((512,512)))
+
+    transforms_list.append(DefaultTransform())
+
+    return Compose(transforms_list)

+ 8 - 72
models/line_detect/line_dataset.py

@@ -3,27 +3,17 @@ from torch.utils.data.dataset import T_co
 from libs.vision_libs.utils import draw_keypoints
 from models.base.base_dataset import BaseDataset
 
-import glob
 import json
-import math
 import os
-import random
-import cv2
 import PIL
-import imageio
-import matplotlib.pyplot as plt
 import matplotlib as mpl
 from torchvision.utils import draw_bounding_boxes
 import torchvision.transforms.v2 as transforms
-import numpy as np
-import numpy.linalg as LA
 import torch
-from skimage import io
-from torch.utils.data import Dataset
-from torch.utils.data.dataloader import default_collate
 
 import matplotlib.pyplot as plt
-from models.dataset_tool import read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
+from models.base.transforms import get_transforms
+
 
 def validate_keypoints(keypoints, image_width, image_height):
     for kp in keypoints:
@@ -32,58 +22,6 @@ def validate_keypoints(keypoints, image_width, image_height):
             raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
 
 
-def apply_transform_with_boxes_and_keypoints(img,target):
-    """
-    对图像、边界框和关键点应用相同的变换。
-
-    :param img_path: 图像文件路径
-    :param boxes: 形状为 (N, 4) 的 Tensor,表示 N 个边界框的坐标 [x_min, y_min, x_max, y_max]
-    :param keypoints: 形状为 (N, K, 3) 的 Tensor,表示 N 个实例的 K 个关键点的坐标和可见性 [x, y, visibility]
-    :return: 变换后的图像、边界框和关键点
-    """
-
-
-    # 定义一系列用于数据增强的变换
-    data_transforms = transforms.Compose([
-        # 随机调整大小和随机裁剪
-        # transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0), antialias=True),
-
-        # 随机水平翻转
-        transforms.RandomHorizontalFlip(p=0.5),
-
-        # 颜色抖动: 改变亮度、对比度、饱和度和色调
-        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
-
-        # # 转换为张量
-        # transforms.ToTensor(),
-        #
-        # # 标准化
-        # transforms.Normalize(mean=[0.485, 0.456, 0.406],
-        #                      std=[0.229, 0.224, 0.225])
-    ])
-
-    boxes=target['boxes']
-    keypoints=target['lines']
-    # 将边界框转换为适合传递给 transforms 的格式
-    boxes_format = [(box[0].item(), box[1].item(), box[2].item(), box[3].item()) for box in boxes]
-
-    # 将关键点转换为适合传递给 transforms 的格式
-    keypoints_format = [[(kp[0].item(), kp[1].item(), bool(kp[2].item())) for kp in keypoint] for keypoint in keypoints]
-
-    # 应用变换
-    transformed = data_transforms(img, {"boxes": boxes_format, "keypoints": keypoints_format})
-
-    # 获取变换后的图像、边界框和关键点
-    img_transformed = transformed[0]
-    boxes_transformed = torch.tensor([(box[0], box[1], box[2], box[3]) for box in transformed[1]['boxes']],
-                                     dtype=torch.float32)
-    keypoints_transformed = torch.tensor(
-        [[(kp[0], kp[1], int(kp[2])) for kp in keypoint] for keypoint in transformed[1]['keypoints']],
-        dtype=torch.float32)
-
-    target['boxes']=boxes_transformed
-    target['lines']=keypoints_transformed
-    return img_transformed, target
 
 """
 直接读取xanlabel标注的数据集json格式
@@ -114,16 +52,13 @@ class LineDataset(BaseDataset):
         w, h = img.size
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
         target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        if self.transforms:
-            img, target = self.transforms(img, target)
 
-        else:
-            img = self.default_transform(img)
 
-        # print(f'img:{img}')
-        # print(f'img shape:{img.shape}')
-        if self.augmentation:
-            img, target=apply_transform_with_boxes_and_keypoints(img, target)
+        self.transforms=get_transforms(augmention=self.augmentation)
+
+        img, target = self.transforms(img, target)
+
+
         return img, target
 
     def __len__(self):
@@ -154,6 +89,7 @@ class LineDataset(BaseDataset):
         lines = torch.cat((lines, a), dim=1)
 
         target["lines"] = lines.to(torch.float32).view(-1,2,3)
+        print(f'lines:{target["lines"].shape}')
         target["img_size"]=shape
 
         validate_keypoints(lines, shape[0], shape[1])

+ 14 - 54
models/line_detect/line_detect.py

@@ -50,7 +50,7 @@ class LineDetect(BaseDetectionNet):
     def __init__(
             self,
             backbone,
-            num_classes=None,
+            num_classes=2,
             # transform parameters
             min_size=512,
             max_size=2048,
@@ -85,7 +85,7 @@ class LineDetect(BaseDetectionNet):
             line_roi_pool=None,
             line_head=None,
             line_predictor=None,
-            num_keypoints=None,
+            num_points=3,
             **kwargs,
     ):
 
@@ -149,30 +149,13 @@ class LineDetect(BaseDetectionNet):
 
 
 
-        if not isinstance(line_roi_pool, (MultiScaleRoIAlign, type(None))):
-            raise TypeError(
-                "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
-            )
-        if min_size is None:
-            min_size = (640, 672, 704, 736, 768, 800)
-
-        if num_keypoints is not None:
-            if line_predictor is not None:
-                raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
-        else:
-            num_keypoints = 2
-
-
-        if line_roi_pool is None:
-            line_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
-
         if line_head is None:
             keypoint_layers = tuple(1 for _ in range(8))
             line_head = LineHeads(8, keypoint_layers)
 
-        if line_predictor is None:
-            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
-            line_predictor = LinePredictor(keypoint_dim_reduced)
+        # if line_predictor is None:
+        #     keypoint_dim_reduced = 512  # == keypoint_layers[-1]
+        #     line_predictor = LinePredictor(keypoint_dim_reduced)
 
 
         self.roi_heads.line_roi_pool = line_roi_pool
@@ -303,30 +286,7 @@ class LineHeads(nn.Sequential):
                 nn.init.constant_(m.bias, 0)
 
 
-class LinePredictor(nn.Module):
-    def __init__(self, in_channels, out_channels=1 ):
-        super().__init__()
-        input_features = in_channels
-        deconv_kernel = 4
-        self.kps_score_lowres = nn.ConvTranspose2d(
-            input_features,
-            out_channels,
-            deconv_kernel,
-            stride=2,
-            padding=deconv_kernel // 2 - 1,
-        )
-        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
-        nn.init.constant_(self.kps_score_lowres.bias, 0)
-        self.up_scale = 2
-        self.out_channels = out_channels
 
-    def forward(self, x):
-        print(f'before kps_score_lowres x:{x.shape}')
-        x = self.kps_score_lowres(x)
-        print(f'kps_score_lowres x:{x.shape}')
-        return torch.nn.functional.interpolate(
-            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
-        )
 
 def linedetect_newresnet18fpn(
         *,
@@ -339,9 +299,9 @@ def linedetect_newresnet18fpn(
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     if num_classes is None:
-        num_classes = 2
+        num_classes = 3
     if num_points is None:
-        num_points = 2
+        num_points = 3
 
 
     backbone =resnet18fpn()
@@ -361,7 +321,7 @@ def linedetect_newresnet18fpn(
 
     anchor_generator =  AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)
 
-    model = LineDetect(backbone, num_classes, num_keypoints=num_points,rpn_anchor_generator=anchor_generator,box_roi_pool=roi_pooler, **kwargs)
+    model = LineDetect(backbone, num_classes, num_points=num_points, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler, **kwargs)
 
     return model
 
@@ -375,12 +335,12 @@ def linedetect_resnet18_fpn(
 ) -> LineDetect:
 
     if num_classes is None:
-        num_classes = 2
+        num_classes = 3
     if num_points is None:
-        num_points = 2
+        num_points = 3
 
     backbone = resnet_fpn_backbone(backbone_name='resnet18',weights=None)
-    model = LineDetect(backbone, num_classes, num_keypoints=num_points, **kwargs)
+    model = LineDetect(backbone, num_classes, num_points=num_points, **kwargs)
 
     return model
 
@@ -391,12 +351,12 @@ def linedetect_resnet50_fpn(
         **kwargs: Any,
 ) -> LineDetect:
     if num_classes is None:
-        num_classes = 2
+        num_classes = 3
     if num_points is None:
-        num_points = 2
+        num_points = 3
 
     backbone = resnet_fpn_backbone(backbone_name='resnet18', weights=None)
-    model = LineDetect(backbone, num_classes, num_keypoints=num_points, **kwargs)
+    model = LineDetect(backbone, num_classes, num_points=num_points, **kwargs)
 
 
     return model

+ 8 - 14
models/line_detect/loi_heads.py

@@ -191,15 +191,7 @@ def line_points_to_heatmap(keypoints, rois, heatmap_size):
     # type: (Tensor, Tensor, int) -> Tensor
     print(f'rois:{rois.shape}')
     print(f'heatmap_size:{heatmap_size}')
-    # offset_x = rois[:, 0]
-    # offset_y = rois[:, 1]
-    # scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
-    # scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
-    #
-    # offset_x = offset_x[:, None]
-    # offset_y = offset_y[:, None]
-    # scale_x = scale_x[:, None]
-    # scale_y = scale_y[:, None]
+
 
     print(f'keypoints.shape:{keypoints.shape}')
     # batch_size, num_keypoints, _ = keypoints.shape
@@ -1070,12 +1062,12 @@ class RoIHeads(nn.Module):
         return True
 
     def has_line(self):
-        if self.line_roi_pool is None:
-            return False
+        # if self.line_roi_pool is None:
+        #     return False
         if self.line_head is None:
             return False
-        if self.line_predictor is None:
-            return False
+        # if self.line_predictor is None:
+        #     return False
         return True
 
     def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
@@ -1351,12 +1343,14 @@ class RoIHeads(nn.Module):
             # line_features = self.line_roi_pool(features, line_proposals, image_shapes)
 
             # print(f'line_features from line_roi_pool:{line_features.shape}')
-
+            #(b,256,512,512)
             line_features = self.channel_compress(features['0'])
+            #(b.8,512,512)
 
             line_features = lines_features_align(line_features, line_proposals, image_shapes)
 
             line_features = self.line_head(line_features)
+            #(N,1,512,512)
             print(f'line_features from line_head:{line_features.shape}')
             # line_logits = self.line_predictor(line_features)
 

+ 2 - 2
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: /data/share/zyh/202507/a_dataset
+  datadir: \\192.168.50.222/share/rlq/datasets/0706_
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
@@ -11,7 +11,7 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 2
+  batch_size: 1
   max_epoch: 80000
   augmentation: True
   optim:

+ 1 - 1
models/line_detect/train_demo.py

@@ -16,6 +16,6 @@ if __name__ == '__main__':
     # model = lineDetect_resnet18_fpn()
 
     # model=linedetect_resnet18_fpn()
-    model=linedetect_newresnet18fpn()
+    model=linedetect_newresnet18fpn(num_points=2)
 
     model.start_train(cfg='train.yaml')

+ 1 - 1
models/line_detect/trainer.py

@@ -12,7 +12,7 @@ from libs.vision_libs.utils import draw_bounding_boxes, draw_keypoints
 from models.base.base_model import BaseModel
 from models.base.base_trainer import BaseTrainer
 from models.config.config_tool import read_yaml
-from models.line_detect.line_dataset_old import LineDataset
+from models.line_detect.line_dataset import LineDataset
 
 from models.line_net.dataset_LD import WirePointDataset
 from models.wirenet.postprocess import postprocess