diff --git a/pyproject.toml b/pyproject.toml index cacc5c3d..fb37f4ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "langchain-community>=0.3.27", "langchain-milvus>=0.2.1", "langchain-nvidia-ai-endpoints>=0.3.18", + "langchain-openai>=0.2,<1.0", "minio>=7.2,<8.0", "pdfplumber>=0.6", "pydantic>=2.11,<3.0", @@ -44,7 +45,6 @@ dependencies = [ [project.optional-dependencies] rag = [ - "langchain-openai>=0.2,<1.0", "opentelemetry-api>=1.29,<2.0", "opentelemetry-exporter-otlp>=1.29,<2.0", "opentelemetry-exporter-prometheus>=0.50b0,<1.0", @@ -65,7 +65,6 @@ ingest = [ "nv-ingest-client==26.1.0rc5", "tritonclient==2.57.0", # Other ingest dependencies - "langchain-openai>=0.2,<1.0", "overrides>=7.7,<8.0", "tqdm>=4.67,<5.0", "opentelemetry-api>=1.29,<2.0", @@ -86,7 +85,6 @@ all = [ "nv-ingest-client==26.1.0rc5", "tritonclient==2.57.0", # RAG + Ingest dependencies - "langchain-openai>=0.2,<1.0", "overrides>=7.7,<8.0", "tqdm>=4.67,<5.0", "opentelemetry-api>=1.29,<2.0", diff --git a/src/nvidia_rag/rag_server/main.py b/src/nvidia_rag/rag_server/main.py index 206caf6d..80453386 100644 --- a/src/nvidia_rag/rag_server/main.py +++ b/src/nvidia_rag/rag_server/main.py @@ -48,7 +48,7 @@ from langchain_core.output_parsers.string import StrOutputParser from langchain_core.prompts import MessagesPlaceholder from langchain_core.prompts.chat import ChatPromptTemplate -from langchain_core.runnables import RunnableAssign +from langchain_core.runnables import RunnableAssign, RunnableGenerator from opentelemetry import context as otel_context from requests import ConnectTimeout @@ -1408,14 +1408,15 @@ async def _llm_chain( prompt_template = ChatPromptTemplate.from_messages(message) llm = get_llm(config=self.config, **llm_settings) - chain = ( + # Chain for streaming: we remove StrOutputParser so we yield AIMessageChunks, + # allowing us to access .usage_metadata in the response generator. + stream_chain = ( prompt_template | llm | self.StreamingFilterThinkParser - | StrOutputParser() ) # Create async stream generator - stream_gen = chain.astream( + stream_gen = stream_chain.astream( {"question": query_text}, config={"run_name": "llm-stream"} ) # Eagerly fetch first chunk to trigger any errors before returning response @@ -2437,7 +2438,16 @@ def generate_filter_for_collection(collection_name): self._print_conversation_history(message) prompt = ChatPromptTemplate.from_messages(message) - chain = prompt | llm | self.StreamingFilterThinkParser | StrOutputParser() + # Base chain (no usage sentinel) used for non-streaming reflection path. + # We keep StrOutputParser here because we want the full string response for logic. + base_chain = prompt | llm | self.StreamingFilterThinkParser | StrOutputParser() + + # Streaming chain: yields AIMessageChunks to preserve usage metadata. + stream_chain = ( + prompt + | llm + | self.StreamingFilterThinkParser + ) # Check response groundedness if we still have reflection # iterations available @@ -2496,8 +2506,8 @@ def generate_filter_for_collection(collection_name): status_code=ErrorCodeMapping.SUCCESS, ) else: - # Create async stream generator - stream_gen = chain.astream( + # Create async stream generator using the streaming chain + stream_gen = stream_chain.astream( {"question": query, "context": docs}, config={"run_name": "llm-stream"}, ) diff --git a/src/nvidia_rag/rag_server/response_generator.py b/src/nvidia_rag/rag_server/response_generator.py index 971f9a39..a7a8d254 100644 --- a/src/nvidia_rag/rag_server/response_generator.py +++ b/src/nvidia_rag/rag_server/response_generator.py @@ -443,17 +443,54 @@ def generate_answer( llm_ttft_ms: float | None = None rag_ttft_ms: float | None = None llm_generation_time_ms: float | None = None + usage: Usage | None = None for chunk in generator: + # Handle usage metadata if present (AIMessageChunk) + if hasattr(chunk, "usage_metadata") and chunk.usage_metadata: + usage_dict = chunk.usage_metadata + try: + prompt_tokens = int(usage_dict.get("input_tokens", 0)) + completion_tokens = int(usage_dict.get("output_tokens", 0)) + total_tokens = int( + usage_dict.get( + "total_tokens", prompt_tokens + completion_tokens + ) + ) + usage = Usage( + total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + logger.info( + "LLM usage for model %s (sync): prompt_tokens=%d, completion_tokens=%d, total_tokens=%d, raw=%s", + model, + prompt_tokens, + completion_tokens, + total_tokens, + usage_dict, + ) + except Exception as e: + logger.debug("Failed to parse usage metadata: %s", e) + + # Extract content + content = chunk + if hasattr(chunk, "content"): + content = chunk.content + + # Skip empty content (e.g. usage-only chunks) + if not content: + continue + # TODO: This is a hack to clear contexts if we get an error # response from nemoguardrails - if chunk == "I'm sorry, I can't respond to that.": + if content == "I'm sorry, I can't respond to that.": # Clear contexts if we get an error response contexts = [] chain_response = ChainResponse() response_choice = ChainResponseChoices( index=0, - message=Message(role="assistant", content=chunk), - delta=Message(role=None, content=chunk), + message=Message(role="assistant", content=content), + delta=Message(role=None, content=content), finish_reason=None, ) chain_response.id = resp_id @@ -519,6 +556,8 @@ def generate_answer( # Create response first, then attach metrics for clarity chain_response = ChainResponse() chain_response.metrics = final_metrics + if usage is not None: + chain_response.usage = usage # [DONE] indicate end of response from server response_choice = ChainResponseChoices( @@ -591,17 +630,54 @@ async def generate_answer_async( llm_ttft_ms: float | None = None rag_ttft_ms: float | None = None llm_generation_time_ms: float | None = None + usage: Usage | None = None async for chunk in generator: + # Handle usage metadata if present (AIMessageChunk) + if hasattr(chunk, "usage_metadata") and chunk.usage_metadata: + usage_dict = chunk.usage_metadata + try: + prompt_tokens = int(usage_dict.get("input_tokens", 0)) + completion_tokens = int(usage_dict.get("output_tokens", 0)) + total_tokens = int( + usage_dict.get( + "total_tokens", prompt_tokens + completion_tokens + ) + ) + usage = Usage( + total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + logger.info( + "LLM usage for model %s (async): prompt_tokens=%d, completion_tokens=%d, total_tokens=%d, raw=%s", + model, + prompt_tokens, + completion_tokens, + total_tokens, + usage_dict, + ) + except Exception as e: + logger.debug("Failed to parse usage metadata: %s", e) + + # Extract content + content = chunk + if hasattr(chunk, "content"): + content = chunk.content + + # Skip empty content (e.g. usage-only chunks) + if not content: + continue + # TODO: This is a hack to clear contexts if we get an error # response from nemoguardrails - if chunk == "I'm sorry, I can't respond to that.": + if content == "I'm sorry, I can't respond to that.": # Clear contexts if we get an error response contexts = [] chain_response = ChainResponse() response_choice = ChainResponseChoices( index=0, - message=Message(role="assistant", content=chunk), - delta=Message(role=None, content=chunk), + message=Message(role="assistant", content=content), + delta=Message(role=None, content=content), finish_reason=None, ) chain_response.id = resp_id @@ -667,6 +743,8 @@ async def generate_answer_async( # Create response first, then attach metrics for clarity chain_response = ChainResponse() chain_response.metrics = final_metrics + if usage is not None: + chain_response.usage = usage # [DONE] indicate end of response from server response_choice = ChainResponseChoices( diff --git a/src/nvidia_rag/utils/llm.py b/src/nvidia_rag/utils/llm.py index cd2d3f08..9a021d2b 100644 --- a/src/nvidia_rag/utils/llm.py +++ b/src/nvidia_rag/utils/llm.py @@ -22,6 +22,7 @@ 6. get_streaming_filter_think_parser_async: Get the parser for filtering the think tokens (async). """ +import json import logging import os from collections.abc import Iterable @@ -32,7 +33,8 @@ import yaml from langchain.llms.base import LLM from langchain_core.language_models.chat_models import SimpleChatModel -from langchain_nvidia_ai_endpoints import ChatNVIDIA +from langchain_core.messages import AIMessageChunk +from langchain_openai import ChatOpenAI from nvidia_rag.rag_server.response_generator import APIError, ErrorCodeMapping from nvidia_rag.utils.common import ( @@ -44,12 +46,6 @@ logger = logging.getLogger(__name__) -try: - from langchain_openai import ChatOpenAI -except ImportError: - logger.info("Langchain OpenAI is not installed.") - pass - def get_prompts(source: str | dict | None = None) -> dict: """Retrieves prompt configurations from source or YAML file and return a dict. @@ -216,55 +212,68 @@ def get_llm(config: NvidiaRAGConfig | None = None, **kwargs) -> LLM | SimpleChat error_msg, ErrorCodeMapping.SERVICE_UNAVAILABLE ) from e - if url: - logger.debug(f"Length of llm endpoint url string {url}") - logger.info("Using llm model %s hosted at %s", kwargs.get("model"), url) - - api_key = kwargs.get("api_key") or config.llm.get_api_key() - # Detect endpoint type using URL patterns only - is_nvidia = _is_nvidia_endpoint(url) - - # Build kwargs dict, only including parameters that are set - # For non-NVIDIA endpoints, exclude NVIDIA-specific parameters - chat_nvidia_kwargs = { - "base_url": url, - "model": kwargs.get("model"), - "api_key": api_key, - "stop": kwargs.get("stop", []), - } - if kwargs.get("temperature") is not None: - chat_nvidia_kwargs["temperature"] = kwargs["temperature"] - if kwargs.get("top_p") is not None: - chat_nvidia_kwargs["top_p"] = kwargs["top_p"] - if kwargs.get("max_tokens") is not None: - chat_nvidia_kwargs["max_tokens"] = kwargs["max_tokens"] - # Only include NVIDIA-specific parameters for NVIDIA endpoints - if is_nvidia: - if kwargs.get("min_tokens") is not None: - chat_nvidia_kwargs["min_tokens"] = kwargs["min_tokens"] - if kwargs.get("ignore_eos") is not None: - chat_nvidia_kwargs["ignore_eos"] = kwargs["ignore_eos"] - - llm = ChatNVIDIA(**chat_nvidia_kwargs) - # Only bind thinking tokens for NVIDIA endpoints - if is_nvidia: - llm = _bind_thinking_tokens_if_configured(llm, **kwargs) - return llm - - logger.info("Using llm model %s from api catalog", kwargs.get("model")) + # Consolidate logic for both NVIDIA endpoints and API Catalog using ChatOpenAI + # to avoid token usage reporting issues with ChatNVIDIA + base_url = url + if not base_url: + logger.info("Using llm model %s from api catalog (via ChatOpenAI)", kwargs.get("model")) + # Default to NVIDIA API Catalog URL for OpenAI client + base_url = "https://integrate.api.nvidia.com/v1" + else: + logger.debug(f"Length of llm endpoint url string {base_url}") + logger.info("Using llm model %s hosted at %s", kwargs.get("model"), base_url) api_key = kwargs.get("api_key") or config.llm.get_api_key() - llm = ChatNVIDIA( - model=kwargs.get("model"), - api_key=api_key, - temperature=kwargs.get("temperature", None), - top_p=kwargs.get("top_p", None), - max_tokens=kwargs.get("max_tokens", None), - min_tokens=kwargs.get("min_tokens", None), - ignore_eos=kwargs.get("ignore_eos", False), - stop=kwargs.get("stop", []), - ) - llm = _bind_thinking_tokens_if_configured(llm, **kwargs) + + # Detect endpoint type (still useful for logic branching) + is_nvidia = _is_nvidia_endpoint(base_url) + + # Prepare kwargs for ChatOpenAI + chat_openai_kwargs = { + "model": kwargs.get("model"), + "api_key": api_key, + "base_url": base_url, + "stop": kwargs.get("stop", []), + } + + # Optional standard parameters + if kwargs.get("temperature") is not None: + chat_openai_kwargs["temperature"] = kwargs["temperature"] + if kwargs.get("top_p") is not None: + chat_openai_kwargs["top_p"] = kwargs["top_p"] + if kwargs.get("max_tokens") is not None: + chat_openai_kwargs["max_tokens"] = kwargs["max_tokens"] + + # Prepare extra parameters (NVIDIA specific) for model_kwargs via extra_body + # OpenAI API client requires non-standard parameters to be in 'extra_body' + # Also request stream options to ensure usage is returned + model_kwargs = { + "stream_options": {"include_usage": True} + } + extra_body = {} + + if is_nvidia: + if kwargs.get("min_tokens") is not None: + extra_body["min_tokens"] = kwargs["min_tokens"] + if kwargs.get("ignore_eos") is not None: + extra_body["ignore_eos"] = kwargs["ignore_eos"] + + # Handle thinking tokens + min_think = kwargs.get("min_thinking_tokens", None) + max_think = kwargs.get("max_thinking_tokens", None) + if min_think is not None and min_think > 0: + extra_body["min_thinking_tokens"] = min_think + if max_think is not None and max_think > 0: + extra_body["max_thinking_tokens"] = max_think + + if extra_body: + model_kwargs["extra_body"] = extra_body + + if model_kwargs: + chat_openai_kwargs["model_kwargs"] = model_kwargs + + llm = ChatOpenAI(**chat_openai_kwargs) + return llm raise RuntimeError( @@ -419,7 +428,7 @@ def get_streaming_filter_think_parser(): If FILTER_THINK_TOKENS environment variable is set to "true" (case-insensitive), returns a parser that filters out content between and tags. - Otherwise, returns a pass-through parser that doesn't modify the content. + Otherwise, returns a parser that passes content as-is. Returns: RunnableGenerator: A parser for filtering (or not filtering) think tokens @@ -448,7 +457,7 @@ async def streaming_filter_think_async(chunks): chunks: Async iterable of chunks from a streaming LLM response Yields: - str: Filtered content with think blocks removed + AIMessageChunk: Filtered content with think blocks removed """ # Complete tags FULL_START_TAG = "" @@ -469,10 +478,16 @@ async def streaming_filter_think_async(chunks): buffer = "" output_buffer = "" chunk_count = 0 + last_usage = None async for chunk in chunks: content = chunk.content chunk_count += 1 + + # Capture usage metadata if present + if hasattr(chunk, "usage_metadata") and chunk.usage_metadata: + last_usage = chunk.usage_metadata + logger.info(f"Usage found in streaming chunk: {last_usage}") # Let's first check for full tags - this is the most reliable approach buffer += content @@ -565,15 +580,40 @@ async def streaming_filter_think_async(chunks): # Yield accumulated output before processing next chunk if output_buffer: - yield output_buffer + msg = AIMessageChunk(content=output_buffer) + if last_usage: + msg.usage_metadata = last_usage + yield msg output_buffer = "" + # Handle partial matches at EOF - treat as content + if state == MATCHING_START or state == MATCHING_END: + output_buffer += buffer + buffer = "" + state = NORMAL + + emitted_content = False # Yield any remaining content if not in a think block if state == NORMAL: if buffer: - yield buffer + msg = AIMessageChunk(content=buffer) + if last_usage: + msg.usage_metadata = last_usage + yield msg + emitted_content = True if output_buffer: - yield output_buffer + msg = AIMessageChunk(content=output_buffer) + if last_usage: + msg.usage_metadata = last_usage + yield msg + emitted_content = True + + if last_usage and not emitted_content: + # If we have usage but didn't emit it above (either because everything was filtered + # or just end of stream with empty buffer), emit an empty chunk with usage + msg = AIMessageChunk(content="") + msg.usage_metadata = last_usage + yield msg logger.info( "Finished streaming_filter_think_async processing after %d chunks", chunk_count @@ -586,7 +626,7 @@ def get_streaming_filter_think_parser_async(): If FILTER_THINK_TOKENS environment variable is set to "true" (case-insensitive), returns a parser that filters out content between and tags. - Otherwise, returns a pass-through parser that doesn't modify the content. + Otherwise, returns a parser that passes content as-is. Returns: RunnableGenerator: An async parser for filtering (or not filtering) think tokens @@ -603,3 +643,4 @@ def get_streaming_filter_think_parser_async(): logger.info("Think token filtering is disabled (async)") # If filtering is disabled, use a passthrough that passes content as-is return RunnablePassthrough() + diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 65dba580..d7cd969d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -19,12 +19,17 @@ import atexit import logging +import os import sys import types from unittest.mock import MagicMock, patch import pytest +# Set a dummy OpenAI API key to prevent initialization errors during test collection +# This must happen before any imports that might instantiate OpenAI clients +os.environ.setdefault("OPENAI_API_KEY", "dummy-test-api-key") + # OpenTelemetry imports (optional - may not be available in all environments) try: from opentelemetry import metrics, trace diff --git a/uv.lock b/uv.lock index 28949274..3cdb4ca8 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11, <3.14" resolution-markers = [ "python_full_version >= '3.13'", @@ -1920,6 +1920,7 @@ dependencies = [ { name = "langchain-elasticsearch" }, { name = "langchain-milvus" }, { name = "langchain-nvidia-ai-endpoints" }, + { name = "langchain-openai" }, { name = "lark" }, { name = "mcp" }, { name = "minio" }, @@ -1939,7 +1940,6 @@ dependencies = [ all = [ { name = "azure-core" }, { name = "azure-storage-blob" }, - { name = "langchain-openai" }, { name = "nv-ingest-api" }, { name = "nv-ingest-client" }, { name = "opentelemetry-api" }, @@ -1958,7 +1958,6 @@ all = [ ingest = [ { name = "azure-core" }, { name = "azure-storage-blob" }, - { name = "langchain-openai" }, { name = "nv-ingest-api" }, { name = "nv-ingest-client" }, { name = "opentelemetry-api" }, @@ -1977,7 +1976,6 @@ ingest = [ rag = [ { name = "azure-core" }, { name = "azure-storage-blob" }, - { name = "langchain-openai" }, { name = "opentelemetry-api" }, { name = "opentelemetry-exporter-otlp" }, { name = "opentelemetry-exporter-prometheus" }, @@ -2023,9 +2021,7 @@ requires-dist = [ { name = "langchain-elasticsearch", specifier = ">=0.3,<1.0" }, { name = "langchain-milvus", specifier = ">=0.2.1" }, { name = "langchain-nvidia-ai-endpoints", specifier = ">=0.3.18" }, - { name = "langchain-openai", marker = "extra == 'all'", specifier = ">=0.2,<1.0" }, - { name = "langchain-openai", marker = "extra == 'ingest'", specifier = ">=0.2,<1.0" }, - { name = "langchain-openai", marker = "extra == 'rag'", specifier = ">=0.2,<1.0" }, + { name = "langchain-openai", specifier = ">=0.2,<1.0" }, { name = "lark", specifier = ">=1.2.2" }, { name = "mcp", specifier = ">=1.23.1" }, { name = "minio", specifier = ">=7.2,<8.0" },