_box_convert.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import torch
  2. from torch import Tensor
  3. def _box_cxcywh_to_xyxy(boxes: Tensor) -> Tensor:
  4. """
  5. Converts bounding boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format.
  6. (cx, cy) refers to center of bounding box
  7. (w, h) are width and height of bounding box
  8. Args:
  9. boxes (Tensor[N, 4]): boxes in (cx, cy, w, h) format which will be converted.
  10. Returns:
  11. boxes (Tensor(N, 4)): boxes in (x1, y1, x2, y2) format.
  12. """
  13. # We need to change all 4 of them so some temporary variable is needed.
  14. cx, cy, w, h = boxes.unbind(-1)
  15. x1 = cx - 0.5 * w
  16. y1 = cy - 0.5 * h
  17. x2 = cx + 0.5 * w
  18. y2 = cy + 0.5 * h
  19. boxes = torch.stack((x1, y1, x2, y2), dim=-1)
  20. return boxes
  21. def _box_xyxy_to_cxcywh(boxes: Tensor) -> Tensor:
  22. """
  23. Converts bounding boxes from (x1, y1, x2, y2) format to (cx, cy, w, h) format.
  24. (x1, y1) refer to top left of bounding box
  25. (x2, y2) refer to bottom right of bounding box
  26. Args:
  27. boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format which will be converted.
  28. Returns:
  29. boxes (Tensor(N, 4)): boxes in (cx, cy, w, h) format.
  30. """
  31. x1, y1, x2, y2 = boxes.unbind(-1)
  32. cx = (x1 + x2) / 2
  33. cy = (y1 + y2) / 2
  34. w = x2 - x1
  35. h = y2 - y1
  36. boxes = torch.stack((cx, cy, w, h), dim=-1)
  37. return boxes
  38. def _box_xywh_to_xyxy(boxes: Tensor) -> Tensor:
  39. """
  40. Converts bounding boxes from (x, y, w, h) format to (x1, y1, x2, y2) format.
  41. (x, y) refers to top left of bounding box.
  42. (w, h) refers to width and height of box.
  43. Args:
  44. boxes (Tensor[N, 4]): boxes in (x, y, w, h) which will be converted.
  45. Returns:
  46. boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format.
  47. """
  48. x, y, w, h = boxes.unbind(-1)
  49. boxes = torch.stack([x, y, x + w, y + h], dim=-1)
  50. return boxes
  51. def _box_xyxy_to_xywh(boxes: Tensor) -> Tensor:
  52. """
  53. Converts bounding boxes from (x1, y1, x2, y2) format to (x, y, w, h) format.
  54. (x1, y1) refer to top left of bounding box
  55. (x2, y2) refer to bottom right of bounding box
  56. Args:
  57. boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) which will be converted.
  58. Returns:
  59. boxes (Tensor[N, 4]): boxes in (x, y, w, h) format.
  60. """
  61. x1, y1, x2, y2 = boxes.unbind(-1)
  62. w = x2 - x1 # x2 - x1
  63. h = y2 - y1 # y2 - y1
  64. boxes = torch.stack((x1, y1, w, h), dim=-1)
  65. return boxes