diff --git a/pyproject.toml b/pyproject.toml
index cacc5c3d..fb37f4ca 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,6 +26,7 @@ dependencies = [
"langchain-community>=0.3.27",
"langchain-milvus>=0.2.1",
"langchain-nvidia-ai-endpoints>=0.3.18",
+ "langchain-openai>=0.2,<1.0",
"minio>=7.2,<8.0",
"pdfplumber>=0.6",
"pydantic>=2.11,<3.0",
@@ -44,7 +45,6 @@ dependencies = [
[project.optional-dependencies]
rag = [
- "langchain-openai>=0.2,<1.0",
"opentelemetry-api>=1.29,<2.0",
"opentelemetry-exporter-otlp>=1.29,<2.0",
"opentelemetry-exporter-prometheus>=0.50b0,<1.0",
@@ -65,7 +65,6 @@ ingest = [
"nv-ingest-client==26.1.0rc5",
"tritonclient==2.57.0",
# Other ingest dependencies
- "langchain-openai>=0.2,<1.0",
"overrides>=7.7,<8.0",
"tqdm>=4.67,<5.0",
"opentelemetry-api>=1.29,<2.0",
@@ -86,7 +85,6 @@ all = [
"nv-ingest-client==26.1.0rc5",
"tritonclient==2.57.0",
# RAG + Ingest dependencies
- "langchain-openai>=0.2,<1.0",
"overrides>=7.7,<8.0",
"tqdm>=4.67,<5.0",
"opentelemetry-api>=1.29,<2.0",
diff --git a/src/nvidia_rag/rag_server/main.py b/src/nvidia_rag/rag_server/main.py
index 206caf6d..80453386 100644
--- a/src/nvidia_rag/rag_server/main.py
+++ b/src/nvidia_rag/rag_server/main.py
@@ -48,7 +48,7 @@
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompts import MessagesPlaceholder
from langchain_core.prompts.chat import ChatPromptTemplate
-from langchain_core.runnables import RunnableAssign
+from langchain_core.runnables import RunnableAssign, RunnableGenerator
from opentelemetry import context as otel_context
from requests import ConnectTimeout
@@ -1408,14 +1408,15 @@ async def _llm_chain(
prompt_template = ChatPromptTemplate.from_messages(message)
llm = get_llm(config=self.config, **llm_settings)
- chain = (
+ # Chain for streaming: we remove StrOutputParser so we yield AIMessageChunks,
+ # allowing us to access .usage_metadata in the response generator.
+ stream_chain = (
prompt_template
| llm
| self.StreamingFilterThinkParser
- | StrOutputParser()
)
# Create async stream generator
- stream_gen = chain.astream(
+ stream_gen = stream_chain.astream(
{"question": query_text}, config={"run_name": "llm-stream"}
)
# Eagerly fetch first chunk to trigger any errors before returning response
@@ -2437,7 +2438,16 @@ def generate_filter_for_collection(collection_name):
self._print_conversation_history(message)
prompt = ChatPromptTemplate.from_messages(message)
- chain = prompt | llm | self.StreamingFilterThinkParser | StrOutputParser()
+ # Base chain (no usage sentinel) used for non-streaming reflection path.
+ # We keep StrOutputParser here because we want the full string response for logic.
+ base_chain = prompt | llm | self.StreamingFilterThinkParser | StrOutputParser()
+
+ # Streaming chain: yields AIMessageChunks to preserve usage metadata.
+ stream_chain = (
+ prompt
+ | llm
+ | self.StreamingFilterThinkParser
+ )
# Check response groundedness if we still have reflection
# iterations available
@@ -2496,8 +2506,8 @@ def generate_filter_for_collection(collection_name):
status_code=ErrorCodeMapping.SUCCESS,
)
else:
- # Create async stream generator
- stream_gen = chain.astream(
+ # Create async stream generator using the streaming chain
+ stream_gen = stream_chain.astream(
{"question": query, "context": docs},
config={"run_name": "llm-stream"},
)
diff --git a/src/nvidia_rag/rag_server/response_generator.py b/src/nvidia_rag/rag_server/response_generator.py
index 971f9a39..a7a8d254 100644
--- a/src/nvidia_rag/rag_server/response_generator.py
+++ b/src/nvidia_rag/rag_server/response_generator.py
@@ -443,17 +443,54 @@ def generate_answer(
llm_ttft_ms: float | None = None
rag_ttft_ms: float | None = None
llm_generation_time_ms: float | None = None
+ usage: Usage | None = None
for chunk in generator:
+ # Handle usage metadata if present (AIMessageChunk)
+ if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
+ usage_dict = chunk.usage_metadata
+ try:
+ prompt_tokens = int(usage_dict.get("input_tokens", 0))
+ completion_tokens = int(usage_dict.get("output_tokens", 0))
+ total_tokens = int(
+ usage_dict.get(
+ "total_tokens", prompt_tokens + completion_tokens
+ )
+ )
+ usage = Usage(
+ total_tokens=total_tokens,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ )
+ logger.info(
+ "LLM usage for model %s (sync): prompt_tokens=%d, completion_tokens=%d, total_tokens=%d, raw=%s",
+ model,
+ prompt_tokens,
+ completion_tokens,
+ total_tokens,
+ usage_dict,
+ )
+ except Exception as e:
+ logger.debug("Failed to parse usage metadata: %s", e)
+
+ # Extract content
+ content = chunk
+ if hasattr(chunk, "content"):
+ content = chunk.content
+
+ # Skip empty content (e.g. usage-only chunks)
+ if not content:
+ continue
+
# TODO: This is a hack to clear contexts if we get an error
# response from nemoguardrails
- if chunk == "I'm sorry, I can't respond to that.":
+ if content == "I'm sorry, I can't respond to that.":
# Clear contexts if we get an error response
contexts = []
chain_response = ChainResponse()
response_choice = ChainResponseChoices(
index=0,
- message=Message(role="assistant", content=chunk),
- delta=Message(role=None, content=chunk),
+ message=Message(role="assistant", content=content),
+ delta=Message(role=None, content=content),
finish_reason=None,
)
chain_response.id = resp_id
@@ -519,6 +556,8 @@ def generate_answer(
# Create response first, then attach metrics for clarity
chain_response = ChainResponse()
chain_response.metrics = final_metrics
+ if usage is not None:
+ chain_response.usage = usage
# [DONE] indicate end of response from server
response_choice = ChainResponseChoices(
@@ -591,17 +630,54 @@ async def generate_answer_async(
llm_ttft_ms: float | None = None
rag_ttft_ms: float | None = None
llm_generation_time_ms: float | None = None
+ usage: Usage | None = None
async for chunk in generator:
+ # Handle usage metadata if present (AIMessageChunk)
+ if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
+ usage_dict = chunk.usage_metadata
+ try:
+ prompt_tokens = int(usage_dict.get("input_tokens", 0))
+ completion_tokens = int(usage_dict.get("output_tokens", 0))
+ total_tokens = int(
+ usage_dict.get(
+ "total_tokens", prompt_tokens + completion_tokens
+ )
+ )
+ usage = Usage(
+ total_tokens=total_tokens,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ )
+ logger.info(
+ "LLM usage for model %s (async): prompt_tokens=%d, completion_tokens=%d, total_tokens=%d, raw=%s",
+ model,
+ prompt_tokens,
+ completion_tokens,
+ total_tokens,
+ usage_dict,
+ )
+ except Exception as e:
+ logger.debug("Failed to parse usage metadata: %s", e)
+
+ # Extract content
+ content = chunk
+ if hasattr(chunk, "content"):
+ content = chunk.content
+
+ # Skip empty content (e.g. usage-only chunks)
+ if not content:
+ continue
+
# TODO: This is a hack to clear contexts if we get an error
# response from nemoguardrails
- if chunk == "I'm sorry, I can't respond to that.":
+ if content == "I'm sorry, I can't respond to that.":
# Clear contexts if we get an error response
contexts = []
chain_response = ChainResponse()
response_choice = ChainResponseChoices(
index=0,
- message=Message(role="assistant", content=chunk),
- delta=Message(role=None, content=chunk),
+ message=Message(role="assistant", content=content),
+ delta=Message(role=None, content=content),
finish_reason=None,
)
chain_response.id = resp_id
@@ -667,6 +743,8 @@ async def generate_answer_async(
# Create response first, then attach metrics for clarity
chain_response = ChainResponse()
chain_response.metrics = final_metrics
+ if usage is not None:
+ chain_response.usage = usage
# [DONE] indicate end of response from server
response_choice = ChainResponseChoices(
diff --git a/src/nvidia_rag/utils/llm.py b/src/nvidia_rag/utils/llm.py
index cd2d3f08..9a021d2b 100644
--- a/src/nvidia_rag/utils/llm.py
+++ b/src/nvidia_rag/utils/llm.py
@@ -22,6 +22,7 @@
6. get_streaming_filter_think_parser_async: Get the parser for filtering the think tokens (async).
"""
+import json
import logging
import os
from collections.abc import Iterable
@@ -32,7 +33,8 @@
import yaml
from langchain.llms.base import LLM
from langchain_core.language_models.chat_models import SimpleChatModel
-from langchain_nvidia_ai_endpoints import ChatNVIDIA
+from langchain_core.messages import AIMessageChunk
+from langchain_openai import ChatOpenAI
from nvidia_rag.rag_server.response_generator import APIError, ErrorCodeMapping
from nvidia_rag.utils.common import (
@@ -44,12 +46,6 @@
logger = logging.getLogger(__name__)
-try:
- from langchain_openai import ChatOpenAI
-except ImportError:
- logger.info("Langchain OpenAI is not installed.")
- pass
-
def get_prompts(source: str | dict | None = None) -> dict:
"""Retrieves prompt configurations from source or YAML file and return a dict.
@@ -216,55 +212,68 @@ def get_llm(config: NvidiaRAGConfig | None = None, **kwargs) -> LLM | SimpleChat
error_msg, ErrorCodeMapping.SERVICE_UNAVAILABLE
) from e
- if url:
- logger.debug(f"Length of llm endpoint url string {url}")
- logger.info("Using llm model %s hosted at %s", kwargs.get("model"), url)
-
- api_key = kwargs.get("api_key") or config.llm.get_api_key()
- # Detect endpoint type using URL patterns only
- is_nvidia = _is_nvidia_endpoint(url)
-
- # Build kwargs dict, only including parameters that are set
- # For non-NVIDIA endpoints, exclude NVIDIA-specific parameters
- chat_nvidia_kwargs = {
- "base_url": url,
- "model": kwargs.get("model"),
- "api_key": api_key,
- "stop": kwargs.get("stop", []),
- }
- if kwargs.get("temperature") is not None:
- chat_nvidia_kwargs["temperature"] = kwargs["temperature"]
- if kwargs.get("top_p") is not None:
- chat_nvidia_kwargs["top_p"] = kwargs["top_p"]
- if kwargs.get("max_tokens") is not None:
- chat_nvidia_kwargs["max_tokens"] = kwargs["max_tokens"]
- # Only include NVIDIA-specific parameters for NVIDIA endpoints
- if is_nvidia:
- if kwargs.get("min_tokens") is not None:
- chat_nvidia_kwargs["min_tokens"] = kwargs["min_tokens"]
- if kwargs.get("ignore_eos") is not None:
- chat_nvidia_kwargs["ignore_eos"] = kwargs["ignore_eos"]
-
- llm = ChatNVIDIA(**chat_nvidia_kwargs)
- # Only bind thinking tokens for NVIDIA endpoints
- if is_nvidia:
- llm = _bind_thinking_tokens_if_configured(llm, **kwargs)
- return llm
-
- logger.info("Using llm model %s from api catalog", kwargs.get("model"))
+ # Consolidate logic for both NVIDIA endpoints and API Catalog using ChatOpenAI
+ # to avoid token usage reporting issues with ChatNVIDIA
+ base_url = url
+ if not base_url:
+ logger.info("Using llm model %s from api catalog (via ChatOpenAI)", kwargs.get("model"))
+ # Default to NVIDIA API Catalog URL for OpenAI client
+ base_url = "https://integrate.api.nvidia.com/v1"
+ else:
+ logger.debug(f"Length of llm endpoint url string {base_url}")
+ logger.info("Using llm model %s hosted at %s", kwargs.get("model"), base_url)
api_key = kwargs.get("api_key") or config.llm.get_api_key()
- llm = ChatNVIDIA(
- model=kwargs.get("model"),
- api_key=api_key,
- temperature=kwargs.get("temperature", None),
- top_p=kwargs.get("top_p", None),
- max_tokens=kwargs.get("max_tokens", None),
- min_tokens=kwargs.get("min_tokens", None),
- ignore_eos=kwargs.get("ignore_eos", False),
- stop=kwargs.get("stop", []),
- )
- llm = _bind_thinking_tokens_if_configured(llm, **kwargs)
+
+ # Detect endpoint type (still useful for logic branching)
+ is_nvidia = _is_nvidia_endpoint(base_url)
+
+ # Prepare kwargs for ChatOpenAI
+ chat_openai_kwargs = {
+ "model": kwargs.get("model"),
+ "api_key": api_key,
+ "base_url": base_url,
+ "stop": kwargs.get("stop", []),
+ }
+
+ # Optional standard parameters
+ if kwargs.get("temperature") is not None:
+ chat_openai_kwargs["temperature"] = kwargs["temperature"]
+ if kwargs.get("top_p") is not None:
+ chat_openai_kwargs["top_p"] = kwargs["top_p"]
+ if kwargs.get("max_tokens") is not None:
+ chat_openai_kwargs["max_tokens"] = kwargs["max_tokens"]
+
+ # Prepare extra parameters (NVIDIA specific) for model_kwargs via extra_body
+ # OpenAI API client requires non-standard parameters to be in 'extra_body'
+ # Also request stream options to ensure usage is returned
+ model_kwargs = {
+ "stream_options": {"include_usage": True}
+ }
+ extra_body = {}
+
+ if is_nvidia:
+ if kwargs.get("min_tokens") is not None:
+ extra_body["min_tokens"] = kwargs["min_tokens"]
+ if kwargs.get("ignore_eos") is not None:
+ extra_body["ignore_eos"] = kwargs["ignore_eos"]
+
+ # Handle thinking tokens
+ min_think = kwargs.get("min_thinking_tokens", None)
+ max_think = kwargs.get("max_thinking_tokens", None)
+ if min_think is not None and min_think > 0:
+ extra_body["min_thinking_tokens"] = min_think
+ if max_think is not None and max_think > 0:
+ extra_body["max_thinking_tokens"] = max_think
+
+ if extra_body:
+ model_kwargs["extra_body"] = extra_body
+
+ if model_kwargs:
+ chat_openai_kwargs["model_kwargs"] = model_kwargs
+
+ llm = ChatOpenAI(**chat_openai_kwargs)
+
return llm
raise RuntimeError(
@@ -419,7 +428,7 @@ def get_streaming_filter_think_parser():
If FILTER_THINK_TOKENS environment variable is set to "true" (case-insensitive),
returns a parser that filters out content between and tags.
- Otherwise, returns a pass-through parser that doesn't modify the content.
+ Otherwise, returns a parser that passes content as-is.
Returns:
RunnableGenerator: A parser for filtering (or not filtering) think tokens
@@ -448,7 +457,7 @@ async def streaming_filter_think_async(chunks):
chunks: Async iterable of chunks from a streaming LLM response
Yields:
- str: Filtered content with think blocks removed
+ AIMessageChunk: Filtered content with think blocks removed
"""
# Complete tags
FULL_START_TAG = ""
@@ -469,10 +478,16 @@ async def streaming_filter_think_async(chunks):
buffer = ""
output_buffer = ""
chunk_count = 0
+ last_usage = None
async for chunk in chunks:
content = chunk.content
chunk_count += 1
+
+ # Capture usage metadata if present
+ if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
+ last_usage = chunk.usage_metadata
+ logger.info(f"Usage found in streaming chunk: {last_usage}")
# Let's first check for full tags - this is the most reliable approach
buffer += content
@@ -565,15 +580,40 @@ async def streaming_filter_think_async(chunks):
# Yield accumulated output before processing next chunk
if output_buffer:
- yield output_buffer
+ msg = AIMessageChunk(content=output_buffer)
+ if last_usage:
+ msg.usage_metadata = last_usage
+ yield msg
output_buffer = ""
+ # Handle partial matches at EOF - treat as content
+ if state == MATCHING_START or state == MATCHING_END:
+ output_buffer += buffer
+ buffer = ""
+ state = NORMAL
+
+ emitted_content = False
# Yield any remaining content if not in a think block
if state == NORMAL:
if buffer:
- yield buffer
+ msg = AIMessageChunk(content=buffer)
+ if last_usage:
+ msg.usage_metadata = last_usage
+ yield msg
+ emitted_content = True
if output_buffer:
- yield output_buffer
+ msg = AIMessageChunk(content=output_buffer)
+ if last_usage:
+ msg.usage_metadata = last_usage
+ yield msg
+ emitted_content = True
+
+ if last_usage and not emitted_content:
+ # If we have usage but didn't emit it above (either because everything was filtered
+ # or just end of stream with empty buffer), emit an empty chunk with usage
+ msg = AIMessageChunk(content="")
+ msg.usage_metadata = last_usage
+ yield msg
logger.info(
"Finished streaming_filter_think_async processing after %d chunks", chunk_count
@@ -586,7 +626,7 @@ def get_streaming_filter_think_parser_async():
If FILTER_THINK_TOKENS environment variable is set to "true" (case-insensitive),
returns a parser that filters out content between and tags.
- Otherwise, returns a pass-through parser that doesn't modify the content.
+ Otherwise, returns a parser that passes content as-is.
Returns:
RunnableGenerator: An async parser for filtering (or not filtering) think tokens
@@ -603,3 +643,4 @@ def get_streaming_filter_think_parser_async():
logger.info("Think token filtering is disabled (async)")
# If filtering is disabled, use a passthrough that passes content as-is
return RunnablePassthrough()
+
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
index 65dba580..d7cd969d 100644
--- a/tests/unit/conftest.py
+++ b/tests/unit/conftest.py
@@ -19,12 +19,17 @@
import atexit
import logging
+import os
import sys
import types
from unittest.mock import MagicMock, patch
import pytest
+# Set a dummy OpenAI API key to prevent initialization errors during test collection
+# This must happen before any imports that might instantiate OpenAI clients
+os.environ.setdefault("OPENAI_API_KEY", "dummy-test-api-key")
+
# OpenTelemetry imports (optional - may not be available in all environments)
try:
from opentelemetry import metrics, trace
diff --git a/uv.lock b/uv.lock
index 28949274..3cdb4ca8 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1,5 +1,5 @@
version = 1
-revision = 2
+revision = 3
requires-python = ">=3.11, <3.14"
resolution-markers = [
"python_full_version >= '3.13'",
@@ -1920,6 +1920,7 @@ dependencies = [
{ name = "langchain-elasticsearch" },
{ name = "langchain-milvus" },
{ name = "langchain-nvidia-ai-endpoints" },
+ { name = "langchain-openai" },
{ name = "lark" },
{ name = "mcp" },
{ name = "minio" },
@@ -1939,7 +1940,6 @@ dependencies = [
all = [
{ name = "azure-core" },
{ name = "azure-storage-blob" },
- { name = "langchain-openai" },
{ name = "nv-ingest-api" },
{ name = "nv-ingest-client" },
{ name = "opentelemetry-api" },
@@ -1958,7 +1958,6 @@ all = [
ingest = [
{ name = "azure-core" },
{ name = "azure-storage-blob" },
- { name = "langchain-openai" },
{ name = "nv-ingest-api" },
{ name = "nv-ingest-client" },
{ name = "opentelemetry-api" },
@@ -1977,7 +1976,6 @@ ingest = [
rag = [
{ name = "azure-core" },
{ name = "azure-storage-blob" },
- { name = "langchain-openai" },
{ name = "opentelemetry-api" },
{ name = "opentelemetry-exporter-otlp" },
{ name = "opentelemetry-exporter-prometheus" },
@@ -2023,9 +2021,7 @@ requires-dist = [
{ name = "langchain-elasticsearch", specifier = ">=0.3,<1.0" },
{ name = "langchain-milvus", specifier = ">=0.2.1" },
{ name = "langchain-nvidia-ai-endpoints", specifier = ">=0.3.18" },
- { name = "langchain-openai", marker = "extra == 'all'", specifier = ">=0.2,<1.0" },
- { name = "langchain-openai", marker = "extra == 'ingest'", specifier = ">=0.2,<1.0" },
- { name = "langchain-openai", marker = "extra == 'rag'", specifier = ">=0.2,<1.0" },
+ { name = "langchain-openai", specifier = ">=0.2,<1.0" },
{ name = "lark", specifier = ">=1.2.2" },
{ name = "mcp", specifier = ">=1.23.1" },
{ name = "minio", specifier = ">=7.2,<8.0" },