Skip to content

Commit 0666571

Browse files
chore(perplexity): Added all keys for usage metadata (#33480)
1 parent ef85161 commit 0666571

File tree

2 files changed

+267
-3
lines changed

2 files changed

+267
-3
lines changed

libs/partners/perplexity/langchain_perplexity/chat_models.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
SystemMessageChunk,
2929
ToolMessageChunk,
3030
)
31-
from langchain_core.messages.ai import UsageMetadata, subtract_usage
31+
from langchain_core.messages.ai import (
32+
OutputTokenDetails,
33+
UsageMetadata,
34+
subtract_usage,
35+
)
3236
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
3337
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
3438
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
@@ -49,13 +53,28 @@ def _is_pydantic_class(obj: Any) -> bool:
4953

5054

5155
def _create_usage_metadata(token_usage: dict) -> UsageMetadata:
56+
"""Create UsageMetadata from Perplexity token usage data.
57+
58+
Args:
59+
token_usage: Dictionary containing token usage information from Perplexity API.
60+
61+
Returns:
62+
UsageMetadata with properly structured token counts and details.
63+
"""
5264
input_tokens = token_usage.get("prompt_tokens", 0)
5365
output_tokens = token_usage.get("completion_tokens", 0)
5466
total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens)
67+
68+
# Build output_token_details for Perplexity-specific fields
69+
output_token_details: OutputTokenDetails = {}
70+
output_token_details["reasoning"] = token_usage.get("reasoning_tokens", 0)
71+
output_token_details["citation_tokens"] = token_usage.get("citation_tokens", 0) # type: ignore[typeddict-unknown-key]
72+
5573
return UsageMetadata(
5674
input_tokens=input_tokens,
5775
output_tokens=output_tokens,
5876
total_tokens=total_tokens,
77+
output_token_details=output_token_details,
5978
)
6079

6180

