Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions libs/partners/perplexity/langchain_perplexity/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.messages.ai import UsageMetadata, subtract_usage
from langchain_core.messages.ai import (
OutputTokenDetails,
UsageMetadata,
subtract_usage,
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
Expand All @@ -49,13 +53,28 @@ def _is_pydantic_class(obj: Any) -> bool:


def _create_usage_metadata(token_usage: dict) -> UsageMetadata:
"""Create UsageMetadata from Perplexity token usage data.

Args:
token_usage: Dictionary containing token usage information from Perplexity API.

Returns:
UsageMetadata with properly structured token counts and details.
"""
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens)

# Build output_token_details for Perplexity-specific fields
output_token_details: OutputTokenDetails = {}
output_token_details["reasoning"] = token_usage.get("reasoning_tokens", 0)
output_token_details["citation_tokens"] = token_usage.get("citation_tokens", 0) # type: ignore[typeddict-unknown-key]
Comment on lines +69 to +71
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) might make sense to only populate the keys if they are present in token_usage.


return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
output_token_details=output_token_details,
)


Expand Down Expand Up @@ -301,6 +320,7 @@ def _stream(
prev_total_usage: UsageMetadata | None = None

added_model_name: bool = False
added_search_queries: bool = False
for chunk in stream_resp:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
Expand Down Expand Up @@ -332,6 +352,13 @@ def _stream(
generation_info["model_name"] = model_name
added_model_name = True

# Add num_search_queries to generation_info if present
if total_usage := chunk.get("usage"):
if num_search_queries := total_usage.get("num_search_queries"):
if not added_search_queries:
generation_info["num_search_queries"] = num_search_queries
added_search_queries = True

chunk = self._convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
Expand Down Expand Up @@ -369,20 +396,29 @@ def _generate(
params = {**params, **kwargs}
response = self.client.chat.completions.create(messages=message_dicts, **params)
if usage := getattr(response, "usage", None):
usage_metadata = _create_usage_metadata(usage.model_dump())
usage_dict = usage.model_dump()
usage_metadata = _create_usage_metadata(usage_dict)
else:
usage_metadata = None
usage_dict = {}

additional_kwargs = {}
for attr in ["citations", "images", "related_questions", "search_results"]:
if hasattr(response, attr):
additional_kwargs[attr] = getattr(response, attr)

# Build response_metadata with model_name and num_search_queries
response_metadata: dict[str, Any] = {
"model_name": getattr(response, "model", self.model)
}
if num_search_queries := usage_dict.get("num_search_queries"):
response_metadata["num_search_queries"] = num_search_queries

message = AIMessage(
content=response.choices[0].message.content,
additional_kwargs=additional_kwargs,
usage_metadata=usage_metadata,
response_metadata={"model_name": getattr(response, "model", self.model)},
response_metadata=response_metadata,
)
return ChatResult(generations=[ChatGeneration(message=message)])

Expand Down
228 changes: 228 additions & 0 deletions libs/partners/perplexity/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytest_mock import MockerFixture

from langchain_perplexity import ChatPerplexity
from langchain_perplexity.chat_models import _create_usage_metadata


def test_perplexity_model_name_param() -> None:
Expand Down Expand Up @@ -295,3 +296,230 @@ def test_perplexity_stream_includes_citations_and_search_results(
}

patcher.assert_called_once()


def test_create_usage_metadata_basic() -> None:
"""Test _create_usage_metadata with basic token counts."""
token_usage = {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
}

usage_metadata = _create_usage_metadata(token_usage)

assert usage_metadata["input_tokens"] == 10
assert usage_metadata["output_tokens"] == 20
assert usage_metadata["total_tokens"] == 30
assert usage_metadata["output_token_details"]["reasoning"] == 0
assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item]


def test_create_usage_metadata_with_reasoning_tokens() -> None:
"""Test _create_usage_metadata with reasoning tokens."""
token_usage = {
"prompt_tokens": 50,
"completion_tokens": 100,
"total_tokens": 150,
"reasoning_tokens": 25,
}

usage_metadata = _create_usage_metadata(token_usage)

assert usage_metadata["input_tokens"] == 50
assert usage_metadata["output_tokens"] == 100
assert usage_metadata["total_tokens"] == 150
assert usage_metadata["output_token_details"]["reasoning"] == 25
assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item]


def test_create_usage_metadata_with_citation_tokens() -> None:
"""Test _create_usage_metadata with citation tokens."""
token_usage = {
"prompt_tokens": 100,
"completion_tokens": 200,
"total_tokens": 300,
"citation_tokens": 15,
}

usage_metadata = _create_usage_metadata(token_usage)

assert usage_metadata["input_tokens"] == 100
assert usage_metadata["output_tokens"] == 200
assert usage_metadata["total_tokens"] == 300
assert usage_metadata["output_token_details"]["reasoning"] == 0
assert usage_metadata["output_token_details"]["citation_tokens"] == 15 # type: ignore[typeddict-item]


def test_create_usage_metadata_with_all_token_types() -> None:
"""Test _create_usage_metadata with all token types.

Tests reasoning tokens and citation tokens together.
"""
token_usage = {
"prompt_tokens": 75,
"completion_tokens": 150,
"total_tokens": 225,
"reasoning_tokens": 30,
"citation_tokens": 20,
}

usage_metadata = _create_usage_metadata(token_usage)

assert usage_metadata["input_tokens"] == 75
assert usage_metadata["output_tokens"] == 150
assert usage_metadata["total_tokens"] == 225
assert usage_metadata["output_token_details"]["reasoning"] == 30
assert usage_metadata["output_token_details"]["citation_tokens"] == 20 # type: ignore[typeddict-item]


def test_create_usage_metadata_missing_optional_fields() -> None:
"""Test _create_usage_metadata with missing optional fields defaults to 0."""
token_usage = {
"prompt_tokens": 25,
"completion_tokens": 50,
}

usage_metadata = _create_usage_metadata(token_usage)

assert usage_metadata["input_tokens"] == 25
assert usage_metadata["output_tokens"] == 50
# Total tokens should be calculated if not provided
assert usage_metadata["total_tokens"] == 75
assert usage_metadata["output_token_details"]["reasoning"] == 0
assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item]


