main.rs 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. use anyhow::Result;
  2. use clap::Parser;
  3. use usls::{
  4. models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask,
  5. YOLOVersion, COCO_SKELETONS_16,
  6. };
  7. #[derive(Parser, Clone)]
  8. #[command(author, version, about, long_about = None)]
  9. pub struct Args {
  10. /// Path to the ONNX model
  11. #[arg(long)]
  12. pub model: Option<String>,
  13. /// Input source path
  14. #[arg(long, default_value_t = String::from("../../ultralytics/assets/bus.jpg"))]
  15. pub source: String,
  16. /// YOLO Task
  17. #[arg(long, value_enum, default_value_t = YOLOTask::Detect)]
  18. pub task: YOLOTask,
  19. /// YOLO Version
  20. #[arg(long, value_enum, default_value_t = YOLOVersion::V8)]
  21. pub ver: YOLOVersion,
  22. /// YOLO Scale
  23. #[arg(long, value_enum, default_value_t = YOLOScale::N)]
  24. pub scale: YOLOScale,
  25. /// Batch size
  26. #[arg(long, default_value_t = 1)]
  27. pub batch_size: usize,
  28. /// Minimum input width
  29. #[arg(long, default_value_t = 224)]
  30. pub width_min: isize,
  31. /// Input width
  32. #[arg(long, default_value_t = 640)]
  33. pub width: isize,
  34. /// Maximum input width
  35. #[arg(long, default_value_t = 1024)]
  36. pub width_max: isize,
  37. /// Minimum input height
  38. #[arg(long, default_value_t = 224)]
  39. pub height_min: isize,
  40. /// Input height
  41. #[arg(long, default_value_t = 640)]
  42. pub height: isize,
  43. /// Maximum input height
  44. #[arg(long, default_value_t = 1024)]
  45. pub height_max: isize,
  46. /// Number of classes
  47. #[arg(long, default_value_t = 80)]
  48. pub nc: usize,
  49. /// Class confidence
  50. #[arg(long)]
  51. pub confs: Vec<f32>,
  52. /// Enable TensorRT support
  53. #[arg(long)]
  54. pub trt: bool,
  55. /// Enable CUDA support
  56. #[arg(long)]
  57. pub cuda: bool,
  58. /// Enable CoreML support
  59. #[arg(long)]
  60. pub coreml: bool,
  61. /// Use TensorRT half precision
  62. #[arg(long)]
  63. pub half: bool,
  64. /// Device ID to use
  65. #[arg(long, default_value_t = 0)]
  66. pub device_id: usize,
  67. /// Enable performance profiling
  68. #[arg(long)]
  69. pub profile: bool,
  70. /// Disable contour drawing, for saving time
  71. #[arg(long)]
  72. pub no_contours: bool,
  73. /// Show result
  74. #[arg(long)]
  75. pub view: bool,
  76. /// Do not save output
  77. #[arg(long)]
  78. pub nosave: bool,
  79. }
  80. fn main() -> Result<()> {
  81. let args = Args::parse();
  82. // logger
  83. if args.profile {
  84. tracing_subscriber::fmt()
  85. .with_max_level(tracing::Level::INFO)
  86. .init();
  87. }
  88. // model path
  89. let path = match &args.model {
  90. None => format!(
  91. "yolo/{}-{}-{}.onnx",
  92. args.ver.name(),
  93. args.scale.name(),
  94. args.task.name()
  95. ),
  96. Some(x) => x.to_string(),
  97. };
  98. // saveout
  99. let saveout = match &args.model {
  100. None => format!(
  101. "{}-{}-{}",
  102. args.ver.name(),
  103. args.scale.name(),
  104. args.task.name()
  105. ),
  106. Some(x) => {
  107. let p = std::path::PathBuf::from(&x);
  108. p.file_stem().unwrap().to_str().unwrap().to_string()
  109. }
  110. };
  111. // device
  112. let device = if args.cuda {
  113. Device::Cuda(args.device_id)
  114. } else if args.trt {
  115. Device::Trt(args.device_id)
  116. } else if args.coreml {
  117. Device::CoreML(args.device_id)
  118. } else {
  119. Device::Cpu(args.device_id)
  120. };
  121. // build options
  122. let options = Options::new()
  123. .with_model(&path)?
  124. .with_yolo_version(args.ver)
  125. .with_yolo_task(args.task)
  126. .with_device(device)
  127. .with_trt_fp16(args.half)
  128. .with_ixx(0, 0, (1, args.batch_size as _, 4).into())
  129. .with_ixx(0, 2, (args.height_min, args.height, args.height_max).into())
  130. .with_ixx(0, 3, (args.width_min, args.width, args.width_max).into())
  131. .with_confs(if args.confs.is_empty() {
  132. &[0.2, 0.15]
  133. } else {
  134. &args.confs
  135. })
  136. .with_nc(args.nc)
  137. .with_find_contours(!args.no_contours) // find contours or not
  138. // .with_names(&COCO_CLASS_NAMES_80) // detection class names
  139. // .with_names2(&COCO_KEYPOINTS_17) // keypoints class names
  140. // .exclude_classes(&[0])
  141. // .retain_classes(&[0, 5])
  142. .with_profile(args.profile);
  143. // build model
  144. let mut model = YOLO::new(options)?;
  145. // build dataloader
  146. let dl = DataLoader::new(&args.source)?
  147. .with_batch(model.batch() as _)
  148. .build()?;
  149. // build annotator
  150. let annotator = Annotator::default()
  151. .with_skeletons(&COCO_SKELETONS_16)
  152. .without_masks(true) // no masks plotting when doing segment task
  153. .with_bboxes_thickness(3)
  154. .with_keypoints_name(false) // enable keypoints names
  155. .with_saveout_subs(&["YOLO"])
  156. .with_saveout(&saveout);
  157. // build viewer
  158. let mut viewer = if args.view {
  159. Some(Viewer::new().with_delay(5).with_scale(1.).resizable(true))
  160. } else {
  161. None
  162. };
  163. // run & annotate
  164. for (xs, _paths) in dl {
  165. let ys = model.forward(&xs, args.profile)?;
  166. let images_plotted = annotator.plot(&xs, &ys, !args.nosave)?;
  167. // show image
  168. match &mut viewer {
  169. Some(viewer) => viewer.imshow(&images_plotted)?,
  170. None => continue,
  171. }
  172. // check out window and key event
  173. match &mut viewer {
  174. Some(viewer) => {
  175. if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) {
  176. break;
  177. }
  178. }
  179. None => continue,
  180. }
  181. // write video
  182. if !args.nosave {
  183. match &mut viewer {
  184. Some(viewer) => viewer.write_batch(&images_plotted)?,
  185. None => continue,
  186. }
  187. }
  188. }
  189. // finish video write
  190. if !args.nosave {
  191. if let Some(viewer) = &mut viewer {
  192. viewer.finish_write()?;
  193. }
  194. }
  195. Ok(())
  196. }