@@ -301,6 +320,7 @@ def _stream(
301320
prev_total_usage: UsageMetadata | None = None
302321

303322
added_model_name: bool = False
323+
added_search_queries: bool = False
304324
for chunk in stream_resp:
305325
if not isinstance(chunk, dict):
306326
chunk = chunk.model_dump()
@@ -332,6 +352,13 @@ def _stream(
332352
generation_info["model_name"] = model_name
333353
added_model_name = True
334354

355+
# Add num_search_queries to generation_info if present
356+
if total_usage := chunk.get("usage"):
357+
if num_search_queries := total_usage.get("num_search_queries"):
358+
if not added_search_queries:
359+
generation_info["num_search_queries"] = num_search_queries
360+
added_search_queries = True
361+
335362
chunk = self._convert_delta_to_message_chunk(
336363
choice["delta"], default_chunk_class
337364
)
@@ -369,20 +396,29 @@ def _generate(
369396
params = {**params, **kwargs}
370397
response = self.client.chat.completions.create(messages=message_dicts, **params)
371398
if usage := getattr(response, "usage", None):
372-
usage_metadata = _create_usage_metadata(usage.model_dump())
399+
usage_dict = usage.model_dump()
400+
usage_metadata = _create_usage_metadata(usage_dict)
373401
else:
374402
usage_metadata = None
403+
usage_dict = {}
375404

376405
additional_kwargs = {}
377406
for attr in ["citations", "images", "related_questions", "search_results"]:
378407
if hasattr(response, attr):
379408
additional_kwargs[attr] = getattr(response, attr)
380409

410+
# Build response_metadata with model_name and num_search_queries
411+
response_metadata: dict[str, Any] = {
412+
"model_name": getattr(response, "model", self.model)
413+
}
414+
if num_search_queries := usage_dict.get("num_search_queries"):
415+
response_metadata["num_search_queries"] = num_search_queries
416+
381417
message = AIMessage(
382418
content=response.choices[0].message.content,
383419
additional_kwargs=additional_kwargs,
384420
usage_metadata=usage_metadata,
385-
response_metadata={"model_name": getattr(response, "model", self.model)},
421+
response_metadata=response_metadata,
386422
)
387423
return ChatResult(generations=[ChatGeneration(message=message)])
388424

libs/partners/perplexity/tests/unit_tests/test_chat_models.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytest_mock import MockerFixture
66

77
from langchain_perplexity import ChatPerplexity
8+
from langchain_perplexity.chat_models import _create_usage_metadata
89

910

1011
def test_perplexity_model_name_param() -> None:
@@ -295,3 +296,230 @@ def test_perplexity_stream_includes_citations_and_search_results(
295296
}
296297

297298
patcher.assert_called_once()
299+
300+
301+
def test_create_usage_metadata_basic() -> None:
302+
"""Test _create_usage_metadata with basic token counts."""
303+
token_usage = {
304+
"prompt_tokens": 10,
305+
"completion_tokens": 20,
306+
"total_tokens": 30,
307+
}
308+
309+
usage_metadata = _create_usage_metadata(token_usage)
310+
311+
assert usage_metadata["input_tokens"] == 10
312+
assert usage_metadata["output_tokens"] == 20
313+
assert usage_metadata["total_tokens"] == 30
314+
assert usage_metadata["output_token_details"]["reasoning"] == 0
315+
assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item]
316+
317+
318+
def test_create_usage_metadata_with_reasoning_tokens() -> None:
319+
"""Test _create_usage_metadata with reasoning tokens."""
320+
token_usage = {
321+
"prompt_tokens": 50,
322+
"completion_tokens": 100,
323+
"total_tokens": 150,
324+
"reasoning_tokens": 25,
325+
}
326+
327+
usage_metadata = _create_usage_metadata(token_usage)
328+
329+
assert usage_metadata["input_tokens"] == 50
330+
assert usage_metadata["output_tokens"] == 100
331+
assert usage_metadata["total_tokens"] == 150
332+
assert usage_metadata["output_token_details"]["reasoning"] == 25
333+
assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item]
334+
335+
336+
def test_create_usage_metadata_with_citation_tokens() -> None:
337+
"""Test _create_usage_metadata with citation tokens."""
338+
token_usage = {
339+
"prompt_tokens": 100,
340+
"completion_tokens": 200,
341+
"total_tokens": 300,
342+
"citation_tokens": 15,
343+
}
344+
345+
usage_metadata = _create_usage_metadata(token_usage)
346+
347+
assert usage_metadata["input_tokens"] == 100
348+
assert usage_metadata["output_tokens"] == 200
349+
assert usage_metadata["total_tokens"] == 300
350+
assert usage_metadata["output_token_details"]["reasoning"] == 0
351+
assert usage_metadata["output_token_details"]["citation_tokens"] == 15 # type: ignore[typeddict-item]
352+
353+
354+
def test_create_usage_metadata_with_all_token_types() -> None:
355+
"""Test _create_usage_metadata with all token types.
356+
357+
Tests reasoning tokens and citation tokens together.
358+
"""
359+
token_usage = {
360+
"prompt_tokens": 75,
361+
"completion_tokens": 150,
362+
"total_tokens": 225,
363+
"reasoning_tokens": 30,
364+
"citation_tokens": 20,
365+
}
366+
367+
usage_metadata = _create_usage_metadata(token_usage)
368+
369+
assert usage_metadata["input_tokens"] == 75
370+
assert usage_metadata["output_tokens"] == 150
371+
assert usage_metadata["total_tokens"] == 225
372+
assert usage_metadata["output_token_details"]["reasoning"] == 30
373+
assert usage_metadata["output_token_details"]["citation_tokens"] == 20 # type: ignore[typeddict-item]
374+
375+
376+
def test_create_usage_metadata_missing_optional_fields() -> None:
377+
"""Test _create_usage_metadata with missing optional fields defaults to 0."""
378+
token_usage = {
379+
"prompt_tokens": 25,
380+
"completion_tokens": 50,
381+
}
382+
383+
usage_metadata = _create_usage_metadata(token_usage)
384+
385+
assert usage_metadata["input_tokens"] == 25
386+
assert usage_metadata["output_tokens"] == 50
387+
# Total tokens should be calculated if not provided
388+
assert usage_metadata["total_tokens"] == 75
389+
assert usage_metadata["output_token_details"]["reasoning"] == 0
390+
assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item]
391+
392+
393+
def test_create_usage_metadata_empty_dict() -> None:
394+
"""Test _create_usage_metadata with empty token usage dict."""
395+
token_usage: dict = {}
396+
397+
usage_metadata = _create_usage_metadata(token_usage)
398+
399+
assert usage_metadata["input_tokens"] == 0
400+
assert usage_metadata["output_tokens"] == 0
401+
assert usage_metadata["total_tokens"] == 0
402+
assert usage_metadata["output_token_details"]["reasoning"] == 0
403+
assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item]
404+
405+
406+
def test_perplexity_invoke_includes_num_search_queries(mocker: MockerFixture) -> None:
407+
"""Test that invoke includes num_search_queries in response_metadata."""
408+
llm = ChatPerplexity(model="test", timeout=30, verbose=True)
409+
410+
mock_usage = MagicMock()
411+
mock_usage.model_dump.return_value = {
412+
"prompt_tokens": 10,
413+
"completion_tokens": 20,
414+
"total_tokens": 30,
415+
"num_search_queries": 3,
416+
}
417+
418+
mock_response = MagicMock()
419+
mock_response.choices = [
420+
MagicMock(
421+
message=MagicMock(
422+
content="Test response",
423+
tool_calls=None,
424+
),
425+
finish_reason="stop",
426+
)
427+
]
428+
mock_response.model = "test-model"
429+
mock_response.usage = mock_usage
430+
431+
patcher = mocker.patch.object(
432+
llm.client.chat.completions, "create", return_value=mock_response
433+
)
434+
435+
result = llm.invoke("Test query")
436+
437+
assert result.response_metadata["num_search_queries"] == 3
438+
assert result.response_metadata["model_name"] == "test-model"
439+
patcher.assert_called_once()
440+
441+
442+
def test_perplexity_invoke_without_num_search_queries(mocker: MockerFixture) -> None:
443+
"""Test that invoke works when num_search_queries is not provided."""
444+
llm = ChatPerplexity(model="test", timeout=30, verbose=True)
445+
446+
mock_usage = MagicMock()
447+
mock_usage.model_dump.return_value = {
448+
"prompt_tokens": 10,
449+
"completion_tokens": 20,
450+
"total_tokens": 30,
451+
}
452+
453+
mock_response = MagicMock()
454+
mock_response.choices = [
455+
MagicMock(
456+
message=MagicMock(
457+
content="Test response",
458+
tool_calls=None,
459+
),
460+
finish_reason="stop",
461+
)
462+
]
463+
mock_response.model = "test-model"
464+
mock_response.usage = mock_usage
465+
466+
patcher = mocker.patch.object(
467+
llm.client.chat.completions, "create", return_value=mock_response
468+
)
469+
470+
result = llm.invoke("Test query")
471+
472+
assert "num_search_queries" not in result.response_metadata
473+
assert result.response_metadata["model_name"] == "test-model"
474+
patcher.assert_called_once()
475+
476+
477+
def test_perplexity_stream_includes_num_search_queries(mocker: MockerFixture) -> None:
478+
"""Test that stream properly handles num_search_queries in usage data."""
479+
llm = ChatPerplexity(model="test", timeout=30, verbose=True)
480+
481+
mock_chunk_0 = {
482+
"choices": [{"delta": {"content": "Hello "}, "finish_reason": None}],
483+
}
484+
mock_chunk_1 = {
485+
"choices": [{"delta": {"content": "world"}, "finish_reason": None}],
486+
}
487+
mock_chunk_2 = {
488+
"choices": [{"delta": {}, "finish_reason": "stop"}],
489+
"usage": {
490+
"prompt_tokens": 5,
491+
"completion_tokens": 10,
492+
"total_tokens": 15,
493+
"num_search_queries": 2,
494+
"reasoning_tokens": 1,
495+
"citation_tokens": 3,
496+
},
497+
}
498+
mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1, mock_chunk_2]
499+
mock_stream = MagicMock()
500+
mock_stream.__iter__.return_value = mock_chunks
501+
502+
patcher = mocker.patch.object(
503+
llm.client.chat.completions, "create", return_value=mock_stream
504+
)
505+
506+
chunks_list = list(llm.stream("Test query"))
507+
508+
# Find the chunk with usage metadata
509+
usage_chunk = None
510+
for chunk in chunks_list:
511+
if chunk.usage_metadata:
512+
usage_chunk = chunk
513+
break
514+
515+
# Verify usage metadata is properly set
516+
assert usage_chunk is not None
517+
assert usage_chunk.usage_metadata is not None
518+
assert usage_chunk.usage_metadata["input_tokens"] == 5
519+
assert usage_chunk.usage_metadata["output_tokens"] == 10
520+
assert usage_chunk.usage_metadata["total_tokens"] == 15
521+
# Verify reasoning and citation tokens are included
522+
assert usage_chunk.usage_metadata["output_token_details"]["reasoning"] == 1
523+
assert usage_chunk.usage_metadata["output_token_details"]["citation_tokens"] == 3 # type: ignore[typeddict-item]
524+
525+
patcher.assert_called_once()

0 commit comments

Comments
 (0)