transforms.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. import logging
  2. import random
  3. from typing import Any,Tuple
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from PIL import Image
  8. from torch import nn, Tensor
  9. from libs.vision_libs .transforms import functional as F
  10. from libs.vision_libs import transforms
  11. class Compose:
  12. def __init__(self, transforms):
  13. self.transforms = transforms
  14. def __call__(self, img, target):
  15. for t in self.transforms:
  16. img, target = t(img, target)
  17. return img, target
  18. class RandomHorizontalFlip:
  19. def __init__(self, prob=0.5):
  20. self.prob = prob
  21. def __call__(self, img, target):
  22. if random.random() < self.prob:
  23. width = img.width if isinstance(img, Image.Image) else img.shape[-1]
  24. # Flip image
  25. img = F.hflip(img)
  26. # Flip boxes
  27. boxes = target["boxes"]
  28. x1, y1, x2, y2 = boxes.unbind(dim=1)
  29. boxes_flipped = torch.stack((width - x2, y1, width - x1, y2), dim=1)
  30. target["boxes"] = boxes_flipped
  31. # Flip lines
  32. if "lines" in target:
  33. lines = target["lines"].clone()
  34. # 只翻转 x 坐标,y 和 visibility 不变
  35. lines[..., 0] = width - lines[..., 0]
  36. target["lines"] = lines
  37. # Flip lines
  38. if "circle_masks" in target:
  39. lines = target["circle_masks"].clone()
  40. # 只翻转 x 坐标,y 和 visibility 不变
  41. lines[..., 0] = width - lines[..., 0]
  42. target["circle_masks"] = lines
  43. return img, target
  44. class RandomVerticalFlip:
  45. def __init__(self, prob=0.5):
  46. self.prob = prob
  47. def __call__(self, img, target):
  48. if random.random() < self.prob:
  49. height = img.height if isinstance(img, Image.Image) else img.shape[-2]
  50. # Flip image
  51. img = F.vflip(img)
  52. # Flip boxes
  53. boxes = target["boxes"]
  54. x1, y1, x2, y2 = boxes.unbind(dim=1)
  55. boxes_flipped = torch.stack((x1, height - y2, x2, height - y1), dim=1)
  56. target["boxes"] = boxes_flipped
  57. # Flip lines
  58. if "lines" in target:
  59. lines = target["lines"].clone()
  60. lines[..., 1] = height - lines[..., 1]
  61. target["lines"] = lines
  62. if "circle_masks" in target:
  63. lines = target["circle_masks"].clone()
  64. lines[..., 1] = height - lines[..., 1]
  65. target["circle_masks"] = lines
  66. return img, target
  67. class ColorJitter:
  68. def __init__(self, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2):
  69. if not (0 <= hue <= 0.5):
  70. raise ValueError(f"Hue jitter value should be in [0, 0.5], but got {hue}")
  71. self.color_jitter = transforms.ColorJitter(
  72. brightness=brightness,
  73. contrast=contrast,
  74. saturation=saturation,
  75. hue=hue
  76. )
  77. def __call__(self, img, target):
  78. print(f"Original image type: {type(img)}")
  79. img = self.color_jitter(img)
  80. print("Color jitter applied successfully.")
  81. return img, target
  82. class RandomGrayscale:
  83. def __init__(self, p=0.1):
  84. self.p = p
  85. def __call__(self, img, target):
  86. print(f"RandomGrayscale Original image type: {type(img)}")
  87. if random.random() < self.p:
  88. img = F.to_grayscale(img, num_output_channels=3)
  89. return img, target
  90. class RandomResize:
  91. def __init__(self, min_size, max_size=None):
  92. self.min_size = min_size
  93. self.max_size = max_size
  94. def __call__(self, img, target):
  95. size = random.randint(self.min_size, self.max_size) if self.max_size else self.min_size
  96. w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
  97. scale = size / min(h, w)
  98. new_h, new_w = int(scale * h), int(scale * w)
  99. img = F.resize(img, (new_h, new_w))
  100. # Update boxes
  101. boxes = target["boxes"]
  102. boxes = boxes * scale
  103. target["boxes"] = boxes
  104. # Update lines
  105. if "lines" in target:
  106. target["lines"] = target["lines"] * torch.tensor([scale, scale, 1], device=target["lines"].device)
  107. if "circle_masks" in target:
  108. target["circle_masks"] = target["circle_masks"] * torch.tensor([scale, scale, 1], device=target["circle_masks"].device)
  109. return img, target
  110. class RandomCrop:
  111. def __init__(self, size):
  112. self.size = size
  113. def __call__(self, img, target):
  114. width, height = F.get_image_size(img)
  115. crop_height, crop_width = self.size
  116. # 随机选择裁剪区域
  117. left = random.randint(0, max(width - crop_width, 0))
  118. top = random.randint(0, max(height - crop_height, 0))
  119. right = min(left + crop_width, width)
  120. bottom = min(top + crop_height, height)
  121. # 裁剪图像
  122. img = F.crop(img, top, left, bottom - top, right - left)
  123. if "boxes" in target:
  124. boxes = target["boxes"]
  125. labels = target["labels"] if "labels" in target else None
  126. # 将bounding boxes转换到裁剪区域坐标系
  127. cropped_boxes = boxes.clone()
  128. cropped_boxes[:, 0::2] -= left
  129. cropped_boxes[:, 1::2] -= top
  130. # 确保bounding boxes在裁剪区域内
  131. cropped_boxes[:, 0::2].clamp_(min=0, max=crop_width)
  132. cropped_boxes[:, 1::2].clamp_(min=0, max=crop_height)
  133. # 计算新的宽高
  134. w = cropped_boxes[:, 2] - cropped_boxes[:, 0]
  135. h = cropped_boxes[:, 3] - cropped_boxes[:, 1]
  136. # 过滤掉无效的bounding boxes(宽度或高度为0)
  137. valid_boxes_mask = (w > 0) & (h > 0)
  138. # 更新有效bounding boxes
  139. cropped_boxes = cropped_boxes[valid_boxes_mask]
  140. if labels is not None:
  141. labels = labels[valid_boxes_mask]
  142. # 更新target
  143. target["boxes"] = cropped_boxes
  144. if labels is not None:
  145. target["labels"] = labels
  146. return img, target
  147. class GaussianBlur:
  148. def __init__(self, kernel_size=5, sigma=(0.1, 2.0), prob=0.2):
  149. self.kernel_size = kernel_size if kernel_size % 2 == 1 else kernel_size + 1 # Ensure kernel size is odd
  150. self.sigma = sigma
  151. self.prob = prob
  152. def __call__(self, img, target):
  153. if random.random() < self.prob:
  154. # Convert PIL Image to Tensor if necessary
  155. if isinstance(img, Image.Image):
  156. img = transforms.ToTensor()(img)
  157. # Apply Gaussian blur using PyTorch's functional interface
  158. img = transforms.GaussianBlur(kernel_size=self.kernel_size, sigma=random.uniform(*self.sigma))(img)
  159. # If the original image was a PIL Image, convert it back
  160. if isinstance(img, Tensor) and not isinstance(target.get('original_image_format', None), Tensor):
  161. img = transforms.ToPILImage()(img)
  162. return img, target
  163. class RandomRotation:
  164. def __init__(self, degrees=15, prob=0.5):
  165. self.degrees = degrees
  166. self.prob = prob
  167. def rotate_boxes(self, boxes, angle, center):
  168. # Convert to numpy for easier rotation math
  169. boxes_np = boxes.cpu().numpy()
  170. center_np = np.array(center)
  171. corners = np.array([
  172. [boxes_np[:, 0], boxes_np[:, 1]], # top-left
  173. [boxes_np[:, 2], boxes_np[:, 1]], # top-right
  174. [boxes_np[:, 2], boxes_np[:, 3]], # bottom-right
  175. [boxes_np[:, 0], boxes_np[:, 3]] # bottom-left
  176. ]).transpose(2, 0, 1) # shape: (N, 4, 2)
  177. # Translate to origin
  178. corners -= center_np
  179. # Rotate points
  180. theta = np.radians(angle)
  181. c, s = np.cos(theta), np.sin(theta)
  182. R = np.array([[c, -s], [s, c]])
  183. rotated_corners = corners @ R
  184. # Translate back
  185. rotated_corners += center_np
  186. # Get new bounding box coordinates
  187. x_min = np.min(rotated_corners[:, :, 0], axis=1)
  188. y_min = np.min(rotated_corners[:, :, 1], axis=1)
  189. x_max = np.max(rotated_corners[:, :, 0], axis=1)
  190. y_max = np.max(rotated_corners[:, :, 1], axis=1)
  191. # Convert back to tensor and move to the same device
  192. device = boxes.device
  193. return torch.tensor(np.stack([x_min, y_min, x_max, y_max], axis=1), dtype=boxes.dtype, device=device)
  194. def rotate_lines(self, lines, angle, center):
  195. coords = lines[..., :2] # shape: (..., 2)
  196. visibility = lines[..., 2:] # shape: (..., N)
  197. # Translate to origin
  198. coords = coords - torch.tensor(center, dtype=coords.dtype, device=coords.device)
  199. # Rotation matrix
  200. theta = torch.deg2rad(torch.tensor(angle))
  201. cos_t = torch.cos(theta)
  202. sin_t = torch.sin(theta)
  203. R = torch.tensor([[cos_t, -sin_t], [sin_t, cos_t]], dtype=coords.dtype, device=coords.device)
  204. # Apply rotation using torch.matmul
  205. rotated_coords = torch.matmul(coords, R)
  206. # Translate back
  207. rotated_coords = rotated_coords + torch.tensor(center, dtype=coords.dtype, device=coords.device)
  208. # Concatenate with visibility
  209. rotated_lines = torch.cat([rotated_coords, visibility], dim=-1)
  210. return rotated_lines
  211. def __call__(self, img, target):
  212. if random.random() < self.prob:
  213. angle = random.uniform(-self.degrees, self.degrees)
  214. w, h = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
  215. center = (w / 2, h / 2)
  216. # Rotate image
  217. img = F.rotate(img, angle, center=center)
  218. # Rotate boxes
  219. if "boxes" in target:
  220. target["boxes"] = self.rotate_boxes(target["boxes"], angle, center)
  221. # Rotate lines
  222. if "lines" in target:
  223. target["lines"] = self.rotate_lines(target["lines"], angle, center)
  224. if "circle_masks" in target:
  225. target["circle_masks"] = self.rotate_lines(target["circle_masks"], angle, center)
  226. return img, target
  227. class RandomErasing:
  228. def __init__(self, prob=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.485, 0.456, 0.406]):
  229. """
  230. :param prob: 应用擦除的概率
  231. :param sl: 擦除面积比例的下界
  232. :param sh: 擦除面积比例的上界
  233. :param r1: 长宽比的下界
  234. :param mean: 用于填充擦除区域的像素值
  235. """
  236. self.prob = prob
  237. self.sl = sl
  238. self.sh = sh
  239. self.r1 = r1
  240. self.mean = mean
  241. def __call__(self, img, target):
  242. if random.random() < self.prob:
  243. # 如果是Tensor,则直接处理
  244. if isinstance(img, torch.Tensor):
  245. img = self._erase_tensor(img)
  246. # 如果是PIL Image,则转换为Tensor处理后再转回PIL Image
  247. elif isinstance(img, Image.Image):
  248. img_tensor = transforms.ToTensor()(img)
  249. img_tensor = self._erase_tensor(img_tensor)
  250. img = transforms.ToPILImage()(img_tensor)
  251. return img, target
  252. def _erase_tensor(self, img_tensor):
  253. """
  254. 对Tensor类型的图像执行随机擦除
  255. """
  256. img_c, img_h, img_w = img_tensor.shape
  257. area = img_h * img_w
  258. # 计算擦除区域的大小
  259. erase_area = random.uniform(self.sl, self.sh) * area
  260. aspect_ratio = random.uniform(self.r1, 1 / self.r1)
  261. h = int(round((erase_area * aspect_ratio) ** 0.5))
  262. w = int(round((erase_area / aspect_ratio) ** 0.5))
  263. # 确保不会超出图像边界
  264. if h < img_h and w < img_w:
  265. x = random.randint(0, img_w - w)
  266. y = random.randint(0, img_h - h)
  267. # 创建一个与擦除区域相同大小且填充指定均值的区域
  268. mean_tensor = torch.tensor(self.mean).view(img_c, 1, 1).expand(img_c, h, w)
  269. # 将该区域应用到原始图像上
  270. img_tensor[:, y:y + h, x:x + w] = mean_tensor
  271. return img_tensor
  272. """
  273. 有Bugs
  274. """
  275. class RandomPerspective:
  276. def __init__(self, distortion_scale=0.5, p=0.5):
  277. self.distortion_scale = distortion_scale
  278. self.p = p
  279. def _get_perspective_params(self, width, height, distortion_scale):
  280. half_w = width // 2
  281. half_h = height // 2
  282. w = int(width * distortion_scale)
  283. h = int(height * distortion_scale)
  284. startpoints = [
  285. [0, 0],
  286. [width - 1, 0],
  287. [width - 1, height - 1],
  288. [0, height - 1]
  289. ]
  290. endpoints = [
  291. [random.randint(0, w), random.randint(0, h)],
  292. [width - 1 - random.randint(0, w), random.randint(0, h)],
  293. [width - 1 - random.randint(0, w), height - 1 - random.randint(0, h)],
  294. [random.randint(0, w), height - 1 - random.randint(0, h)]
  295. ]
  296. return startpoints, endpoints
  297. def perspective_boxes(self, boxes, M, width, height):
  298. # 将boxes转换为角点形式
  299. corners = np.array([
  300. [boxes[:, 0], boxes[:, 1]], # top-left
  301. [boxes[:, 2], boxes[:, 1]], # top-right
  302. [boxes[:, 2], boxes[:, 3]], # bottom-right
  303. [boxes[:, 0], boxes[:, 3]] # bottom-left
  304. ]).transpose(2, 0, 1).reshape(-1, 2) # shape: (N*4, 2)
  305. # 应用透视变换
  306. ones = np.ones((corners.shape[0], 1))
  307. coords_homogeneous = np.hstack([corners, ones])
  308. transformed_coords = (M @ coords_homogeneous.T).T
  309. transformed_coords /= transformed_coords[:, 2].reshape(-1, 1) # 齐次除法
  310. transformed_coords = transformed_coords[:, :2]
  311. # 重新组合成bounding box
  312. transformed_coords = transformed_coords.reshape(-1, 4, 2)
  313. x_min = np.min(transformed_coords[:, :, 0], axis=1)
  314. y_min = np.min(transformed_coords[:, :, 1], axis=1)
  315. x_max = np.max(transformed_coords[:, :, 0], axis=1)
  316. y_max = np.max(transformed_coords[:, :, 1], axis=1)
  317. # 裁剪到图像范围内
  318. x_min = np.clip(x_min, 0, width)
  319. y_min = np.clip(y_min, 0, height)
  320. x_max = np.clip(x_max, 0, width)
  321. y_max = np.clip(y_max, 0, height)
  322. return torch.tensor(np.stack([x_min, y_min, x_max, y_max], axis=1), dtype=boxes.dtype, device=boxes.device)
  323. def perspective_lines(self, lines, M, width, height):
  324. # 提取坐标和可见性标志
  325. coords = lines[..., :2].cpu().numpy() # Shape: (N, L, 2)
  326. visibility = lines[..., 2:]
  327. # 确保coords是二维数组,如果它是三维的,则将其重塑为二维
  328. original_shape = coords.shape
  329. coords_reshaped = coords.reshape(-1, 2) # Reshape to (N*L, 2)
  330. # 添加齐次坐标
  331. ones = np.ones((coords_reshaped.shape[0], 1))
  332. coords_homogeneous = np.hstack([coords_reshaped, ones]) # Shape: (N*L, 3)
  333. # 应用透视变换矩阵
  334. transformed_coords_homogeneous = np.dot(M, coords_homogeneous.T).T
  335. transformed_coords = transformed_coords_homogeneous[:, :2] / transformed_coords_homogeneous[:, 2:] # 归一化
  336. # 将变换后的坐标恢复到原始形状
  337. transformed_coords = transformed_coords.reshape(original_shape) # Reshape back to (N, L, 2)
  338. # 裁剪到图像范围内
  339. transformed_coords = np.clip(transformed_coords, [0, 0], [width, height])
  340. # 转换回tensor
  341. transformed_coords = torch.tensor(transformed_coords, dtype=lines.dtype, device=lines.device)
  342. return torch.cat([transformed_coords, visibility], dim=-1)
  343. def __call__(self, img, target):
  344. if random.random() < self.p:
  345. width, height = img.size if isinstance(img, Image.Image) else (img.shape[2], img.shape[1])
  346. startpoints, endpoints = self._get_perspective_params(width, height, self.distortion_scale)
  347. # 使用 OpenCV 计算透视变换矩阵
  348. M = cv2.getPerspectiveTransform(
  349. np.float32(startpoints),
  350. np.float32(endpoints)
  351. )
  352. # 对图像应用透视变换
  353. if isinstance(img, Image.Image):
  354. img = img.transform((width, height), Image.PERSPECTIVE, M.flatten(), resample=Image.BILINEAR)
  355. elif isinstance(img, torch.Tensor):
  356. # 如果你需要用 TorchVision 实现,可以考虑使用 F.perspective,但更推荐配合PIL操作
  357. pil_img = F.to_pil_image(img)
  358. pil_img = pil_img.transform((width, height), Image.PERSPECTIVE, M.flatten(), resample=Image.BILINEAR)
  359. img = F.to_tensor(pil_img)
  360. # 对 boxes 变换
  361. if "boxes" in target:
  362. target["boxes"] = self.perspective_boxes(target["boxes"], M, width, height)
  363. # 对 lines 变换
  364. if "lines" in target:
  365. target["lines"] = self.perspective_lines(target["lines"], M, width, height)
  366. if "circle_masks" in target:
  367. target["circle_masks"] = self.perspective_lines(target["circle_masks"], M, width, height)
  368. return img, target
  369. class DefaultTransform(nn.Module):
  370. def forward(self, img: Tensor, target) -> Tuple[Tensor, Any]:
  371. if not isinstance(img, Tensor):
  372. img = F.pil_to_tensor(img)
  373. return F.convert_image_dtype(img, torch.float),target
  374. def __repr__(self) -> str:
  375. return self.__class__.__name__ + "()"
  376. def describe(self) -> str:
  377. return (
  378. "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
  379. "The images are rescaled to ``[0.0, 1.0]``."
  380. )
  381. class ToTensor:
  382. def __call__(self, img, target):
  383. img = F.to_tensor(img)
  384. return img, target
  385. def get_transforms(augmention=True):
  386. transforms_list = []
  387. if augmention:
  388. # transforms_list.append(ColorJitter())
  389. transforms_list.append(RandomGrayscale(0.1))
  390. transforms_list.append(GaussianBlur())
  391. # transforms_list.append(RandomErasing())
  392. transforms_list.append(RandomHorizontalFlip(0.5))
  393. transforms_list.append(RandomVerticalFlip(0.5))
  394. # transforms_list.append(RandomPerspective())
  395. transforms_list.append(RandomRotation(degrees=15))
  396. # transforms_list.append(RandomResize(512, 2048))
  397. # transforms_list.append(RandomCrop((512,512)))
  398. transforms_list.append(DefaultTransform())
  399. return Compose(transforms_list)