diff --git a/libs/partners/perplexity/langchain_perplexity/chat_models.py b/libs/partners/perplexity/langchain_perplexity/chat_models.py index e46cdd879cf08..4321c05755225 100644 --- a/libs/partners/perplexity/langchain_perplexity/chat_models.py +++ b/libs/partners/perplexity/langchain_perplexity/chat_models.py @@ -28,7 +28,11 @@ SystemMessageChunk, ToolMessageChunk, ) -from langchain_core.messages.ai import UsageMetadata, subtract_usage +from langchain_core.messages.ai import ( + OutputTokenDetails, + UsageMetadata, + subtract_usage, +) from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough @@ -49,13 +53,28 @@ def _is_pydantic_class(obj: Any) -> bool: def _create_usage_metadata(token_usage: dict) -> UsageMetadata: + """Create UsageMetadata from Perplexity token usage data. + + Args: + token_usage: Dictionary containing token usage information from Perplexity API. + + Returns: + UsageMetadata with properly structured token counts and details. + """ input_tokens = token_usage.get("prompt_tokens", 0) output_tokens = token_usage.get("completion_tokens", 0) total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens) + + # Build output_token_details for Perplexity-specific fields + output_token_details: OutputTokenDetails = {} + output_token_details["reasoning"] = token_usage.get("reasoning_tokens", 0) + output_token_details["citation_tokens"] = token_usage.get("citation_tokens", 0) # type: ignore[typeddict-unknown-key] + return UsageMetadata( input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens, + output_token_details=output_token_details, ) @@ -301,6 +320,7 @@ def _stream( prev_total_usage: UsageMetadata | None = None added_model_name: bool = False + added_search_queries: bool = False for chunk in stream_resp: if not isinstance(chunk, dict): chunk = chunk.model_dump() @@ -332,6 +352,13 @@ def _stream( generation_info["model_name"] = model_name added_model_name = True + # Add num_search_queries to generation_info if present + if total_usage := chunk.get("usage"): + if num_search_queries := total_usage.get("num_search_queries"): + if not added_search_queries: + generation_info["num_search_queries"] = num_search_queries + added_search_queries = True + chunk = self._convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) @@ -369,20 +396,29 @@ def _generate( params = {**params, **kwargs} response = self.client.chat.completions.create(messages=message_dicts, **params) if usage := getattr(response, "usage", None): - usage_metadata = _create_usage_metadata(usage.model_dump()) + usage_dict = usage.model_dump() + usage_metadata = _create_usage_metadata(usage_dict) else: usage_metadata = None + usage_dict = {} additional_kwargs = {} for attr in ["citations", "images", "related_questions", "search_results"]: if hasattr(response, attr): additional_kwargs[attr] = getattr(response, attr) + # Build response_metadata with model_name and num_search_queries + response_metadata: dict[str, Any] = { + "model_name": getattr(response, "model", self.model) + } + if num_search_queries := usage_dict.get("num_search_queries"): + response_metadata["num_search_queries"] = num_search_queries + message = AIMessage( content=response.choices[0].message.content, additional_kwargs=additional_kwargs, usage_metadata=usage_metadata, - response_metadata={"model_name": getattr(response, "model", self.model)}, + response_metadata=response_metadata, ) return ChatResult(generations=[ChatGeneration(message=message)]) diff --git a/libs/partners/perplexity/tests/unit_tests/test_chat_models.py b/libs/partners/perplexity/tests/unit_tests/test_chat_models.py index a05f80d6c3f6e..4fc7f08611bfc 100644 --- a/libs/partners/perplexity/tests/unit_tests/test_chat_models.py +++ b/libs/partners/perplexity/tests/unit_tests/test_chat_models.py @@ -5,6 +5,7 @@ from pytest_mock import MockerFixture from langchain_perplexity import ChatPerplexity +from langchain_perplexity.chat_models import _create_usage_metadata def test_perplexity_model_name_param() -> None: @@ -295,3 +296,230 @@ def test_perplexity_stream_includes_citations_and_search_results( } patcher.assert_called_once() + + +def test_create_usage_metadata_basic() -> None: + """Test _create_usage_metadata with basic token counts.""" + token_usage = { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + } + + usage_metadata = _create_usage_metadata(token_usage) + + assert usage_metadata["input_tokens"] == 10 + assert usage_metadata["output_tokens"] == 20 + assert usage_metadata["total_tokens"] == 30 + assert usage_metadata["output_token_details"]["reasoning"] == 0 + assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item] + + +def test_create_usage_metadata_with_reasoning_tokens() -> None: + """Test _create_usage_metadata with reasoning tokens.""" + token_usage = { + "prompt_tokens": 50, + "completion_tokens": 100, + "total_tokens": 150, + "reasoning_tokens": 25, + } + + usage_metadata = _create_usage_metadata(token_usage) + + assert usage_metadata["input_tokens"] == 50 + assert usage_metadata["output_tokens"] == 100 + assert usage_metadata["total_tokens"] == 150 + assert usage_metadata["output_token_details"]["reasoning"] == 25 + assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item] + + +def test_create_usage_metadata_with_citation_tokens() -> None: + """Test _create_usage_metadata with citation tokens.""" + token_usage = { + "prompt_tokens": 100, + "completion_tokens": 200, + "total_tokens": 300, + "citation_tokens": 15, + } + + usage_metadata = _create_usage_metadata(token_usage) + + assert usage_metadata["input_tokens"] == 100 + assert usage_metadata["output_tokens"] == 200 + assert usage_metadata["total_tokens"] == 300 + assert usage_metadata["output_token_details"]["reasoning"] == 0 + assert usage_metadata["output_token_details"]["citation_tokens"] == 15 # type: ignore[typeddict-item] + + +def test_create_usage_metadata_with_all_token_types() -> None: + """Test _create_usage_metadata with all token types. + + Tests reasoning tokens and citation tokens together. + """ + token_usage = { + "prompt_tokens": 75, + "completion_tokens": 150, + "total_tokens": 225, + "reasoning_tokens": 30, + "citation_tokens": 20, + } + + usage_metadata = _create_usage_metadata(token_usage) + + assert usage_metadata["input_tokens"] == 75 + assert usage_metadata["output_tokens"] == 150 + assert usage_metadata["total_tokens"] == 225 + assert usage_metadata["output_token_details"]["reasoning"] == 30 + assert usage_metadata["output_token_details"]["citation_tokens"] == 20 # type: ignore[typeddict-item] + + +def test_create_usage_metadata_missing_optional_fields() -> None: + """Test _create_usage_metadata with missing optional fields defaults to 0.""" + token_usage = { + "prompt_tokens": 25, + "completion_tokens": 50, + } + + usage_metadata = _create_usage_metadata(token_usage) + + assert usage_metadata["input_tokens"] == 25 + assert usage_metadata["output_tokens"] == 50 + # Total tokens should be calculated if not provided + assert usage_metadata["total_tokens"] == 75 + assert usage_metadata["output_token_details"]["reasoning"] == 0 + assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item] + + +def test_create_usage_metadata_empty_dict() -> None: + """Test _create_usage_metadata with empty token usage dict.""" + token_usage: dict = {} + + usage_metadata = _create_usage_metadata(token_usage) + + assert usage_metadata["input_tokens"] == 0 + assert usage_metadata["output_tokens"] == 0 + assert usage_metadata["total_tokens"] == 0 + assert usage_metadata["output_token_details"]["reasoning"] == 0 + assert usage_metadata["output_token_details"]["citation_tokens"] == 0 # type: ignore[typeddict-item] + + +def test_perplexity_invoke_includes_num_search_queries(mocker: MockerFixture) -> None: + """Test that invoke includes num_search_queries in response_metadata.""" + llm = ChatPerplexity(model="test", timeout=30, verbose=True) + + mock_usage = MagicMock() + mock_usage.model_dump.return_value = { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + "num_search_queries": 3, + } + + mock_response = MagicMock() + mock_response.choices = [ + MagicMock( + message=MagicMock( + content="Test response", + tool_calls=None, + ), + finish_reason="stop", + ) + ] + mock_response.model = "test-model" + mock_response.usage = mock_usage + + patcher = mocker.patch.object( + llm.client.chat.completions, "create", return_value=mock_response + ) + + result = llm.invoke("Test query") + + assert result.response_metadata["num_search_queries"] == 3 + assert result.response_metadata["model_name"] == "test-model" + patcher.assert_called_once() + + +def test_perplexity_invoke_without_num_search_queries(mocker: MockerFixture) -> None: + """Test that invoke works when num_search_queries is not provided.""" + llm = ChatPerplexity(model="test", timeout=30, verbose=True) + + mock_usage = MagicMock() + mock_usage.model_dump.return_value = { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30, + } + + mock_response = MagicMock() + mock_response.choices = [ + MagicMock( + message=MagicMock( + content="Test response", + tool_calls=None, + ), + finish_reason="stop", + ) + ] + mock_response.model = "test-model" + mock_response.usage = mock_usage + + patcher = mocker.patch.object( + llm.client.chat.completions, "create", return_value=mock_response + ) + + result = llm.invoke("Test query") + + assert "num_search_queries" not in result.response_metadata + assert result.response_metadata["model_name"] == "test-model" + patcher.assert_called_once() + + +def test_perplexity_stream_includes_num_search_queries(mocker: MockerFixture) -> None: + """Test that stream properly handles num_search_queries in usage data.""" + llm = ChatPerplexity(model="test", timeout=30, verbose=True) + + mock_chunk_0 = { + "choices": [{"delta": {"content": "Hello "}, "finish_reason": None}], + } + mock_chunk_1 = { + "choices": [{"delta": {"content": "world"}, "finish_reason": None}], + } + mock_chunk_2 = { + "choices": [{"delta": {}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 10, + "total_tokens": 15, + "num_search_queries": 2, + "reasoning_tokens": 1, + "citation_tokens": 3, + }, + } + mock_chunks: list[dict[str, Any]] = [mock_chunk_0, mock_chunk_1, mock_chunk_2] + mock_stream = MagicMock() + mock_stream.__iter__.return_value = mock_chunks + + patcher = mocker.patch.object( + llm.client.chat.completions, "create", return_value=mock_stream + ) + + chunks_list = list(llm.stream("Test query")) + + # Find the chunk with usage metadata + usage_chunk = None + for chunk in chunks_list: + if chunk.usage_metadata: + usage_chunk = chunk + break + + # Verify usage metadata is properly set + assert usage_chunk is not None + assert usage_chunk.usage_metadata is not None + assert usage_chunk.usage_metadata["input_tokens"] == 5 + assert usage_chunk.usage_metadata["output_tokens"] == 10 + assert usage_chunk.usage_metadata["total_tokens"] == 15 + # Verify reasoning and citation tokens are included + assert usage_chunk.usage_metadata["output_token_details"]["reasoning"] == 1 + assert usage_chunk.usage_metadata["output_token_details"]["citation_tokens"] == 3 # type: ignore[typeddict-item] + + patcher.assert_called_once()