-
Notifications
You must be signed in to change notification settings - Fork 655
feat: allow Triton model config specification in TensorModelConfig #3874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Guan Luo <[email protected]>
Signed-off-by: Guan Luo <[email protected]>
Signed-off-by: Guan Luo <[email protected]>
Signed-off-by: Guan Luo <[email protected]>
Signed-off-by: GuanLuo <[email protected]>
WalkthroughThis pull request adds Triton model configuration support to tensor-based models. The build script is updated to add serde derives to generated protobuf types. TensorModelConfig gains an optional triton_model_config field to embed serialized Triton configurations. The KServe service decodes Triton configs when present and returns them in model metadata and config responses. Frontend and test infrastructure are updated to generate, register, and verify Triton model configurations end-to-end. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes The changes span multiple layers with consistent Triton config decoding patterns, but require careful review of the service implementation's conditional logic, protobuf message handling, and integration across build, protocol, and test layers. Poem
Pre-merge checks❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
lib/llm/src/protocols/tensor.rs (1)
165-172: Shape validation rejects -1 wildcard; fix boundary check.Current check errors on any negative dim, contradicting the message “only -1 is allowed”. This blocks valid shapes.
Apply:
- for &d in &self.metadata.shape { - if d < 0 { - let mut e = ValidationError::new("negative_dim"); - e.message = Some("only -1 is allowed as a wildcard dimension".into()); - errs.add("shape", e); - break; - } - product = product.saturating_mul(d as usize); - } + for &d in &self.metadata.shape { + if d < -1 { + let mut e = ValidationError::new("negative_dim"); + e.message = Some("only -1 is allowed as a wildcard dimension".into()); + errs.add("shape", e); + break; + } + if d != -1 { + product = product.saturating_mul(d as usize); + } + }Also consider skipping the element-count check when any dim is -1, or computing a compatible count instead.
🧹 Nitpick comments (7)
tests/frontend/grpc/test_tensor_mocker_engine.py (1)
123-123: Post-infer config check is good; avoid hard exits in helper.Calling get_config() here is sensible. Ensure the helper raises (AssertionError/RuntimeError) instead of sys.exit on connection issues, and consider a brief retry to deflake CI.
lib/llm/build.rs (1)
47-49: Derive serde for enums too for consistency.type_attribute applies to messages only. If you ever serialize enums (e.g., DataType), add enum_attribute as well.
- tonic_build::configure() - .type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]") - .compile_protos(&["kserve.proto"], &["src/grpc/protos"])?; + tonic_build::configure() + .type_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]") + .enum_attribute(".", "#[derive(serde::Serialize,serde::Deserialize)]") + .compile_protos(&["kserve.proto"], &["src/grpc/protos"])?;lib/llm/Cargo.toml (1)
119-119: Consolidate base64 to workspace.dependencies.The change is good, but base64 is duplicated across workspace members: it appears at both
lib/async-openai/Cargo.toml:39andlib/llm/Cargo.toml:119with identical versions. Promote base64 to[workspace.dependencies]in the rootCargo.tomland remove the duplicate declarations from both crates to maintain consistency.lib/llm/src/protocols/tensor.rs (1)
134-142: Prefer derive(Default) over a manual impl.No custom defaults here; derive reduces boilerplate and keeps defaults in one place.
Within this block, remove the manual Default:
-impl Default for TensorModelConfig { - fn default() -> Self { - Self { - name: "".to_string(), - inputs: vec![], - outputs: vec![], - triton_model_config: None, - } - } -}And update the struct derive (outside this range) to include Default:
-#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)] +#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq, Default)] pub struct TensorModelConfig {lib/llm/src/grpc/service/kserve.rs (2)
530-536: Mirror the same fixes in model_config.Keep decode zero-copy and use from_i32 consistently (here, the decoded config is returned directly).
- if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { - let bytes = general_purpose::STANDARD.decode(triton_model_config.clone()).map_err(|e| Status::invalid_argument(format!("Failed to decode base64 model config: {}", e)))?; - let model_config = ModelConfig::decode(&*bytes).map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e)))?; + if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { + let bytes = general_purpose::STANDARD + .decode(triton_model_config.as_bytes()) + .map_err(|e| Status::invalid_argument(format!("Failed to decode base64 model config: {}", e)))?; + let model_config = ModelConfig::decode(bytes.as_slice()) + .map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e)))?; return Ok(Response::new(ModelConfigResponse { config: Some(model_config), })); }
424-448: DRY: extract a helper to decode Triton config once.Both branches duplicate base64+prost decode. Extract a small helper:
fn decode_triton_config_b64(b64: &str) -> Result<inference::ModelConfig, Status> { let bytes = general_purpose::STANDARD .decode(b64.as_bytes()) .map_err(|e| Status::invalid_argument(format!("Failed to decode base64 model config: {}", e)))?; inference::ModelConfig::decode(bytes.as_slice()) .map_err(|e| Status::invalid_argument(format!("Failed to deserialize model config: {}", e))) }Then:
let model_config = decode_triton_config_b64(triton_model_config)?;Also applies to: 530-536
lib/llm/tests/kserve_service.rs (1)
1209-1301: Great positive-path test; add error-paths and metadata coverage.
- Add a test with an invalid base64 string to assert Code::InvalidArgument with a clear message.
- Add a test with valid base64 but invalid protobuf bytes to assert Code::InvalidArgument on deserialize.
- Add a companion test for model_metadata using the same Triton payload (verifying datatype strings and shapes).
I can sketch these tests if you want them in this PR.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (2)
Cargo.lockis excluded by!**/*.locklib/bindings/python/Cargo.lockis excluded by!**/*.lock
📒 Files selected for processing (8)
lib/llm/Cargo.toml(1 hunks)lib/llm/build.rs(1 hunks)lib/llm/src/grpc/service/kserve.rs(3 hunks)lib/llm/src/protocols/tensor.rs(1 hunks)lib/llm/tests/kserve_service.rs(5 hunks)tests/frontend/grpc/echo_tensor_worker.py(3 hunks)tests/frontend/grpc/test_tensor_mocker_engine.py(1 hunks)tests/frontend/grpc/triton_echo_client.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
lib/llm/src/protocols/tensor.rs (1)
lib/llm/src/local_model/runtime_config.rs (1)
default(46-57)
lib/llm/tests/kserve_service.rs (4)
lib/llm/src/model_card.rs (2)
name(271-273)with_name_only(238-244)lib/llm/src/protocols/tensor.rs (1)
default(135-142)lib/llm/src/discovery/model_manager.rs (3)
default(64-66)default(478-484)new(70-81)lib/llm/src/local_model/runtime_config.rs (2)
default(46-57)new(61-63)
tests/frontend/grpc/echo_tensor_worker.py (1)
lib/llm/src/grpc/service/kserve.rs (1)
model_config(517-616)
tests/frontend/grpc/test_tensor_mocker_engine.py (1)
tests/frontend/grpc/triton_echo_client.py (1)
get_config(48-59)
🪛 Ruff (0.14.1)
tests/frontend/grpc/triton_echo_client.py
52-52: Do not catch blind exception: Exception
(BLE001)
🔇 Additional comments (2)
lib/llm/tests/kserve_service.rs (2)
44-46: Imports OK.prost::Message and base64 are used in the new Triton test; no issues.
364-365: Port allocation looks safe.New TritonModelConfig = 8997 doesn’t collide with existing cases.
Signed-off-by: Guan Luo <[email protected]>
Signed-off-by: Guan Luo <[email protected]>
|
@CodeRabbit full review |
✅ Actions performedFull review triggered. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/frontend/grpc/triton_echo_client.py (1)
48-59: Tighten exception handling, close the client, and return the response.Avoid sys.exit in tests; raise instead so pytest reports failures, and ensure the client is closed.
Apply:
+from tritonclient.utils import InferenceServerException @@ -def get_config(): +def get_config(): server_url = "localhost:8000" try: - triton_client = grpcclient.InferenceServerClient(url=server_url) - except Exception as e: - print("channel creation failed: " + str(e)) - sys.exit() - - model_name = "echo" - response = triton_client.get_model_config(model_name=model_name) - # Check one of the field that can only be set by providing Triton model config - assert response.config.model_transaction_policy.decoupled + triton_client = grpcclient.InferenceServerClient(url=server_url) + except InferenceServerException as e: + raise RuntimeError(f"channel creation failed: {e}") from e + try: + model_name = "echo" + response = triton_client.get_model_config(model_name=model_name) + finally: + try: + triton_client.close() + except Exception: + pass + assert response.config.model_transaction_policy.decoupled is True + return responseNote: You may similarly update run_infer() later for consistency.
🧹 Nitpick comments (7)
lib/llm/src/protocols/tensor.rs (2)
127-131: Doc mismatch: field is bytes, not string; optionally clarify encoding.Update comment to reflect Vec payload and what it encodes.
Apply:
- // Optional Triton model config in serialized protobuf string, - // if provided, it supersedes the basic model config defined above. + // Optional Triton model config as raw prost-serialized bytes of `inference::ModelConfig`; + // if provided, it supersedes the basic model config defined above.Optional (if you want JSON as base64 instead of list-of-ints):
- #[serde(default, skip_serializing_if = "Option::is_none")] - pub triton_model_config: Option<Vec<u8>>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + with = "serde_bytes::option" + )] + pub triton_model_config: Option<Vec<u8>>,This keeps internal type as bytes but renders JSON as a base64 string. Based on learnings.
133-141: Derive Default instead of manual impl.No custom init logic; deriving keeps it concise.
Apply:
-#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)] +#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq, Default)] pub struct TensorModelConfig { @@ -impl Default for TensorModelConfig { - fn default() -> Self { - Self { - name: "".to_string(), - inputs: vec![], - outputs: vec![], - triton_model_config: None, - } - } -} +// Manual Default no longer needed; using #[derive(Default)] above.tests/frontend/grpc/test_tensor_mocker_engine.py (1)
122-123: Use the returned config for explicit assertions (after get_config returns it).Once get_config() returns the response (see suggested change in triton_echo_client.py), capture and assert here for clearer failures.
Example:
- triton_echo_client.get_config() + cfg = triton_echo_client.get_config() + assert cfg.config.model_transaction_policy.decoupled is Truetests/frontend/grpc/echo_tensor_worker.py (2)
41-46: Provide a non-empty wrapper model name to keep fallback paths sane.If Triton bytes are absent in future runs, an empty name can surface confusing metadata.
Apply:
- model_config = { - "name": "", + model_config = { + "name": "echo", "inputs": [], "outputs": [], "triton_model_config": triton_model_config.SerializeToString(), }
50-55: Consider base64 JSON encoding to avoid manual bytes<->list[int] conversion.You currently coerce the stored list[int] back to bytes for equality. If you switch the Rust field to serialize via serde_bytes (base64 string), this extra conversion goes away.
No change required now; see suggested serde_bytes attribute on TensorModelConfig.triton_model_config.
lib/llm/src/grpc/service/kserve.rs (2)
423-466: Deduplicate Triton decode and add mismatch warning.Same decode logic appears here and in model_config(); extract a small helper and warn if the requested model name differs from the decoded config.name.
Apply:
@@ - if let Some(triton_model_config) = - tensor_model_config.triton_model_config.as_ref() - { - let model_config = ModelConfig::decode(triton_model_config.as_slice()) - .map_err(|e| { - Status::invalid_argument(format!( - "Failed to deserialize model config: {}", - e - )) - })?; + if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { + let model_config = decode_triton_cfg(triton_model_config)?; + if model_config.name != *request_model_name { + tracing::warn!( + "Requested model '{}' but Triton config name is '{}'", + request_model_name, model_config.name + ); + } return Ok(Response::new(ModelMetadataResponse {Add once (outside the impl block):
fn decode_triton_cfg(bytes: &[u8]) -> Result<ModelConfig, Status> { ModelConfig::decode(bytes).map_err(|e| { Status::invalid_argument(format!("Failed to deserialize model config: {}", e)) }) }
548-561: Use the shared decode helper here too.Avoid duplicating the Prost decode and error mapping.
Apply:
- if let Some(triton_model_config) = - tensor_model_config.triton_model_config.as_ref() - { - let model_config = ModelConfig::decode(triton_model_config.as_slice()) - .map_err(|e| { - Status::invalid_argument(format!( - "Failed to deserialize model config: {}", - e - )) - })?; + if let Some(triton_model_config) = tensor_model_config.triton_model_config.as_ref() { + let model_config = decode_triton_cfg(triton_model_config)?; return Ok(Response::new(ModelConfigResponse { config: Some(model_config), })); }
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
lib/llm/build.rs(1 hunks)lib/llm/src/grpc/service/kserve.rs(3 hunks)lib/llm/src/protocols/tensor.rs(1 hunks)lib/llm/tests/kserve_service.rs(5 hunks)tests/frontend/grpc/echo_tensor_worker.py(3 hunks)tests/frontend/grpc/test_tensor_mocker_engine.py(1 hunks)tests/frontend/grpc/triton_echo_client.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.14.1)
tests/frontend/grpc/triton_echo_client.py
52-52: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
- GitHub Check: Build and Test - dynamo
- GitHub Check: tests (launch/dynamo-run)
- GitHub Check: tests (.)
- GitHub Check: tests (lib/bindings/python)
- GitHub Check: sglang
- GitHub Check: vllm (amd64)
- GitHub Check: operator (amd64)
- GitHub Check: trtllm (arm64)
- GitHub Check: trtllm (amd64)
- GitHub Check: vllm (arm64)
🔇 Additional comments (1)
lib/llm/tests/kserve_service.rs (1)
1208-1299: Good E2E for Triton payload round‑trip.Encoding expected ModelConfig and asserting exact equality against server response is strong coverage. Nice.
Optional follow‑up: add a companion metadata test to validate dtype/name/shape mapping derived from the same Triton payload.
Signed-off-by: GuanLuo <[email protected]>
Overview:
Details:
Where should the reviewer start?
Related Issues: (use one of the action keywords Closes / Fixes / Resolves / Relates to)
Summary by CodeRabbit
New Features
Tests