| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521 |
- import logging
- import random
- from typing import Any,Tuple
- import cv2
- import numpy as np
- import torch
- from PIL import Image
- from torch import nn, Tensor
- from libs.vision_libs .transforms import functional as F
- from libs.vision_libs import transforms
- class Compose:
- def __init__(self, transforms):
- self.transforms = transforms
- def __call__(self, img, target):
- for t in self.transforms:
- img, target = t(img, target)
- return img, target
- class RandomHorizontalFlip:
- def __init__(self, prob=0.5):
- self.prob = prob
- def __call__(self, img, target):
- if random.random() < self.prob:
- width = img.width if isinstance(img, Image.Image) else img.shape[-1]
- # Flip image
- img = F.hflip(img)
- # Flip boxes
- boxes = target["boxes"]
- x1, y1, x2, y2 = boxes.unbind(dim=1)
- boxes_flipped = torch.stack((width - x2, y1, width - x1, y2), dim=1)
- target["boxes"] = boxes_flipped
- # Flip lines
- if "lines" in target:
- lines = target["lines"].clone()
- # 只翻转 x 坐标,y 和 visibility 不变
- lines[..., 0] = width - lines[..., 0]
- target["lines"] = lines
- # Flip lines
- if "circle_masks" in target:
- lines = target["circle_masks"].clone()
- # 只翻转 x 坐标,y 和 visibility 不变
- lines[..., 0] = width - lines[..., 0]
- target["circle_masks"] = lines
- return img, target
- class RandomVerticalFlip:
- def __init__(self, prob=0.5):
- self.prob = prob
- def __call__(self, img, target):
- if random.random() < self.prob:
- height = img.height if isinstance(img, Image.Image) else img.shape[-2]
- # Flip image
- img = F.vflip(img)
- # Flip boxes
- boxes = target["boxes"]
- x1, y1, x2, y2 = boxes.unbind(dim=1)
- boxes_flipped = torch.stack((x1, height - y2, x2, height - y1), dim=1)
- target["boxes"] = boxes_flipped
- # Flip lines
- if "lines" in target:
- lines = target["lines"].clone()
- lines[..., 1] = height - lines[..., 1]
- target["lines"] = lines
- if "circle_masks" in target:
- lines = target["circle_masks"].clone()
- lines[..., 1] = height - lines[..., 1]
- target["circle_masks"] = lines
- return img, target
- class ColorJitter:
- def __init__(self, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2):
- if not (0 <= hue <= 0.5):
- raise ValueError(f"Hue jitter value should be in [0, 0.5], but got {hue}")
- self.color_jitter = transforms.ColorJitter(
- brightness=brightness,
- contrast=contrast,
- saturation=saturation,
- hue=hue
- )
- def __call__(self, img, target):
- print(f"Original image type: {type(img)}")
- img = self.color_jitter(img)
- print("Color jitter applied successfully.")
- return img, target
- class RandomGrayscale:
- def __init__(self, p=0.1):
- self.p = p
- def __call__(self, img, target):
- print(f"RandomGrayscale Original image type: {type(img)}")
- if random.random() < self.p:
- img = F.to_grayscale(img, num_output_channels=3)
- return img, target
- class RandomResize:
- def __init__(self, min_size, max_size=None):
- self.min_size = min_size
- self.max_size = max_size
- def __call__(self, img, target):
- size = random.randint(self.min_size, self.max_size) if self.max_size else self.min_size
- w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
- scale = size / min(h, w)
- new_h, new_w = int(scale * h), int(scale * w)
- img = F.resize(img, (new_h, new_w))
- # Update boxes
- boxes = target["boxes"]
- boxes = boxes * scale
- target["boxes"] = boxes
- # Update lines
- if "lines" in target:
- target["lines"] = target["lines"] * torch.tensor([scale, scale, 1], device=target["lines"].device)
- if "circle_masks" in target:
- target["circle_masks"] = target["circle_masks"] * torch.tensor([scale, scale, 1], device=target["circle_masks"].device)
- return img, target
- class RandomCrop:
- def __init__(self, size):
- self.size = size
- def __call__(self, img, target):
- width, height = F.get_image_size(img)
- crop_height, crop_width = self.size
- # 随机选择裁剪区域
- left = random.randint(0, max(width - crop_width, 0))
- top = random.randint(0, max(height - crop_height, 0))
- right = min(left + crop_width, width)
- bottom = min(top + crop_height, height)
- # 裁剪图像
- img = F.crop(img, top, left, bottom - top, right - left)
- if "boxes" in target:
- boxes = target["boxes"]
- labels = target["labels"] if "labels" in target else None
- # 将bounding boxes转换到裁剪区域坐标系
- cropped_boxes = boxes.clone()
- cropped_boxes[:, 0::2] -= left
- cropped_boxes[:, 1::2] -= top
- # 确保bounding boxes在裁剪区域内
- cropped_boxes[:, 0::2].clamp_(min=0, max=crop_width)
- cropped_boxes[:, 1::2].clamp_(min=0, max=crop_height)
- # 计算新的宽高
- w = cropped_boxes[:, 2] - cropped_boxes[:, 0]
- h = cropped_boxes[:, 3] - cropped_boxes[:, 1]
- # 过滤掉无效的bounding boxes(宽度或高度为0)
- valid_boxes_mask = (w > 0) & (h > 0)
- # 更新有效bounding boxes
- cropped_boxes = cropped_boxes[valid_boxes_mask]
- if labels is not None:
- labels = labels[valid_boxes_mask]
- # 更新target
- target["boxes"] = cropped_boxes
- if labels is not None:
- target["labels"] = labels
- return img, target
- class GaussianBlur:
- def __init__(self, kernel_size=5, sigma=(0.1, 2.0), prob=0.2):
- self.kernel_size = kernel_size if kernel_size % 2 == 1 else kernel_size + 1 # Ensure kernel size is odd
- self.sigma = sigma
- self.prob = prob
- def __call__(self, img, target):
- if random.random() < self.prob:
- # Convert PIL Image to Tensor if necessary
- if isinstance(img, Image.Image):
- img = transforms.ToTensor()(img)
- # Apply Gaussian blur using PyTorch's functional interface
- img = transforms.GaussianBlur(kernel_size=self.kernel_size, sigma=random.uniform(*self.sigma))(img)
- # If the original image was a PIL Image, convert it back
- if isinstance(img, Tensor) and not isinstance(target.get('original_image_format', None), Tensor):
- img = transforms.ToPILImage()(img)
- return img, target
- class RandomRotation:
- def __init__(self, degrees=15, prob=0.5):
- self.degrees = degrees
- self.prob = prob
- def rotate_boxes(self, boxes, angle, center):
- # Convert to numpy for easier rotation math
- boxes_np = boxes.cpu().numpy()
- center_np = np.array(center)
- corners = np.array([
- [boxes_np[:, 0], boxes_np[:, 1]], # top-left
- [boxes_np[:, 2], boxes_np[:, 1]], # top-right
- [boxes_np[:, 2], boxes_np[:, 3]], # bottom-right
- [boxes_np[:, 0], boxes_np[:, 3]] # bottom-left
- ]).transpose(2, 0, 1) # shape: (N, 4, 2)
- # Translate to origin
- corners -= center_np
- # Rotate points
- theta = np.radians(angle)
- c, s = np.cos(theta), np.sin(theta)
- R = np.array([[c, -s], [s, c]])
- rotated_corners = corners @ R
- # Translate back
- rotated_corners += center_np
- # Get new bounding box coordinates
- x_min = np.min(rotated_corners[:, :, 0], axis=1)
- y_min = np.min(rotated_corners[:, :, 1], axis=1)
- x_max = np.max(rotated_corners[:, :, 0], axis=1)
- y_max = np.max(rotated_corners[:, :, 1], axis=1)
- # Convert back to tensor and move to the same device
- device = boxes.device
- return torch.tensor(np.stack([x_min, y_min, x_max, y_max], axis=1), dtype=boxes.dtype, device=device)
- def rotate_lines(self, lines, angle, center):
- coords = lines[..., :2] # shape: (..., 2)
- visibility = lines[..., 2:] # shape: (..., N)
- # Translate to origin
- coords = coords - torch.tensor(center, dtype=coords.dtype, device=coords.device)
- # Rotation matrix
- theta = torch.deg2rad(torch.tensor(angle))
- cos_t = torch.cos(theta)
- sin_t = torch.sin(theta)
- R = torch.tensor([[cos_t, -sin_t], [sin_t, cos_t]], dtype=coords.dtype, device=coords.device)
- # Apply rotation using torch.matmul
- rotated_coords = torch.matmul(coords, R)
- # Translate back
- rotated_coords = rotated_coords + torch.tensor(center, dtype=coords.dtype, device=coords.device)
- # Concatenate with visibility
- rotated_lines = torch.cat([rotated_coords, visibility], dim=-1)
- return rotated_lines
- def __call__(self, img, target):
- if random.random() < self.prob:
- angle = random.uniform(-self.degrees, self.degrees)
- w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
- center = (w / 2, h / 2)
- # Rotate image
- img = F.rotate(img, angle, center=center)
- # Rotate boxes
- if "boxes" in target:
- target["boxes"] = self.rotate_boxes(target["boxes"], angle, center)
- # Rotate lines
- if "lines" in target:
- target["lines"] = self.rotate_lines(target["lines"], angle, center)
- if "circle_masks" in target:
- target["circle_masks"] = self.rotate_lines(target["circle_masks"], angle, center)
- return img, target
- class RandomErasing:
- def __init__(self, prob=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.485, 0.456, 0.406]):
- """
- :param prob: 应用擦除的概率
- :param sl: 擦除面积比例的下界
- :param sh: 擦除面积比例的上界
- :param r1: 长宽比的下界
- :param mean: 用于填充擦除区域的像素值
- """
- self.prob = prob
- self.sl = sl
- self.sh = sh
- self.r1 = r1
- self.mean = mean
- def __call__(self, img, target):
- if random.random() < self.prob:
- # 如果是Tensor,则直接处理
- if isinstance(img, torch.Tensor):
- img = self._erase_tensor(img)
- # 如果是PIL Image,则转换为Tensor处理后再转回PIL Image
- elif isinstance(img, Image.Image):
- img_tensor = transforms.ToTensor()(img)
- img_tensor = self._erase_tensor(img_tensor)
- img = transforms.ToPILImage()(img_tensor)
- return img, target
- def _erase_tensor(self, img_tensor):
- """
- 对Tensor类型的图像执行随机擦除
- """
- img_c, img_h, img_w = img_tensor.shape
- area = img_h * img_w
- # 计算擦除区域的大小
- erase_area = random.uniform(self.sl, self.sh) * area
- aspect_ratio = random.uniform(self.r1, 1 / self.r1)
- h = int(round((erase_area * aspect_ratio) ** 0.5))
- w = int(round((erase_area / aspect_ratio) ** 0.5))
- # 确保不会超出图像边界
- if h < img_h and w < img_w:
- x = random.randint(0, img_w - w)
- y = random.randint(0, img_h - h)
- # 创建一个与擦除区域相同大小且填充指定均值的区域
- mean_tensor = torch.tensor(self.mean).view(img_c, 1, 1).expand(img_c, h, w)
- # 将该区域应用到原始图像上
- img_tensor[:, y:y + h, x:x + w] = mean_tensor
- return img_tensor
- """
- 有Bugs
- """
- class RandomPerspective:
- def __init__(self, distortion_scale=0.5, p=0.5):
- self.distortion_scale = distortion_scale
- self.p = p
- def _get_perspective_params(self, width, height, distortion_scale):
- half_w = width // 2
- half_h = height // 2
- w = int(width * distortion_scale)
- h = int(height * distortion_scale)
- startpoints = [
- [0, 0],
- [width - 1, 0],
- [width - 1, height - 1],
- [0, height - 1]
- ]
- endpoints = [
- [random.randint(0, w), random.randint(0, h)],
- [width - 1 - random.randint(0, w), random.randint(0, h)],
- [width - 1 - random.randint(0, w), height - 1 - random.randint(0, h)],
- [random.randint(0, w), height - 1 - random.randint(0, h)]
- ]
- return startpoints, endpoints
- def perspective_boxes(self, boxes, M, width, height):
- # 将boxes转换为角点形式
- corners = np.array([
- [boxes[:, 0], boxes[:, 1]], # top-left
- [boxes[:, 2], boxes[:, 1]], # top-right
- [boxes[:, 2], boxes[:, 3]], # bottom-right
- [boxes[:, 0], boxes[:, 3]] # bottom-left
- ]).transpose(2, 0, 1).reshape(-1, 2) # shape: (N*4, 2)
- # 应用透视变换
- ones = np.ones((corners.shape[0], 1))
- coords_homogeneous = np.hstack([corners, ones])
- transformed_coords = (M @ coords_homogeneous.T).T
- transformed_coords /= transformed_coords[:, 2].reshape(-1, 1) # 齐次除法
- transformed_coords = transformed_coords[:, :2]
- # 重新组合成bounding box
- transformed_coords = transformed_coords.reshape(-1, 4, 2)
- x_min = np.min(transformed_coords[:, :, 0], axis=1)
- y_min = np.min(transformed_coords[:, :, 1], axis=1)
- x_max = np.max(transformed_coords[:, :, 0], axis=1)
- y_max = np.max(transformed_coords[:, :, 1], axis=1)
- # 裁剪到图像范围内
- x_min = np.clip(x_min, 0, width)
- y_min = np.clip(y_min, 0, height)
- x_max = np.clip(x_max, 0, width)
- y_max = np.clip(y_max, 0, height)
- return torch.tensor(np.stack([x_min, y_min, x_max, y_max], axis=1), dtype=boxes.dtype, device=boxes.device)
- def perspective_lines(self, lines, M, width, height):
- # 提取坐标和可见性标志
- coords = lines[..., :2].cpu().numpy() # Shape: (N, L, 2)
- visibility = lines[..., 2:]
- # 确保coords是二维数组,如果它是三维的,则将其重塑为二维
- original_shape = coords.shape
- coords_reshaped = coords.reshape(-1, 2) # Reshape to (N*L, 2)
- # 添加齐次坐标
- ones = np.ones((coords_reshaped.shape[0], 1))
- coords_homogeneous = np.hstack([coords_reshaped, ones]) # Shape: (N*L, 3)
- # 应用透视变换矩阵
- transformed_coords_homogeneous = np.dot(M, coords_homogeneous.T).T
- transformed_coords = transformed_coords_homogeneous[:, :2] / transformed_coords_homogeneous[:, 2:] # 归一化
- # 将变换后的坐标恢复到原始形状
- transformed_coords = transformed_coords.reshape(original_shape) # Reshape back to (N, L, 2)
- # 裁剪到图像范围内
- transformed_coords = np.clip(transformed_coords, [0, 0], [width, height])
- # 转换回tensor
- transformed_coords = torch.tensor(transformed_coords, dtype=lines.dtype, device=lines.device)
- return torch.cat([transformed_coords, visibility], dim=-1)
- def __call__(self, img, target):
- if random.random() < self.p:
- width, height = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
- startpoints, endpoints = self._get_perspective_params(width, height, self.distortion_scale)
- # 使用 OpenCV 计算透视变换矩阵
- M = cv2.getPerspectiveTransform(
- np.float32(startpoints),
- np.float32(endpoints)
- )
- # 对图像应用透视变换
- if isinstance(img, Image.Image):
- img = img.transform((width, height), Image.PERSPECTIVE, M.flatten(), resample=Image.BILINEAR)
- elif isinstance(img, torch.Tensor):
- # 如果你需要用 TorchVision 实现,可以考虑使用 F.perspective,但更推荐配合PIL操作
- pil_img = F.to_pil_image(img)
- pil_img = pil_img.transform((width, height), Image.PERSPECTIVE, M.flatten(), resample=Image.BILINEAR)
- img = F.to_tensor(pil_img)
- # 对 boxes 变换
- if "boxes" in target:
- target["boxes"] = self.perspective_boxes(target["boxes"], M, width, height)
- # 对 lines 变换
- if "lines" in target:
- target["lines"] = self.perspective_lines(target["lines"], M, width, height)
- if "circle_masks" in target:
- target["circle_masks"] = self.perspective_lines(target["circle_masks"], M, width, height)
- return img, target
- class DefaultTransform(nn.Module):
- def forward(self, img: Tensor, target) -> Tuple[Tensor, Any]:
- if not isinstance(img, Tensor):
- img = F.pil_to_tensor(img)
- return F.convert_image_dtype(img, torch.float),target
- def __repr__(self) -> str:
- return self.__class__.__name__ + "()"
- def describe(self) -> str:
- return (
- "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
- "The images are rescaled to ``[0.0, 1.0]``."
- )
- class ToTensor:
- def __call__(self, img, target):
- img = F.to_tensor(img)
- return img, target
- def get_transforms(augmention=True):
- transforms_list = []
- if augmention:
- # transforms_list.append(ColorJitter())
- transforms_list.append(RandomGrayscale(0.1))
- transforms_list.append(GaussianBlur())
- # transforms_list.append(RandomErasing())
- transforms_list.append(RandomHorizontalFlip(0.5))
- transforms_list.append(RandomVerticalFlip(0.5))
- # transforms_list.append(RandomPerspective())
- transforms_list.append(RandomRotation(degrees=15))
- # transforms_list.append(RandomResize(512, 2048))
- # transforms_list.append(RandomCrop((512,512)))
- transforms_list.append(DefaultTransform())
- return Compose(transforms_list)
|