123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553 |
- use anyhow::Result;
- use clap::ValueEnum;
- use half::f16;
- use ndarray::{Array, CowArray, IxDyn};
- use ort::{
- CPUExecutionProvider, CUDAExecutionProvider, ExecutionProvider, ExecutionProviderDispatch,
- TensorRTExecutionProvider,
- };
- use ort::{Session, SessionBuilder};
- use ort::{TensorElementType, ValueType};
- use regex::Regex;
- #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
- pub enum YOLOTask {
- // YOLO tasks
- Classify,
- Detect,
- Pose,
- Segment,
- }
- #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
- pub enum OrtEP {
- // ONNXRuntime execution provider
- CPU,
- CUDA(i32),
- Trt(i32),
- }
- #[derive(Debug)]
- pub struct Batch {
- pub opt: u32,
- pub min: u32,
- pub max: u32,
- }
- impl Default for Batch {
- fn default() -> Self {
- Self {
- opt: 1,
- min: 1,
- max: 1,
- }
- }
- }
- #[derive(Debug, Default)]
- pub struct OrtInputs {
- // ONNX model inputs attrs
- pub shapes: Vec<Vec<i64>>,
- //pub dtypes: Vec<TensorElementDataType>,
- pub dtypes: Vec<TensorElementType>,
- pub names: Vec<String>,
- pub sizes: Vec<Vec<u32>>,
- }
- impl OrtInputs {
- pub fn new(session: &Session) -> Self {
- let mut shapes = Vec::new();
- let mut dtypes = Vec::new();
- let mut names = Vec::new();
- for i in session.inputs.iter() {
- /* let shape: Vec<i32> = i
- .dimensions()
- .map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
- .collect();
- shapes.push(shape); */
- if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type {
- dtypes.push(ty.clone());
- let shape = dimensions.clone();
- shapes.push(shape);
- } else {
- panic!("不支持的数据格式, {} - {}", file!(), line!());
- }
- //dtypes.push(i.input_type);
- names.push(i.name.clone());
- }
- Self {
- shapes,
- dtypes,
- names,
- ..Default::default()
- }
- }
- }
- #[derive(Debug)]
- pub struct OrtConfig {
- // ORT config
- pub f: String,
- pub task: Option<YOLOTask>,
- pub ep: OrtEP,
- pub trt_fp16: bool,
- pub batch: Batch,
- pub image_size: (Option<u32>, Option<u32>),
- }
- #[derive(Debug)]
- pub struct OrtBackend {
- // ORT engine
- session: Session,
- task: YOLOTask,
- ep: OrtEP,
- batch: Batch,
- inputs: OrtInputs,
- }
- impl OrtBackend {
- pub fn build(args: OrtConfig) -> Result<Self> {
- // build env & session
- // in version 2.x environment is removed
- /* let env = ort::EnvironmentBuilder
- ::with_name("YOLOv8")
- .build()?
- .into_arc(); */
- let sessionbuilder = SessionBuilder::new()?;
- let session = sessionbuilder.commit_from_file(&args.f)?;
- //let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?;
- // get inputs
- let mut inputs = OrtInputs::new(&session);
- // batch size
- let mut batch = args.batch;
- let batch = if inputs.shapes[0][0] == -1 {
- batch
- } else {
- assert_eq!(
- inputs.shapes[0][0] as u32, batch.opt,
- "Expected batch size: {}, got {}. Try using `--batch {}`.",
- inputs.shapes[0][0] as u32, batch.opt, inputs.shapes[0][0] as u32
- );
- batch.opt = inputs.shapes[0][0] as u32;
- batch
- };
- // input size: height and width
- let height = if inputs.shapes[0][2] == -1 {
- match args.image_size.0 {
- Some(height) => height,
- None => panic!("Failed to get model height. Make it explicit with `--height`"),
- }
- } else {
- inputs.shapes[0][2] as u32
- };
- let width = if inputs.shapes[0][3] == -1 {
- match args.image_size.1 {
- Some(width) => width,
- None => panic!("Failed to get model width. Make it explicit with `--width`"),
- }
- } else {
- inputs.shapes[0][3] as u32
- };
- inputs.sizes.push(vec![height, width]);
- // build provider
- let (ep, provider) = match args.ep {
- OrtEP::CUDA(device_id) => Self::set_ep_cuda(device_id),
- OrtEP::Trt(device_id) => Self::set_ep_trt(device_id, args.trt_fp16, &batch, &inputs),
- _ => (
- OrtEP::CPU,
- ExecutionProviderDispatch::from(CPUExecutionProvider::default()),
- ),
- };
- // build session again with the new provider
- let session = SessionBuilder::new()?
- // .with_optimization_level(ort::GraphOptimizationLevel::Level3)?
- .with_execution_providers([provider])?
- .commit_from_file(args.f)?;
- // task: using given one or guessing
- let task = match args.task {
- Some(task) => task,
- None => match session.metadata() {
- Err(_) => panic!("No metadata found. Try making it explicit by `--task`"),
- Ok(metadata) => match metadata.custom("task") {
- Err(_) => panic!("Can not get custom value. Try making it explicit by `--task`"),
- Ok(value) => match value {
- None => panic!("No corresponding value of `task` found in metadata. Make it explicit by `--task`"),
- Some(task) => match task.as_str() {
- "classify" => YOLOTask::Classify,
- "detect" => YOLOTask::Detect,
- "pose" => YOLOTask::Pose,
- "segment" => YOLOTask::Segment,
- x => todo!("{:?} is not supported for now!", x),
- },
- },
- },
- },
- };
- Ok(Self {
- session,
- task,
- ep,
- batch,
- inputs,
- })
- }
- pub fn fetch_inputs_from_session(
- session: &Session,
- ) -> (Vec<Vec<i64>>, Vec<TensorElementType>, Vec<String>) {
- // get inputs attrs from ONNX model
- let mut shapes = Vec::new();
- let mut dtypes = Vec::new();
- let mut names = Vec::new();
- for i in session.inputs.iter() {
- if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type {
- dtypes.push(ty.clone());
- let shape = dimensions.clone();
- shapes.push(shape);
- } else {
- panic!("不支持的数据格式, {} - {}", file!(), line!());
- }
- names.push(i.name.clone());
- }
- (shapes, dtypes, names)
- }
- pub fn set_ep_cuda(device_id: i32) -> (OrtEP, ExecutionProviderDispatch) {
- let cuda_provider = CUDAExecutionProvider::default().with_device_id(device_id);
- if let Ok(true) = cuda_provider.is_available() {
- (
- OrtEP::CUDA(device_id),
- ExecutionProviderDispatch::from(cuda_provider), //PlantForm::CUDA(cuda_provider)
- )
- } else {
- println!("> CUDA is not available! Using CPU.");
- (
- OrtEP::CPU,
- ExecutionProviderDispatch::from(CPUExecutionProvider::default()), //PlantForm::CPU(CPUExecutionProvider::default())
- )
- }
- }
- pub fn set_ep_trt(
- device_id: i32,
- fp16: bool,
- batch: &Batch,
- inputs: &OrtInputs,
- ) -> (OrtEP, ExecutionProviderDispatch) {
- // set TensorRT
- let trt_provider = TensorRTExecutionProvider::default().with_device_id(device_id);
- //trt_provider.
- if let Ok(true) = trt_provider.is_available() {
- let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]);
- if inputs.dtypes[0] == TensorElementType::Float16 && !fp16 {
- panic!(
- "Dtype mismatch! Expected: Float32, got: {:?}. You should use `--fp16`",
- inputs.dtypes[0]
- );
- }
- // dynamic shape: input_tensor_1:dim_1xdim_2x...,input_tensor_2:dim_3xdim_4x...,...
- let mut opt_string = String::new();
- let mut min_string = String::new();
- let mut max_string = String::new();
- for name in inputs.names.iter() {
- let s_opt = format!("{}:{}x3x{}x{},", name, batch.opt, height, width);
- let s_min = format!("{}:{}x3x{}x{},", name, batch.min, height, width);
- let s_max = format!("{}:{}x3x{}x{},", name, batch.max, height, width);
- opt_string.push_str(s_opt.as_str());
- min_string.push_str(s_min.as_str());
- max_string.push_str(s_max.as_str());
- }
- let _ = opt_string.pop();
- let _ = min_string.pop();
- let _ = max_string.pop();
- let trt_provider = trt_provider
- .with_profile_opt_shapes(opt_string)
- .with_profile_min_shapes(min_string)
- .with_profile_max_shapes(max_string)
- .with_fp16(fp16)
- .with_timing_cache(true);
- (
- OrtEP::Trt(device_id),
- ExecutionProviderDispatch::from(trt_provider),
- )
- } else {
- println!("> TensorRT is not available! Try using CUDA...");
- Self::set_ep_cuda(device_id)
- }
- }
- pub fn fetch_from_metadata(&self, key: &str) -> Option<String> {
- // fetch value from onnx model file by key
- match self.session.metadata() {
- Err(_) => None,
- Ok(metadata) => match metadata.custom(key) {
- Err(_) => None,
- Ok(value) => value,
- },
- }
- }
- pub fn run(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
- // ORT inference
- match self.dtype() {
- TensorElementType::Float16 => self.run_fp16(xs, profile),
- TensorElementType::Float32 => self.run_fp32(xs, profile),
- _ => todo!(),
- }
- }
- pub fn run_fp16(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
- // f32->f16
- let t = std::time::Instant::now();
- let xs = xs.mapv(f16::from_f32);
- if profile {
- println!("[ORT f32->f16]: {:?}", t.elapsed());
- }
- // h2d
- let t = std::time::Instant::now();
- let xs = CowArray::from(xs);
- if profile {
- println!("[ORT H2D]: {:?}", t.elapsed());
- }
- // run
- let t = std::time::Instant::now();
- let ys = self.session.run(ort::inputs![xs.view()]?)?;
- if profile {
- println!("[ORT Inference]: {:?}", t.elapsed());
- }
- // d2h
- Ok(ys
- .iter()
- .map(|(_k, v)| {
- // d2h
- let t = std::time::Instant::now();
- let v = v.try_extract_tensor().unwrap();
- //let v = v.try_extract::<_>().unwrap().view().clone().into_owned();
- if profile {
- println!("[ORT D2H]: {:?}", t.elapsed());
- }
- // f16->f32
- let t_ = std::time::Instant::now();
- let v = v.mapv(f16::to_f32);
- if profile {
- println!("[ORT f16->f32]: {:?}", t_.elapsed());
- }
- v
- })
- .collect::<Vec<Array<_, _>>>())
- }
- pub fn run_fp32(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
- // h2d
- let t = std::time::Instant::now();
- let xs = CowArray::from(xs);
- if profile {
- println!("[ORT H2D]: {:?}", t.elapsed());
- }
- // run
- let t = std::time::Instant::now();
- let ys = self.session.run(ort::inputs![xs.view()]?)?;
- if profile {
- println!("[ORT Inference]: {:?}", t.elapsed());
- }
- // d2h
- Ok(ys
- .iter()
- .map(|(_k, v)| {
- let t = std::time::Instant::now();
- let v = v.try_extract_tensor::<f32>().unwrap().into_owned();
- //let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
- if profile {
- println!("[ORT D2H]: {:?}", t.elapsed());
- }
- v
- })
- .collect::<Vec<Array<_, _>>>())
- }
- pub fn output_shapes(&self) -> Vec<Vec<i64>> {
- let mut shapes = Vec::new();
- for output in &self.session.outputs {
- if let ValueType::Tensor { ty: _, dimensions } = &output.output_type {
- let shape = dimensions.clone();
- shapes.push(shape);
- } else {
- panic!("not support data format, {} - {}", file!(), line!());
- }
- }
- shapes
- }
- pub fn output_dtypes(&self) -> Vec<TensorElementType> {
- let mut dtypes = Vec::new();
- for output in &self.session.outputs {
- if let ValueType::Tensor { ty, dimensions: _ } = &output.output_type {
- dtypes.push(ty.clone());
- } else {
- panic!("not support data format, {} - {}", file!(), line!());
- }
- }
- dtypes
- }
- pub fn input_shapes(&self) -> &Vec<Vec<i64>> {
- &self.inputs.shapes
- }
- pub fn input_names(&self) -> &Vec<String> {
- &self.inputs.names
- }
- pub fn input_dtypes(&self) -> &Vec<TensorElementType> {
- &self.inputs.dtypes
- }
- pub fn dtype(&self) -> TensorElementType {
- self.input_dtypes()[0]
- }
- pub fn height(&self) -> u32 {
- self.inputs.sizes[0][0]
- }
- pub fn width(&self) -> u32 {
- self.inputs.sizes[0][1]
- }
- pub fn is_height_dynamic(&self) -> bool {
- self.input_shapes()[0][2] == -1
- }
- pub fn is_width_dynamic(&self) -> bool {
- self.input_shapes()[0][3] == -1
- }
- pub fn batch(&self) -> u32 {
- self.batch.opt
- }
- pub fn is_batch_dynamic(&self) -> bool {
- self.input_shapes()[0][0] == -1
- }
- pub fn ep(&self) -> &OrtEP {
- &self.ep
- }
- pub fn task(&self) -> YOLOTask {
- self.task.clone()
- }
- pub fn names(&self) -> Option<Vec<String>> {
- // class names, metadata parsing
- // String format: `{0: 'person', 1: 'bicycle', 2: 'sports ball', ..., 27: "yellow_lady's_slipper"}`
- match self.fetch_from_metadata("names") {
- Some(names) => {
- let re = Regex::new(r#"(['"])([-()\w '"]+)(['"])"#).unwrap();
- let mut names_ = vec![];
- for (_, [_, name, _]) in re.captures_iter(&names).map(|x| x.extract()) {
- names_.push(name.to_string());
- }
- Some(names_)
- }
- None => None,
- }
- }
- pub fn nk(&self) -> Option<u32> {
- // num_keypoints, metadata parsing: String `nk` in onnx model: `[17, 3]`
- match self.fetch_from_metadata("kpt_shape") {
- None => None,
- Some(kpt_string) => {
- let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
- let caps = re.captures(&kpt_string).unwrap();
- Some(caps.get(1).unwrap().as_str().parse::<u32>().unwrap())
- }
- }
- }
- pub fn nc(&self) -> Option<u32> {
- // num_classes
- match self.names() {
- // by names
- Some(names) => Some(names.len() as u32),
- None => match self.task() {
- // by task calculation
- YOLOTask::Classify => Some(self.output_shapes()[0][1] as u32),
- YOLOTask::Detect => {
- if self.output_shapes()[0][1] == -1 {
- None
- } else {
- // cxywhclss
- Some(self.output_shapes()[0][1] as u32 - 4)
- }
- }
- YOLOTask::Pose => {
- match self.nk() {
- None => None,
- Some(nk) => {
- if self.output_shapes()[0][1] == -1 {
- None
- } else {
- // cxywhclss3*kpt
- Some(self.output_shapes()[0][1] as u32 - 4 - 3 * nk)
- }
- }
- }
- }
- YOLOTask::Segment => {
- if self.output_shapes()[0][1] == -1 {
- None
- } else {
- // cxywhclssnm
- Some((self.output_shapes()[0][1] - self.output_shapes()[1][1]) as u32 - 4)
- }
- }
- },
- }
- }
- pub fn nm(&self) -> Option<u32> {
- // num_masks
- match self.task() {
- YOLOTask::Segment => Some(self.output_shapes()[1][1] as u32),
- _ => None,
- }
- }
- pub fn na(&self) -> Option<u32> {
- // num_anchors
- match self.task() {
- YOLOTask::Segment | YOLOTask::Detect | YOLOTask::Pose => {
- if self.output_shapes()[0][2] == -1 {
- None
- } else {
- Some(self.output_shapes()[0][2] as u32)
- }
- }
- _ => None,
- }
- }
- pub fn author(&self) -> Option<String> {
- self.fetch_from_metadata("author")
- }
- pub fn version(&self) -> Option<String> {
- self.fetch_from_metadata("version")
- }
- }
|