transforms.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. import logging
  2. import random
  3. from typing import Any,Tuple
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from PIL import Image
  8. from torch import nn, Tensor
  9. from libs.vision_libs .transforms import functional as F
  10. from libs.vision_libs import transforms
  11. class Compose:
  12. def __init__(self, transforms):
  13. self.transforms = transforms
  14. def __call__(self, img, target):
  15. for t in self.transforms:
  16. img, target = t(img, target)
  17. return img, target
  18. class RandomHorizontalFlip:
  19. def __init__(self, prob=0.5):
  20. self.prob = prob
  21. def __call__(self, img, target):
  22. if random.random() < self.prob:
  23. width = img.width if isinstance(img, Image.Image) else img.shape[-1]
  24. # Flip image
  25. img = F.hflip(img)
  26. # Flip boxes
  27. boxes = target["boxes"]
  28. x1, y1, x2, y2 = boxes.unbind(dim=1)
  29. boxes_flipped = torch.stack((width - x2, y1, width - x1, y2), dim=1)
  30. target["boxes"] = boxes_flipped
  31. # Flip lines
  32. if "lines" in target:
  33. lines = target["lines"].clone()
  34. # 只翻转 x 坐标,y 和 visibility 不变
  35. lines[..., 0] = width - lines[..., 0]
  36. target["lines"] = lines
  37. return img, target
  38. class RandomVerticalFlip:
  39. def __init__(self, prob=0.5):
  40. self.prob = prob
  41. def __call__(self, img, target):
  42. if random.random() < self.prob:
  43. height = img.height if isinstance(img, Image.Image) else img.shape[-2]
  44. # Flip image
  45. img = F.vflip(img)
  46. # Flip boxes
  47. boxes = target["boxes"]
  48. x1, y1, x2, y2 = boxes.unbind(dim=1)
  49. boxes_flipped = torch.stack((x1, height - y2, x2, height - y1), dim=1)
  50. target["boxes"] = boxes_flipped
  51. # Flip lines
  52. if "lines" in target:
  53. lines = target["lines"].clone()
  54. lines[..., 1] = height - lines[..., 1]
  55. target["lines"] = lines
  56. return img, target
  57. class ColorJitter:
  58. def __init__(self, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2):
  59. if not (0 <= hue <= 0.5):
  60. raise ValueError(f"Hue jitter value should be in [0, 0.5], but got {hue}")
  61. self.color_jitter = transforms.ColorJitter(
  62. brightness=brightness,
  63. contrast=contrast,
  64. saturation=saturation,
  65. hue=hue
  66. )
  67. def __call__(self, img, target):
  68. print(f"Original image type: {type(img)}")
  69. img = self.color_jitter(img)
  70. print("Color jitter applied successfully.")
  71. return img, target
  72. class RandomGrayscale:
  73. def __init__(self, p=0.1):
  74. self.p = p
  75. def __call__(self, img, target):
  76. print(f"RandomGrayscale Original image type: {type(img)}")
  77. if random.random() < self.p:
  78. img = F.to_grayscale(img, num_output_channels=3)
  79. return img, target
  80. class RandomResize:
  81. def __init__(self, min_size, max_size=None):
  82. self.min_size = min_size
  83. self.max_size = max_size
  84. def __call__(self, img, target):
  85. size = random.randint(self.min_size, self.max_size) if self.max_size else self.min_size
  86. w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
  87. scale = size / min(h, w)
  88. new_h, new_w = int(scale * h), int(scale * w)
  89. img = F.resize(img, (new_h, new_w))
  90. # Update boxes
  91. boxes = target["boxes"]
  92. boxes = boxes * scale
  93. target["boxes"] = boxes
  94. # Update lines
  95. if "lines" in target:
  96. target["lines"] = target["lines"] * torch.tensor([scale, scale, 1], device=target["lines"].device)
  97. return img, target
  98. class RandomCrop:
  99. def __init__(self, size):
  100. self.size = size
  101. def __call__(self, img, target):
  102. width, height = F.get_image_size(img)
  103. crop_height, crop_width = self.size
  104. # 随机选择裁剪区域
  105. left = random.randint(0, max(width - crop_width, 0))
  106. top = random.randint(0, max(height - crop_height, 0))
  107. right = min(left + crop_width, width)
  108. bottom = min(top + crop_height, height)
  109. # 裁剪图像
  110. img = F.crop(img, top, left, bottom - top, right - left)
  111. if "boxes" in target:
  112. boxes = target["boxes"]
  113. labels = target["labels"] if "labels" in target else None
  114. # 将bounding boxes转换到裁剪区域坐标系
  115. cropped_boxes = boxes.clone()
  116. cropped_boxes[:, 0::2] -= left
  117. cropped_boxes[:, 1::2] -= top
  118. # 确保bounding boxes在裁剪区域内
  119. cropped_boxes[:, 0::2].clamp_(min=0, max=crop_width)
  120. cropped_boxes[:, 1::2].clamp_(min=0, max=crop_height)
  121. # 计算新的宽高
  122. w = cropped_boxes[:, 2] - cropped_boxes[:, 0]
  123. h = cropped_boxes[:, 3] - cropped_boxes[:, 1]
  124. # 过滤掉无效的bounding boxes(宽度或高度为0)
  125. valid_boxes_mask = (w > 0) & (h > 0)
  126. # 更新有效bounding boxes
  127. cropped_boxes = cropped_boxes[valid_boxes_mask]
  128. if labels is not None:
  129. labels = labels[valid_boxes_mask]
  130. # 更新target
  131. target["boxes"] = cropped_boxes
  132. if labels is not None:
  133. target["labels"] = labels
  134. return img, target
  135. class GaussianBlur:
  136. def __init__(self, kernel_size=5, sigma=(0.1, 2.0), prob=0.2):
  137. self.kernel_size = kernel_size if kernel_size % 2 == 1 else kernel_size + 1 # Ensure kernel size is odd
  138. self.sigma = sigma
  139. self.prob = prob
  140. def __call__(self, img, target):
  141. if random.random() < self.prob:
  142. # Convert PIL Image to Tensor if necessary
  143. if isinstance(img, Image.Image):
  144. img = transforms.ToTensor()(img)
  145. # Apply Gaussian blur using PyTorch's functional interface
  146. img = transforms.GaussianBlur(kernel_size=self.kernel_size, sigma=random.uniform(*self.sigma))(img)
  147. # If the original image was a PIL Image, convert it back
  148. if isinstance(img, Tensor) and not isinstance(target.get('original_image_format', None), Tensor):
  149. img = transforms.ToPILImage()(img)
  150. return img, target
  151. class RandomRotation:
  152. def __init__(self, degrees=15, prob=0.5):
  153. self.degrees = degrees
  154. self.prob = prob
  155. def rotate_boxes(self, boxes, angle, center):
  156. # Convert to numpy for easier rotation math
  157. boxes_np = boxes.cpu().numpy()
  158. center_np = np.array(center)
  159. corners = np.array([
  160. [boxes_np[:, 0], boxes_np[:, 1]], # top-left
  161. [boxes_np[:, 2], boxes_np[:, 1]], # top-right
  162. [boxes_np[:, 2], boxes_np[:, 3]], # bottom-right
  163. [boxes_np[:, 0], boxes_np[:, 3]] # bottom-left
  164. ]).transpose(2, 0, 1) # shape: (N, 4, 2)
  165. # Translate to origin
  166. corners -= center_np
  167. # Rotate points
  168. theta = np.radians(angle)
  169. c, s = np.cos(theta), np.sin(theta)
  170. R = np.array([[c, -s], [s, c]])
  171. rotated_corners = corners @ R
  172. # Translate back
  173. rotated_corners += center_np
  174. # Get new bounding box coordinates
  175. x_min = np.min(rotated_corners[:, :, 0], axis=1)
  176. y_min = np.min(rotated_corners[:, :, 1], axis=1)
  177. x_max = np.max(rotated_corners[:, :, 0], axis=1)
  178. y_max = np.max(rotated_corners[:, :, 1], axis=1)
  179. # Convert back to tensor and move to the same device
  180. device = boxes.device
  181. return torch.tensor(np.stack([x_min, y_min, x_max, y_max], axis=1), dtype=boxes.dtype, device=device)
  182. def rotate_lines(self, lines, angle, center):
  183. coords = lines[..., :2] # shape: (..., 2)
  184. visibility = lines[..., 2:] # shape: (..., N)
  185. # Translate to origin
  186. coords = coords - torch.tensor(center, dtype=coords.dtype, device=coords.device)
  187. # Rotation matrix
  188. theta = torch.deg2rad(torch.tensor(angle))
  189. cos_t = torch.cos(theta)
  190. sin_t = torch.sin(theta)
  191. R = torch.tensor([[cos_t, -sin_t], [sin_t, cos_t]], dtype=coords.dtype, device=coords.device)
  192. # Apply rotation using torch.matmul
  193. rotated_coords = torch.matmul(coords, R)
  194. # Translate back
  195. rotated_coords = rotated_coords + torch.tensor(center, dtype=coords.dtype, device=coords.device)
  196. # Concatenate with visibility
  197. rotated_lines = torch.cat([rotated_coords, visibility], dim=-1)
  198. return rotated_lines
  199. def __call__(self, img, target):
  200. if random.random() < self.prob:
  201. angle = random.uniform(-self.degrees, self.degrees)
  202. w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
  203. center = (w / 2, h / 2)
  204. # Rotate image
  205. img = F.rotate(img, angle, center=center)
  206. # Rotate boxes
  207. if "boxes" in target:
  208. target["boxes"] = self.rotate_boxes(target["boxes"], angle, center)
  209. # Rotate lines
  210. if "lines" in target:
  211. target["lines"] = self.rotate_lines(target["lines"], angle, center)
  212. return img, target
  213. class RandomErasing:
  214. def __init__(self, prob=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.485, 0.456, 0.406]):
  215. """
  216. :param prob: 应用擦除的概率
  217. :param sl: 擦除面积比例的下界
  218. :param sh: 擦除面积比例的上界
  219. :param r1: 长宽比的下界
  220. :param mean: 用于填充擦除区域的像素值
  221. """
  222. self.prob = prob
  223. self.sl = sl
  224. self.sh = sh
  225. self.r1 = r1
  226. self.mean = mean
  227. def __call__(self, img, target):
  228. if random.random() < self.prob:
  229. # 如果是Tensor,则直接处理
  230. if isinstance(img, torch.Tensor):
  231. img = self._erase_tensor(img)
  232. # 如果是PIL Image,则转换为Tensor处理后再转回PIL Image
  233. elif isinstance(img, Image.Image):
  234. img_tensor = transforms.ToTensor()(img)
  235. img_tensor = self._erase_tensor(img_tensor)
  236. img = transforms.ToPILImage()(img_tensor)
  237. return img, target
  238. def _erase_tensor(self, img_tensor):
  239. """
  240. 对Tensor类型的图像执行随机擦除
  241. """
  242. img_c, img_h, img_w = img_tensor.shape
  243. area = img_h * img_w
  244. # 计算擦除区域的大小
  245. erase_area = random.uniform(self.sl, self.sh) * area
  246. aspect_ratio = random.uniform(self.r1, 1 / self.r1)
  247. h = int(round((erase_area * aspect_ratio) ** 0.5))
  248. w = int(round((erase_area / aspect_ratio) ** 0.5))
  249. # 确保不会超出图像边界
  250. if h < img_h and w < img_w:
  251. x = random.randint(0, img_w - w)
  252. y = random.randint(0, img_h - h)
  253. # 创建一个与擦除区域相同大小且填充指定均值的区域
  254. mean_tensor = torch.tensor(self.mean).view(img_c, 1, 1).expand(img_c, h, w)
  255. # 将该区域应用到原始图像上
  256. img_tensor[:, y:y + h, x:x + w] = mean_tensor
  257. return img_tensor
  258. """
  259. 有Bugs
  260. """
  261. class RandomPerspective:
  262. def __init__(self, distortion_scale=0.5, p=0.5):
  263. self.distortion_scale = distortion_scale
  264. self.p = p
  265. def _get_perspective_params(self, width, height, distortion_scale):
  266. half_w = width // 2
  267. half_h = height // 2
  268. w = int(width * distortion_scale)
  269. h = int(height * distortion_scale)
  270. startpoints = [
  271. [0, 0],
  272. [width - 1, 0],
  273. [width - 1, height - 1],
  274. [0, height - 1]
  275. ]
  276. endpoints = [
  277. [random.randint(0, w), random.randint(0, h)],
  278. [width - 1 - random.randint(0, w), random.randint(0, h)],
  279. [width - 1 - random.randint(0, w), height - 1 - random.randint(0, h)],
  280. [random.randint(0, w), height - 1 - random.randint(0, h)]
  281. ]
  282. return startpoints, endpoints
  283. def perspective_boxes(self, boxes, M, width, height):
  284. # 将boxes转换为角点形式
  285. corners = np.array([
  286. [boxes[:, 0], boxes[:, 1]], # top-left
  287. [boxes[:, 2], boxes[:, 1]], # top-right
  288. [boxes[:, 2], boxes[:, 3]], # bottom-right
  289. [boxes[:, 0], boxes[:, 3]] # bottom-left
  290. ]).transpose(2, 0, 1).reshape(-1, 2) # shape: (N*4, 2)
  291. # 应用透视变换
  292. ones = np.ones((corners.shape[0], 1))
  293. coords_homogeneous = np.hstack([corners, ones])
  294. transformed_coords = (M @ coords_homogeneous.T).T
  295. transformed_coords /= transformed_coords[:, 2].reshape(-1, 1) # 齐次除法
  296. transformed_coords = transformed_coords[:, :2]
  297. # 重新组合成bounding box
  298. transformed_coords = transformed_coords.reshape(-1, 4, 2)
  299. x_min = np.min(transformed_coords[:, :, 0], axis=1)
  300. y_min = np.min(transformed_coords[:, :, 1], axis=1)
  301. x_max = np.max(transformed_coords[:, :, 0], axis=1)
  302. y_max = np.max(transformed_coords[:, :, 1], axis=1)
  303. # 裁剪到图像范围内
  304. x_min = np.clip(x_min, 0, width)
  305. y_min = np.clip(y_min, 0, height)
  306. x_max = np.clip(x_max, 0, width)
  307. y_max = np.clip(y_max, 0, height)
  308. return torch.tensor(np.stack([x_min, y_min, x_max, y_max], axis=1), dtype=boxes.dtype, device=boxes.device)
  309. def perspective_lines(self, lines, M, width, height):
  310. # 提取坐标和可见性标志
  311. coords = lines[..., :2].cpu().numpy() # Shape: (N, L, 2)
  312. visibility = lines[..., 2:]
  313. # 确保coords是二维数组,如果它是三维的,则将其重塑为二维
  314. original_shape = coords.shape
  315. coords_reshaped = coords.reshape(-1, 2) # Reshape to (N*L, 2)
  316. # 添加齐次坐标
  317. ones = np.ones((coords_reshaped.shape[0], 1))
  318. coords_homogeneous = np.hstack([coords_reshaped, ones]) # Shape: (N*L, 3)
  319. # 应用透视变换矩阵
  320. transformed_coords_homogeneous = np.dot(M, coords_homogeneous.T).T
  321. transformed_coords = transformed_coords_homogeneous[:, :2] / transformed_coords_homogeneous[:, 2:] # 归一化
  322. # 将变换后的坐标恢复到原始形状
  323. transformed_coords = transformed_coords.reshape(original_shape) # Reshape back to (N, L, 2)
  324. # 裁剪到图像范围内
  325. transformed_coords = np.clip(transformed_coords, [0, 0], [width, height])
  326. # 转换回tensor
  327. transformed_coords = torch.tensor(transformed_coords, dtype=lines.dtype, device=lines.device)
  328. return torch.cat([transformed_coords, visibility], dim=-1)
  329. def __call__(self, img, target):
  330. if random.random() < self.p:
  331. width, height = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
  332. startpoints, endpoints = self._get_perspective_params(width, height, self.distortion_scale)
  333. # 使用 OpenCV 计算透视变换矩阵
  334. M = cv2.getPerspectiveTransform(
  335. np.float32(startpoints),
  336. np.float32(endpoints)
  337. )
  338. # 对图像应用透视变换
  339. if isinstance(img, Image.Image):
  340. img = img.transform((width, height), Image.PERSPECTIVE, M.flatten(), resample=Image.BILINEAR)
  341. elif isinstance(img, torch.Tensor):
  342. # 如果你需要用 TorchVision 实现,可以考虑使用 F.perspective,但更推荐配合PIL操作
  343. pil_img = F.to_pil_image(img)
  344. pil_img = pil_img.transform((width, height), Image.PERSPECTIVE, M.flatten(), resample=Image.BILINEAR)
  345. img = F.to_tensor(pil_img)
  346. # 对 boxes 变换
  347. if "boxes" in target:
  348. target["boxes"] = self.perspective_boxes(target["boxes"], M, width, height)
  349. # 对 lines 变换
  350. if "lines" in target:
  351. target["lines"] = self.perspective_lines(target["lines"], M, width, height)
  352. return img, target
  353. class DefaultTransform(nn.Module):
  354. def forward(self, img: Tensor, target) -> Tuple[Tensor, Any]:
  355. if not isinstance(img, Tensor):
  356. img = F.pil_to_tensor(img)
  357. return F.convert_image_dtype(img, torch.float),target
  358. def __repr__(self) -> str:
  359. return self.__class__.__name__ + "()"
  360. def describe(self) -> str:
  361. return (
  362. "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
  363. "The images are rescaled to ``[0.0, 1.0]``."
  364. )
  365. class ToTensor:
  366. def __call__(self, img, target):
  367. img = F.to_tensor(img)
  368. return img, target
  369. def get_transforms(augmention=True):
  370. transforms_list = []
  371. if augmention:
  372. # transforms_list.append(ColorJitter())
  373. transforms_list.append(RandomGrayscale(0.1))
  374. transforms_list.append(GaussianBlur())
  375. # transforms_list.append(RandomErasing())
  376. transforms_list.append(RandomHorizontalFlip(0.5))
  377. transforms_list.append(RandomVerticalFlip(0.5))
  378. # transforms_list.append(RandomPerspective())
  379. transforms_list.append(RandomRotation(degrees=15))
  380. # transforms_list.append(RandomResize(512, 2048))
  381. # transforms_list.append(RandomCrop((512,512)))
  382. transforms_list.append(DefaultTransform())
  383. return Compose(transforms_list)