diff --git a/runtime/rust/python-wheel/examples/triton_core/__init__.py b/runtime/rust/python-wheel/examples/triton_core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runtime/rust/python-wheel/examples/triton_core/client.py b/runtime/rust/python-wheel/examples/triton_core/client.py new file mode 100644 index 00000000..a534a1f3 --- /dev/null +++ b/runtime/rust/python-wheel/examples/triton_core/client.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +import uvloop +from triton_distributed_rs import DistributedRuntime, triton_worker + + +@triton_worker() +async def worker(runtime: DistributedRuntime): + namespace: str = "triton_core_example" + await init(runtime, namespace) + + +async def init(runtime: DistributedRuntime, ns: str): + """ + Instantiate a `backend` client and call the `generate` endpoint + """ + # get endpoint + endpoint = runtime.namespace(ns).component("backend").endpoint("generate") + + # create client + client = await endpoint.client() + + # wait for an endpoint to be ready + await client.wait_for_endpoints() + + # issue request + stream = await client.generate("hello world") + + # process the stream + async for char in stream: + print(char) + + +if __name__ == "__main__": + uvloop.install() + asyncio.run(worker()) diff --git a/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py new file mode 100644 index 00000000..a1dd07b4 --- /dev/null +++ b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import time + +import numpy as np +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + self.model_config = json.loads(args["model_config"]) + self.decoupled = self.model_config.get("model_transaction_policy", {}).get( + "decoupled" + ) + + def execute(self, requests): + if self.decoupled: + return self.exec_decoupled(requests) + else: + return self.exec(requests) + + def exec(self, requests): + responses = [] + for request in requests: + params = json.loads(request.parameters()) + rep_count = params["REPETITION"] if "REPETITION" in params else 1 + + input_np = pb_utils.get_input_tensor_by_name( + request, "text_input" + ).as_numpy() + stream_np = pb_utils.get_input_tensor_by_name(request, "stream") + stream = False + if stream_np: + stream = stream_np.as_numpy().flatten()[0] + if stream: + responses.append( + pb_utils.InferenceResponse( + error=pb_utils.TritonError( + "STREAM only supported in decoupled mode" + ) + ) + ) + else: + out_tensor = pb_utils.Tensor( + "text_output", np.repeat(input_np, rep_count, axis=1) + ) + responses.append(pb_utils.InferenceResponse([out_tensor])) + return responses + + def exec_decoupled(self, requests): + for request in requests: + params = json.loads(request.parameters()) + rep_count = params["REPETITION"] if "REPETITION" in params else 1 + fail_last = params["FAIL_LAST"] if "FAIL_LAST" in params else False + delay = params["DELAY"] if "DELAY" in params else None + + sender = request.get_response_sender() + input_np = pb_utils.get_input_tensor_by_name( + request, "text_input" + ).as_numpy() + stream_np = pb_utils.get_input_tensor_by_name(request, "stream") + stream = False + if stream_np: + stream = stream_np.as_numpy().flatten()[0] + out_tensor = pb_utils.Tensor("text_output", input_np) + response = pb_utils.InferenceResponse([out_tensor]) + # If stream enabled, just send multiple copies of response + # FIXME: Could split up response string into tokens, but this is simpler for now. + if stream: + for _ in range(rep_count): + if delay is not None: + time.sleep(delay) + sender.send(response) + sender.send( + None + if not fail_last + else pb_utils.InferenceResponse( + error=pb_utils.TritonError("An Error Occurred") + ), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) + # If stream disabled, just send one response + else: + sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + return None diff --git a/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt new file mode 100644 index 00000000..ec48f0d5 --- /dev/null +++ b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +backend: "python" + +max_batch_size: 0 + +model_transaction_policy { + decoupled: True +} + +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ 1 ] + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: True + }, + { + name: "sampling_parameters" + data_type: TYPE_STRING + dims: [ 1 ] + optional: True + } +] + +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_MODEL + } +] diff --git a/runtime/rust/python-wheel/examples/triton_core/server.py b/runtime/rust/python-wheel/examples/triton_core/server.py new file mode 100644 index 00000000..d298c813 --- /dev/null +++ b/runtime/rust/python-wheel/examples/triton_core/server.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import os +from typing import Any, AsyncIterator, Dict, List + +import uvloop +from triton_distributed_rs import DistributedRuntime, triton_worker +from tritonserver import ModelControlMode +from tritonserver import Server as TritonCore +from tritonserver import Tensor +from tritonserver._api._response import InferenceResponse + + +# FIXME: Can this be more generic to arbitrary Triton models? +class RequestHandler: + """ + Request handler for the generate endpoint that uses TritonCore + to process text generation requests. + """ + + def __init__( + self, + model_name: str = "mock_llm", + model_repository: str = os.path.join(os.path.dirname(__file__), "models"), + ): + self.model_name: str = model_name + + # Initialize TritonCore + self._triton_core = TritonCore( + model_repository=model_repository, + log_info=True, + log_error=True, + model_control_mode=ModelControlMode.EXPLICIT, + ).start(wait_until_ready=True) + + # Load only the requested model + self._triton_core.load(self.model_name) + + # Get a handle to the requested model for re-use + self._model = self._triton_core.model(self.model_name) + + # Validate the model has the expected inputs and outputs + self._validate_model_config() + + print(f"Model {self.model_name} ready to generate") + + # FIXME: Can this be more generic to arbitrary Triton models? + def _validate_model_config(self): + self._model_metadata = self._model.metadata() + self._inputs = self._model_metadata["inputs"] + self._outputs = self._model_metadata["outputs"] + + # Validate the model has the expected input and output + self._expected_input_name: str = "text_input" + if not any( + input["name"] == self._expected_input_name for input in self._inputs + ): + raise ValueError( + f"Model {self.model_name} does not have an input named {self._expected_input_name}" + ) + + self._expected_output_name: str = "text_output" + if not any( + output["name"] == self._expected_output_name for output in self._outputs + ): + raise ValueError( + f"Model {self.model_name} does not have an output named {self._expected_output_name}" + ) + + async def generate(self, request: str) -> AsyncIterator[str]: + # FIXME: Iron out request type/schema + if not isinstance(request, str): + raise ValueError("Request must be a string") + + try: + print(f"Processing generation request: {request}") + text_input: List[str] = [request] + stream: List[bool] = [True] + + triton_core_inputs: Dict[str, Any] = { + "text_input": text_input, + "stream": stream, + } + responses: AsyncIterator[InferenceResponse] = self._model.async_infer( + inputs=triton_core_inputs + ) + + async for response in responses: + print(f"Received response: {response}") + text_output: str = "" + + text_output_tensor: Tensor | None = response.outputs.get("text_output") + if text_output_tensor: + text_output = text_output_tensor.to_string_array()[0] + + if response.error: + raise response.error + + yield text_output + + except Exception as e: + print(f"Error processing request: {e}") + raise + + +@triton_worker() +async def worker(runtime: DistributedRuntime): + """ + Instantiate a `backend` component and serve the `generate` endpoint + A `Component` can serve multiple endpoints + """ + namespace: str = "triton_core_example" + component = runtime.namespace(namespace).component("backend") + await component.create_service() + + endpoint = component.endpoint("generate") + print("Started server instance") + await endpoint.serve_endpoint(RequestHandler().generate) + + +if __name__ == "__main__": + uvloop.install() + asyncio.run(worker())