ort_backend.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. use anyhow::Result;
  2. use clap::ValueEnum;
  3. use half::f16;
  4. use ndarray::{Array, CowArray, IxDyn};
  5. use ort::{
  6. CPUExecutionProvider, CUDAExecutionProvider, ExecutionProvider, ExecutionProviderDispatch,
  7. TensorRTExecutionProvider,
  8. };
  9. use ort::{Session, SessionBuilder};
  10. use ort::{TensorElementType, ValueType};
  11. use regex::Regex;
  12. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
  13. pub enum YOLOTask {
  14. // YOLO tasks
  15. Classify,
  16. Detect,
  17. Pose,
  18. Segment,
  19. }
  20. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
  21. pub enum OrtEP {
  22. // ONNXRuntime execution provider
  23. CPU,
  24. CUDA(i32),
  25. Trt(i32),
  26. }
  27. #[derive(Debug)]
  28. pub struct Batch {
  29. pub opt: u32,
  30. pub min: u32,
  31. pub max: u32,
  32. }
  33. impl Default for Batch {
  34. fn default() -> Self {
  35. Self {
  36. opt: 1,
  37. min: 1,
  38. max: 1,
  39. }
  40. }
  41. }
  42. #[derive(Debug, Default)]
  43. pub struct OrtInputs {
  44. // ONNX model inputs attrs
  45. pub shapes: Vec<Vec<i64>>,
  46. //pub dtypes: Vec<TensorElementDataType>,
  47. pub dtypes: Vec<TensorElementType>,
  48. pub names: Vec<String>,
  49. pub sizes: Vec<Vec<u32>>,
  50. }
  51. impl OrtInputs {
  52. pub fn new(session: &Session) -> Self {
  53. let mut shapes = Vec::new();
  54. let mut dtypes = Vec::new();
  55. let mut names = Vec::new();
  56. for i in session.inputs.iter() {
  57. /* let shape: Vec<i32> = i
  58. .dimensions()
  59. .map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
  60. .collect();
  61. shapes.push(shape); */
  62. if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type {
  63. dtypes.push(ty.clone());
  64. let shape = dimensions.clone();
  65. shapes.push(shape);
  66. } else {
  67. panic!("不支持的数据格式, {} - {}", file!(), line!());
  68. }
  69. //dtypes.push(i.input_type);
  70. names.push(i.name.clone());
  71. }
  72. Self {
  73. shapes,
  74. dtypes,
  75. names,
  76. ..Default::default()
  77. }
  78. }
  79. }
  80. #[derive(Debug)]
  81. pub struct OrtConfig {
  82. // ORT config
  83. pub f: String,
  84. pub task: Option<YOLOTask>,
  85. pub ep: OrtEP,
  86. pub trt_fp16: bool,
  87. pub batch: Batch,
  88. pub image_size: (Option<u32>, Option<u32>),
  89. }
  90. #[derive(Debug)]
  91. pub struct OrtBackend {
  92. // ORT engine
  93. session: Session,
  94. task: YOLOTask,
  95. ep: OrtEP,
  96. batch: Batch,
  97. inputs: OrtInputs,
  98. }
  99. impl OrtBackend {
  100. pub fn build(args: OrtConfig) -> Result<Self> {
  101. // build env & session
  102. // in version 2.x environment is removed
  103. /* let env = ort::EnvironmentBuilder
  104. ::with_name("YOLOv8")
  105. .build()?
  106. .into_arc(); */
  107. let sessionbuilder = SessionBuilder::new()?;
  108. let session = sessionbuilder.commit_from_file(&args.f)?;
  109. //let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?;
  110. // get inputs
  111. let mut inputs = OrtInputs::new(&session);
  112. // batch size
  113. let mut batch = args.batch;
  114. let batch = if inputs.shapes[0][0] == -1 {
  115. batch
  116. } else {
  117. assert_eq!(
  118. inputs.shapes[0][0] as u32, batch.opt,
  119. "Expected batch size: {}, got {}. Try using `--batch {}`.",
  120. inputs.shapes[0][0] as u32, batch.opt, inputs.shapes[0][0] as u32
  121. );
  122. batch.opt = inputs.shapes[0][0] as u32;
  123. batch
  124. };
  125. // input size: height and width
  126. let height = if inputs.shapes[0][2] == -1 {
  127. match args.image_size.0 {
  128. Some(height) => height,
  129. None => panic!("Failed to get model height. Make it explicit with `--height`"),
  130. }
  131. } else {
  132. inputs.shapes[0][2] as u32
  133. };
  134. let width = if inputs.shapes[0][3] == -1 {
  135. match args.image_size.1 {
  136. Some(width) => width,
  137. None => panic!("Failed to get model width. Make it explicit with `--width`"),
  138. }
  139. } else {
  140. inputs.shapes[0][3] as u32
  141. };
  142. inputs.sizes.push(vec![height, width]);
  143. // build provider
  144. let (ep, provider) = match args.ep {
  145. OrtEP::CUDA(device_id) => Self::set_ep_cuda(device_id),
  146. OrtEP::Trt(device_id) => Self::set_ep_trt(device_id, args.trt_fp16, &batch, &inputs),
  147. _ => (
  148. OrtEP::CPU,
  149. ExecutionProviderDispatch::from(CPUExecutionProvider::default()),
  150. ),
  151. };
  152. // build session again with the new provider
  153. let session = SessionBuilder::new()?
  154. // .with_optimization_level(ort::GraphOptimizationLevel::Level3)?
  155. .with_execution_providers([provider])?
  156. .commit_from_file(args.f)?;
  157. // task: using given one or guessing
  158. let task = match args.task {
  159. Some(task) => task,
  160. None => match session.metadata() {
  161. Err(_) => panic!("No metadata found. Try making it explicit by `--task`"),
  162. Ok(metadata) => match metadata.custom("task") {
  163. Err(_) => panic!("Can not get custom value. Try making it explicit by `--task`"),
  164. Ok(value) => match value {
  165. None => panic!("No corresponding value of `task` found in metadata. Make it explicit by `--task`"),
  166. Some(task) => match task.as_str() {
  167. "classify" => YOLOTask::Classify,
  168. "detect" => YOLOTask::Detect,
  169. "pose" => YOLOTask::Pose,
  170. "segment" => YOLOTask::Segment,
  171. x => todo!("{:?} is not supported for now!", x),
  172. },
  173. },
  174. },
  175. },
  176. };
  177. Ok(Self {
  178. session,
  179. task,
  180. ep,
  181. batch,
  182. inputs,
  183. })
  184. }
  185. pub fn fetch_inputs_from_session(
  186. session: &Session,
  187. ) -> (Vec<Vec<i64>>, Vec<TensorElementType>, Vec<String>) {
  188. // get inputs attrs from ONNX model
  189. let mut shapes = Vec::new();
  190. let mut dtypes = Vec::new();
  191. let mut names = Vec::new();
  192. for i in session.inputs.iter() {
  193. if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type {
  194. dtypes.push(ty.clone());
  195. let shape = dimensions.clone();
  196. shapes.push(shape);
  197. } else {
  198. panic!("不支持的数据格式, {} - {}", file!(), line!());
  199. }
  200. names.push(i.name.clone());
  201. }
  202. (shapes, dtypes, names)
  203. }
  204. pub fn set_ep_cuda(device_id: i32) -> (OrtEP, ExecutionProviderDispatch) {
  205. let cuda_provider = CUDAExecutionProvider::default().with_device_id(device_id);
  206. if let Ok(true) = cuda_provider.is_available() {
  207. (
  208. OrtEP::CUDA(device_id),
  209. ExecutionProviderDispatch::from(cuda_provider), //PlantForm::CUDA(cuda_provider)
  210. )
  211. } else {
  212. println!("> CUDA is not available! Using CPU.");
  213. (
  214. OrtEP::CPU,
  215. ExecutionProviderDispatch::from(CPUExecutionProvider::default()), //PlantForm::CPU(CPUExecutionProvider::default())
  216. )
  217. }
  218. }
  219. pub fn set_ep_trt(
  220. device_id: i32,
  221. fp16: bool,
  222. batch: &Batch,
  223. inputs: &OrtInputs,
  224. ) -> (OrtEP, ExecutionProviderDispatch) {
  225. // set TensorRT
  226. let trt_provider = TensorRTExecutionProvider::default().with_device_id(device_id);
  227. //trt_provider.
  228. if let Ok(true) = trt_provider.is_available() {
  229. let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]);
  230. if inputs.dtypes[0] == TensorElementType::Float16 && !fp16 {
  231. panic!(
  232. "Dtype mismatch! Expected: Float32, got: {:?}. You should use `--fp16`",
  233. inputs.dtypes[0]
  234. );
  235. }
  236. // dynamic shape: input_tensor_1:dim_1xdim_2x...,input_tensor_2:dim_3xdim_4x...,...
  237. let mut opt_string = String::new();
  238. let mut min_string = String::new();
  239. let mut max_string = String::new();
  240. for name in inputs.names.iter() {
  241. let s_opt = format!("{}:{}x3x{}x{},", name, batch.opt, height, width);
  242. let s_min = format!("{}:{}x3x{}x{},", name, batch.min, height, width);
  243. let s_max = format!("{}:{}x3x{}x{},", name, batch.max, height, width);
  244. opt_string.push_str(s_opt.as_str());
  245. min_string.push_str(s_min.as_str());
  246. max_string.push_str(s_max.as_str());
  247. }
  248. let _ = opt_string.pop();
  249. let _ = min_string.pop();
  250. let _ = max_string.pop();
  251. let trt_provider = trt_provider
  252. .with_profile_opt_shapes(opt_string)
  253. .with_profile_min_shapes(min_string)
  254. .with_profile_max_shapes(max_string)
  255. .with_fp16(fp16)
  256. .with_timing_cache(true);
  257. (
  258. OrtEP::Trt(device_id),
  259. ExecutionProviderDispatch::from(trt_provider),
  260. )
  261. } else {
  262. println!("> TensorRT is not available! Try using CUDA...");
  263. Self::set_ep_cuda(device_id)
  264. }
  265. }
  266. pub fn fetch_from_metadata(&self, key: &str) -> Option<String> {
  267. // fetch value from onnx model file by key
  268. match self.session.metadata() {
  269. Err(_) => None,
  270. Ok(metadata) => match metadata.custom(key) {
  271. Err(_) => None,
  272. Ok(value) => value,
  273. },
  274. }
  275. }
  276. pub fn run(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
  277. // ORT inference
  278. match self.dtype() {
  279. TensorElementType::Float16 => self.run_fp16(xs, profile),
  280. TensorElementType::Float32 => self.run_fp32(xs, profile),
  281. _ => todo!(),
  282. }
  283. }
  284. pub fn run_fp16(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
  285. // f32->f16
  286. let t = std::time::Instant::now();
  287. let xs = xs.mapv(f16::from_f32);
  288. if profile {
  289. println!("[ORT f32->f16]: {:?}", t.elapsed());
  290. }
  291. // h2d
  292. let t = std::time::Instant::now();
  293. let xs = CowArray::from(xs);
  294. if profile {
  295. println!("[ORT H2D]: {:?}", t.elapsed());
  296. }
  297. // run
  298. let t = std::time::Instant::now();
  299. let ys = self.session.run(ort::inputs![xs.view()]?)?;
  300. if profile {
  301. println!("[ORT Inference]: {:?}", t.elapsed());
  302. }
  303. // d2h
  304. Ok(ys
  305. .iter()
  306. .map(|(_k, v)| {
  307. // d2h
  308. let t = std::time::Instant::now();
  309. let v = v.try_extract_tensor().unwrap();
  310. //let v = v.try_extract::<_>().unwrap().view().clone().into_owned();
  311. if profile {
  312. println!("[ORT D2H]: {:?}", t.elapsed());
  313. }
  314. // f16->f32
  315. let t_ = std::time::Instant::now();
  316. let v = v.mapv(f16::to_f32);
  317. if profile {
  318. println!("[ORT f16->f32]: {:?}", t_.elapsed());
  319. }
  320. v
  321. })
  322. .collect::<Vec<Array<_, _>>>())
  323. }
  324. pub fn run_fp32(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
  325. // h2d
  326. let t = std::time::Instant::now();
  327. let xs = CowArray::from(xs);
  328. if profile {
  329. println!("[ORT H2D]: {:?}", t.elapsed());
  330. }
  331. // run
  332. let t = std::time::Instant::now();
  333. let ys = self.session.run(ort::inputs![xs.view()]?)?;
  334. if profile {
  335. println!("[ORT Inference]: {:?}", t.elapsed());
  336. }
  337. // d2h
  338. Ok(ys
  339. .iter()
  340. .map(|(_k, v)| {
  341. let t = std::time::Instant::now();
  342. let v = v.try_extract_tensor::<f32>().unwrap().into_owned();
  343. //let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
  344. if profile {
  345. println!("[ORT D2H]: {:?}", t.elapsed());
  346. }
  347. v
  348. })
  349. .collect::<Vec<Array<_, _>>>())
  350. }
  351. pub fn output_shapes(&self) -> Vec<Vec<i64>> {
  352. let mut shapes = Vec::new();
  353. for output in &self.session.outputs {
  354. if let ValueType::Tensor { ty: _, dimensions } = &output.output_type {
  355. let shape = dimensions.clone();
  356. shapes.push(shape);
  357. } else {
  358. panic!("not support data format, {} - {}", file!(), line!());
  359. }
  360. }
  361. shapes
  362. }
  363. pub fn output_dtypes(&self) -> Vec<TensorElementType> {
  364. let mut dtypes = Vec::new();
  365. for output in &self.session.outputs {
  366. if let ValueType::Tensor { ty, dimensions: _ } = &output.output_type {
  367. dtypes.push(ty.clone());
  368. } else {
  369. panic!("not support data format, {} - {}", file!(), line!());
  370. }
  371. }
  372. dtypes
  373. }
  374. pub fn input_shapes(&self) -> &Vec<Vec<i64>> {
  375. &self.inputs.shapes
  376. }
  377. pub fn input_names(&self) -> &Vec<String> {
  378. &self.inputs.names
  379. }
  380. pub fn input_dtypes(&self) -> &Vec<TensorElementType> {
  381. &self.inputs.dtypes
  382. }
  383. pub fn dtype(&self) -> TensorElementType {
  384. self.input_dtypes()[0]
  385. }
  386. pub fn height(&self) -> u32 {
  387. self.inputs.sizes[0][0]
  388. }
  389. pub fn width(&self) -> u32 {
  390. self.inputs.sizes[0][1]
  391. }
  392. pub fn is_height_dynamic(&self) -> bool {
  393. self.input_shapes()[0][2] == -1
  394. }
  395. pub fn is_width_dynamic(&self) -> bool {
  396. self.input_shapes()[0][3] == -1
  397. }
  398. pub fn batch(&self) -> u32 {
  399. self.batch.opt
  400. }
  401. pub fn is_batch_dynamic(&self) -> bool {
  402. self.input_shapes()[0][0] == -1
  403. }
  404. pub fn ep(&self) -> &OrtEP {
  405. &self.ep
  406. }
  407. pub fn task(&self) -> YOLOTask {
  408. self.task.clone()
  409. }
  410. pub fn names(&self) -> Option<Vec<String>> {
  411. // class names, metadata parsing
  412. // String format: `{0: 'person', 1: 'bicycle', 2: 'sports ball', ..., 27: "yellow_lady's_slipper"}`
  413. match self.fetch_from_metadata("names") {
  414. Some(names) => {
  415. let re = Regex::new(r#"(['"])([-()\w '"]+)(['"])"#).unwrap();
  416. let mut names_ = vec![];
  417. for (_, [_, name, _]) in re.captures_iter(&names).map(|x| x.extract()) {
  418. names_.push(name.to_string());
  419. }
  420. Some(names_)
  421. }
  422. None => None,
  423. }
  424. }
  425. pub fn nk(&self) -> Option<u32> {
  426. // num_keypoints, metadata parsing: String `nk` in onnx model: `[17, 3]`
  427. match self.fetch_from_metadata("kpt_shape") {
  428. None => None,
  429. Some(kpt_string) => {
  430. let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
  431. let caps = re.captures(&kpt_string).unwrap();
  432. Some(caps.get(1).unwrap().as_str().parse::<u32>().unwrap())
  433. }
  434. }
  435. }
  436. pub fn nc(&self) -> Option<u32> {
  437. // num_classes
  438. match self.names() {
  439. // by names
  440. Some(names) => Some(names.len() as u32),
  441. None => match self.task() {
  442. // by task calculation
  443. YOLOTask::Classify => Some(self.output_shapes()[0][1] as u32),
  444. YOLOTask::Detect => {
  445. if self.output_shapes()[0][1] == -1 {
  446. None
  447. } else {
  448. // cxywhclss
  449. Some(self.output_shapes()[0][1] as u32 - 4)
  450. }
  451. }
  452. YOLOTask::Pose => {
  453. match self.nk() {
  454. None => None,
  455. Some(nk) => {
  456. if self.output_shapes()[0][1] == -1 {
  457. None
  458. } else {
  459. // cxywhclss3*kpt
  460. Some(self.output_shapes()[0][1] as u32 - 4 - 3 * nk)
  461. }
  462. }
  463. }
  464. }
  465. YOLOTask::Segment => {
  466. if self.output_shapes()[0][1] == -1 {
  467. None
  468. } else {
  469. // cxywhclssnm
  470. Some((self.output_shapes()[0][1] - self.output_shapes()[1][1]) as u32 - 4)
  471. }
  472. }
  473. },
  474. }
  475. }
  476. pub fn nm(&self) -> Option<u32> {
  477. // num_masks
  478. match self.task() {
  479. YOLOTask::Segment => Some(self.output_shapes()[1][1] as u32),
  480. _ => None,
  481. }
  482. }
  483. pub fn na(&self) -> Option<u32> {
  484. // num_anchors
  485. match self.task() {
  486. YOLOTask::Segment | YOLOTask::Detect | YOLOTask::Pose => {
  487. if self.output_shapes()[0][2] == -1 {
  488. None
  489. } else {
  490. Some(self.output_shapes()[0][2] as u32)
  491. }
  492. }
  493. _ => None,
  494. }
  495. }
  496. pub fn author(&self) -> Option<String> {
  497. self.fetch_from_metadata("author")
  498. }
  499. pub fn version(&self) -> Option<String> {
  500. self.fetch_from_metadata("version")
  501. }
  502. }