Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lib/llm/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}

fn build_protos() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("src/grpc/protos/kserve.proto")?;
tonic_build::configure()
.type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]")
.compile_protos(&["kserve.proto"], &["src/grpc/protos"])?;
Ok(())
}

Expand Down
60 changes: 60 additions & 0 deletions lib/llm/src/grpc/service/kserve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ use inference::{
ModelMetadataRequest, ModelMetadataResponse, ModelStreamInferResponse,
};

use prost::Message;

/// [gluo TODO] 'metrics' are for HTTP service and there is HTTP endpoint
/// for it as part of HTTP service. Should we always start HTTP service up
/// for non-inference?
Expand Down Expand Up @@ -418,6 +420,50 @@ impl GrpcInferenceService for KserveService {
if card.model_type.supports_tensor() {
if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref()
{
if let Some(triton_model_config) =
tensor_model_config.triton_model_config.as_ref()
{
let model_config = ModelConfig::decode(triton_model_config.as_slice())
.map_err(|e| {
Status::invalid_argument(format!(
"Failed to deserialize model config: {}",
e
))
})?;
return Ok(Response::new(ModelMetadataResponse {
name: model_config.name,
versions: vec!["1".to_string()],
platform: model_config.platform,
inputs: model_config
.input
.iter()
.map(|input| inference::model_metadata_response::TensorMetadata {
name: input.name.clone(),
datatype: match inference::DataType::try_from(input.data_type) {
Ok(dt) => dt.as_str_name().to_string(),
Err(_) => "TYPE_INVALID".to_string(),
},
shape: input.dims.clone(),
})
.collect(),
outputs: model_config
.output
.iter()
.map(
|output| inference::model_metadata_response::TensorMetadata {
name: output.name.clone(),
datatype: match inference::DataType::try_from(
output.data_type,
) {
Ok(dt) => dt.as_str_name().to_string(),
Err(_) => "TYPE_INVALID".to_string(),
},
shape: output.dims.clone(),
},
)
.collect(),
}));
}
return Ok(Response::new(ModelMetadataResponse {
name: tensor_model_config.name.clone(),
versions: vec!["1".to_string()],
Expand Down Expand Up @@ -499,6 +545,20 @@ impl GrpcInferenceService for KserveService {
if card.model_type.supports_tensor() {
if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref()
{
if let Some(triton_model_config) =
tensor_model_config.triton_model_config.as_ref()
{
let model_config = ModelConfig::decode(triton_model_config.as_slice())
.map_err(|e| {
Status::invalid_argument(format!(
"Failed to deserialize model config: {}",
e
))
})?;
Comment on lines +551 to +557
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could extract ModelConfig::decode(...).map_err(...) operation to a helper function (ex: decode_triton_config()). I believe it is used below async fn model_config() as well.

Additionally, within decode_triton_config() would it make sense to perform any validation on the inputs to make sure the deserialized config is valid?

return Ok(Response::new(ModelConfigResponse {
config: Some(model_config),
}));
}
let model_config = ModelConfig {
name: tensor_model_config.name.clone(),
platform: "dynamo".to_string(),
Expand Down
15 changes: 15 additions & 0 deletions lib/llm/src/protocols/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ pub struct TensorModelConfig {
pub name: String,
pub inputs: Vec<TensorMetadata>,
pub outputs: Vec<TensorMetadata>,
// Optional Triton model config in serialized protobuf string,
// if provided, it supersedes the basic model config defined above.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub triton_model_config: Option<Vec<u8>>,
}

impl Default for TensorModelConfig {
fn default() -> Self {
Self {
name: "".to_string(),
inputs: vec![],
outputs: vec![],
triton_model_config: None,
}
}
}

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down
96 changes: 96 additions & 0 deletions lib/llm/tests/kserve_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub mod kserve_test {
use tonic::{Request, Response, transport::Channel};

use dynamo_async_openai::types::Prompt;
use prost::Message;

struct SplitEngine {}

Expand Down Expand Up @@ -361,6 +362,7 @@ pub mod kserve_test {
ModelInfo = 8994,
TensorModel = 8995,
TensorModelTypes = 8996,
TritonModelConfig = 8997,
}

#[rstest]
Expand Down Expand Up @@ -1173,6 +1175,7 @@ pub mod kserve_test {
shape: vec![-1],
parameters: Default::default(),
}],
triton_model_config: None,
}),
..Default::default()
};
Expand Down Expand Up @@ -1206,6 +1209,98 @@ pub mod kserve_test {
);
}

#[rstest]
#[tokio::test]
async fn test_triton_model_config(
#[with(TestPort::TritonModelConfig as u16)] service_with_engines: (
KserveService,
Arc<SplitEngine>,
Arc<AlwaysFailEngine>,
Arc<LongRunningEngine>,
),
) {
// start server
let _running = RunningService::spawn(service_with_engines.0.clone());

let mut client = get_ready_client(TestPort::TritonModelConfig as u16, 5).await;

let model_name = "tensor";
let expected_model_config = inference::ModelConfig {
name: model_name.to_string(),
platform: "custom".to_string(),
backend: "custom".to_string(),
input: vec![
inference::ModelInput {
name: "input".to_string(),
data_type: DataType::TypeInt32 as i32,
dims: vec![1],
optional: false,
..Default::default()
},
inference::ModelInput {
name: "optional_input".to_string(),
data_type: DataType::TypeInt32 as i32,
dims: vec![1],
optional: true,
..Default::default()
},
],
output: vec![inference::ModelOutput {
name: "output".to_string(),
data_type: DataType::TypeBool as i32,
dims: vec![-1],
..Default::default()
}],
model_transaction_policy: Some(inference::ModelTransactionPolicy { decoupled: true }),
..Default::default()
};

let mut buf = vec![];
expected_model_config.encode(&mut buf).unwrap();

// Register a tensor model
let mut card = ModelDeploymentCard::with_name_only(model_name);
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
card.runtime_config = ModelRuntimeConfig {
tensor_model_config: Some(tensor::TensorModelConfig {
triton_model_config: Some(buf),
..Default::default()
}),
..Default::default()
};
let tensor = Arc::new(TensorEngine {});
service_with_engines
.0
.model_manager()
.add_tensor_model("tensor", card.mdcsum(), tensor.clone())
.unwrap();
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);

// success config
let request = tonic::Request::new(ModelConfigRequest {
name: model_name.into(),
version: "".into(),
});

let response = client
.model_config(request)
.await
.unwrap()
.into_inner()
.config;
let Some(config) = response else {
panic!("Expected Some(config), got None");
};
assert_eq!(
config, expected_model_config,
"Expected same model config to be returned",
);
}

#[rstest]
#[tokio::test]
async fn test_tensor_infer(
Expand Down Expand Up @@ -1305,6 +1400,7 @@ pub mod kserve_test {
shape: vec![-1],
parameters: Default::default(),
}],
triton_model_config: None,
}),
..Default::default()
};
Expand Down
40 changes: 34 additions & 6 deletions tests/frontend/grpc/echo_tensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# Usage: `TEST_END_TO_END=1 python test_tensor.py` to run this worker as tensor based echo worker.


# Knowing the test will be run in environment that has tritonclient installed,
# which contain the generated file equivalent to model_config.proto.
import tritonclient.grpc.model_config_pb2 as mc
import uvloop

from dynamo.llm import ModelInput, ModelRuntimeConfig, ModelType, register_llm
Expand All @@ -17,17 +20,39 @@ async def echo_tensor_worker(runtime: DistributedRuntime):

endpoint = component.endpoint("generate")

triton_model_config = mc.ModelConfig()
triton_model_config.name = "echo"
triton_model_config.platform = "custom"
input_tensor = triton_model_config.input.add()
input_tensor.name = "input"
input_tensor.data_type = mc.TYPE_STRING
input_tensor.dims.extend([-1])
optional_input_tensor = triton_model_config.input.add()
optional_input_tensor.name = "optional_input"
optional_input_tensor.data_type = mc.TYPE_INT32
optional_input_tensor.dims.extend([-1])
optional_input_tensor.optional = True
output_tensor = triton_model_config.output.add()
output_tensor.name = "dummy_output"
output_tensor.data_type = mc.TYPE_STRING
output_tensor.dims.extend([-1])
triton_model_config.model_transaction_policy.decoupled = True

model_config = {
"name": "echo",
"inputs": [
{"name": "dummy_input", "data_type": "Bytes", "shape": [-1]},
],
"outputs": [{"name": "dummy_output", "data_type": "Bytes", "shape": [-1]}],
"name": "",
"inputs": [],
"outputs": [],
"triton_model_config": triton_model_config.SerializeToString(),
}
runtime_config = ModelRuntimeConfig()
runtime_config.set_tensor_model_config(model_config)

assert model_config == runtime_config.get_tensor_model_config()
# Internally the bytes string will be converted to List of int
retrieved_model_config = runtime_config.get_tensor_model_config()
retrieved_model_config["triton_model_config"] = bytes(
retrieved_model_config["triton_model_config"]
)
assert model_config == retrieved_model_config

# [gluo FIXME] register_llm will attempt to load a LLM model,
# which is not well-defined for Tensor yet. Currently provide
Expand All @@ -46,6 +71,9 @@ async def echo_tensor_worker(runtime: DistributedRuntime):

async def generate(request, context):
"""Echo tensors and parameters back to the client."""
# [NOTE] gluo: currently there is no frontend side
# validation between model config and actual request,
# so any request will reach here and be echoed back.
print(f"Echoing request: {request}")

params = {}
Expand Down
1 change: 1 addition & 0 deletions tests/frontend/grpc/test_tensor_mocker_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,4 @@ def start_services(request, runtime_services):
@pytest.mark.model(TEST_MODEL)
def test_echo() -> None:
triton_echo_client.run_infer()
triton_echo_client.get_config()
14 changes: 14 additions & 0 deletions tests/frontend/grpc/triton_echo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,17 @@ def run_infer():

assert np.array_equal(input0_data, output0_data)
assert np.array_equal(input1_data, output1_data)


def get_config():
server_url = "localhost:8000"
try:
triton_client = grpcclient.InferenceServerClient(url=server_url)
except Exception as e:
print("channel creation failed: " + str(e))
sys.exit()

model_name = "echo"
response = triton_client.get_model_config(model_name=model_name)
# Check one of the field that can only be set by providing Triton model config
assert response.config.model_transaction_policy.decoupled
Loading