yolo_result.rs 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. use ndarray::{Array, Axis, IxDyn};
  2. #[derive(Clone, PartialEq, Default)]
  3. pub struct YOLOResult {
  4. // YOLO tasks results of an image
  5. pub probs: Option<Embedding>,
  6. pub bboxes: Option<Vec<Bbox>>,
  7. pub keypoints: Option<Vec<Vec<Point2>>>,
  8. pub masks: Option<Vec<Vec<u8>>>,
  9. }
  10. impl std::fmt::Debug for YOLOResult {
  11. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  12. f.debug_struct("YOLOResult")
  13. .field(
  14. "Probs(top5)",
  15. &format_args!("{:?}", self.probs().map(|probs| probs.topk(5))),
  16. )
  17. .field("Bboxes", &self.bboxes)
  18. .field("Keypoints", &self.keypoints)
  19. .field(
  20. "Masks",
  21. &format_args!("{:?}", self.masks().map(|masks| masks.len())),
  22. )
  23. .finish()
  24. }
  25. }
  26. impl YOLOResult {
  27. pub fn new(
  28. probs: Option<Embedding>,
  29. bboxes: Option<Vec<Bbox>>,
  30. keypoints: Option<Vec<Vec<Point2>>>,
  31. masks: Option<Vec<Vec<u8>>>,
  32. ) -> Self {
  33. Self {
  34. probs,
  35. bboxes,
  36. keypoints,
  37. masks,
  38. }
  39. }
  40. pub fn probs(&self) -> Option<&Embedding> {
  41. self.probs.as_ref()
  42. }
  43. pub fn keypoints(&self) -> Option<&Vec<Vec<Point2>>> {
  44. self.keypoints.as_ref()
  45. }
  46. pub fn masks(&self) -> Option<&Vec<Vec<u8>>> {
  47. self.masks.as_ref()
  48. }
  49. pub fn bboxes(&self) -> Option<&Vec<Bbox>> {
  50. self.bboxes.as_ref()
  51. }
  52. pub fn bboxes_mut(&mut self) -> Option<&mut Vec<Bbox>> {
  53. self.bboxes.as_mut()
  54. }
  55. }
  56. #[derive(Debug, PartialEq, Clone, Default)]
  57. pub struct Point2 {
  58. // A point2d with x, y, conf
  59. x: f32,
  60. y: f32,
  61. confidence: f32,
  62. }
  63. impl Point2 {
  64. pub fn new_with_conf(x: f32, y: f32, confidence: f32) -> Self {
  65. Self { x, y, confidence }
  66. }
  67. pub fn new(x: f32, y: f32) -> Self {
  68. Self {
  69. x,
  70. y,
  71. ..Default::default()
  72. }
  73. }
  74. pub fn x(&self) -> f32 {
  75. self.x
  76. }
  77. pub fn y(&self) -> f32 {
  78. self.y
  79. }
  80. pub fn confidence(&self) -> f32 {
  81. self.confidence
  82. }
  83. }
  84. #[derive(Debug, Clone, PartialEq, Default)]
  85. pub struct Embedding {
  86. // An float32 n-dims tensor
  87. data: Array<f32, IxDyn>,
  88. }
  89. impl Embedding {
  90. pub fn new(data: Array<f32, IxDyn>) -> Self {
  91. Self { data }
  92. }
  93. pub fn data(&self) -> &Array<f32, IxDyn> {
  94. &self.data
  95. }
  96. pub fn topk(&self, k: usize) -> Vec<(usize, f32)> {
  97. let mut probs = self
  98. .data
  99. .iter()
  100. .enumerate()
  101. .map(|(a, b)| (a, *b))
  102. .collect::<Vec<_>>();
  103. probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
  104. let mut topk = Vec::new();
  105. for &(id, confidence) in probs.iter().take(k) {
  106. topk.push((id, confidence));
  107. }
  108. topk
  109. }
  110. pub fn norm(&self) -> Array<f32, IxDyn> {
  111. let std_ = self.data.mapv(|x| x * x).sum_axis(Axis(0)).mapv(f32::sqrt);
  112. self.data.clone() / std_
  113. }
  114. pub fn top1(&self) -> (usize, f32) {
  115. self.topk(1)[0]
  116. }
  117. }
  118. #[derive(Debug, Clone, PartialEq, Default)]
  119. pub struct Bbox {
  120. // a bounding box around an object
  121. xmin: f32,
  122. ymin: f32,
  123. width: f32,
  124. height: f32,
  125. id: usize,
  126. confidence: f32,
  127. }
  128. impl Bbox {
  129. pub fn new_from_xywh(xmin: f32, ymin: f32, width: f32, height: f32) -> Self {
  130. Self {
  131. xmin,
  132. ymin,
  133. width,
  134. height,
  135. ..Default::default()
  136. }
  137. }
  138. pub fn new(xmin: f32, ymin: f32, width: f32, height: f32, id: usize, confidence: f32) -> Self {
  139. Self {
  140. xmin,
  141. ymin,
  142. width,
  143. height,
  144. id,
  145. confidence,
  146. }
  147. }
  148. pub fn width(&self) -> f32 {
  149. self.width
  150. }
  151. pub fn height(&self) -> f32 {
  152. self.height
  153. }
  154. pub fn xmin(&self) -> f32 {
  155. self.xmin
  156. }
  157. pub fn ymin(&self) -> f32 {
  158. self.ymin
  159. }
  160. pub fn xmax(&self) -> f32 {
  161. self.xmin + self.width
  162. }
  163. pub fn ymax(&self) -> f32 {
  164. self.ymin + self.height
  165. }
  166. pub fn tl(&self) -> Point2 {
  167. Point2::new(self.xmin, self.ymin)
  168. }
  169. pub fn br(&self) -> Point2 {
  170. Point2::new(self.xmax(), self.ymax())
  171. }
  172. pub fn cxcy(&self) -> Point2 {
  173. Point2::new(self.xmin + self.width / 2., self.ymin + self.height / 2.)
  174. }
  175. pub fn id(&self) -> usize {
  176. self.id
  177. }
  178. pub fn confidence(&self) -> f32 {
  179. self.confidence
  180. }
  181. pub fn area(&self) -> f32 {
  182. self.width * self.height
  183. }
  184. pub fn intersection_area(&self, another: &Bbox) -> f32 {
  185. let l = self.xmin.max(another.xmin);
  186. let r = (self.xmin + self.width).min(another.xmin + another.width);
  187. let t = self.ymin.max(another.ymin);
  188. let b = (self.ymin + self.height).min(another.ymin + another.height);
  189. (r - l + 1.).max(0.) * (b - t + 1.).max(0.)
  190. }
  191. pub fn union(&self, another: &Bbox) -> f32 {
  192. self.area() + another.area() - self.intersection_area(another)
  193. }
  194. pub fn iou(&self, another: &Bbox) -> f32 {
  195. self.intersection_area(another) / self.union(another)
  196. }
  197. }