123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- use anyhow::Result;
- use clap::Parser;
- use usls::{
- models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask,
- YOLOVersion, COCO_SKELETONS_16,
- };
- #[derive(Parser, Clone)]
- #[command(author, version, about, long_about = None)]
- pub struct Args {
- /// Path to the ONNX model
- #[arg(long)]
- pub model: Option<String>,
- /// Input source path
- #[arg(long, default_value_t = String::from("../../ultralytics/assets/bus.jpg"))]
- pub source: String,
- /// YOLO Task
- #[arg(long, value_enum, default_value_t = YOLOTask::Detect)]
- pub task: YOLOTask,
- /// YOLO Version
- #[arg(long, value_enum, default_value_t = YOLOVersion::V8)]
- pub ver: YOLOVersion,
- /// YOLO Scale
- #[arg(long, value_enum, default_value_t = YOLOScale::N)]
- pub scale: YOLOScale,
- /// Batch size
- #[arg(long, default_value_t = 1)]
- pub batch_size: usize,
- /// Minimum input width
- #[arg(long, default_value_t = 224)]
- pub width_min: isize,
- /// Input width
- #[arg(long, default_value_t = 640)]
- pub width: isize,
- /// Maximum input width
- #[arg(long, default_value_t = 1024)]
- pub width_max: isize,
- /// Minimum input height
- #[arg(long, default_value_t = 224)]
- pub height_min: isize,
- /// Input height
- #[arg(long, default_value_t = 640)]
- pub height: isize,
- /// Maximum input height
- #[arg(long, default_value_t = 1024)]
- pub height_max: isize,
- /// Number of classes
- #[arg(long, default_value_t = 80)]
- pub nc: usize,
- /// Class confidence
- #[arg(long)]
- pub confs: Vec<f32>,
- /// Enable TensorRT support
- #[arg(long)]
- pub trt: bool,
- /// Enable CUDA support
- #[arg(long)]
- pub cuda: bool,
- /// Enable CoreML support
- #[arg(long)]
- pub coreml: bool,
- /// Use TensorRT half precision
- #[arg(long)]
- pub half: bool,
- /// Device ID to use
- #[arg(long, default_value_t = 0)]
- pub device_id: usize,
- /// Enable performance profiling
- #[arg(long)]
- pub profile: bool,
- /// Disable contour drawing, for saving time
- #[arg(long)]
- pub no_contours: bool,
- /// Show result
- #[arg(long)]
- pub view: bool,
- /// Do not save output
- #[arg(long)]
- pub nosave: bool,
- }
- fn main() -> Result<()> {
- let args = Args::parse();
- // logger
- if args.profile {
- tracing_subscriber::fmt()
- .with_max_level(tracing::Level::INFO)
- .init();
- }
- // model path
- let path = match &args.model {
- None => format!(
- "yolo/{}-{}-{}.onnx",
- args.ver.name(),
- args.scale.name(),
- args.task.name()
- ),
- Some(x) => x.to_string(),
- };
- // saveout
- let saveout = match &args.model {
- None => format!(
- "{}-{}-{}",
- args.ver.name(),
- args.scale.name(),
- args.task.name()
- ),
- Some(x) => {
- let p = std::path::PathBuf::from(&x);
- p.file_stem().unwrap().to_str().unwrap().to_string()
- }
- };
- // device
- let device = if args.cuda {
- Device::Cuda(args.device_id)
- } else if args.trt {
- Device::Trt(args.device_id)
- } else if args.coreml {
- Device::CoreML(args.device_id)
- } else {
- Device::Cpu(args.device_id)
- };
- // build options
- let options = Options::new()
- .with_model(&path)?
- .with_yolo_version(args.ver)
- .with_yolo_task(args.task)
- .with_device(device)
- .with_trt_fp16(args.half)
- .with_ixx(0, 0, (1, args.batch_size as _, 4).into())
- .with_ixx(0, 2, (args.height_min, args.height, args.height_max).into())
- .with_ixx(0, 3, (args.width_min, args.width, args.width_max).into())
- .with_confs(if args.confs.is_empty() {
- &[0.2, 0.15]
- } else {
- &args.confs
- })
- .with_nc(args.nc)
- .with_find_contours(!args.no_contours) // find contours or not
- // .with_names(&COCO_CLASS_NAMES_80) // detection class names
- // .with_names2(&COCO_KEYPOINTS_17) // keypoints class names
- // .exclude_classes(&[0])
- // .retain_classes(&[0, 5])
- .with_profile(args.profile);
- // build model
- let mut model = YOLO::new(options)?;
- // build dataloader
- let dl = DataLoader::new(&args.source)?
- .with_batch(model.batch() as _)
- .build()?;
- // build annotator
- let annotator = Annotator::default()
- .with_skeletons(&COCO_SKELETONS_16)
- .without_masks(true) // no masks plotting when doing segment task
- .with_bboxes_thickness(3)
- .with_keypoints_name(false) // enable keypoints names
- .with_saveout_subs(&["YOLO"])
- .with_saveout(&saveout);
- // build viewer
- let mut viewer = if args.view {
- Some(Viewer::new().with_delay(5).with_scale(1.).resizable(true))
- } else {
- None
- };
- // run & annotate
- for (xs, _paths) in dl {
- let ys = model.forward(&xs, args.profile)?;
- let images_plotted = annotator.plot(&xs, &ys, !args.nosave)?;
- // show image
- match &mut viewer {
- Some(viewer) => viewer.imshow(&images_plotted)?,
- None => continue,
- }
- // check out window and key event
- match &mut viewer {
- Some(viewer) => {
- if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) {
- break;
- }
- }
- None => continue,
- }
- // write video
- if !args.nosave {
- match &mut viewer {
- Some(viewer) => viewer.write_batch(&images_plotted)?,
- None => continue,
- }
- }
- }
- // finish video write
- if !args.nosave {
- if let Some(viewer) = &mut viewer {
- viewer.finish_write()?;
- }
- }
- Ok(())
- }
|