triton.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from typing import List
  3. from urllib.parse import urlsplit
  4. import numpy as np
  5. class TritonRemoteModel:
  6. """
  7. Client for interacting with a remote Triton Inference Server model.
  8. Attributes:
  9. endpoint (str): The name of the model on the Triton server.
  10. url (str): The URL of the Triton server.
  11. triton_client: The Triton client (either HTTP or gRPC).
  12. InferInput: The input class for the Triton client.
  13. InferRequestedOutput: The output request class for the Triton client.
  14. input_formats (List[str]): The data types of the model inputs.
  15. np_input_formats (List[type]): The numpy data types of the model inputs.
  16. input_names (List[str]): The names of the model inputs.
  17. output_names (List[str]): The names of the model outputs.
  18. """
  19. def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
  20. """
  21. Initialize the TritonRemoteModel.
  22. Arguments may be provided individually or parsed from a collective 'url' argument of the form
  23. <scheme>://<netloc>/<endpoint>/<task_name>
  24. Args:
  25. url (str): The URL of the Triton server.
  26. endpoint (str): The name of the model on the Triton server.
  27. scheme (str): The communication scheme ('http' or 'grpc').
  28. """
  29. if not endpoint and not scheme: # Parse all args from URL string
  30. splits = urlsplit(url)
  31. endpoint = splits.path.strip("/").split("/")[0]
  32. scheme = splits.scheme
  33. url = splits.netloc
  34. self.endpoint = endpoint
  35. self.url = url
  36. # Choose the Triton client based on the communication scheme
  37. if scheme == "http":
  38. import tritonclient.http as client # noqa
  39. self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
  40. config = self.triton_client.get_model_config(endpoint)
  41. else:
  42. import tritonclient.grpc as client # noqa
  43. self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
  44. config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
  45. # Sort output names alphabetically, i.e. 'output0', 'output1', etc.
  46. config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
  47. # Define model attributes
  48. type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
  49. self.InferRequestedOutput = client.InferRequestedOutput
  50. self.InferInput = client.InferInput
  51. self.input_formats = [x["data_type"] for x in config["input"]]
  52. self.np_input_formats = [type_map[x] for x in self.input_formats]
  53. self.input_names = [x["name"] for x in config["input"]]
  54. self.output_names = [x["name"] for x in config["output"]]
  55. self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None"))
  56. def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
  57. """
  58. Call the model with the given inputs.
  59. Args:
  60. *inputs (List[np.ndarray]): Input data to the model.
  61. Returns:
  62. (List[np.ndarray]): Model outputs.
  63. """
  64. infer_inputs = []
  65. input_format = inputs[0].dtype
  66. for i, x in enumerate(inputs):
  67. if x.dtype != self.np_input_formats[i]:
  68. x = x.astype(self.np_input_formats[i])
  69. infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
  70. infer_input.set_data_from_numpy(x)
  71. infer_inputs.append(infer_input)
  72. infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
  73. outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
  74. return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]