From e83ea4c86cb4f1a46940648e17aa599598cf4227 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Tue, 9 Dec 2025 09:00:07 +0530 Subject: [PATCH 1/5] Add support for LLM token usage data streaming --- src/nvidia_rag/rag_server/main.py | 31 ++++++++++--- .../rag_server/response_generator.py | 46 +++++++++++++++++++ src/nvidia_rag/utils/llm.py | 25 ++++++++++ 3 files changed, 96 insertions(+), 6 deletions(-) diff --git a/src/nvidia_rag/rag_server/main.py b/src/nvidia_rag/rag_server/main.py index 206caf6d..b4e38b78 100644 --- a/src/nvidia_rag/rag_server/main.py +++ b/src/nvidia_rag/rag_server/main.py @@ -48,7 +48,7 @@ from langchain_core.output_parsers.string import StrOutputParser from langchain_core.prompts import MessagesPlaceholder from langchain_core.prompts.chat import ChatPromptTemplate -from langchain_core.runnables import RunnableAssign +from langchain_core.runnables import RunnableAssign, RunnableGenerator from opentelemetry import context as otel_context from requests import ConnectTimeout @@ -90,9 +90,11 @@ ) from nvidia_rag.utils.health_models import RAGHealthResponse from nvidia_rag.utils.llm import ( + USAGE_SENTINEL_PREFIX, get_llm, get_prompts, get_streaming_filter_think_parser_async, + stream_with_usage_sentinel, ) from nvidia_rag.utils.observability.otel_metrics import OtelMetrics from nvidia_rag.utils.reranker import get_ranking_model @@ -248,6 +250,9 @@ def __init__( self.prompts = get_prompts(prompts) self.vdb_top_k = int(self.config.retriever.vdb_top_k) self.StreamingFilterThinkParser = get_streaming_filter_think_parser_async() + # Runnable that injects a final sentinel chunk carrying usage metadata + # as a special string; used only for streaming chains. + self.UsageSentinelParser = RunnableGenerator(stream_with_usage_sentinel) if self._init_errors: logger.warning( @@ -1408,14 +1413,17 @@ async def _llm_chain( prompt_template = ChatPromptTemplate.from_messages(message) llm = get_llm(config=self.config, **llm_settings) - chain = ( + # Chain for streaming: add usage-sentinel parser between LLM and + # think-token filter so we can surface token usage in the final chunk. + stream_chain = ( prompt_template | llm + | self.UsageSentinelParser | self.StreamingFilterThinkParser | StrOutputParser() ) # Create async stream generator - stream_gen = chain.astream( + stream_gen = stream_chain.astream( {"question": query_text}, config={"run_name": "llm-stream"} ) # Eagerly fetch first chunk to trigger any errors before returning response @@ -2437,7 +2445,18 @@ def generate_filter_for_collection(collection_name): self._print_conversation_history(message) prompt = ChatPromptTemplate.from_messages(message) - chain = prompt | llm | self.StreamingFilterThinkParser | StrOutputParser() + # Base chain (no usage sentinel) used for non-streaming reflection path. + base_chain = prompt | llm | self.StreamingFilterThinkParser | StrOutputParser() + + # Streaming chain adds usage-sentinel parser between LLM and think-token + # filter so we can surface token usage in the final streamed chunk. + stream_chain = ( + prompt + | llm + | self.UsageSentinelParser + | self.StreamingFilterThinkParser + | StrOutputParser() + ) # Check response groundedness if we still have reflection # iterations available @@ -2496,8 +2515,8 @@ def generate_filter_for_collection(collection_name): status_code=ErrorCodeMapping.SUCCESS, ) else: - # Create async stream generator - stream_gen = chain.astream( + # Create async stream generator using the streaming chain + stream_gen = stream_chain.astream( {"question": query, "context": docs}, config={"run_name": "llm-stream"}, ) diff --git a/src/nvidia_rag/rag_server/response_generator.py b/src/nvidia_rag/rag_server/response_generator.py index 971f9a39..f1f9eb77 100644 --- a/src/nvidia_rag/rag_server/response_generator.py +++ b/src/nvidia_rag/rag_server/response_generator.py @@ -38,6 +38,7 @@ from pydantic import BaseModel, Field, validator from pymilvus.exceptions import MilvusException, MilvusUnavailableException +from nvidia_rag.utils.llm import USAGE_SENTINEL_PREFIX from nvidia_rag.utils.minio_operator import ( get_minio_operator, get_unique_thumbnail_id, @@ -443,7 +444,28 @@ def generate_answer( llm_ttft_ms: float | None = None rag_ttft_ms: float | None = None llm_generation_time_ms: float | None = None + usage: Usage | None = None for chunk in generator: + # Sentinel from LangChain runnable carrying token-usage JSON. + if isinstance(chunk, str) and chunk.startswith(USAGE_SENTINEL_PREFIX): + usage_json = chunk[len(USAGE_SENTINEL_PREFIX) :] + try: + usage_dict = json.loads(usage_json) or {} + prompt_tokens = int(usage_dict.get("input_tokens", 0)) + completion_tokens = int(usage_dict.get("output_tokens", 0)) + total_tokens = int( + usage_dict.get( + "total_tokens", prompt_tokens + completion_tokens + ) + ) + usage = Usage( + total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + except Exception as e: + logger.debug("Failed to parse usage sentinel: %s", e) + continue # TODO: This is a hack to clear contexts if we get an error # response from nemoguardrails if chunk == "I'm sorry, I can't respond to that.": @@ -519,6 +541,8 @@ def generate_answer( # Create response first, then attach metrics for clarity chain_response = ChainResponse() chain_response.metrics = final_metrics + if usage is not None: + chain_response.usage = usage # [DONE] indicate end of response from server response_choice = ChainResponseChoices( @@ -591,7 +615,27 @@ async def generate_answer_async( llm_ttft_ms: float | None = None rag_ttft_ms: float | None = None llm_generation_time_ms: float | None = None + usage: Usage | None = None async for chunk in generator: + if isinstance(chunk, str) and chunk.startswith(USAGE_SENTINEL_PREFIX): + usage_json = chunk[len(USAGE_SENTINEL_PREFIX) :] + try: + usage_dict = json.loads(usage_json) or {} + prompt_tokens = int(usage_dict.get("input_tokens", 0)) + completion_tokens = int(usage_dict.get("output_tokens", 0)) + total_tokens = int( + usage_dict.get( + "total_tokens", prompt_tokens + completion_tokens + ) + ) + usage = Usage( + total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + except Exception as e: + logger.debug("Failed to parse usage sentinel: %s", e) + continue # TODO: This is a hack to clear contexts if we get an error # response from nemoguardrails if chunk == "I'm sorry, I can't respond to that.": @@ -667,6 +711,8 @@ async def generate_answer_async( # Create response first, then attach metrics for clarity chain_response = ChainResponse() chain_response.metrics = final_metrics + if usage is not None: + chain_response.usage = usage # [DONE] indicate end of response from server response_choice = ChainResponseChoices( diff --git a/src/nvidia_rag/utils/llm.py b/src/nvidia_rag/utils/llm.py index cd2d3f08..29a69fdd 100644 --- a/src/nvidia_rag/utils/llm.py +++ b/src/nvidia_rag/utils/llm.py @@ -22,6 +22,7 @@ 6. get_streaming_filter_think_parser_async: Get the parser for filtering the think tokens (async). """ +import json import logging import os from collections.abc import Iterable @@ -32,6 +33,7 @@ import yaml from langchain.llms.base import LLM from langchain_core.language_models.chat_models import SimpleChatModel +from langchain_core.messages import AIMessageChunk from langchain_nvidia_ai_endpoints import ChatNVIDIA from nvidia_rag.rag_server.response_generator import APIError, ErrorCodeMapping @@ -603,3 +605,26 @@ def get_streaming_filter_think_parser_async(): logger.info("Think token filtering is disabled (async)") # If filtering is disabled, use a passthrough that passes content as-is return RunnablePassthrough() + + +USAGE_SENTINEL_PREFIX = "__RAG_USAGE_SENTINEL__:" + + +async def stream_with_usage_sentinel(chunks): + """ + Pass through model chunks and, at the end, emit a synthetic chunk whose + content encodes token-usage metadata. + """ + last_usage = None + + async for chunk in chunks: + if hasattr(chunk, "usage_metadata") and getattr(chunk, "usage_metadata", None): + last_usage = getattr(chunk, "usage_metadata", None) + yield chunk + + if last_usage is not None: + try: + payload = json.dumps(last_usage) + yield AIMessageChunk(content=f"{USAGE_SENTINEL_PREFIX}{payload}") + except Exception as e: + logger.debug("Failed to emit usage sentinel chunk: %s", e) From 8b9e3b44c4912648b650c66e1820c8cf15e72291 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Tue, 9 Dec 2025 20:24:51 +0530 Subject: [PATCH 2/5] Log the LLM token usage --- src/nvidia_rag/rag_server/response_generator.py | 16 ++++++++++++++++ src/nvidia_rag/utils/llm.py | 3 +++ 2 files changed, 19 insertions(+) diff --git a/src/nvidia_rag/rag_server/response_generator.py b/src/nvidia_rag/rag_server/response_generator.py index f1f9eb77..0433d58a 100644 --- a/src/nvidia_rag/rag_server/response_generator.py +++ b/src/nvidia_rag/rag_server/response_generator.py @@ -463,6 +463,14 @@ def generate_answer( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) + logger.info( + "LLM usage for model %s (sync): prompt_tokens=%d, completion_tokens=%d, total_tokens=%d, raw=%s", + model, + prompt_tokens, + completion_tokens, + total_tokens, + usage_dict, + ) except Exception as e: logger.debug("Failed to parse usage sentinel: %s", e) continue @@ -633,6 +641,14 @@ async def generate_answer_async( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) + logger.info( + "LLM usage for model %s (async): prompt_tokens=%d, completion_tokens=%d, total_tokens=%d, raw=%s", + model, + prompt_tokens, + completion_tokens, + total_tokens, + usage_dict, + ) except Exception as e: logger.debug("Failed to parse usage sentinel: %s", e) continue diff --git a/src/nvidia_rag/utils/llm.py b/src/nvidia_rag/utils/llm.py index 29a69fdd..14a4ddc4 100644 --- a/src/nvidia_rag/utils/llm.py +++ b/src/nvidia_rag/utils/llm.py @@ -618,13 +618,16 @@ async def stream_with_usage_sentinel(chunks): last_usage = None async for chunk in chunks: + logger.info("Chunk: %s", chunk.usage_metadata) if hasattr(chunk, "usage_metadata") and getattr(chunk, "usage_metadata", None): last_usage = getattr(chunk, "usage_metadata", None) + logger.info("Usage metadata: %s", last_usage) yield chunk if last_usage is not None: try: payload = json.dumps(last_usage) + logger.info("Usage sentinel chunk: %s", payload) yield AIMessageChunk(content=f"{USAGE_SENTINEL_PREFIX}{payload}") except Exception as e: logger.debug("Failed to emit usage sentinel chunk: %s", e) From d66d0de7b524241a7cb0294500bb76c6c1a70a14 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Mon, 5 Jan 2026 15:13:11 +0530 Subject: [PATCH 3/5] Update to LLM API tokens in ChatNVIDIA --- src/nvidia_rag/rag_server/main.py | 17 +--- .../rag_server/response_generator.py | 48 ++++++---- src/nvidia_rag/utils/llm.py | 96 +++++++++++-------- 3 files changed, 92 insertions(+), 69 deletions(-) diff --git a/src/nvidia_rag/rag_server/main.py b/src/nvidia_rag/rag_server/main.py index b4e38b78..80453386 100644 --- a/src/nvidia_rag/rag_server/main.py +++ b/src/nvidia_rag/rag_server/main.py @@ -90,11 +90,9 @@ ) from nvidia_rag.utils.health_models import RAGHealthResponse from nvidia_rag.utils.llm import ( - USAGE_SENTINEL_PREFIX, get_llm, get_prompts, get_streaming_filter_think_parser_async, - stream_with_usage_sentinel, ) from nvidia_rag.utils.observability.otel_metrics import OtelMetrics from nvidia_rag.utils.reranker import get_ranking_model @@ -250,9 +248,6 @@ def __init__( self.prompts = get_prompts(prompts) self.vdb_top_k = int(self.config.retriever.vdb_top_k) self.StreamingFilterThinkParser = get_streaming_filter_think_parser_async() - # Runnable that injects a final sentinel chunk carrying usage metadata - # as a special string; used only for streaming chains. - self.UsageSentinelParser = RunnableGenerator(stream_with_usage_sentinel) if self._init_errors: logger.warning( @@ -1413,14 +1408,12 @@ async def _llm_chain( prompt_template = ChatPromptTemplate.from_messages(message) llm = get_llm(config=self.config, **llm_settings) - # Chain for streaming: add usage-sentinel parser between LLM and - # think-token filter so we can surface token usage in the final chunk. + # Chain for streaming: we remove StrOutputParser so we yield AIMessageChunks, + # allowing us to access .usage_metadata in the response generator. stream_chain = ( prompt_template | llm - | self.UsageSentinelParser | self.StreamingFilterThinkParser - | StrOutputParser() ) # Create async stream generator stream_gen = stream_chain.astream( @@ -2446,16 +2439,14 @@ def generate_filter_for_collection(collection_name): prompt = ChatPromptTemplate.from_messages(message) # Base chain (no usage sentinel) used for non-streaming reflection path. + # We keep StrOutputParser here because we want the full string response for logic. base_chain = prompt | llm | self.StreamingFilterThinkParser | StrOutputParser() - # Streaming chain adds usage-sentinel parser between LLM and think-token - # filter so we can surface token usage in the final streamed chunk. + # Streaming chain: yields AIMessageChunks to preserve usage metadata. stream_chain = ( prompt | llm - | self.UsageSentinelParser | self.StreamingFilterThinkParser - | StrOutputParser() ) # Check response groundedness if we still have reflection diff --git a/src/nvidia_rag/rag_server/response_generator.py b/src/nvidia_rag/rag_server/response_generator.py index 0433d58a..a7a8d254 100644 --- a/src/nvidia_rag/rag_server/response_generator.py +++ b/src/nvidia_rag/rag_server/response_generator.py @@ -38,7 +38,6 @@ from pydantic import BaseModel, Field, validator from pymilvus.exceptions import MilvusException, MilvusUnavailableException -from nvidia_rag.utils.llm import USAGE_SENTINEL_PREFIX from nvidia_rag.utils.minio_operator import ( get_minio_operator, get_unique_thumbnail_id, @@ -446,11 +445,10 @@ def generate_answer( llm_generation_time_ms: float | None = None usage: Usage | None = None for chunk in generator: - # Sentinel from LangChain runnable carrying token-usage JSON. - if isinstance(chunk, str) and chunk.startswith(USAGE_SENTINEL_PREFIX): - usage_json = chunk[len(USAGE_SENTINEL_PREFIX) :] + # Handle usage metadata if present (AIMessageChunk) + if hasattr(chunk, "usage_metadata") and chunk.usage_metadata: + usage_dict = chunk.usage_metadata try: - usage_dict = json.loads(usage_json) or {} prompt_tokens = int(usage_dict.get("input_tokens", 0)) completion_tokens = int(usage_dict.get("output_tokens", 0)) total_tokens = int( @@ -472,18 +470,27 @@ def generate_answer( usage_dict, ) except Exception as e: - logger.debug("Failed to parse usage sentinel: %s", e) + logger.debug("Failed to parse usage metadata: %s", e) + + # Extract content + content = chunk + if hasattr(chunk, "content"): + content = chunk.content + + # Skip empty content (e.g. usage-only chunks) + if not content: continue + # TODO: This is a hack to clear contexts if we get an error # response from nemoguardrails - if chunk == "I'm sorry, I can't respond to that.": + if content == "I'm sorry, I can't respond to that.": # Clear contexts if we get an error response contexts = [] chain_response = ChainResponse() response_choice = ChainResponseChoices( index=0, - message=Message(role="assistant", content=chunk), - delta=Message(role=None, content=chunk), + message=Message(role="assistant", content=content), + delta=Message(role=None, content=content), finish_reason=None, ) chain_response.id = resp_id @@ -625,10 +632,10 @@ async def generate_answer_async( llm_generation_time_ms: float | None = None usage: Usage | None = None async for chunk in generator: - if isinstance(chunk, str) and chunk.startswith(USAGE_SENTINEL_PREFIX): - usage_json = chunk[len(USAGE_SENTINEL_PREFIX) :] + # Handle usage metadata if present (AIMessageChunk) + if hasattr(chunk, "usage_metadata") and chunk.usage_metadata: + usage_dict = chunk.usage_metadata try: - usage_dict = json.loads(usage_json) or {} prompt_tokens = int(usage_dict.get("input_tokens", 0)) completion_tokens = int(usage_dict.get("output_tokens", 0)) total_tokens = int( @@ -650,18 +657,27 @@ async def generate_answer_async( usage_dict, ) except Exception as e: - logger.debug("Failed to parse usage sentinel: %s", e) + logger.debug("Failed to parse usage metadata: %s", e) + + # Extract content + content = chunk + if hasattr(chunk, "content"): + content = chunk.content + + # Skip empty content (e.g. usage-only chunks) + if not content: continue + # TODO: This is a hack to clear contexts if we get an error # response from nemoguardrails - if chunk == "I'm sorry, I can't respond to that.": + if content == "I'm sorry, I can't respond to that.": # Clear contexts if we get an error response contexts = [] chain_response = ChainResponse() response_choice = ChainResponseChoices( index=0, - message=Message(role="assistant", content=chunk), - delta=Message(role=None, content=chunk), + message=Message(role="assistant", content=content), + delta=Message(role=None, content=content), finish_reason=None, ) chain_response.id = resp_id diff --git a/src/nvidia_rag/utils/llm.py b/src/nvidia_rag/utils/llm.py index 14a4ddc4..cfcb6d02 100644 --- a/src/nvidia_rag/utils/llm.py +++ b/src/nvidia_rag/utils/llm.py @@ -240,6 +240,8 @@ def get_llm(config: NvidiaRAGConfig | None = None, **kwargs) -> LLM | SimpleChat chat_nvidia_kwargs["top_p"] = kwargs["top_p"] if kwargs.get("max_tokens") is not None: chat_nvidia_kwargs["max_tokens"] = kwargs["max_tokens"] + # Also set max_completion_tokens as max_tokens is deprecated in newer libs + chat_nvidia_kwargs["max_completion_tokens"] = kwargs["max_tokens"] # Only include NVIDIA-specific parameters for NVIDIA endpoints if is_nvidia: if kwargs.get("min_tokens") is not None: @@ -256,16 +258,25 @@ def get_llm(config: NvidiaRAGConfig | None = None, **kwargs) -> LLM | SimpleChat logger.info("Using llm model %s from api catalog", kwargs.get("model")) api_key = kwargs.get("api_key") or config.llm.get_api_key() - llm = ChatNVIDIA( - model=kwargs.get("model"), - api_key=api_key, - temperature=kwargs.get("temperature", None), - top_p=kwargs.get("top_p", None), - max_tokens=kwargs.get("max_tokens", None), - min_tokens=kwargs.get("min_tokens", None), - ignore_eos=kwargs.get("ignore_eos", False), - stop=kwargs.get("stop", []), - ) + + chat_nvidia_kwargs = { + "model": kwargs.get("model"), + "api_key": api_key, + "stop": kwargs.get("stop", []), + } + if kwargs.get("temperature") is not None: + chat_nvidia_kwargs["temperature"] = kwargs["temperature"] + if kwargs.get("top_p") is not None: + chat_nvidia_kwargs["top_p"] = kwargs["top_p"] + if kwargs.get("max_tokens") is not None: + chat_nvidia_kwargs["max_tokens"] = kwargs["max_tokens"] + chat_nvidia_kwargs["max_completion_tokens"] = kwargs["max_tokens"] + if kwargs.get("min_tokens") is not None: + chat_nvidia_kwargs["min_tokens"] = kwargs["min_tokens"] + if kwargs.get("ignore_eos") is not None: + chat_nvidia_kwargs["ignore_eos"] = kwargs["ignore_eos"] + + llm = ChatNVIDIA(**chat_nvidia_kwargs) llm = _bind_thinking_tokens_if_configured(llm, **kwargs) return llm @@ -450,7 +461,7 @@ async def streaming_filter_think_async(chunks): chunks: Async iterable of chunks from a streaming LLM response Yields: - str: Filtered content with think blocks removed + AIMessageChunk: Filtered content with think blocks removed """ # Complete tags FULL_START_TAG = "" @@ -471,10 +482,16 @@ async def streaming_filter_think_async(chunks): buffer = "" output_buffer = "" chunk_count = 0 + last_usage = None async for chunk in chunks: content = chunk.content chunk_count += 1 + + # Capture usage metadata if present + if hasattr(chunk, "usage_metadata") and chunk.usage_metadata: + last_usage = chunk.usage_metadata + logger.info(f"Usage found in streaming chunk: {last_usage}") # Let's first check for full tags - this is the most reliable approach buffer += content @@ -567,15 +584,40 @@ async def streaming_filter_think_async(chunks): # Yield accumulated output before processing next chunk if output_buffer: - yield output_buffer + msg = AIMessageChunk(content=output_buffer) + if last_usage: + msg.usage_metadata = last_usage + yield msg output_buffer = "" + # Handle partial matches at EOF - treat as content + if state == MATCHING_START or state == MATCHING_END: + output_buffer += buffer + buffer = "" + state = NORMAL + + emitted_content = False # Yield any remaining content if not in a think block if state == NORMAL: if buffer: - yield buffer + msg = AIMessageChunk(content=buffer) + if last_usage: + msg.usage_metadata = last_usage + yield msg + emitted_content = True if output_buffer: - yield output_buffer + msg = AIMessageChunk(content=output_buffer) + if last_usage: + msg.usage_metadata = last_usage + yield msg + emitted_content = True + + if last_usage and not emitted_content: + # If we have usage but didn't emit it above (either because everything was filtered + # or just end of stream with empty buffer), emit an empty chunk with usage + msg = AIMessageChunk(content="") + msg.usage_metadata = last_usage + yield msg logger.info( "Finished streaming_filter_think_async processing after %d chunks", chunk_count @@ -605,29 +647,3 @@ def get_streaming_filter_think_parser_async(): logger.info("Think token filtering is disabled (async)") # If filtering is disabled, use a passthrough that passes content as-is return RunnablePassthrough() - - -USAGE_SENTINEL_PREFIX = "__RAG_USAGE_SENTINEL__:" - - -async def stream_with_usage_sentinel(chunks): - """ - Pass through model chunks and, at the end, emit a synthetic chunk whose - content encodes token-usage metadata. - """ - last_usage = None - - async for chunk in chunks: - logger.info("Chunk: %s", chunk.usage_metadata) - if hasattr(chunk, "usage_metadata") and getattr(chunk, "usage_metadata", None): - last_usage = getattr(chunk, "usage_metadata", None) - logger.info("Usage metadata: %s", last_usage) - yield chunk - - if last_usage is not None: - try: - payload = json.dumps(last_usage) - logger.info("Usage sentinel chunk: %s", payload) - yield AIMessageChunk(content=f"{USAGE_SENTINEL_PREFIX}{payload}") - except Exception as e: - logger.debug("Failed to emit usage sentinel chunk: %s", e) From 626d011de43bae737028d5dbce2b67bd0cdc00db Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Fri, 9 Jan 2026 09:14:48 +0530 Subject: [PATCH 4/5] (llm): switch to ChatOpenAI and fix token usage reporting Replaces ChatNVIDIA with ChatOpenAI to resolve token usage reporting issues. Moves non-standard parameters to extra_body and enables stream usage reporting. --- pyproject.toml | 4 +- src/nvidia_rag/utils/llm.py | 115 ++++++++++++++++++------------------ 2 files changed, 57 insertions(+), 62 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cacc5c3d..fb37f4ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "langchain-community>=0.3.27", "langchain-milvus>=0.2.1", "langchain-nvidia-ai-endpoints>=0.3.18", + "langchain-openai>=0.2,<1.0", "minio>=7.2,<8.0", "pdfplumber>=0.6", "pydantic>=2.11,<3.0", @@ -44,7 +45,6 @@ dependencies = [ [project.optional-dependencies] rag = [ - "langchain-openai>=0.2,<1.0", "opentelemetry-api>=1.29,<2.0", "opentelemetry-exporter-otlp>=1.29,<2.0", "opentelemetry-exporter-prometheus>=0.50b0,<1.0", @@ -65,7 +65,6 @@ ingest = [ "nv-ingest-client==26.1.0rc5", "tritonclient==2.57.0", # Other ingest dependencies - "langchain-openai>=0.2,<1.0", "overrides>=7.7,<8.0", "tqdm>=4.67,<5.0", "opentelemetry-api>=1.29,<2.0", @@ -86,7 +85,6 @@ all = [ "nv-ingest-client==26.1.0rc5", "tritonclient==2.57.0", # RAG + Ingest dependencies - "langchain-openai>=0.2,<1.0", "overrides>=7.7,<8.0", "tqdm>=4.67,<5.0", "opentelemetry-api>=1.29,<2.0", diff --git a/src/nvidia_rag/utils/llm.py b/src/nvidia_rag/utils/llm.py index cfcb6d02..9a021d2b 100644 --- a/src/nvidia_rag/utils/llm.py +++ b/src/nvidia_rag/utils/llm.py @@ -34,7 +34,7 @@ from langchain.llms.base import LLM from langchain_core.language_models.chat_models import SimpleChatModel from langchain_core.messages import AIMessageChunk -from langchain_nvidia_ai_endpoints import ChatNVIDIA +from langchain_openai import ChatOpenAI from nvidia_rag.rag_server.response_generator import APIError, ErrorCodeMapping from nvidia_rag.utils.common import ( @@ -46,12 +46,6 @@ logger = logging.getLogger(__name__) -try: - from langchain_openai import ChatOpenAI -except ImportError: - logger.info("Langchain OpenAI is not installed.") - pass - def get_prompts(source: str | dict | None = None) -> dict: """Retrieves prompt configurations from source or YAML file and return a dict. @@ -218,66 +212,68 @@ def get_llm(config: NvidiaRAGConfig | None = None, **kwargs) -> LLM | SimpleChat error_msg, ErrorCodeMapping.SERVICE_UNAVAILABLE ) from e - if url: - logger.debug(f"Length of llm endpoint url string {url}") - logger.info("Using llm model %s hosted at %s", kwargs.get("model"), url) - - api_key = kwargs.get("api_key") or config.llm.get_api_key() - # Detect endpoint type using URL patterns only - is_nvidia = _is_nvidia_endpoint(url) - - # Build kwargs dict, only including parameters that are set - # For non-NVIDIA endpoints, exclude NVIDIA-specific parameters - chat_nvidia_kwargs = { - "base_url": url, - "model": kwargs.get("model"), - "api_key": api_key, - "stop": kwargs.get("stop", []), - } - if kwargs.get("temperature") is not None: - chat_nvidia_kwargs["temperature"] = kwargs["temperature"] - if kwargs.get("top_p") is not None: - chat_nvidia_kwargs["top_p"] = kwargs["top_p"] - if kwargs.get("max_tokens") is not None: - chat_nvidia_kwargs["max_tokens"] = kwargs["max_tokens"] - # Also set max_completion_tokens as max_tokens is deprecated in newer libs - chat_nvidia_kwargs["max_completion_tokens"] = kwargs["max_tokens"] - # Only include NVIDIA-specific parameters for NVIDIA endpoints - if is_nvidia: - if kwargs.get("min_tokens") is not None: - chat_nvidia_kwargs["min_tokens"] = kwargs["min_tokens"] - if kwargs.get("ignore_eos") is not None: - chat_nvidia_kwargs["ignore_eos"] = kwargs["ignore_eos"] - - llm = ChatNVIDIA(**chat_nvidia_kwargs) - # Only bind thinking tokens for NVIDIA endpoints - if is_nvidia: - llm = _bind_thinking_tokens_if_configured(llm, **kwargs) - return llm - - logger.info("Using llm model %s from api catalog", kwargs.get("model")) + # Consolidate logic for both NVIDIA endpoints and API Catalog using ChatOpenAI + # to avoid token usage reporting issues with ChatNVIDIA + base_url = url + if not base_url: + logger.info("Using llm model %s from api catalog (via ChatOpenAI)", kwargs.get("model")) + # Default to NVIDIA API Catalog URL for OpenAI client + base_url = "https://integrate.api.nvidia.com/v1" + else: + logger.debug(f"Length of llm endpoint url string {base_url}") + logger.info("Using llm model %s hosted at %s", kwargs.get("model"), base_url) api_key = kwargs.get("api_key") or config.llm.get_api_key() - chat_nvidia_kwargs = { + # Detect endpoint type (still useful for logic branching) + is_nvidia = _is_nvidia_endpoint(base_url) + + # Prepare kwargs for ChatOpenAI + chat_openai_kwargs = { "model": kwargs.get("model"), "api_key": api_key, + "base_url": base_url, "stop": kwargs.get("stop", []), } + + # Optional standard parameters if kwargs.get("temperature") is not None: - chat_nvidia_kwargs["temperature"] = kwargs["temperature"] + chat_openai_kwargs["temperature"] = kwargs["temperature"] if kwargs.get("top_p") is not None: - chat_nvidia_kwargs["top_p"] = kwargs["top_p"] + chat_openai_kwargs["top_p"] = kwargs["top_p"] if kwargs.get("max_tokens") is not None: - chat_nvidia_kwargs["max_tokens"] = kwargs["max_tokens"] - chat_nvidia_kwargs["max_completion_tokens"] = kwargs["max_tokens"] - if kwargs.get("min_tokens") is not None: - chat_nvidia_kwargs["min_tokens"] = kwargs["min_tokens"] - if kwargs.get("ignore_eos") is not None: - chat_nvidia_kwargs["ignore_eos"] = kwargs["ignore_eos"] - - llm = ChatNVIDIA(**chat_nvidia_kwargs) - llm = _bind_thinking_tokens_if_configured(llm, **kwargs) + chat_openai_kwargs["max_tokens"] = kwargs["max_tokens"] + + # Prepare extra parameters (NVIDIA specific) for model_kwargs via extra_body + # OpenAI API client requires non-standard parameters to be in 'extra_body' + # Also request stream options to ensure usage is returned + model_kwargs = { + "stream_options": {"include_usage": True} + } + extra_body = {} + + if is_nvidia: + if kwargs.get("min_tokens") is not None: + extra_body["min_tokens"] = kwargs["min_tokens"] + if kwargs.get("ignore_eos") is not None: + extra_body["ignore_eos"] = kwargs["ignore_eos"] + + # Handle thinking tokens + min_think = kwargs.get("min_thinking_tokens", None) + max_think = kwargs.get("max_thinking_tokens", None) + if min_think is not None and min_think > 0: + extra_body["min_thinking_tokens"] = min_think + if max_think is not None and max_think > 0: + extra_body["max_thinking_tokens"] = max_think + + if extra_body: + model_kwargs["extra_body"] = extra_body + + if model_kwargs: + chat_openai_kwargs["model_kwargs"] = model_kwargs + + llm = ChatOpenAI(**chat_openai_kwargs) + return llm raise RuntimeError( @@ -432,7 +428,7 @@ def get_streaming_filter_think_parser(): If FILTER_THINK_TOKENS environment variable is set to "true" (case-insensitive), returns a parser that filters out content between and tags. - Otherwise, returns a pass-through parser that doesn't modify the content. + Otherwise, returns a parser that passes content as-is. Returns: RunnableGenerator: A parser for filtering (or not filtering) think tokens @@ -630,7 +626,7 @@ def get_streaming_filter_think_parser_async(): If FILTER_THINK_TOKENS environment variable is set to "true" (case-insensitive), returns a parser that filters out content between and tags. - Otherwise, returns a pass-through parser that doesn't modify the content. + Otherwise, returns a parser that passes content as-is. Returns: RunnableGenerator: An async parser for filtering (or not filtering) think tokens @@ -647,3 +643,4 @@ def get_streaming_filter_think_parser_async(): logger.info("Think token filtering is disabled (async)") # If filtering is disabled, use a passthrough that passes content as-is return RunnablePassthrough() + From 9442c177c2436dff2afde61756ad369a744ffc8d Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Tue, 13 Jan 2026 10:33:36 +0530 Subject: [PATCH 5/5] Fix failing init/integration tests --- tests/unit/conftest.py | 5 +++++ uv.lock | 10 +++------- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 65dba580..d7cd969d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -19,12 +19,17 @@ import atexit import logging +import os import sys import types from unittest.mock import MagicMock, patch import pytest +# Set a dummy OpenAI API key to prevent initialization errors during test collection +# This must happen before any imports that might instantiate OpenAI clients +os.environ.setdefault("OPENAI_API_KEY", "dummy-test-api-key") + # OpenTelemetry imports (optional - may not be available in all environments) try: from opentelemetry import metrics, trace diff --git a/uv.lock b/uv.lock index 28949274..3cdb4ca8 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11, <3.14" resolution-markers = [ "python_full_version >= '3.13'", @@ -1920,6 +1920,7 @@ dependencies = [ { name = "langchain-elasticsearch" }, { name = "langchain-milvus" }, { name = "langchain-nvidia-ai-endpoints" }, + { name = "langchain-openai" }, { name = "lark" }, { name = "mcp" }, { name = "minio" }, @@ -1939,7 +1940,6 @@ dependencies = [ all = [ { name = "azure-core" }, { name = "azure-storage-blob" }, - { name = "langchain-openai" }, { name = "nv-ingest-api" }, { name = "nv-ingest-client" }, { name = "opentelemetry-api" }, @@ -1958,7 +1958,6 @@ all = [ ingest = [ { name = "azure-core" }, { name = "azure-storage-blob" }, - { name = "langchain-openai" }, { name = "nv-ingest-api" }, { name = "nv-ingest-client" }, { name = "opentelemetry-api" }, @@ -1977,7 +1976,6 @@ ingest = [ rag = [ { name = "azure-core" }, { name = "azure-storage-blob" }, - { name = "langchain-openai" }, { name = "opentelemetry-api" }, { name = "opentelemetry-exporter-otlp" }, { name = "opentelemetry-exporter-prometheus" }, @@ -2023,9 +2021,7 @@ requires-dist = [ { name = "langchain-elasticsearch", specifier = ">=0.3,<1.0" }, { name = "langchain-milvus", specifier = ">=0.2.1" }, { name = "langchain-nvidia-ai-endpoints", specifier = ">=0.3.18" }, - { name = "langchain-openai", marker = "extra == 'all'", specifier = ">=0.2,<1.0" }, - { name = "langchain-openai", marker = "extra == 'ingest'", specifier = ">=0.2,<1.0" }, - { name = "langchain-openai", marker = "extra == 'rag'", specifier = ">=0.2,<1.0" }, + { name = "langchain-openai", specifier = ">=0.2,<1.0" }, { name = "lark", specifier = ">=1.2.2" }, { name = "mcp", specifier = ">=1.23.1" }, { name = "minio", specifier = ">=7.2,<8.0" },