From 0faf75a797480ba66ca3841bbedf4902b59688e9 Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Wed, 5 Feb 2025 18:47:30 -0800 Subject: [PATCH 1/5] Start PoC of TritonCore based RequestHandler hooked up to DistributedRuntime --- .../examples/triton_core/client.py | 36 ++++++ .../triton_core/models/mock_llm/1/model.py | 112 +++++++++++++++++ .../triton_core/models/mock_llm/config.pbtxt | 67 ++++++++++ .../examples/triton_core/server.py | 116 ++++++++++++++++++ 4 files changed, 331 insertions(+) create mode 100644 runtime/rust/python-wheel/examples/triton_core/client.py create mode 100644 runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py create mode 100644 runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt create mode 100644 runtime/rust/python-wheel/examples/triton_core/server.py diff --git a/runtime/rust/python-wheel/examples/triton_core/client.py b/runtime/rust/python-wheel/examples/triton_core/client.py new file mode 100644 index 00000000..94d6fb24 --- /dev/null +++ b/runtime/rust/python-wheel/examples/triton_core/client.py @@ -0,0 +1,36 @@ +import asyncio + +import uvloop +from triton_distributed_rs import DistributedRuntime, triton_worker + + +@triton_worker() +async def worker(runtime: DistributedRuntime): + namespace: str = "triton_core_example" + await init(runtime, namespace) + + +async def init(runtime: DistributedRuntime, ns: str): + """ + Instantiate a `backend` client and call the `generate` endpoint + """ + # get endpoint + endpoint = runtime.namespace(ns).component("backend").endpoint("generate") + + # create client + client = await endpoint.client() + + # wait for an endpoint to be ready + await client.wait_for_endpoints() + + # issue request + stream = await client.generate("hello world") + + # process the stream + async for char in stream: + print(char) + + +if __name__ == "__main__": + uvloop.install() + asyncio.run(worker()) diff --git a/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py new file mode 100644 index 00000000..b82dedaa --- /dev/null +++ b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py @@ -0,0 +1,112 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import time + +import numpy as np +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + self.model_config = json.loads(args["model_config"]) + self.decoupled = self.model_config.get("model_transaction_policy", {}).get( + "decoupled" + ) + + def execute(self, requests): + if self.decoupled: + return self.exec_decoupled(requests) + else: + return self.exec(requests) + + def exec(self, requests): + responses = [] + for request in requests: + params = json.loads(request.parameters()) + rep_count = params["REPETITION"] if "REPETITION" in params else 1 + + input_np = pb_utils.get_input_tensor_by_name( + request, "text_input" + ).as_numpy() + stream_np = pb_utils.get_input_tensor_by_name(request, "stream") + stream = False + if stream_np: + stream = stream_np.as_numpy().flatten()[0] + if stream: + responses.append( + pb_utils.InferenceResponse( + error=pb_utils.TritonError( + "STREAM only supported in decoupled mode" + ) + ) + ) + else: + out_tensor = pb_utils.Tensor( + "text_output", np.repeat(input_np, rep_count, axis=1) + ) + responses.append(pb_utils.InferenceResponse([out_tensor])) + return responses + + def exec_decoupled(self, requests): + for request in requests: + params = json.loads(request.parameters()) + rep_count = params["REPETITION"] if "REPETITION" in params else 1 + fail_last = params["FAIL_LAST"] if "FAIL_LAST" in params else False + delay = params["DELAY"] if "DELAY" in params else None + + sender = request.get_response_sender() + input_np = pb_utils.get_input_tensor_by_name( + request, "text_input" + ).as_numpy() + stream_np = pb_utils.get_input_tensor_by_name(request, "stream") + stream = False + if stream_np: + stream = stream_np.as_numpy().flatten()[0] + out_tensor = pb_utils.Tensor("text_output", input_np) + response = pb_utils.InferenceResponse([out_tensor]) + # If stream enabled, just send multiple copies of response + # FIXME: Could split up response string into tokens, but this is simpler for now. + if stream: + for _ in range(rep_count): + if delay is not None: + time.sleep(delay) + sender.send(response) + sender.send( + None + if not fail_last + else pb_utils.InferenceResponse( + error=pb_utils.TritonError("An Error Occurred") + ), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) + # If stream disabled, just send one response + else: + sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + return None diff --git a/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt new file mode 100644 index 00000000..e9e640ad --- /dev/null +++ b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt @@ -0,0 +1,67 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +backend: "python" + +max_batch_size: 0 + +model_transaction_policy { + decoupled: True +} + +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ 1 ] + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: True + }, + { + name: "sampling_parameters" + data_type: TYPE_STRING + dims: [ 1 ] + optional: True + } +] + +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_MODEL + } +] diff --git a/runtime/rust/python-wheel/examples/triton_core/server.py b/runtime/rust/python-wheel/examples/triton_core/server.py new file mode 100644 index 00000000..7dfb4175 --- /dev/null +++ b/runtime/rust/python-wheel/examples/triton_core/server.py @@ -0,0 +1,116 @@ +import asyncio +from typing import Any, AsyncIterator, Dict, List + +import uvloop +from triton_distributed_rs import DistributedRuntime, triton_worker +from tritonserver import ModelControlMode +from tritonserver import Server as TritonCore +from tritonserver import Tensor +from tritonserver._api._response import InferenceResponse + + +# FIXME: Can this be more generic to arbitrary Triton models? +class RequestHandler: + """ + Request handler for the generate endpoint that uses TritonCoreOperator + to process text generation requests. + """ + + def __init__(self, name: str = "mock_llm", repository: str = "./models"): + self.name: str = name + + # Initialize TritonCore + self._triton_core = TritonCore( + model_repository=repository, + log_info=True, + log_error=True, + model_control_mode=ModelControlMode.EXPLICIT, + ).start(wait_until_ready=True) + + # Load only the requested model + self._triton_core.load(name) + + # Get a handle to the requested model for re-use + self._model = self._triton_core.model(name) + + # Validate the model has the expected inputs and outputs + self._validate_model_config() + + print(f"Model {self.name} ready to generate") + + def _validate_model_config(self): + self._model_metadata = self._model.metadata() + self._inputs = self._model_metadata["inputs"] + self._outputs = self._model_metadata["outputs"] + + # Validate the model has the expected input and output + self._expected_input_name: str = "text_input" + if not any( + input["name"] == self._expected_input_name for input in self._inputs + ): + raise ValueError( + f"Model {self.name} does not have an input named {self._expected_input_name}" + ) + + self._expected_output_name: str = "text_output" + if not any( + output["name"] == self._expected_output_name for output in self._outputs + ): + raise ValueError( + f"Model {self.name} does not have an output named {self._expected_output_name}" + ) + + async def generate(self, request: str): + # FIXME: Iron out request type/schema + if not isinstance(request, str): + raise ValueError("Request must be a string") + + try: + print(f"Processing generation request: {request}") + text_input: List[str] = [request] + stream: List[bool] = [True] + + triton_core_inputs: Dict[str, Any] = { + "text_input": text_input, + "stream": stream, + } + responses: AsyncIterator[InferenceResponse] = self._model.async_infer( + inputs=triton_core_inputs + ) + + async for response in responses: + print(f"Received response: {response}") + text_output: str = "" + + text_output_tensor: Tensor = response.outputs.get("text_output") + if text_output_tensor: + text_output: str = text_output_tensor.to_string_array()[0] + + if response.error: + raise response.error + + yield text_output + + except Exception as e: + print(f"Error processing request: {e}") + raise + + +@triton_worker() +async def worker(runtime: DistributedRuntime): + """ + Instantiate a `backend` component and serve the `generate` endpoint + A `Component` can serve multiple endpoints + """ + namespace: str = "triton_core_example" + component = runtime.namespace(namespace).component("backend") + await component.create_service() + + endpoint = component.endpoint("generate") + print("Started server instance") + await endpoint.serve_endpoint(RequestHandler().generate) + + +if __name__ == "__main__": + uvloop.install() + asyncio.run(worker()) From 41716f56a561aa89cf71db152df5006579fb4d3b Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Wed, 5 Feb 2025 18:55:04 -0800 Subject: [PATCH 2/5] Cleanup --- .../triton_core/models/mock_llm/1/model.py | 26 ------------------- .../triton_core/models/mock_llm/config.pbtxt | 25 ------------------ .../examples/triton_core/server.py | 15 ++++++----- 3 files changed, 8 insertions(+), 58 deletions(-) diff --git a/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py index b82dedaa..456a36d2 100644 --- a/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py +++ b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py @@ -1,29 +1,3 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - import json import time diff --git a/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt index e9e640ad..28889ee3 100644 --- a/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt +++ b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt @@ -1,28 +1,3 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. backend: "python" max_batch_size: 0 diff --git a/runtime/rust/python-wheel/examples/triton_core/server.py b/runtime/rust/python-wheel/examples/triton_core/server.py index 7dfb4175..e5b6be7e 100644 --- a/runtime/rust/python-wheel/examples/triton_core/server.py +++ b/runtime/rust/python-wheel/examples/triton_core/server.py @@ -16,8 +16,8 @@ class RequestHandler: to process text generation requests. """ - def __init__(self, name: str = "mock_llm", repository: str = "./models"): - self.name: str = name + def __init__(self, model_name: str = "mock_llm", repository: str = "./models"): + self.model_name: str = model_name # Initialize TritonCore self._triton_core = TritonCore( @@ -28,16 +28,17 @@ def __init__(self, name: str = "mock_llm", repository: str = "./models"): ).start(wait_until_ready=True) # Load only the requested model - self._triton_core.load(name) + self._triton_core.load(self.model_name) # Get a handle to the requested model for re-use - self._model = self._triton_core.model(name) + self._model = self._triton_core.model(self.model_name) # Validate the model has the expected inputs and outputs self._validate_model_config() - print(f"Model {self.name} ready to generate") + print(f"Model {self.model_name} ready to generate") + # FIXME: Can this be more generic to arbitrary Triton models? def _validate_model_config(self): self._model_metadata = self._model.metadata() self._inputs = self._model_metadata["inputs"] @@ -49,7 +50,7 @@ def _validate_model_config(self): input["name"] == self._expected_input_name for input in self._inputs ): raise ValueError( - f"Model {self.name} does not have an input named {self._expected_input_name}" + f"Model {self.model_name} does not have an input named {self._expected_input_name}" ) self._expected_output_name: str = "text_output" @@ -57,7 +58,7 @@ def _validate_model_config(self): output["name"] == self._expected_output_name for output in self._outputs ): raise ValueError( - f"Model {self.name} does not have an output named {self._expected_output_name}" + f"Model {self.model_name} does not have an output named {self._expected_output_name}" ) async def generate(self, request: str): From 5eb03d1f4a91af74f5a4d17a5256a97507f862ff Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Wed, 5 Feb 2025 19:03:49 -0800 Subject: [PATCH 3/5] Add copyright --- .../python-wheel/examples/triton_core/client.py | 16 ++++++++++++++++ .../triton_core/models/mock_llm/1/model.py | 15 +++++++++++++++ .../triton_core/models/mock_llm/config.pbtxt | 15 +++++++++++++++ .../python-wheel/examples/triton_core/server.py | 16 ++++++++++++++++ 4 files changed, 62 insertions(+) diff --git a/runtime/rust/python-wheel/examples/triton_core/client.py b/runtime/rust/python-wheel/examples/triton_core/client.py index 94d6fb24..a534a1f3 100644 --- a/runtime/rust/python-wheel/examples/triton_core/client.py +++ b/runtime/rust/python-wheel/examples/triton_core/client.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import asyncio import uvloop diff --git a/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py index 456a36d2..a1dd07b4 100644 --- a/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py +++ b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/1/model.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import time diff --git a/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt index 28889ee3..ec48f0d5 100644 --- a/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt +++ b/runtime/rust/python-wheel/examples/triton_core/models/mock_llm/config.pbtxt @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + backend: "python" max_batch_size: 0 diff --git a/runtime/rust/python-wheel/examples/triton_core/server.py b/runtime/rust/python-wheel/examples/triton_core/server.py index e5b6be7e..161018d7 100644 --- a/runtime/rust/python-wheel/examples/triton_core/server.py +++ b/runtime/rust/python-wheel/examples/triton_core/server.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import asyncio from typing import Any, AsyncIterator, Dict, List From 55db145bc39eaf98de8da6f16a3c6cbc5f12d21a Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Wed, 5 Feb 2025 19:16:16 -0800 Subject: [PATCH 4/5] Make default model repository a relative path, cleanup a few types/imports --- .../python-wheel/examples/triton_core/server.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/runtime/rust/python-wheel/examples/triton_core/server.py b/runtime/rust/python-wheel/examples/triton_core/server.py index 161018d7..49195f40 100644 --- a/runtime/rust/python-wheel/examples/triton_core/server.py +++ b/runtime/rust/python-wheel/examples/triton_core/server.py @@ -15,6 +15,7 @@ import asyncio +import os from typing import Any, AsyncIterator, Dict, List import uvloop @@ -28,16 +29,20 @@ # FIXME: Can this be more generic to arbitrary Triton models? class RequestHandler: """ - Request handler for the generate endpoint that uses TritonCoreOperator + Request handler for the generate endpoint that uses TritonCore to process text generation requests. """ - def __init__(self, model_name: str = "mock_llm", repository: str = "./models"): + def __init__( + self, + model_name: str = "mock_llm", + model_repository: str = os.path.join(os.path.dirname(__file__), "models"), + ): self.model_name: str = model_name # Initialize TritonCore self._triton_core = TritonCore( - model_repository=repository, + model_repository=model_repository, log_info=True, log_error=True, model_control_mode=ModelControlMode.EXPLICIT, @@ -77,7 +82,7 @@ def _validate_model_config(self): f"Model {self.model_name} does not have an output named {self._expected_output_name}" ) - async def generate(self, request: str): + async def generate(self, request: str) -> AsyncIterator[str]: # FIXME: Iron out request type/schema if not isinstance(request, str): raise ValueError("Request must be a string") From 4a8275d7e6a765ef259412ec3f6776e807992118 Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Wed, 5 Feb 2025 19:19:12 -0800 Subject: [PATCH 5/5] mypy fixes --- runtime/rust/python-wheel/examples/triton_core/__init__.py | 0 runtime/rust/python-wheel/examples/triton_core/server.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 runtime/rust/python-wheel/examples/triton_core/__init__.py diff --git a/runtime/rust/python-wheel/examples/triton_core/__init__.py b/runtime/rust/python-wheel/examples/triton_core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runtime/rust/python-wheel/examples/triton_core/server.py b/runtime/rust/python-wheel/examples/triton_core/server.py index 49195f40..d298c813 100644 --- a/runtime/rust/python-wheel/examples/triton_core/server.py +++ b/runtime/rust/python-wheel/examples/triton_core/server.py @@ -104,9 +104,9 @@ async def generate(self, request: str) -> AsyncIterator[str]: print(f"Received response: {response}") text_output: str = "" - text_output_tensor: Tensor = response.outputs.get("text_output") + text_output_tensor: Tensor | None = response.outputs.get("text_output") if text_output_tensor: - text_output: str = text_output_tensor.to_string_array()[0] + text_output = text_output_tensor.to_string_array()[0] if response.error: raise response.error