def test_create_usage_metadata_empty_dict() -> None:
"""Test _create_usage_metadata with empty token usage dict."""
token_usage: dict = {}

usage_metadata = _create_usage_metadata(token_usage)

assert usage_metadata["input_tokens"] == 0
assert usage_metadata["output_tokens"] == 0
assert usage_metadata["total_tokens"] == 0
assert usage_metadata["output_token_details"]["reasoning"] == 0
assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item]


def test_perplexity_invoke_includes_num_search_queries(mocker: MockerFixture) -> None:
"""Test that invoke includes num_search_queries in response_metadata."""
llm = ChatPerplexity(model="test", timeout=30, verbose=True)

mock_usage = MagicMock()
mock_usage.model_dump.return_value = {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
"num_search_queries": 3,
}

mock_response = MagicMock()
mock_response.choices = [
MagicMock(
message=MagicMock(
content="Test response",
tool_calls=None,
),
finish_reason="stop",
)
]
mock_response.model = "test-model"
mock_response.usage = mock_usage

patcher = mocker.patch.object(
llm.client.chat.completions, "create", return_value=mock_response
)

result = llm.invoke("Test query")

assert result.response_metadata["num_search_queries"] == 3
assert result.response_metadata["model_name"] == "test-model"
patcher.assert_called_once()


def test_perplexity_invoke_without_num_search_queries(mocker: MockerFixture) -> None:
"""Test that invoke works when num_search_queries is not provided."""
llm = ChatPerplexity(model="test", timeout=30, verbose=True)

mock_usage = MagicMock()
mock_usage.model_dump.return_value = {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
}

mock_response = MagicMock()
mock_response.choices = [
MagicMock(
message=MagicMock(
content="Test response",
tool_calls=None,
),
finish_reason="stop",
)
]
mock_response.model = "test-model"
mock_response.usage = mock_usage

patcher = mocker.patch.object(
llm.client.chat.completions, "create", return_value=mock_response
)

result = llm.invoke("Test query")

assert "num_search_queries" not in result.response_metadata
assert result.response_metadata["model_name"] == "test-model"
patcher.assert_called_once()


def test_perplexity_stream_includes_num_search_queries(mocker: MockerFixture) -> None:
"""Test that stream properly handles num_search_queries in usage data."""
llm = ChatPerplexity(model="test", timeout=30, verbose=True)

mock_chunk_0 = {
"choices": [{"delta": {"content": "Hello "}, "finish_reason": None}],
}
mock_chunk_1 = {
"choices": [{"delta": {"content": "world"}, "finish_reason": None}],
}
mock_chunk_2 = {
"choices": [{"delta": {}, "finish_reason": "stop"}],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 10,
"total_tokens": 15,
"num_search_queries": 2,
"reasoning_tokens": 1,
"citation_tokens": 3,
},
}
mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1, mock_chunk_2]
mock_stream = MagicMock()
mock_stream.__iter__.return_value = mock_chunks

patcher = mocker.patch.object(
llm.client.chat.completions, "create", return_value=mock_stream
)

chunks_list = list(llm.stream("Test query"))

# Find the chunk with usage metadata
usage_chunk = None
for chunk in chunks_list:
if chunk.usage_metadata:
usage_chunk = chunk
break

# Verify usage metadata is properly set
assert usage_chunk is not None
assert usage_chunk.usage_metadata is not None
assert usage_chunk.usage_metadata["input_tokens"] == 5
assert usage_chunk.usage_metadata["output_tokens"] == 10
assert usage_chunk.usage_metadata["total_tokens"] == 15
# Verify reasoning and citation tokens are included
assert usage_chunk.usage_metadata["output_token_details"]["reasoning"] == 1
assert usage_chunk.usage_metadata["output_token_details"]["citation_tokens"] == 3 # type: ignore[typeddict-item]

patcher.assert_called_once()