model.rs 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651
  1. #![allow(clippy::type_complexity)]
  2. use ab_glyph::FontArc;
  3. use anyhow::Result;
  4. use image::{DynamicImage, GenericImageView, ImageBuffer};
  5. use ndarray::{s, Array, Axis, IxDyn};
  6. use rand::{thread_rng, Rng};
  7. use std::path::PathBuf;
  8. use crate::{
  9. gen_time_string, load_font, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend,
  10. OrtConfig, OrtEP, Point2, YOLOResult, YOLOTask, SKELETON,
  11. };
  12. pub struct YOLOv8 {
  13. // YOLOv8 model for all yolo-tasks
  14. engine: OrtBackend,
  15. nc: u32,
  16. nk: u32,
  17. nm: u32,
  18. height: u32,
  19. width: u32,
  20. batch: u32,
  21. task: YOLOTask,
  22. conf: f32,
  23. kconf: f32,
  24. iou: f32,
  25. names: Vec<String>,
  26. color_palette: Vec<(u8, u8, u8)>,
  27. profile: bool,
  28. plot: bool,
  29. }
  30. impl YOLOv8 {
  31. pub fn new(config: Args) -> Result<Self> {
  32. // execution provider
  33. let ep = if config.trt {
  34. OrtEP::Trt(config.device_id)
  35. } else if config.cuda {
  36. OrtEP::CUDA(config.device_id)
  37. } else {
  38. OrtEP::CPU
  39. };
  40. // batch
  41. let batch = Batch {
  42. opt: config.batch,
  43. min: config.batch_min,
  44. max: config.batch_max,
  45. };
  46. // build ort engine
  47. let ort_args = OrtConfig {
  48. ep,
  49. batch,
  50. f: config.model,
  51. task: config.task,
  52. trt_fp16: config.fp16,
  53. image_size: (config.height, config.width),
  54. };
  55. let engine = OrtBackend::build(ort_args)?;
  56. // get batch, height, width, tasks, nc, nk, nm
  57. let (batch, height, width, task) = (
  58. engine.batch(),
  59. engine.height(),
  60. engine.width(),
  61. engine.task(),
  62. );
  63. let nc = engine.nc().or(config.nc).unwrap_or_else(|| {
  64. panic!("Failed to get num_classes, make it explicit with `--nc`");
  65. });
  66. let (nk, nm) = match task {
  67. YOLOTask::Pose => {
  68. let nk = engine.nk().or(config.nk).unwrap_or_else(|| {
  69. panic!("Failed to get num_keypoints, make it explicit with `--nk`");
  70. });
  71. (nk, 0)
  72. }
  73. YOLOTask::Segment => {
  74. let nm = engine.nm().or(config.nm).unwrap_or_else(|| {
  75. panic!("Failed to get num_masks, make it explicit with `--nm`");
  76. });
  77. (0, nm)
  78. }
  79. _ => (0, 0),
  80. };
  81. // class names
  82. let names = engine.names().unwrap_or(vec!["Unknown".to_string()]);
  83. // color palette
  84. let mut rng = thread_rng();
  85. let color_palette: Vec<_> = names
  86. .iter()
  87. .map(|_| {
  88. (
  89. rng.gen_range(0..=255),
  90. rng.gen_range(0..=255),
  91. rng.gen_range(0..=255),
  92. )
  93. })
  94. .collect();
  95. Ok(Self {
  96. engine,
  97. names,
  98. conf: config.conf,
  99. kconf: config.kconf,
  100. iou: config.iou,
  101. color_palette,
  102. profile: config.profile,
  103. plot: config.plot,
  104. nc,
  105. nk,
  106. nm,
  107. height,
  108. width,
  109. batch,
  110. task,
  111. })
  112. }
  113. pub fn scale_wh(&self, w0: f32, h0: f32, w1: f32, h1: f32) -> (f32, f32, f32) {
  114. let r = (w1 / w0).min(h1 / h0);
  115. (r, (w0 * r).round(), (h0 * r).round())
  116. }
  117. pub fn preprocess(&mut self, xs: &Vec<DynamicImage>) -> Result<Array<f32, IxDyn>> {
  118. let mut ys =
  119. Array::ones((xs.len(), 3, self.height() as usize, self.width() as usize)).into_dyn();
  120. ys.fill(144.0 / 255.0);
  121. for (idx, x) in xs.iter().enumerate() {
  122. let img = match self.task() {
  123. YOLOTask::Classify => x.resize_exact(
  124. self.width(),
  125. self.height(),
  126. image::imageops::FilterType::Triangle,
  127. ),
  128. _ => {
  129. let (w0, h0) = x.dimensions();
  130. let w0 = w0 as f32;
  131. let h0 = h0 as f32;
  132. let (_, w_new, h_new) =
  133. self.scale_wh(w0, h0, self.width() as f32, self.height() as f32); // f32 round
  134. x.resize_exact(
  135. w_new as u32,
  136. h_new as u32,
  137. if let YOLOTask::Segment = self.task() {
  138. image::imageops::FilterType::CatmullRom
  139. } else {
  140. image::imageops::FilterType::Triangle
  141. },
  142. )
  143. }
  144. };
  145. for (x, y, rgb) in img.pixels() {
  146. let x = x as usize;
  147. let y = y as usize;
  148. let [r, g, b, _] = rgb.0;
  149. ys[[idx, 0, y, x]] = (r as f32) / 255.0;
  150. ys[[idx, 1, y, x]] = (g as f32) / 255.0;
  151. ys[[idx, 2, y, x]] = (b as f32) / 255.0;
  152. }
  153. }
  154. Ok(ys)
  155. }
  156. pub fn run(&mut self, xs: &Vec<DynamicImage>) -> Result<Vec<YOLOResult>> {
  157. // pre-process
  158. let t_pre = std::time::Instant::now();
  159. let xs_ = self.preprocess(xs)?;
  160. if self.profile {
  161. println!("[Model Preprocess]: {:?}", t_pre.elapsed());
  162. }
  163. // run
  164. let t_run = std::time::Instant::now();
  165. let ys = self.engine.run(xs_, self.profile)?;
  166. if self.profile {
  167. println!("[Model Inference]: {:?}", t_run.elapsed());
  168. }
  169. // post-process
  170. let t_post = std::time::Instant::now();
  171. let ys = self.postprocess(ys, xs)?;
  172. if self.profile {
  173. println!("[Model Postprocess]: {:?}", t_post.elapsed());
  174. }
  175. // plot and save
  176. if self.plot {
  177. self.plot_and_save(&ys, xs, Some(&SKELETON));
  178. }
  179. Ok(ys)
  180. }
  181. pub fn postprocess(
  182. &self,
  183. xs: Vec<Array<f32, IxDyn>>,
  184. xs0: &[DynamicImage],
  185. ) -> Result<Vec<YOLOResult>> {
  186. if let YOLOTask::Classify = self.task() {
  187. let mut ys = Vec::new();
  188. let preds = &xs[0];
  189. for batch in preds.axis_iter(Axis(0)) {
  190. ys.push(YOLOResult::new(
  191. Some(Embedding::new(batch.into_owned())),
  192. None,
  193. None,
  194. None,
  195. ));
  196. }
  197. Ok(ys)
  198. } else {
  199. const CXYWH_OFFSET: usize = 4; // cxcywh
  200. const KPT_STEP: usize = 3; // xyconf
  201. let preds = &xs[0];
  202. let protos = {
  203. if xs.len() > 1 {
  204. Some(&xs[1])
  205. } else {
  206. None
  207. }
  208. };
  209. let mut ys = Vec::new();
  210. for (idx, anchor) in preds.axis_iter(Axis(0)).enumerate() {
  211. // [bs, 4 + nc + nm, anchors]
  212. // input image
  213. let width_original = xs0[idx].width() as f32;
  214. let height_original = xs0[idx].height() as f32;
  215. let ratio = (self.width() as f32 / width_original)
  216. .min(self.height() as f32 / height_original);
  217. // save each result
  218. let mut data: Vec<(Bbox, Option<Vec<Point2>>, Option<Vec<f32>>)> = Vec::new();
  219. for pred in anchor.axis_iter(Axis(1)) {
  220. // split preds for different tasks
  221. let bbox = pred.slice(s![0..CXYWH_OFFSET]);
  222. let clss = pred.slice(s![CXYWH_OFFSET..CXYWH_OFFSET + self.nc() as usize]);
  223. let kpts = {
  224. if let YOLOTask::Pose = self.task() {
  225. Some(pred.slice(s![pred.len() - KPT_STEP * self.nk() as usize..]))
  226. } else {
  227. None
  228. }
  229. };
  230. let coefs = {
  231. if let YOLOTask::Segment = self.task() {
  232. Some(pred.slice(s![pred.len() - self.nm() as usize..]).to_vec())
  233. } else {
  234. None
  235. }
  236. };
  237. // confidence and id
  238. let (id, &confidence) = clss
  239. .into_iter()
  240. .enumerate()
  241. .reduce(|max, x| if x.1 > max.1 { x } else { max })
  242. .unwrap(); // definitely will not panic!
  243. // confidence filter
  244. if confidence < self.conf {
  245. continue;
  246. }
  247. // bbox re-scale
  248. let cx = bbox[0] / ratio;
  249. let cy = bbox[1] / ratio;
  250. let w = bbox[2] / ratio;
  251. let h = bbox[3] / ratio;
  252. let x = cx - w / 2.;
  253. let y = cy - h / 2.;
  254. let y_bbox = Bbox::new(
  255. x.max(0.0f32).min(width_original),
  256. y.max(0.0f32).min(height_original),
  257. w,
  258. h,
  259. id,
  260. confidence,
  261. );
  262. // kpts
  263. let y_kpts = {
  264. if let Some(kpts) = kpts {
  265. let mut kpts_ = Vec::new();
  266. // rescale
  267. for i in 0..self.nk() as usize {
  268. let kx = kpts[KPT_STEP * i] / ratio;
  269. let ky = kpts[KPT_STEP * i + 1] / ratio;
  270. let kconf = kpts[KPT_STEP * i + 2];
  271. if kconf < self.kconf {
  272. kpts_.push(Point2::default());
  273. } else {
  274. kpts_.push(Point2::new_with_conf(
  275. kx.max(0.0f32).min(width_original),
  276. ky.max(0.0f32).min(height_original),
  277. kconf,
  278. ));
  279. }
  280. }
  281. Some(kpts_)
  282. } else {
  283. None
  284. }
  285. };
  286. // data merged
  287. data.push((y_bbox, y_kpts, coefs));
  288. }
  289. // nms
  290. non_max_suppression(&mut data, self.iou);
  291. // decode
  292. let mut y_bboxes: Vec<Bbox> = Vec::new();
  293. let mut y_kpts: Vec<Vec<Point2>> = Vec::new();
  294. let mut y_masks: Vec<Vec<u8>> = Vec::new();
  295. for elem in data.into_iter() {
  296. if let Some(kpts) = elem.1 {
  297. y_kpts.push(kpts)
  298. }
  299. // decode masks
  300. if let Some(coefs) = elem.2 {
  301. let proto = protos.unwrap().slice(s![idx, .., .., ..]);
  302. let (nm, nh, nw) = proto.dim();
  303. // coefs * proto -> mask
  304. let coefs = Array::from_shape_vec((1, nm), coefs)?; // (n, nm)
  305. let proto = proto.to_owned();
  306. let proto = proto.to_shape((nm, nh * nw))?; // (nm, nh*nw)
  307. let mask = coefs.dot(&proto); // (nh, nw, n)
  308. let mask = mask.to_shape((nh, nw, 1))?;
  309. // build image from ndarray
  310. let mask_im: ImageBuffer<image::Luma<_>, Vec<f32>> =
  311. match ImageBuffer::from_raw(
  312. nw as u32,
  313. nh as u32,
  314. mask.to_owned().into_raw_vec_and_offset().0,
  315. ) {
  316. Some(image) => image,
  317. None => panic!("can not create image from ndarray"),
  318. };
  319. let mut mask_im = image::DynamicImage::from(mask_im); // -> dyn
  320. // rescale masks
  321. let (_, w_mask, h_mask) =
  322. self.scale_wh(width_original, height_original, nw as f32, nh as f32);
  323. let mask_cropped = mask_im.crop(0, 0, w_mask as u32, h_mask as u32);
  324. let mask_original = mask_cropped.resize_exact(
  325. // resize_to_fill
  326. width_original as u32,
  327. height_original as u32,
  328. match self.task() {
  329. YOLOTask::Segment => image::imageops::FilterType::CatmullRom,
  330. _ => image::imageops::FilterType::Triangle,
  331. },
  332. );
  333. // crop-mask with bbox
  334. let mut mask_original_cropped = mask_original.into_luma8();
  335. for y in 0..height_original as usize {
  336. for x in 0..width_original as usize {
  337. if x < elem.0.xmin() as usize
  338. || x > elem.0.xmax() as usize
  339. || y < elem.0.ymin() as usize
  340. || y > elem.0.ymax() as usize
  341. {
  342. mask_original_cropped.put_pixel(
  343. x as u32,
  344. y as u32,
  345. image::Luma([0u8]),
  346. );
  347. }
  348. }
  349. }
  350. y_masks.push(mask_original_cropped.into_raw());
  351. }
  352. y_bboxes.push(elem.0);
  353. }
  354. // save each result
  355. let y = YOLOResult {
  356. probs: None,
  357. bboxes: if !y_bboxes.is_empty() {
  358. Some(y_bboxes)
  359. } else {
  360. None
  361. },
  362. keypoints: if !y_kpts.is_empty() {
  363. Some(y_kpts)
  364. } else {
  365. None
  366. },
  367. masks: if !y_masks.is_empty() {
  368. Some(y_masks)
  369. } else {
  370. None
  371. },
  372. };
  373. ys.push(y);
  374. }
  375. Ok(ys)
  376. }
  377. }
  378. pub fn plot_and_save(
  379. &self,
  380. ys: &[YOLOResult],
  381. xs0: &[DynamicImage],
  382. skeletons: Option<&[(usize, usize)]>,
  383. ) {
  384. // check font then load
  385. let font: FontArc = load_font();
  386. for (_idb, (img0, y)) in xs0.iter().zip(ys.iter()).enumerate() {
  387. let mut img = img0.to_rgb8();
  388. // draw for classifier
  389. if let Some(probs) = y.probs() {
  390. for (i, k) in probs.topk(5).iter().enumerate() {
  391. let legend = format!("{} {:.2}%", self.names[k.0], k.1);
  392. let scale = 32;
  393. let legend_size = img.width().max(img.height()) / scale;
  394. let x = img.width() / 20;
  395. let y = img.height() / 20 + i as u32 * legend_size;
  396. imageproc::drawing::draw_text_mut(
  397. &mut img,
  398. image::Rgb([0, 255, 0]),
  399. x as i32,
  400. y as i32,
  401. legend_size as f32,
  402. &font,
  403. &legend,
  404. );
  405. }
  406. }
  407. // draw bboxes & keypoints
  408. if let Some(bboxes) = y.bboxes() {
  409. for (_idx, bbox) in bboxes.iter().enumerate() {
  410. // rect
  411. imageproc::drawing::draw_hollow_rect_mut(
  412. &mut img,
  413. imageproc::rect::Rect::at(bbox.xmin() as i32, bbox.ymin() as i32)
  414. .of_size(bbox.width() as u32, bbox.height() as u32),
  415. image::Rgb(self.color_palette[bbox.id()].into()),
  416. );
  417. // text
  418. let legend = format!("{} {:.2}%", self.names[bbox.id()], bbox.confidence());
  419. let scale = 40;
  420. let legend_size = img.width().max(img.height()) / scale;
  421. imageproc::drawing::draw_text_mut(
  422. &mut img,
  423. image::Rgb(self.color_palette[bbox.id()].into()),
  424. bbox.xmin() as i32,
  425. (bbox.ymin() - legend_size as f32) as i32,
  426. legend_size as f32,
  427. &font,
  428. &legend,
  429. );
  430. }
  431. }
  432. // draw kpts
  433. if let Some(keypoints) = y.keypoints() {
  434. for kpts in keypoints.iter() {
  435. for kpt in kpts.iter() {
  436. // filter
  437. if kpt.confidence() < self.kconf {
  438. continue;
  439. }
  440. // draw point
  441. imageproc::drawing::draw_filled_circle_mut(
  442. &mut img,
  443. (kpt.x() as i32, kpt.y() as i32),
  444. 2,
  445. image::Rgb([0, 255, 0]),
  446. );
  447. }
  448. // draw skeleton if has
  449. if let Some(skeletons) = skeletons {
  450. for &(idx1, idx2) in skeletons.iter() {
  451. let kpt1 = &kpts[idx1];
  452. let kpt2 = &kpts[idx2];
  453. if kpt1.confidence() < self.kconf || kpt2.confidence() < self.kconf {
  454. continue;
  455. }
  456. imageproc::drawing::draw_line_segment_mut(
  457. &mut img,
  458. (kpt1.x(), kpt1.y()),
  459. (kpt2.x(), kpt2.y()),
  460. image::Rgb([233, 14, 57]),
  461. );
  462. }
  463. }
  464. }
  465. }
  466. // draw mask
  467. if let Some(masks) = y.masks() {
  468. for (mask, _bbox) in masks.iter().zip(y.bboxes().unwrap().iter()) {
  469. let mask_nd: ImageBuffer<image::Luma<_>, Vec<u8>> =
  470. match ImageBuffer::from_vec(img.width(), img.height(), mask.to_vec()) {
  471. Some(image) => image,
  472. None => panic!("can not crate image from ndarray"),
  473. };
  474. for _x in 0..img.width() {
  475. for _y in 0..img.height() {
  476. let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_nd, _x, _y);
  477. if mask_p.0[0] > 0 {
  478. let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, _x, _y);
  479. // img_p.0[2] = self.color_palette[bbox.id()].2 / 2;
  480. // img_p.0[1] = self.color_palette[bbox.id()].1 / 2;
  481. // img_p.0[0] = self.color_palette[bbox.id()].0 / 2;
  482. img_p.0[2] /= 2;
  483. img_p.0[1] = 255 - (255 - img_p.0[2]) / 2;
  484. img_p.0[0] /= 2;
  485. imageproc::drawing::Canvas::draw_pixel(&mut img, _x, _y, img_p)
  486. }
  487. }
  488. }
  489. }
  490. }
  491. // mkdir and save
  492. let mut runs = PathBuf::from("runs");
  493. if !runs.exists() {
  494. std::fs::create_dir_all(&runs).unwrap();
  495. }
  496. runs.push(gen_time_string("-"));
  497. let saveout = format!("{}.jpg", runs.to_str().unwrap());
  498. let _ = img.save(saveout);
  499. }
  500. }
  501. pub fn summary(&self) {
  502. println!(
  503. "\nSummary:\n\
  504. > Task: {:?}{}\n\
  505. > EP: {:?} {}\n\
  506. > Dtype: {:?}\n\
  507. > Batch: {} ({}), Height: {} ({}), Width: {} ({})\n\
  508. > nc: {} nk: {}, nm: {}, conf: {}, kconf: {}, iou: {}\n\
  509. ",
  510. self.task(),
  511. match self.engine.author().zip(self.engine.version()) {
  512. Some((author, ver)) => format!(" ({} {})", author, ver),
  513. None => String::from(""),
  514. },
  515. self.engine.ep(),
  516. if let OrtEP::CPU = self.engine.ep() {
  517. ""
  518. } else {
  519. "(May still fall back to CPU)"
  520. },
  521. self.engine.dtype(),
  522. self.batch(),
  523. if self.engine.is_batch_dynamic() {
  524. "Dynamic"
  525. } else {
  526. "Const"
  527. },
  528. self.height(),
  529. if self.engine.is_height_dynamic() {
  530. "Dynamic"
  531. } else {
  532. "Const"
  533. },
  534. self.width(),
  535. if self.engine.is_width_dynamic() {
  536. "Dynamic"
  537. } else {
  538. "Const"
  539. },
  540. self.nc(),
  541. self.nk(),
  542. self.nm(),
  543. self.conf,
  544. self.kconf,
  545. self.iou,
  546. );
  547. }
  548. pub fn engine(&self) -> &OrtBackend {
  549. &self.engine
  550. }
  551. pub fn conf(&self) -> f32 {
  552. self.conf
  553. }
  554. pub fn set_conf(&mut self, val: f32) {
  555. self.conf = val;
  556. }
  557. pub fn conf_mut(&mut self) -> &mut f32 {
  558. &mut self.conf
  559. }
  560. pub fn kconf(&self) -> f32 {
  561. self.kconf
  562. }
  563. pub fn iou(&self) -> f32 {
  564. self.iou
  565. }
  566. pub fn task(&self) -> &YOLOTask {
  567. &self.task
  568. }
  569. pub fn batch(&self) -> u32 {
  570. self.batch
  571. }
  572. pub fn width(&self) -> u32 {
  573. self.width
  574. }
  575. pub fn height(&self) -> u32 {
  576. self.height
  577. }
  578. pub fn nc(&self) -> u32 {
  579. self.nc
  580. }
  581. pub fn nk(&self) -> u32 {
  582. self.nk
  583. }
  584. pub fn nm(&self) -> u32 {
  585. self.nm
  586. }
  587. pub fn names(&self) -> &Vec<String> {
  588. &self.names
  589. }
  590. }