Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"langchain-community>=0.3.27",
"langchain-milvus>=0.2.1",
"langchain-nvidia-ai-endpoints>=0.3.18",
"langchain-openai>=0.2,<1.0",
"minio>=7.2,<8.0",
"pdfplumber>=0.6",
"pydantic>=2.11,<3.0",
Expand All @@ -44,7 +45,6 @@ dependencies = [

[project.optional-dependencies]
rag = [
"langchain-openai>=0.2,<1.0",
"opentelemetry-api>=1.29,<2.0",
"opentelemetry-exporter-otlp>=1.29,<2.0",
"opentelemetry-exporter-prometheus>=0.50b0,<1.0",
Expand All @@ -65,7 +65,6 @@ ingest = [
"nv-ingest-client==26.1.0rc5",
"tritonclient==2.57.0",
# Other ingest dependencies
"langchain-openai>=0.2,<1.0",
"overrides>=7.7,<8.0",
"tqdm>=4.67,<5.0",
"opentelemetry-api>=1.29,<2.0",
Expand All @@ -86,7 +85,6 @@ all = [
"nv-ingest-client==26.1.0rc5",
"tritonclient==2.57.0",
# RAG + Ingest dependencies
"langchain-openai>=0.2,<1.0",
"overrides>=7.7,<8.0",
"tqdm>=4.67,<5.0",
"opentelemetry-api>=1.29,<2.0",
Expand Down
24 changes: 17 additions & 7 deletions src/nvidia_rag/rag_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompts import MessagesPlaceholder
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import RunnableAssign
from langchain_core.runnables import RunnableAssign, RunnableGenerator
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check again if still valid

from opentelemetry import context as otel_context
from requests import ConnectTimeout

Expand Down Expand Up @@ -1408,14 +1408,15 @@ async def _llm_chain(
prompt_template = ChatPromptTemplate.from_messages(message)
llm = get_llm(config=self.config, **llm_settings)

chain = (
# Chain for streaming: we remove StrOutputParser so we yield AIMessageChunks,
# allowing us to access .usage_metadata in the response generator.
stream_chain = (
prompt_template
| llm
| self.StreamingFilterThinkParser
| StrOutputParser()
)
# Create async stream generator
stream_gen = chain.astream(
stream_gen = stream_chain.astream(
{"question": query_text}, config={"run_name": "llm-stream"}
)
# Eagerly fetch first chunk to trigger any errors before returning response
Expand Down Expand Up @@ -2437,7 +2438,16 @@ def generate_filter_for_collection(collection_name):
self._print_conversation_history(message)
prompt = ChatPromptTemplate.from_messages(message)

chain = prompt | llm | self.StreamingFilterThinkParser | StrOutputParser()
# Base chain (no usage sentinel) used for non-streaming reflection path.
# We keep StrOutputParser here because we want the full string response for logic.
base_chain = prompt | llm | self.StreamingFilterThinkParser | StrOutputParser()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check again if still valid


# Streaming chain: yields AIMessageChunks to preserve usage metadata.
stream_chain = (
prompt
| llm
| self.StreamingFilterThinkParser
)

# Check response groundedness if we still have reflection
# iterations available
Expand Down Expand Up @@ -2496,8 +2506,8 @@ def generate_filter_for_collection(collection_name):
status_code=ErrorCodeMapping.SUCCESS,
)
else:
# Create async stream generator
stream_gen = chain.astream(
# Create async stream generator using the streaming chain
stream_gen = stream_chain.astream(
{"question": query, "context": docs},
config={"run_name": "llm-stream"},
)
Expand Down
90 changes: 84 additions & 6 deletions src/nvidia_rag/rag_server/response_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,17 +443,54 @@ def generate_answer(
llm_ttft_ms: float | None = None
rag_ttft_ms: float | None = None
llm_generation_time_ms: float | None = None
usage: Usage | None = None
for chunk in generator:
# Handle usage metadata if present (AIMessageChunk)
if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
usage_dict = chunk.usage_metadata
try:
prompt_tokens = int(usage_dict.get("input_tokens", 0))
completion_tokens = int(usage_dict.get("output_tokens", 0))
total_tokens = int(
usage_dict.get(
"total_tokens", prompt_tokens + completion_tokens
)
)
usage = Usage(
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
logger.info(
"LLM usage for model %s (sync): prompt_tokens=%d, completion_tokens=%d, total_tokens=%d, raw=%s",
model,
prompt_tokens,
completion_tokens,
total_tokens,
usage_dict,
)
except Exception as e:
logger.debug("Failed to parse usage metadata: %s", e)

# Extract content
content = chunk
if hasattr(chunk, "content"):
content = chunk.content

# Skip empty content (e.g. usage-only chunks)
if not content:
continue

# TODO: This is a hack to clear contexts if we get an error
# response from nemoguardrails
if chunk == "I'm sorry, I can't respond to that.":
if content == "I'm sorry, I can't respond to that.":
# Clear contexts if we get an error response
contexts = []
chain_response = ChainResponse()
response_choice = ChainResponseChoices(
index=0,
message=Message(role="assistant", content=chunk),
delta=Message(role=None, content=chunk),
message=Message(role="assistant", content=content),
delta=Message(role=None, content=content),
finish_reason=None,
)
chain_response.id = resp_id
Expand Down Expand Up @@ -519,6 +556,8 @@ def generate_answer(
# Create response first, then attach metrics for clarity
chain_response = ChainResponse()
chain_response.metrics = final_metrics
if usage is not None:
chain_response.usage = usage

# [DONE] indicate end of response from server
response_choice = ChainResponseChoices(
Expand Down Expand Up @@ -591,17 +630,54 @@ async def generate_answer_async(
llm_ttft_ms: float | None = None
rag_ttft_ms: float | None = None
llm_generation_time_ms: float | None = None
usage: Usage | None = None
async for chunk in generator:
# Handle usage metadata if present (AIMessageChunk)
if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
usage_dict = chunk.usage_metadata
try:
prompt_tokens = int(usage_dict.get("input_tokens", 0))
completion_tokens = int(usage_dict.get("output_tokens", 0))
total_tokens = int(
usage_dict.get(
"total_tokens", prompt_tokens + completion_tokens
)
)
usage = Usage(
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
logger.info(
"LLM usage for model %s (async): prompt_tokens=%d, completion_tokens=%d, total_tokens=%d, raw=%s",
model,
prompt_tokens,
completion_tokens,
total_tokens,
usage_dict,
)
except Exception as e:
logger.debug("Failed to parse usage metadata: %s", e)

# Extract content
content = chunk
if hasattr(chunk, "content"):
content = chunk.content

# Skip empty content (e.g. usage-only chunks)
if not content:
continue

# TODO: This is a hack to clear contexts if we get an error
# response from nemoguardrails
if chunk == "I'm sorry, I can't respond to that.":
if content == "I'm sorry, I can't respond to that.":
# Clear contexts if we get an error response
contexts = []
chain_response = ChainResponse()
response_choice = ChainResponseChoices(
index=0,
message=Message(role="assistant", content=chunk),
delta=Message(role=None, content=chunk),
message=Message(role="assistant", content=content),
delta=Message(role=None, content=content),
finish_reason=None,
)
chain_response.id = resp_id
Expand Down Expand Up @@ -667,6 +743,8 @@ async def generate_answer_async(
# Create response first, then attach metrics for clarity
chain_response = ChainResponse()
chain_response.metrics = final_metrics
if usage is not None:
chain_response.usage = usage

# [DONE] indicate end of response from server
response_choice = ChainResponseChoices(
Expand Down
Loading
Loading