diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 04955398aca3..e9d21cdd5166 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -1925,6 +1925,11 @@ def _build_prompt_from_request( prompt_token_ids=request["token_ids"], multi_modal_data=multi_modal_data, ) + nvext_args = extra_args.get("nvext") if isinstance(extra_args, dict) else None + if isinstance(nvext_args, dict): + cache_salt = nvext_args.get("cache_salt") + if cache_salt is not None: + prompt_kwargs["cache_salt"] = cache_salt if mm_uuids is not None: prompt_kwargs["multi_modal_uuids"] = mm_uuids if mm_processor_kwargs is not None: diff --git a/docs/components/frontend/nvext.md b/docs/components/frontend/nvext.md index 65144aad632c..800bbb528d5c 100644 --- a/docs/components/frontend/nvext.md +++ b/docs/components/frontend/nvext.md @@ -33,20 +33,23 @@ Include `nvext` as a top-level field alongside standard OpenAI-compatible fields | `use_raw_prompt` | `bool` | `None` | Preprocessor | Bypasses the prompt template and passes the prompt directly to the tokenizer. | | `annotations` | `string[]` | `None` | Preprocessor | Triggers out-of-band information in the SSE stream via the `event:` field. | | `backend_instance_id` | `u64` | `None` | Router | Routes the request to a specific backend instance. | -| `token_data` | `u32[]` | `None` | Preprocessor | Pre-tokenized prompt tokens. When provided with `backend_instance_id`, tokenization is skipped. | +| `token_data` | `u32[]` | `None` | Preprocessor | Pre-tokenized prompt tokens. When provided, tokenization is skipped. `backend_instance_id` remains an independent routing hint. | | `max_thinking_tokens` | `u32` | `None` | Backend | Maximum thinking tokens allowed (passed through to backends). | -| `extra_fields` | `string[]` | `None` | Response builder | Fields to include in the response `nvext`. Supported: `"worker_id"`, `"timing"`, `"routed_experts"`, `"engine_data"`, `"stop_reason"`. | +| `cache_salt` | `string` | `None` | Backend | Prefix-cache isolation hint for token-in clients. The top-level `cache_salt` request field is also accepted for renderer compatibility. | +| `extra_fields` | `string[]` | `None` | Response builder | Fields to include in the response `nvext`. Supported: `"worker_id"`, `"timing"`, `"routed_experts"`, `"engine_data"`, `"stop_reason"`, `"completion_token_ids"`. | | `prefill_worker_id` | `u64` | `None` | Router | Routes the request to a specific prefill worker (disaggregated serving). | | `decode_worker_id` | `u64` | `None` | Router | Routes the request to a specific decode worker (disaggregated serving). | | `agent_context` | object | `None` | Preprocessor | Passive session and trajectory identity for agent traces. See [Agent Context](#agent-context) below and [Agent Tracing](../../agents/agent-tracing.md). | | `agent_hints` | object | `None` | Router | Per-request hints for scheduling and load balancing. See [Agent Hints](#agent-hints). | | `session_control` | object | `None` | Router | Session lifecycle and sticky routing for subagent KV isolation. See [Session Control](#session-control). | -Related root-level Dynamo output option: +Related root-level Dynamo compatibility fields: | Field | Type | Default | Consumed By | Description | |-------|------|---------|-------------|-------------| | `return_tokens_as_token_ids` | `bool` | `false` | Response builder | Formats logprob token strings as `token_id:` instead of decoded text. | +| `cache_salt` | `string` | `None` | Backend | Compatibility alias for `nvext.cache_salt`; `nvext.cache_salt` takes precedence when both are present. | +| `stop_token_ids` | `u32[]` | `None` | Preprocessor | Compatibility alias for integer token stop IDs, equivalent to passing token IDs in the normal `stop` array. | `return_tokens_as_token_ids` only changes returned logprob token display. To stop on token IDs, pass integer IDs in the normal `stop` array, for example @@ -206,8 +209,9 @@ When the client requests response metadata via `extra_fields`, the response incl | `worker_id` | `extra_fields: ["worker_id"]` | Prefill/decode worker IDs and data parallel ranks that processed the request. | | `timing` | `extra_fields: ["timing"]` | Per-request timing information (TTFT, ITL, queue time, etc.). | | `routed_experts` | `extra_fields: ["routed_experts"]` | Routed expert capture payload returned by SGLang-backed requests. | -| `engine_data` | `extra_fields: ["engine_data"]` | Opaque backend-provided engine metadata. | +| `engine_data` | `extra_fields: ["engine_data"]` | Opaque backend-provided engine metadata. For chat token-in requests, Dynamo also includes generated `completion_token_ids` and, when available, `completion_logprobs` under this object for compatibility with rl-sdk token clients. | | `stop_reason` | `extra_fields: ["stop_reason"]` | Backend-specific matched stop condition, returned under `nvext` because it is not part of the OpenAI completions schema. Dynamo currently serves this as a response-level field for single-choice requests; supporting `n > 1` will require an indexed per-choice shape. | +| `completion_token_ids` | `extra_fields: ["completion_token_ids"]` | Generated token IDs accumulated across the chat-completions response and emitted on the final chunk. Supported only for single-choice requests (`n <= 1`). | | `token_ids` | Automatic (GAIE Stage 1) | Tokenized prompt for reuse in Stage 2 query-only mode. | ### Example response `nvext` diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index ea850db67c3b..5c1b6965edee 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -556,12 +556,36 @@ impl OpenAIPreprocessor { })); } + if let Some(extra_args) = Self::nvext_passthrough_extra_args(request) { + builder.extra_args(Some(extra_args)); + } + // Forward mm_processor_kwargs (e.g. use_audio_in_video) to the backend. builder.mm_processor_kwargs(request.mm_processor_kwargs().cloned()); Ok(builder) } + fn nvext_passthrough_extra_args(request: &R) -> Option { + let mut nvext_args = serde_json::Map::new(); + + if let Some(fields) = request.nvext_extra_fields() + && !fields.is_empty() + { + nvext_args.insert("extra_fields".to_string(), serde_json::json!(fields)); + } + + if let Some(cache_salt) = request.cache_salt() { + nvext_args.insert("cache_salt".to_string(), serde_json::json!(cache_salt)); + } + + if nvext_args.is_empty() { + None + } else { + Some(serde_json::json!({ "nvext": serde_json::Value::Object(nvext_args) })) + } + } + pub fn apply_template< R: OAIChatLikeRequest + AnnotationsProvider @@ -623,7 +647,7 @@ impl OpenAIPreprocessor { } } - pub async fn gather_multi_modal_data( + pub async fn gather_multi_modal_data( &self, request: &R, builder: &mut PreprocessedRequestBuilder, @@ -847,6 +871,11 @@ impl OpenAIPreprocessor { let mut extra_args = serde_json::json!({ "messages": messages_json }); + if let Some(nvext_passthrough) = Self::nvext_passthrough_extra_args(request) + && let Some(nvext) = nvext_passthrough.get("nvext") + { + extra_args["nvext"] = nvext.clone(); + } // Strip redundant inline data: URLs only when frontend decoding is active // (media_loader decoded the images into RDMA descriptors). TRT-LLM and diff --git a/lib/llm/src/protocols/openai/chat_completions.rs b/lib/llm/src/protocols/openai/chat_completions.rs index 291b837d38df..981554762ad4 100644 --- a/lib/llm/src/protocols/openai/chat_completions.rs +++ b/lib/llm/src/protocols/openai/chat_completions.rs @@ -101,6 +101,17 @@ impl NvExtProvider for NvCreateChatCompletionRequest { fn raw_prompt(&self) -> Option { None } + + fn cache_salt(&self) -> Option<&str> { + self.nvext + .as_ref() + .and_then(|nvext| nvext.cache_salt.as_deref()) + .or_else(|| { + self.unsupported_fields + .get("cache_salt") + .and_then(|value| value.as_str()) + }) + } } /// Implements `AnnotationsProvider` for `NvCreateChatCompletionRequest`, @@ -288,7 +299,14 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { } fn get_stop_token_ids(&self) -> Option> { - self.inner.stop.as_ref().and_then(|stop| stop.token_ids()) + if let Some(ids) = self.inner.stop.as_ref().and_then(|stop| stop.token_ids()) { + return Some(ids); + } + self.unsupported_fields + .get("stop_token_ids") + .and_then(|value| { + serde_json::from_value::>(value.clone()).ok() + }) } /// Returns a reference to the optional `NvExt` extension, if available. @@ -353,6 +371,15 @@ impl ValidateRequest for NvCreateChatCompletionRequest { // validate::validate_max_tokens(self.inner.max_tokens)?; // warning depricated field validate::validate_max_completion_tokens(self.inner.max_completion_tokens)?; validate::validate_n(self.inner.n)?; + if self.inner.n.unwrap_or(1) > 1 + && self + .nvext + .as_ref() + .and_then(|nvext| nvext.extra_fields.as_ref()) + .is_some_and(|fields| fields.iter().any(|field| field == "completion_token_ids")) + { + anyhow::bail!("`nvext.extra_fields=[\"completion_token_ids\"]` requires `n <= 1`"); + } // none for modalities // none for prediction // none for audio @@ -504,14 +531,81 @@ mod tests { serde_json::from_value(scalar_token_id_stop); assert!(result.is_err()); - let unsupported_stop_token_ids = json!({ + let passthrough_stop_token_ids = json!({ "model": "test-model", "messages": [{"role": "user", "content": "Hello"}], "stop_token_ids": [576] }); let request: NvCreateChatCompletionRequest = - serde_json::from_value(unsupported_stop_token_ids) + serde_json::from_value(passthrough_stop_token_ids) .expect("Failed to deserialize request"); - assert!(ValidateRequest::validate(&request).is_err()); + ValidateRequest::validate(&request).expect("stop_token_ids should be accepted"); + assert_eq!(request.get_stop_token_ids(), Some(vec![576])); + + let stop_conditions = request + .extract_stop_conditions() + .expect("extract stop conditions"); + assert_eq!(stop_conditions.stop, None); + assert_eq!(stop_conditions.stop_token_ids, Some(vec![576])); + } + + #[test] + fn test_cache_salt_accepts_renderer_top_level_shape() { + let request_json = json!({ + "model": "test-model", + "messages": [{"role": "user", "content": "(token-in mode)"}], + "nvext": { + "token_data": [1, 2, 3], + "extra_fields": ["completion_token_ids"] + }, + "cache_salt": "ckpt-42" + }); + let request: NvCreateChatCompletionRequest = + serde_json::from_value(request_json).expect("Failed to deserialize request"); + + ValidateRequest::validate(&request).expect("cache_salt should be accepted"); + assert_eq!( + ::cache_salt(&request), + Some("ckpt-42") + ); + } + + #[test] + fn test_nvext_cache_salt_takes_precedence() { + let request_json = json!({ + "model": "test-model", + "messages": [{"role": "user", "content": "(token-in mode)"}], + "cache_salt": "top-level", + "nvext": { + "cache_salt": "nvext-level" + } + }); + let request: NvCreateChatCompletionRequest = + serde_json::from_value(request_json).expect("Failed to deserialize request"); + + assert_eq!( + ::cache_salt(&request), + Some("nvext-level") + ); + } + + #[test] + fn test_completion_token_ids_rejects_multiple_choices() { + let request_json = json!({ + "model": "test-model", + "messages": [{"role": "user", "content": "(token-in mode)"}], + "n": 2, + "nvext": { + "extra_fields": ["completion_token_ids"] + } + }); + let request: NvCreateChatCompletionRequest = + serde_json::from_value(request_json).expect("Failed to deserialize request"); + + let err = ValidateRequest::validate(&request).expect_err("n > 1 should be rejected"); + assert!( + err.to_string().contains("completion_token_ids"), + "unexpected error: {err}" + ); } } diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index 5d3cfc75cb50..10960719f6cb 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -10,7 +10,7 @@ use crate::{ openai::{ convert_backend_top_logprobs, delta_common::{self, DeltaGeneratorOptions}, - nvext::NvExtProvider, + nvext::{NvExtProvider, NvExtResponse}, token_to_utf8_bytes, }, }, @@ -59,6 +59,12 @@ pub struct DeltaGenerator { options: DeltaGeneratorOptions, /// Request tracker for per-request metrics (shared with PreprocessedRequest). tracker: Arc, + /// Accumulated output token IDs across chunks, emitted on the final chunk + /// when `nvext.extra_fields` includes `completion_token_ids` or `engine_data`. + accumulated_completion_token_ids: Vec, + /// Accumulated per-token logprobs across chunks, emitted under + /// `nvext.engine_data.completion_logprobs` on the final chunk. + accumulated_completion_logprobs: Vec, } impl DeltaGenerator { @@ -75,6 +81,8 @@ impl DeltaGenerator { msg_counter: 0, options, tracker, + accumulated_completion_token_ids: Vec::new(), + accumulated_completion_logprobs: Vec::new(), } } @@ -257,6 +265,20 @@ impl crate::protocols::openai::DeltaGeneratorExt map, + Some(value) => { + let mut map = serde_json::Map::new(); + map.insert("backend".to_string(), value); + map + } + None => serde_json::Map::new(), + }; + engine_data.insert( + "completion_token_ids".to_string(), + serde_json::json!(self.accumulated_completion_token_ids.clone()), ); + if !self.accumulated_completion_logprobs.is_empty() { + engine_data.insert( + "completion_logprobs".to_string(), + serde_json::json!(self.accumulated_completion_logprobs.clone()), + ); + } + nvext_response.engine_data = Some(serde_json::Value::Object(engine_data)); + } + + if let Ok(nvext_json) = serde_json::to_value(&nvext_response) { + stream_response.nvext = Some(nvext_json); + if let Some(ref info) = nvext_response.worker_id { + tracing::debug!( + "Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}", + info.prefill_worker_id, + info.decode_worker_id + ); + } + if let Some(ref tokens) = nvext_response.token_ids { + tracing::debug!( + "Injected token_ids into chat completion nvext: {} tokens", + tokens.len() + ); + } + if let Some(ref tokens) = nvext_response.completion_token_ids { + tracing::debug!( + "Injected completion_token_ids into chat completion nvext: {} tokens", + tokens.len() + ); + } } } @@ -615,6 +693,48 @@ mod tests { assert!(nvext_json.get("routed_experts").is_none()); } + #[test] + fn test_completion_token_ids_extra_field_emits_accumulated_ids_on_final_chunk() { + let request = + create_test_request_with_extra_fields(vec!["completion_token_ids".to_string()]); + let mut generator = request.response_generator("req-completion-ids".to_string()); + + let mut first_output = final_backend_output(); + first_output.token_ids = vec![7]; + first_output.tokens = vec![Some("A".to_string())]; + first_output.text = Some("A".to_string()); + first_output.finish_reason = None; + first_output.disaggregated_params = None; + + let first_response = generator + .choice_from_postprocessor(first_output) + .expect("first choice generation"); + assert!( + first_response.nvext.is_none(), + "completion_token_ids should be emitted only on the final chunk" + ); + + let mut final_output = final_backend_output(); + final_output.token_ids = vec![8, 9]; + final_output.tokens = vec![Some("B".to_string()), Some("C".to_string())]; + final_output.text = Some("BC".to_string()); + final_output.disaggregated_params = None; + + let final_response = generator + .choice_from_postprocessor(final_output) + .expect("final choice generation"); + + let nvext_json = final_response + .nvext + .expect("nvext present for completion_token_ids request"); + assert_eq!( + nvext_json.get("completion_token_ids"), + Some(&serde_json::json!([7, 8, 9])) + ); + assert!(nvext_json.get("token_ids").is_none()); + assert!(nvext_json.get("routed_experts").is_none()); + } + #[test] fn test_routed_experts_extra_field_emits_routed_experts() { use crate::protocols::openai::nvext::NvExt; @@ -657,6 +777,7 @@ mod tests { .expect("engine_data should be present"); assert_eq!(engine_data["kv_transfer_time_ms"], 12.3); assert_eq!(engine_data["prefill_compute_time_ms"], 45.6); + assert_eq!(engine_data["completion_token_ids"], serde_json::json!([42])); } #[test] @@ -721,12 +842,60 @@ mod tests { .choice_from_postprocessor(backend_output) .expect("should produce a response"); - // engine_data is None from backend, so nvext.engine_data should be absent - if let Some(nvext) = &response.nvext { - assert!( - nvext.get("engine_data").is_none() || nvext.get("engine_data").unwrap().is_null(), - "engine_data should not appear when backend provides None" - ); - } + let nvext = response + .nvext + .expect("nvext present for engine_data request with generated tokens"); + let engine_data = nvext + .get("engine_data") + .expect("engine_data should include generated token IDs"); + assert_eq!(engine_data["completion_token_ids"], serde_json::json!([42])); + assert!(engine_data.get("completion_logprobs").is_none()); + } + + #[test] + fn test_engine_data_accumulates_completion_token_ids_and_logprobs() { + let request = create_test_request_with_extra_fields(vec!["engine_data".to_string()]); + let mut generator = request.response_generator("req-engine-5".to_string()); + + let mut first_output = final_backend_output(); + first_output.token_ids = vec![7]; + first_output.tokens = vec![Some("A".to_string())]; + first_output.text = Some("A".to_string()); + first_output.log_probs = Some(vec![-0.1]); + first_output.finish_reason = None; + first_output.disaggregated_params = None; + + let first_response = generator + .choice_from_postprocessor(first_output) + .expect("first choice generation"); + assert!( + first_response.nvext.is_none(), + "engine_data token IDs should be emitted only on the final chunk" + ); + + let mut final_output = final_backend_output(); + final_output.token_ids = vec![8, 9]; + final_output.tokens = vec![Some("B".to_string()), Some("C".to_string())]; + final_output.text = Some("BC".to_string()); + final_output.log_probs = Some(vec![-0.2, -0.3]); + final_output.disaggregated_params = None; + + let final_response = generator + .choice_from_postprocessor(final_output) + .expect("final choice generation"); + let nvext = final_response + .nvext + .expect("nvext present for engine_data request"); + let engine_data = nvext + .get("engine_data") + .expect("engine_data should include generated token metadata"); + assert_eq!( + engine_data["completion_token_ids"], + serde_json::json!([7, 8, 9]) + ); + assert_eq!( + engine_data["completion_logprobs"], + serde_json::json!([-0.1, -0.2, -0.3]) + ); } } diff --git a/lib/llm/src/protocols/openai/completions.rs b/lib/llm/src/protocols/openai/completions.rs index e60890214bb9..aa69138113da 100644 --- a/lib/llm/src/protocols/openai/completions.rs +++ b/lib/llm/src/protocols/openai/completions.rs @@ -252,7 +252,14 @@ impl OpenAIStopConditionsProvider for NvCreateCompletionRequest { } fn get_stop_token_ids(&self) -> Option> { - self.inner.stop.as_ref().and_then(|stop| stop.token_ids()) + if let Some(ids) = self.inner.stop.as_ref().and_then(|stop| stop.token_ids()) { + return Some(ids); + } + self.unsupported_fields + .get("stop_token_ids") + .and_then(|value| { + serde_json::from_value::>(value.clone()).ok() + }) } fn nvext(&self) -> Option<&NvExt> { @@ -733,13 +740,20 @@ mod tests { assert_eq!(request.get_stop(), Some(vec!["token_id:576".to_string()])); assert_eq!(request.get_stop_token_ids(), None); - let unsupported_stop_token_ids = json!({ + let passthrough_stop_token_ids = json!({ "model": "test-model", "prompt": [1, 2, 3], "stop_token_ids": [576] }); - let request: NvCreateCompletionRequest = serde_json::from_value(unsupported_stop_token_ids) + let request: NvCreateCompletionRequest = serde_json::from_value(passthrough_stop_token_ids) .expect("Failed to deserialize request"); - assert!(ValidateRequest::validate(&request).is_err()); + ValidateRequest::validate(&request).expect("stop_token_ids should be accepted"); + assert_eq!(request.get_stop_token_ids(), Some(vec![576])); + + let stop_conditions = request + .extract_stop_conditions() + .expect("extract stop conditions"); + assert_eq!(stop_conditions.stop, None); + assert_eq!(stop_conditions.stop_token_ids, Some(vec![576])); } } diff --git a/lib/llm/src/protocols/openai/nvext.rs b/lib/llm/src/protocols/openai/nvext.rs index 7c561dd4a59d..aeabee15fa1c 100644 --- a/lib/llm/src/protocols/openai/nvext.rs +++ b/lib/llm/src/protocols/openai/nvext.rs @@ -74,6 +74,14 @@ pub fn apply_header_routing_overrides(nvext: Option, headers: &HeaderMap) pub trait NvExtProvider { fn nvext(&self) -> Option<&NvExt>; fn raw_prompt(&self) -> Option; + + fn nvext_extra_fields(&self) -> Option<&[String]> { + self.nvext().and_then(|nvext| nvext.extra_fields.as_deref()) + } + + fn cache_salt(&self) -> Option<&str> { + self.nvext().and_then(|nvext| nvext.cache_salt.as_deref()) + } } /// Worker ID information for disaggregated serving @@ -129,6 +137,11 @@ pub struct NvExtResponse { /// If `n > 1` is supported here, this needs an indexed/per-choice shape. #[serde(skip_serializing_if = "Option::is_none")] pub stop_reason: Option, + + /// Output token IDs generated by the engine. + /// Populated when the client requests `extra_fields: ["completion_token_ids"]`. + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_token_ids: Option>, } pub(crate) fn merge_response_nvext( @@ -165,6 +178,7 @@ pub struct NvExtResponseFieldSelection { pub routed_experts: bool, pub engine_data: bool, pub stop_reason: bool, + pub completion_token_ids: bool, } impl NvExtResponseFieldSelection { @@ -182,6 +196,7 @@ impl NvExtResponseFieldSelection { "routed_experts" => selection.routed_experts = true, "engine_data" => selection.engine_data = true, "stop_reason" => selection.stop_reason = true, + "completion_token_ids" => selection.completion_token_ids = true, _ => {} } } @@ -215,6 +230,8 @@ impl NvExtResponseFieldSelection { /// - `timing` requires the selection flag, `finish_reason_present == true`, **and** a tracker. /// - `engine_data` requires the selection flag **and** a non-`None` `engine_data_from_backend`. /// - `stop_reason` requires the selection flag **and** a non-`None` `stop_reason_from_backend`. + /// - `completion_token_ids` is accumulated by the chat-completions delta generator + /// and attached to the final chunk after this helper returns. pub fn build_response_nvext( &self, tracker: Option<&std::sync::Arc>, @@ -280,6 +297,7 @@ impl NvExtResponseFieldSelection { routed_experts, engine_data, stop_reason, + completion_token_ids: None, }) } } @@ -328,10 +346,18 @@ pub struct NvExt { #[builder(default, setter(strip_option))] pub max_thinking_tokens: Option, + /// KV prefix-cache isolation hint from token-in clients. + /// + /// A changed salt lets backends isolate prompt cache entries for identical + /// token sequences generated under different checkpoint or rollout state. + #[serde(default, skip_serializing_if = "Option::is_none")] + #[builder(default, setter(strip_option))] + pub cache_salt: Option, + /// Extra fields to be included in the response's nvext /// This is a list of field names that should be populated in the response /// Supported fields include "worker_id", "timing", "routed_experts", "engine_data", - /// "stop_reason", which map to fields in NvExtResponse. + /// "stop_reason", and "completion_token_ids", which map to fields in NvExtResponse. #[serde(default, skip_serializing_if = "Option::is_none")] #[builder(default, setter(strip_option))] pub extra_fields: Option>, @@ -507,6 +533,7 @@ mod tests { assert_eq!(nv_ext.backend_instance_id, None); assert_eq!(nv_ext.token_data, None); assert_eq!(nv_ext.max_thinking_tokens, None); + assert_eq!(nv_ext.cache_salt, None); assert_eq!(nv_ext.extra_fields, None); assert_eq!(nv_ext.prefill_worker_id, None); assert_eq!(nv_ext.decode_worker_id, None); @@ -525,6 +552,7 @@ mod tests { .backend_instance_id(42) .token_data(vec![1, 2, 3, 4]) .max_thinking_tokens(1024) + .cache_salt("ckpt-42".to_string()) .extra_fields(vec!["worker_id".to_string()]) .build() .unwrap(); @@ -534,6 +562,7 @@ mod tests { assert_eq!(nv_ext.backend_instance_id, Some(42)); assert_eq!(nv_ext.token_data, Some(vec![1, 2, 3, 4])); assert_eq!(nv_ext.max_thinking_tokens, Some(1024)); + assert_eq!(nv_ext.cache_salt, Some("ckpt-42".to_string())); assert_eq!(nv_ext.extra_fields, Some(vec!["worker_id".to_string()])); // Validate the built struct assert!(nv_ext.validate().is_ok()); @@ -764,6 +793,22 @@ mod tests { ); } + #[test] + fn test_nvext_response_field_selection_completion_token_ids_only() { + let nvext = NvExt::builder() + .extra_fields(vec!["completion_token_ids".to_string()]) + .build() + .unwrap(); + + assert_eq!( + NvExtResponseFieldSelection::from_nvext(Some(&nvext)), + NvExtResponseFieldSelection { + completion_token_ids: true, + ..Default::default() + } + ); + } + // Helpers for build_response_nvext tests ----------------------------- fn sel_all_false() -> NvExtResponseFieldSelection { @@ -966,6 +1011,7 @@ mod tests { routed_experts: true, engine_data: false, stop_reason: false, + completion_token_ids: false, }; let tracker = tracker_with_prefill_worker(); let params = disagg_params_full(); @@ -1003,6 +1049,7 @@ mod tests { routed_experts: true, engine_data: false, stop_reason: false, + completion_token_ids: false, } ); } diff --git a/lib/llm/src/protocols/openai/validate.rs b/lib/llm/src/protocols/openai/validate.rs index 559dd109ac1b..5d5dc9263929 100644 --- a/lib/llm/src/protocols/openai/validate.rs +++ b/lib/llm/src/protocols/openai/validate.rs @@ -97,15 +97,21 @@ pub const MAX_REPETITION_PENALTY: f32 = 2.0; // Shared Fields // +/// Root-level fields accepted for compatibility with token-in clients even +/// though Dynamo has not modeled them as first-class OpenAI fields. +pub const PASSTHROUGH_EXTRA_FIELDS: &[&str] = &["cache_salt", "stop_token_ids"]; + /// Validates that no unsupported fields are present in the request pub fn validate_no_unsupported_fields( unsupported_fields: &std::collections::HashMap, ) -> Result<(), anyhow::Error> { - if !unsupported_fields.is_empty() { - let fields: Vec<_> = unsupported_fields - .keys() - .map(|s| format!("`{}`", s)) - .collect(); + let fields: Vec<_> = unsupported_fields + .keys() + .filter(|field| !PASSTHROUGH_EXTRA_FIELDS.contains(&field.as_str())) + .map(|s| format!("`{}`", s)) + .collect(); + + if !fields.is_empty() { anyhow::bail!("Unsupported parameter(s): {}", fields.join(", ")); } Ok(())