main.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import argparse
  3. import cv2
  4. import numpy as np
  5. import onnxruntime as ort
  6. from ultralytics.utils import ASSETS, yaml_load
  7. from ultralytics.utils.checks import check_yaml
  8. from ultralytics.utils.plotting import Colors
  9. class YOLOv8Seg:
  10. """YOLOv8 segmentation model."""
  11. def __init__(self, onnx_model):
  12. """
  13. Initialization.
  14. Args:
  15. onnx_model (str): Path to the ONNX model.
  16. """
  17. # Build Ort session
  18. self.session = ort.InferenceSession(
  19. onnx_model,
  20. providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
  21. if ort.get_device() == "GPU"
  22. else ["CPUExecutionProvider"],
  23. )
  24. # Numpy dtype: support both FP32 and FP16 onnx model
  25. self.ndtype = np.half if self.session.get_inputs()[0].type == "tensor(float16)" else np.single
  26. # Get model width and height(YOLOv8-seg only has one input)
  27. self.model_height, self.model_width = [x.shape for x in self.session.get_inputs()][0][-2:]
  28. # Load COCO class names
  29. self.classes = yaml_load(check_yaml("coco8.yaml"))["names"]
  30. # Create color palette
  31. self.color_palette = Colors()
  32. def __call__(self, im0, conf_threshold=0.4, iou_threshold=0.45, nm=32):
  33. """
  34. The whole pipeline: pre-process -> inference -> post-process.
  35. Args:
  36. im0 (Numpy.ndarray): original input image.
  37. conf_threshold (float): confidence threshold for filtering predictions.
  38. iou_threshold (float): iou threshold for NMS.
  39. nm (int): the number of masks.
  40. Returns:
  41. boxes (List): list of bounding boxes.
  42. segments (List): list of segments.
  43. masks (np.ndarray): [N, H, W], output masks.
  44. """
  45. # Pre-process
  46. im, ratio, (pad_w, pad_h) = self.preprocess(im0)
  47. # Ort inference
  48. preds = self.session.run(None, {self.session.get_inputs()[0].name: im})
  49. # Post-process
  50. boxes, segments, masks = self.postprocess(
  51. preds,
  52. im0=im0,
  53. ratio=ratio,
  54. pad_w=pad_w,
  55. pad_h=pad_h,
  56. conf_threshold=conf_threshold,
  57. iou_threshold=iou_threshold,
  58. nm=nm,
  59. )
  60. return boxes, segments, masks
  61. def preprocess(self, img):
  62. """
  63. Pre-processes the input image.
  64. Args:
  65. img (Numpy.ndarray): image about to be processed.
  66. Returns:
  67. img_process (Numpy.ndarray): image preprocessed for inference.
  68. ratio (tuple): width, height ratios in letterbox.
  69. pad_w (float): width padding in letterbox.
  70. pad_h (float): height padding in letterbox.
  71. """
  72. # Resize and pad input image using letterbox() (Borrowed from Ultralytics)
  73. shape = img.shape[:2] # original image shape
  74. new_shape = (self.model_height, self.model_width)
  75. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  76. ratio = r, r
  77. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  78. pad_w, pad_h = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding
  79. if shape[::-1] != new_unpad: # resize
  80. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  81. top, bottom = int(round(pad_h - 0.1)), int(round(pad_h + 0.1))
  82. left, right = int(round(pad_w - 0.1)), int(round(pad_w + 0.1))
  83. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
  84. # Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional)
  85. img = np.ascontiguousarray(np.einsum("HWC->CHW", img)[::-1], dtype=self.ndtype) / 255.0
  86. img_process = img[None] if len(img.shape) == 3 else img
  87. return img_process, ratio, (pad_w, pad_h)
  88. def postprocess(self, preds, im0, ratio, pad_w, pad_h, conf_threshold, iou_threshold, nm=32):
  89. """
  90. Post-process the prediction.
  91. Args:
  92. preds (Numpy.ndarray): predictions come from ort.session.run().
  93. im0 (Numpy.ndarray): [h, w, c] original input image.
  94. ratio (tuple): width, height ratios in letterbox.
  95. pad_w (float): width padding in letterbox.
  96. pad_h (float): height padding in letterbox.
  97. conf_threshold (float): conf threshold.
  98. iou_threshold (float): iou threshold.
  99. nm (int): the number of masks.
  100. Returns:
  101. boxes (List): list of bounding boxes.
  102. segments (List): list of segments.
  103. masks (np.ndarray): [N, H, W], output masks.
  104. """
  105. x, protos = preds[0], preds[1] # Two outputs: predictions and protos
  106. # Transpose dim 1: (Batch_size, xywh_conf_cls_nm, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls_nm)
  107. x = np.einsum("bcn->bnc", x)
  108. # Predictions filtering by conf-threshold
  109. x = x[np.amax(x[..., 4:-nm], axis=-1) > conf_threshold]
  110. # Create a new matrix which merge these(box, score, cls, nm) into one
  111. # For more details about `numpy.c_()`: https://numpy.org/doc/1.26/reference/generated/numpy.c_.html
  112. x = np.c_[x[..., :4], np.amax(x[..., 4:-nm], axis=-1), np.argmax(x[..., 4:-nm], axis=-1), x[..., -nm:]]
  113. # NMS filtering
  114. x = x[cv2.dnn.NMSBoxes(x[:, :4], x[:, 4], conf_threshold, iou_threshold)]
  115. # Decode and return
  116. if len(x) > 0:
  117. # Bounding boxes format change: cxcywh -> xyxy
  118. x[..., [0, 1]] -= x[..., [2, 3]] / 2
  119. x[..., [2, 3]] += x[..., [0, 1]]
  120. # Rescales bounding boxes from model shape(model_height, model_width) to the shape of original image
  121. x[..., :4] -= [pad_w, pad_h, pad_w, pad_h]
  122. x[..., :4] /= min(ratio)
  123. # Bounding boxes boundary clamp
  124. x[..., [0, 2]] = x[:, [0, 2]].clip(0, im0.shape[1])
  125. x[..., [1, 3]] = x[:, [1, 3]].clip(0, im0.shape[0])
  126. # Process masks
  127. masks = self.process_mask(protos[0], x[:, 6:], x[:, :4], im0.shape)
  128. # Masks -> Segments(contours)
  129. segments = self.masks2segments(masks)
  130. return x[..., :6], segments, masks # boxes, segments, masks
  131. else:
  132. return [], [], []
  133. @staticmethod
  134. def masks2segments(masks):
  135. """
  136. Takes a list of masks(n,h,w) and returns a list of segments(n,xy), from
  137. https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py.
  138. Args:
  139. masks (numpy.ndarray): the output of the model, which is a tensor of shape (batch_size, 160, 160).
  140. Returns:
  141. segments (List): list of segment masks.
  142. """
  143. segments = []
  144. for x in masks.astype("uint8"):
  145. c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0] # CHAIN_APPROX_SIMPLE
  146. if c:
  147. c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
  148. else:
  149. c = np.zeros((0, 2)) # no segments found
  150. segments.append(c.astype("float32"))
  151. return segments
  152. @staticmethod
  153. def crop_mask(masks, boxes):
  154. """
  155. Takes a mask and a bounding box, and returns a mask that is cropped to the bounding box, from
  156. https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py.
  157. Args:
  158. masks (Numpy.ndarray): [n, h, w] tensor of masks.
  159. boxes (Numpy.ndarray): [n, 4] tensor of bbox coordinates in relative point form.
  160. Returns:
  161. (Numpy.ndarray): The masks are being cropped to the bounding box.
  162. """
  163. n, h, w = masks.shape
  164. x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1)
  165. r = np.arange(w, dtype=x1.dtype)[None, None, :]
  166. c = np.arange(h, dtype=x1.dtype)[None, :, None]
  167. return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
  168. def process_mask(self, protos, masks_in, bboxes, im0_shape):
  169. """
  170. Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
  171. quality but is slower, from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py.
  172. Args:
  173. protos (numpy.ndarray): [mask_dim, mask_h, mask_w].
  174. masks_in (numpy.ndarray): [n, mask_dim], n is number of masks after nms.
  175. bboxes (numpy.ndarray): bboxes re-scaled to original image shape.
  176. im0_shape (tuple): the size of the input image (h,w,c).
  177. Returns:
  178. (numpy.ndarray): The upsampled masks.
  179. """
  180. c, mh, mw = protos.shape
  181. masks = np.matmul(masks_in, protos.reshape((c, -1))).reshape((-1, mh, mw)).transpose(1, 2, 0) # HWN
  182. masks = np.ascontiguousarray(masks)
  183. masks = self.scale_mask(masks, im0_shape) # re-scale mask from P3 shape to original input image shape
  184. masks = np.einsum("HWN -> NHW", masks) # HWN -> NHW
  185. masks = self.crop_mask(masks, bboxes)
  186. return np.greater(masks, 0.5)
  187. @staticmethod
  188. def scale_mask(masks, im0_shape, ratio_pad=None):
  189. """
  190. Takes a mask, and resizes it to the original image size, from
  191. https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py.
  192. Args:
  193. masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
  194. im0_shape (tuple): the original image shape.
  195. ratio_pad (tuple): the ratio of the padding to the original image.
  196. Returns:
  197. masks (np.ndarray): The masks that are being returned.
  198. """
  199. im1_shape = masks.shape[:2]
  200. if ratio_pad is None: # calculate from im0_shape
  201. gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
  202. pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
  203. else:
  204. pad = ratio_pad[1]
  205. # Calculate tlbr of mask
  206. top, left = int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1)) # y, x
  207. bottom, right = int(round(im1_shape[0] - pad[1] + 0.1)), int(round(im1_shape[1] - pad[0] + 0.1))
  208. if len(masks.shape) < 2:
  209. raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
  210. masks = masks[top:bottom, left:right]
  211. masks = cv2.resize(
  212. masks, (im0_shape[1], im0_shape[0]), interpolation=cv2.INTER_LINEAR
  213. ) # INTER_CUBIC would be better
  214. if len(masks.shape) == 2:
  215. masks = masks[:, :, None]
  216. return masks
  217. def draw_and_visualize(self, im, bboxes, segments, vis=False, save=True):
  218. """
  219. Draw and visualize results.
  220. Args:
  221. im (np.ndarray): original image, shape [h, w, c].
  222. bboxes (numpy.ndarray): [n, 4], n is number of bboxes.
  223. segments (List): list of segment masks.
  224. vis (bool): imshow using OpenCV.
  225. save (bool): save image annotated.
  226. Returns:
  227. None
  228. """
  229. # Draw rectangles and polygons
  230. im_canvas = im.copy()
  231. for (*box, conf, cls_), segment in zip(bboxes, segments):
  232. # draw contour and fill mask
  233. cv2.polylines(im, np.int32([segment]), True, (255, 255, 255), 2) # white borderline
  234. cv2.fillPoly(im_canvas, np.int32([segment]), self.color_palette(int(cls_), bgr=True))
  235. # draw bbox rectangle
  236. cv2.rectangle(
  237. im,
  238. (int(box[0]), int(box[1])),
  239. (int(box[2]), int(box[3])),
  240. self.color_palette(int(cls_), bgr=True),
  241. 1,
  242. cv2.LINE_AA,
  243. )
  244. cv2.putText(
  245. im,
  246. f"{self.classes[cls_]}: {conf:.3f}",
  247. (int(box[0]), int(box[1] - 9)),
  248. cv2.FONT_HERSHEY_SIMPLEX,
  249. 0.7,
  250. self.color_palette(int(cls_), bgr=True),
  251. 2,
  252. cv2.LINE_AA,
  253. )
  254. # Mix image
  255. im = cv2.addWeighted(im_canvas, 0.3, im, 0.7, 0)
  256. # Show image
  257. if vis:
  258. cv2.imshow("demo", im)
  259. cv2.waitKey(0)
  260. cv2.destroyAllWindows()
  261. # Save image
  262. if save:
  263. cv2.imwrite("demo.jpg", im)
  264. if __name__ == "__main__":
  265. # Create an argument parser to handle command-line arguments
  266. parser = argparse.ArgumentParser()
  267. parser.add_argument("--model", type=str, required=True, help="Path to ONNX model")
  268. parser.add_argument("--source", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image")
  269. parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
  270. parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold")
  271. args = parser.parse_args()
  272. # Build model
  273. model = YOLOv8Seg(args.model)
  274. # Read image by OpenCV
  275. img = cv2.imread(args.source)
  276. # Inference
  277. boxes, segments, _ = model(img, conf_threshold=args.conf, iou_threshold=args.iou)
  278. # Draw bboxes and polygons
  279. if len(boxes) > 0:
  280. model.draw_and_visualize(img, boxes, segments, vis=False, save=True)