lib.rs 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #![allow(clippy::type_complexity)]
  2. use std::io::{Read, Write};
  3. pub mod cli;
  4. pub mod model;
  5. pub mod ort_backend;
  6. pub mod yolo_result;
  7. pub use crate::cli::Args;
  8. pub use crate::model::YOLOv8;
  9. pub use crate::ort_backend::{Batch, OrtBackend, OrtConfig, OrtEP, YOLOTask};
  10. pub use crate::yolo_result::{Bbox, Embedding, Point2, YOLOResult};
  11. pub fn non_max_suppression(
  12. xs: &mut Vec<(Bbox, Option<Vec<Point2>>, Option<Vec<f32>>)>,
  13. iou_threshold: f32,
  14. ) {
  15. xs.sort_by(|b1, b2| b2.0.confidence().partial_cmp(&b1.0.confidence()).unwrap());
  16. let mut current_index = 0;
  17. for index in 0..xs.len() {
  18. let mut drop = false;
  19. for prev_index in 0..current_index {
  20. let iou = xs[prev_index].0.iou(&xs[index].0);
  21. if iou > iou_threshold {
  22. drop = true;
  23. break;
  24. }
  25. }
  26. if !drop {
  27. xs.swap(current_index, index);
  28. current_index += 1;
  29. }
  30. }
  31. xs.truncate(current_index);
  32. }
  33. pub fn gen_time_string(delimiter: &str) -> String {
  34. let offset = chrono::FixedOffset::east_opt(8 * 60 * 60).unwrap(); // Beijing
  35. let t_now = chrono::Utc::now().with_timezone(&offset);
  36. let fmt = format!(
  37. "%Y{}%m{}%d{}%H{}%M{}%S{}%f",
  38. delimiter, delimiter, delimiter, delimiter, delimiter, delimiter
  39. );
  40. t_now.format(&fmt).to_string()
  41. }
  42. pub const SKELETON: [(usize, usize); 16] = [
  43. (0, 1),
  44. (0, 2),
  45. (1, 3),
  46. (2, 4),
  47. (5, 6),
  48. (5, 11),
  49. (6, 12),
  50. (11, 12),
  51. (5, 7),
  52. (6, 8),
  53. (7, 9),
  54. (8, 10),
  55. (11, 13),
  56. (12, 14),
  57. (13, 15),
  58. (14, 16),
  59. ];
  60. pub fn check_font(font: &str) -> rusttype::Font<'static> {
  61. // check then load font
  62. // ultralytics font path
  63. let font_path_config = match dirs::config_dir() {
  64. Some(mut d) => {
  65. d.push("Ultralytics");
  66. d.push(font);
  67. d
  68. }
  69. None => panic!("Unsupported operating system. Now support Linux, MacOS, Windows."),
  70. };
  71. // current font path
  72. let font_path_current = std::path::PathBuf::from(font);
  73. // check font
  74. let font_path = if font_path_config.exists() {
  75. font_path_config
  76. } else if font_path_current.exists() {
  77. font_path_current
  78. } else {
  79. println!("Downloading font...");
  80. let source_url = "https://ultralytics.com/assets/Arial.ttf";
  81. let resp = ureq::get(source_url)
  82. .timeout(std::time::Duration::from_secs(500))
  83. .call()
  84. .unwrap_or_else(|err| panic!("> Failed to download font: {source_url}: {err:?}"));
  85. // read to buffer
  86. let mut buffer = vec![];
  87. let total_size = resp
  88. .header("Content-Length")
  89. .and_then(|s| s.parse::<u64>().ok())
  90. .unwrap();
  91. let _reader = resp
  92. .into_reader()
  93. .take(total_size)
  94. .read_to_end(&mut buffer)
  95. .unwrap();
  96. // save
  97. let _path = std::fs::File::create(font).unwrap();
  98. let mut writer = std::io::BufWriter::new(_path);
  99. writer.write_all(&buffer).unwrap();
  100. println!("Font saved at: {:?}", font_path_current.display());
  101. font_path_current
  102. };
  103. // load font
  104. let buffer = std::fs::read(font_path).unwrap();
  105. rusttype::Font::try_from_vec(buffer).unwrap()
  106. }
  107. use ab_glyph::FontArc;
  108. pub fn load_font() -> FontArc {
  109. use std::path::Path;
  110. let font_path = Path::new("./font/Arial.ttf");
  111. match font_path.try_exists() {
  112. Ok(true) => {
  113. let buffer = std::fs::read(font_path).unwrap();
  114. FontArc::try_from_vec(buffer).unwrap()
  115. }
  116. Ok(false) => {
  117. std::fs::create_dir_all("./font").unwrap();
  118. println!("Downloading font...");
  119. let source_url = "https://ultralytics.com/assets/Arial.ttf";
  120. let resp = ureq::get(source_url)
  121. .timeout(std::time::Duration::from_secs(500))
  122. .call()
  123. .unwrap_or_else(|err| panic!("> Failed to download font: {source_url}: {err:?}"));
  124. // read to buffer
  125. let mut buffer = vec![];
  126. let total_size = resp
  127. .header("Content-Length")
  128. .and_then(|s| s.parse::<u64>().ok())
  129. .unwrap();
  130. let _reader = resp
  131. .into_reader()
  132. .take(total_size)
  133. .read_to_end(&mut buffer)
  134. .unwrap();
  135. // save
  136. let mut fd = std::fs::File::create(font_path).unwrap();
  137. fd.write_all(&buffer).unwrap();
  138. println!("Font saved at: {:?}", font_path.display());
  139. FontArc::try_from_vec(buffer).unwrap()
  140. }
  141. Err(e) => {
  142. panic!("Failed to load font {}", e);
  143. }
  144. }
  145. }