transforms.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. import logging
  2. import random
  3. from typing import Any
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from PIL import Image
  8. from torch import nn, Tensor
  9. from libs.vision_libs .transforms import functional as F
  10. from libs.vision_libs import transforms
  11. class Compose:
  12. def __init__(self, transforms):
  13. self.transforms = transforms
  14. def __call__(self, img, target):
  15. for t in self.transforms:
  16. img, target = t(img, target)
  17. return img, target
  18. class RandomHorizontalFlip:
  19. def __init__(self, prob=0.5):
  20. self.prob = prob
  21. def __call__(self, img, target):
  22. if random.random() < self.prob:
  23. width = img.width if isinstance(img, Image.Image) else img.shape[-1]
  24. # Flip image
  25. img = F.hflip(img)
  26. # Flip boxes
  27. boxes = target["boxes"]
  28. x1, y1, x2, y2 = boxes.unbind(dim=1)
  29. boxes_flipped = torch.stack((width - x2, y1, width - x1, y2), dim=1)
  30. target["boxes"] = boxes_flipped
  31. # Flip lines
  32. if "lines" in target:
  33. lines = target["lines"].clone()
  34. # 只翻转 x 坐标,y 和 visibility 不变
  35. lines[..., 0] = width - lines[..., 0]
  36. target["lines"] = lines
  37. return img, target
  38. class RandomVerticalFlip:
  39. def __init__(self, prob=0.5):
  40. self.prob = prob
  41. def __call__(self, img, target):
  42. if random.random() < self.prob:
  43. height = img.height if isinstance(img, Image.Image) else img.shape[-2]
  44. # Flip image
  45. img = F.vflip(img)
  46. # Flip boxes
  47. boxes = target["boxes"]
  48. x1, y1, x2, y2 = boxes.unbind(dim=1)
  49. boxes_flipped = torch.stack((x1, height - y2, x2, height - y1), dim=1)
  50. target["boxes"] = boxes_flipped
  51. # Flip lines
  52. if "lines" in target:
  53. lines = target["lines"].clone()
  54. lines[..., 1] = height - lines[..., 1]
  55. target["lines"] = lines
  56. return img, target
  57. class ColorJitter:
  58. def __init__(self, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2):
  59. if not (0 <= hue <= 0.5):
  60. raise ValueError(f"Hue jitter value should be in [0, 0.5], but got {hue}")
  61. self.color_jitter = transforms.ColorJitter(
  62. brightness=brightness,
  63. contrast=contrast,
  64. saturation=saturation,
  65. hue=hue
  66. )
  67. def __call__(self, img, target):
  68. print(f"Original image type: {type(img)}")
  69. img = self.color_jitter(img)
  70. print("Color jitter applied successfully.")
  71. return img, target
  72. class RandomGrayscale:
  73. def __init__(self, p=0.1):
  74. self.p = p
  75. def __call__(self, img, target):
  76. print(f"RandomGrayscale Original image type: {type(img)}")
  77. if random.random() < self.p:
  78. img = F.to_grayscale(img, num_output_channels=3)
  79. return img, target
  80. class RandomResize:
  81. def __init__(self, min_size, max_size=None):
  82. self.min_size = min_size
  83. self.max_size = max_size
  84. def __call__(self, img, target):
  85. size = random.randint(self.min_size, self.max_size) if self.max_size else self.min_size
  86. w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
  87. scale = size / min(h, w)
  88. new_h, new_w = int(scale * h), int(scale * w)
  89. img = F.resize(img, (new_h, new_w))
  90. # Update boxes
  91. boxes = target["boxes"]
  92. boxes = boxes * scale
  93. target["boxes"] = boxes
  94. # Update lines
  95. if "lines" in target:
  96. target["lines"] = target["lines"] * torch.tensor([scale, scale, 1], device=target["lines"].device)
  97. return img, target
  98. class RandomCrop:
  99. def __init__(self, size):
  100. self.size = size
  101. def __call__(self, img, target):
  102. w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
  103. th, tw = self.size
  104. if h <= th or w <= tw:
  105. return img, target
  106. i = random.randint(0, h - th)
  107. j = random.randint(0, w - tw)
  108. img = F.crop(img, i, j, th, tw)
  109. # Adjust boxes
  110. boxes = target["boxes"]
  111. boxes = boxes - torch.tensor([j, i, j, i], device=boxes.device)
  112. boxes = torch.clamp(boxes, min=0)
  113. xmax, ymax = tw, th
  114. boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(max=xmax)
  115. boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(max=ymax)
  116. target["boxes"] = boxes
  117. # Adjust lines
  118. if "lines" in target:
  119. lines = target["lines"].clone()
  120. lines[..., 0] -= j
  121. lines[..., 1] -= i
  122. lines = torch.clamp(lines, min=0)
  123. lines[..., 0] = torch.clamp(lines[..., 0], max=tw)
  124. lines[..., 1] = torch.clamp(lines[..., 1], max=th)
  125. target["lines"] = lines
  126. return img, target
  127. class GaussianBlur:
  128. def __init__(self, kernel_size=5, sigma=(0.1, 2.0), prob=0.2):
  129. self.kernel_size = kernel_size if kernel_size % 2 == 1 else kernel_size + 1 # Ensure kernel size is odd
  130. self.sigma = sigma
  131. self.prob = prob
  132. def __call__(self, img, target):
  133. if random.random() < self.prob:
  134. # Convert PIL Image to Tensor if necessary
  135. if isinstance(img, Image.Image):
  136. img = transforms.ToTensor()(img)
  137. # Apply Gaussian blur using PyTorch's functional interface
  138. img = transforms.GaussianBlur(kernel_size=self.kernel_size, sigma=random.uniform(*self.sigma))(img)
  139. # If the original image was a PIL Image, convert it back
  140. if isinstance(img, Tensor) and not isinstance(target.get('original_image_format', None), Tensor):
  141. img = transforms.ToPILImage()(img)
  142. return img, target
  143. class RandomRotation:
  144. def __init__(self, degrees=15, prob=0.5):
  145. self.degrees = degrees
  146. self.prob = prob
  147. def rotate_boxes(self, boxes, angle, center):
  148. # Convert to numpy for easier rotation math
  149. boxes_np = boxes.cpu().numpy()
  150. center_np = np.array(center)
  151. corners = np.array([
  152. [boxes_np[:, 0], boxes_np[:, 1]], # top-left
  153. [boxes_np[:, 2], boxes_np[:, 1]], # top-right
  154. [boxes_np[:, 2], boxes_np[:, 3]], # bottom-right
  155. [boxes_np[:, 0], boxes_np[:, 3]] # bottom-left
  156. ]).transpose(2, 0, 1) # shape: (N, 4, 2)
  157. # Translate to origin
  158. corners -= center_np
  159. # Rotate points
  160. theta = np.radians(angle)
  161. c, s = np.cos(theta), np.sin(theta)
  162. R = np.array([[c, -s], [s, c]])
  163. rotated_corners = corners @ R
  164. # Translate back
  165. rotated_corners += center_np
  166. # Get new bounding box coordinates
  167. x_min = np.min(rotated_corners[:, :, 0], axis=1)
  168. y_min = np.min(rotated_corners[:, :, 1], axis=1)
  169. x_max = np.max(rotated_corners[:, :, 0], axis=1)
  170. y_max = np.max(rotated_corners[:, :, 1], axis=1)
  171. # Convert back to tensor and move to the same device
  172. device = boxes.device
  173. return torch.tensor(np.stack([x_min, y_min, x_max, y_max], axis=1), dtype=boxes.dtype, device=device)
  174. def rotate_lines(self, lines, angle, center):
  175. coords = lines[..., :2] # shape: (..., 2)
  176. visibility = lines[..., 2:] # shape: (..., N)
  177. # Translate to origin
  178. coords = coords - torch.tensor(center, dtype=coords.dtype, device=coords.device)
  179. # Rotation matrix
  180. theta = torch.deg2rad(torch.tensor(angle))
  181. cos_t = torch.cos(theta)
  182. sin_t = torch.sin(theta)
  183. R = torch.tensor([[cos_t, -sin_t], [sin_t, cos_t]], dtype=coords.dtype, device=coords.device)
  184. # Apply rotation using torch.matmul
  185. rotated_coords = torch.matmul(coords, R)
  186. # Translate back
  187. rotated_coords = rotated_coords + torch.tensor(center, dtype=coords.dtype, device=coords.device)
  188. # Concatenate with visibility
  189. rotated_lines = torch.cat([rotated_coords, visibility], dim=-1)
  190. return rotated_lines
  191. def __call__(self, img, target):
  192. if random.random() < self.prob:
  193. angle = random.uniform(-self.degrees, self.degrees)
  194. w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
  195. center = (w / 2, h / 2)
  196. # Rotate image
  197. img = F.rotate(img, angle, center=center)
  198. # Rotate boxes
  199. if "boxes" in target:
  200. target["boxes"] = self.rotate_boxes(target["boxes"], angle, center)
  201. # Rotate lines
  202. if "lines" in target:
  203. target["lines"] = self.rotate_lines(target["lines"], angle, center)
  204. return img, target
  205. class RandomErasing:
  206. def __init__(self, prob=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.485, 0.456, 0.406]):
  207. """
  208. :param prob: 应用擦除的概率
  209. :param sl: 擦除面积比例的下界
  210. :param sh: 擦除面积比例的上界
  211. :param r1: 长宽比的下界
  212. :param mean: 用于填充擦除区域的像素值
  213. """
  214. self.prob = prob
  215. self.sl = sl
  216. self.sh = sh
  217. self.r1 = r1
  218. self.mean = mean
  219. def __call__(self, img, target):
  220. if random.random() < self.prob:
  221. # 如果是Tensor,则直接处理
  222. if isinstance(img, torch.Tensor):
  223. img = self._erase_tensor(img)
  224. # 如果是PIL Image,则转换为Tensor处理后再转回PIL Image
  225. elif isinstance(img, Image.Image):
  226. img_tensor = transforms.ToTensor()(img)
  227. img_tensor = self._erase_tensor(img_tensor)
  228. img = transforms.ToPILImage()(img_tensor)
  229. return img, target
  230. def _erase_tensor(self, img_tensor):
  231. """
  232. 对Tensor类型的图像执行随机擦除
  233. """
  234. img_c, img_h, img_w = img_tensor.shape
  235. area = img_h * img_w
  236. # 计算擦除区域的大小
  237. erase_area = random.uniform(self.sl, self.sh) * area
  238. aspect_ratio = random.uniform(self.r1, 1 / self.r1)
  239. h = int(round((erase_area * aspect_ratio) ** 0.5))
  240. w = int(round((erase_area / aspect_ratio) ** 0.5))
  241. # 确保不会超出图像边界
  242. if h < img_h and w < img_w:
  243. x = random.randint(0, img_w - w)
  244. y = random.randint(0, img_h - h)
  245. # 创建一个与擦除区域相同大小且填充指定均值的区域
  246. mean_tensor = torch.tensor(self.mean).view(img_c, 1, 1).expand(img_c, h, w)
  247. # 将该区域应用到原始图像上
  248. img_tensor[:, y:y + h, x:x + w] = mean_tensor
  249. return img_tensor
  250. """
  251. 有Bugs
  252. """
  253. class RandomPerspective:
  254. def __init__(self, distortion_scale=0.5, p=0.5):
  255. self.distortion_scale = distortion_scale
  256. self.p = p
  257. def _get_perspective_params(self, width, height, distortion_scale):
  258. half_w = width // 2
  259. half_h = height // 2
  260. w = int(width * distortion_scale)
  261. h = int(height * distortion_scale)
  262. startpoints = [
  263. [0, 0],
  264. [width - 1, 0],
  265. [width - 1, height - 1],
  266. [0, height - 1]
  267. ]
  268. endpoints = [
  269. [random.randint(0, w), random.randint(0, h)],
  270. [width - 1 - random.randint(0, w), random.randint(0, h)],
  271. [width - 1 - random.randint(0, w), height - 1 - random.randint(0, h)],
  272. [random.randint(0, w), height - 1 - random.randint(0, h)]
  273. ]
  274. return startpoints, endpoints
  275. def perspective_boxes(self, boxes, M, width, height):
  276. # 将boxes转换为角点形式
  277. corners = np.array([
  278. [boxes[:, 0], boxes[:, 1]], # top-left
  279. [boxes[:, 2], boxes[:, 1]], # top-right
  280. [boxes[:, 2], boxes[:, 3]], # bottom-right
  281. [boxes[:, 0], boxes[:, 3]] # bottom-left
  282. ]).transpose(2, 0, 1).reshape(-1, 2) # shape: (N*4, 2)
  283. # 应用透视变换
  284. ones = np.ones((corners.shape[0], 1))
  285. coords_homogeneous = np.hstack([corners, ones])
  286. transformed_coords = (M @ coords_homogeneous.T).T
  287. transformed_coords /= transformed_coords[:, 2].reshape(-1, 1) # 齐次除法
  288. transformed_coords = transformed_coords[:, :2]
  289. # 重新组合成bounding box
  290. transformed_coords = transformed_coords.reshape(-1, 4, 2)
  291. x_min = np.min(transformed_coords[:, :, 0], axis=1)
  292. y_min = np.min(transformed_coords[:, :, 1], axis=1)
  293. x_max = np.max(transformed_coords[:, :, 0], axis=1)
  294. y_max = np.max(transformed_coords[:, :, 1], axis=1)
  295. # 裁剪到图像范围内
  296. x_min = np.clip(x_min, 0, width)
  297. y_min = np.clip(y_min, 0, height)
  298. x_max = np.clip(x_max, 0, width)
  299. y_max = np.clip(y_max, 0, height)
  300. return torch.tensor(np.stack([x_min, y_min, x_max, y_max], axis=1), dtype=boxes.dtype, device=boxes.device)
  301. def perspective_lines(self, lines, M, width, height):
  302. # 提取坐标和可见性标志
  303. coords = lines[..., :2].cpu().numpy() # Shape: (N, L, 2)
  304. visibility = lines[..., 2:]
  305. # 确保coords是二维数组,如果它是三维的,则将其重塑为二维
  306. original_shape = coords.shape
  307. coords_reshaped = coords.reshape(-1, 2) # Reshape to (N*L, 2)
  308. # 添加齐次坐标
  309. ones = np.ones((coords_reshaped.shape[0], 1))
  310. coords_homogeneous = np.hstack([coords_reshaped, ones]) # Shape: (N*L, 3)
  311. # 应用透视变换矩阵
  312. transformed_coords_homogeneous = np.dot(M, coords_homogeneous.T).T
  313. transformed_coords = transformed_coords_homogeneous[:, :2] / transformed_coords_homogeneous[:, 2:] # 归一化
  314. # 将变换后的坐标恢复到原始形状
  315. transformed_coords = transformed_coords.reshape(original_shape) # Reshape back to (N, L, 2)
  316. # 裁剪到图像范围内
  317. transformed_coords = np.clip(transformed_coords, [0, 0], [width, height])
  318. # 转换回tensor
  319. transformed_coords = torch.tensor(transformed_coords, dtype=lines.dtype, device=lines.device)
  320. return torch.cat([transformed_coords, visibility], dim=-1)
  321. def __call__(self, img, target):
  322. if random.random() < self.p:
  323. width, height = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
  324. startpoints, endpoints = self._get_perspective_params(width, height, self.distortion_scale)
  325. # 使用 OpenCV 计算透视变换矩阵
  326. M = cv2.getPerspectiveTransform(
  327. np.float32(startpoints),
  328. np.float32(endpoints)
  329. )
  330. # 对图像应用透视变换
  331. if isinstance(img, Image.Image):
  332. img = img.transform((width, height), Image.PERSPECTIVE, M.flatten(), resample=Image.BILINEAR)
  333. elif isinstance(img, torch.Tensor):
  334. # 如果你需要用 TorchVision 实现,可以考虑使用 F.perspective,但更推荐配合PIL操作
  335. pil_img = F.to_pil_image(img)
  336. pil_img = pil_img.transform((width, height), Image.PERSPECTIVE, M.flatten(), resample=Image.BILINEAR)
  337. img = F.to_tensor(pil_img)
  338. # 对 boxes 变换
  339. if "boxes" in target:
  340. target["boxes"] = self.perspective_boxes(target["boxes"], M, width, height)
  341. # 对 lines 变换
  342. if "lines" in target:
  343. target["lines"] = self.perspective_lines(target["lines"], M, width, height)
  344. return img, target
  345. class DefaultTransform(nn.Module):
  346. def forward(self, img: Tensor,target) -> tuple[Tensor, Any]:
  347. if not isinstance(img, Tensor):
  348. img = F.pil_to_tensor(img)
  349. return F.convert_image_dtype(img, torch.float),target
  350. def __repr__(self) -> str:
  351. return self.__class__.__name__ + "()"
  352. def describe(self) -> str:
  353. return (
  354. "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
  355. "The images are rescaled to ``[0.0, 1.0]``."
  356. )
  357. class ToTensor:
  358. def __call__(self, img, target):
  359. img = F.to_tensor(img)
  360. return img, target
  361. def get_transforms(augmention=True):
  362. transforms_list = []
  363. if augmention:
  364. transforms_list.append(ColorJitter())
  365. transforms_list.append(RandomGrayscale(0.1))
  366. transforms_list.append(GaussianBlur())
  367. transforms_list.append(RandomErasing())
  368. transforms_list.append(RandomHorizontalFlip(0.5))
  369. transforms_list.append(RandomVerticalFlip(0.2))
  370. # transforms_list.append(RandomPerspective())
  371. transforms_list.append(RandomRotation(degrees=15))
  372. transforms_list.append(RandomResize(512, 2048))
  373. transforms_list.append(RandomCrop((512,512)))
  374. transforms_list.append(DefaultTransform())
  375. return Compose(transforms_list)