Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: TritonCore DistributedRuntime Worker Example #121

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
52 changes: 52 additions & 0 deletions runtime/rust/python-wheel/examples/triton_core/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio

import uvloop
from triton_distributed_rs import DistributedRuntime, triton_worker


@triton_worker()
async def worker(runtime: DistributedRuntime):
namespace: str = "triton_core_example"
await init(runtime, namespace)


async def init(runtime: DistributedRuntime, ns: str):
"""
Instantiate a `backend` client and call the `generate` endpoint
"""
# get endpoint
endpoint = runtime.namespace(ns).component("backend").endpoint("generate")

# create client
client = await endpoint.client()

# wait for an endpoint to be ready
await client.wait_for_endpoints()

# issue request
stream = await client.generate("hello world")

# process the stream
async for char in stream:
print(char)


if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import time

import numpy as np
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
def initialize(self, args):
self.model_config = json.loads(args["model_config"])
self.decoupled = self.model_config.get("model_transaction_policy", {}).get(
"decoupled"
)

def execute(self, requests):
if self.decoupled:
return self.exec_decoupled(requests)
else:
return self.exec(requests)

def exec(self, requests):
responses = []
for request in requests:
params = json.loads(request.parameters())
rep_count = params["REPETITION"] if "REPETITION" in params else 1

input_np = pb_utils.get_input_tensor_by_name(
request, "text_input"
).as_numpy()
stream_np = pb_utils.get_input_tensor_by_name(request, "stream")
stream = False
if stream_np:
stream = stream_np.as_numpy().flatten()[0]
if stream:
responses.append(
pb_utils.InferenceResponse(
error=pb_utils.TritonError(
"STREAM only supported in decoupled mode"
)
)
)
else:
out_tensor = pb_utils.Tensor(
"text_output", np.repeat(input_np, rep_count, axis=1)
)
responses.append(pb_utils.InferenceResponse([out_tensor]))
return responses

def exec_decoupled(self, requests):
for request in requests:
params = json.loads(request.parameters())
rep_count = params["REPETITION"] if "REPETITION" in params else 1
fail_last = params["FAIL_LAST"] if "FAIL_LAST" in params else False
delay = params["DELAY"] if "DELAY" in params else None

sender = request.get_response_sender()
input_np = pb_utils.get_input_tensor_by_name(
request, "text_input"
).as_numpy()
stream_np = pb_utils.get_input_tensor_by_name(request, "stream")
stream = False
if stream_np:
stream = stream_np.as_numpy().flatten()[0]
out_tensor = pb_utils.Tensor("text_output", input_np)
response = pb_utils.InferenceResponse([out_tensor])
# If stream enabled, just send multiple copies of response
# FIXME: Could split up response string into tokens, but this is simpler for now.
if stream:
for _ in range(rep_count):
if delay is not None:
time.sleep(delay)
sender.send(response)
sender.send(
None
if not fail_last
else pb_utils.InferenceResponse(
error=pb_utils.TritonError("An Error Occurred")
),
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
)
# If stream disabled, just send one response
else:
sender.send(
response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
)
return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

backend: "python"

max_batch_size: 0

model_transaction_policy {
decoupled: True
}

input [
{
name: "text_input"
data_type: TYPE_STRING
dims: [ 1 ]
},
{
name: "stream"
data_type: TYPE_BOOL
dims: [ 1 ]
optional: True
},
{
name: "sampling_parameters"
data_type: TYPE_STRING
dims: [ 1 ]
optional: True
}
]

output [
{
name: "text_output"
data_type: TYPE_STRING
dims: [ -1 ]
}
]

instance_group [
{
count: 1
kind: KIND_MODEL
}
]
138 changes: 138 additions & 0 deletions runtime/rust/python-wheel/examples/triton_core/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio
import os
from typing import Any, AsyncIterator, Dict, List

import uvloop
from triton_distributed_rs import DistributedRuntime, triton_worker
from tritonserver import ModelControlMode
from tritonserver import Server as TritonCore
from tritonserver import Tensor
from tritonserver._api._response import InferenceResponse


# FIXME: Can this be more generic to arbitrary Triton models?
class RequestHandler:
"""
Request handler for the generate endpoint that uses TritonCore
to process text generation requests.
"""

def __init__(
self,
model_name: str = "mock_llm",
model_repository: str = os.path.join(os.path.dirname(__file__), "models"),
):
self.model_name: str = model_name

# Initialize TritonCore
self._triton_core = TritonCore(
model_repository=model_repository,
log_info=True,
log_error=True,
model_control_mode=ModelControlMode.EXPLICIT,
).start(wait_until_ready=True)

# Load only the requested model
self._triton_core.load(self.model_name)

# Get a handle to the requested model for re-use
self._model = self._triton_core.model(self.model_name)

# Validate the model has the expected inputs and outputs
self._validate_model_config()

print(f"Model {self.model_name} ready to generate")

# FIXME: Can this be more generic to arbitrary Triton models?
def _validate_model_config(self):
self._model_metadata = self._model.metadata()
self._inputs = self._model_metadata["inputs"]
self._outputs = self._model_metadata["outputs"]

# Validate the model has the expected input and output
self._expected_input_name: str = "text_input"
if not any(
input["name"] == self._expected_input_name for input in self._inputs
):
raise ValueError(
f"Model {self.model_name} does not have an input named {self._expected_input_name}"
)

self._expected_output_name: str = "text_output"
if not any(
output["name"] == self._expected_output_name for output in self._outputs
):
raise ValueError(
f"Model {self.model_name} does not have an output named {self._expected_output_name}"
)

async def generate(self, request: str) -> AsyncIterator[str]:
# FIXME: Iron out request type/schema
if not isinstance(request, str):
raise ValueError("Request must be a string")

try:
print(f"Processing generation request: {request}")
text_input: List[str] = [request]
stream: List[bool] = [True]

triton_core_inputs: Dict[str, Any] = {
"text_input": text_input,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

"stream": stream,
}
responses: AsyncIterator[InferenceResponse] = self._model.async_infer(
inputs=triton_core_inputs
)

async for response in responses:
print(f"Received response: {response}")
text_output: str = ""

text_output_tensor: Tensor | None = response.outputs.get("text_output")
if text_output_tensor:
text_output = text_output_tensor.to_string_array()[0]

if response.error:
raise response.error

yield text_output

except Exception as e:
print(f"Error processing request: {e}")
raise


@triton_worker()
async def worker(runtime: DistributedRuntime):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
namespace: str = "triton_core_example"
component = runtime.namespace(namespace).component("backend")
await component.create_service()

endpoint = component.endpoint("generate")
print("Started server instance")
await endpoint.serve_endpoint(RequestHandler().generate)


if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
Loading