diff --git a/src/ares/code_agents/mini_swe_agent_test.py b/src/ares/code_agents/mini_swe_agent_test.py new file mode 100644 index 0000000..bac9e1f --- /dev/null +++ b/src/ares/code_agents/mini_swe_agent_test.py @@ -0,0 +1,374 @@ +"""Unit tests for mini_swe_agent.py""" + +import sys +from unittest import mock + +import pytest + +# Mock minisweagent before importing the module under test +mock_minisweagent = mock.MagicMock() +mock_minisweagent.config.builtin_config_dir = "/fake/path" +mock_default_agent = mock.MagicMock() +mock_default_agent.AgentConfig = mock.MagicMock +sys.modules['minisweagent'] = mock_minisweagent +sys.modules['minisweagent.agents'] = mock.MagicMock() +sys.modules['minisweagent.agents.default'] = mock_default_agent +sys.modules['minisweagent.config'] = mock_minisweagent.config + +from ares.code_agents import mini_swe_agent +from ares.containers import containers + + +# Test configuration to mock yaml.safe_load +TEST_CONFIG = { + "agent": { + "system_template": "You are a helpful assistant. System: {{ system }}", + "instance_template": "Task: {{ task }}\nSystem: {{ system }}\nRelease: {{ release }}\nVersion: {{ version }}\nMachine: {{ machine }}", + "action_observation_template": "Output: {{ output.output }}\nReturn code: {{ output.returncode }}", + "format_error_template": "Format error. Actions found: {{ actions }}", + }, + "environment": { + "timeout": 30, + "env": {"TEST_VAR": "test_value"}, + }, +} + + +@pytest.fixture +def mock_container(): + """Create a mock container.""" + container = mock.AsyncMock(spec=containers.Container) + return container + + +@pytest.fixture +def mock_llm_client(): + """Create a mock LLM client.""" + return mock.AsyncMock() + + +@pytest.fixture +def mock_yaml_config(): + """Mock yaml.safe_load to return test configuration.""" + with mock.patch('yaml.safe_load', return_value=TEST_CONFIG): + yield + + +@pytest.fixture +def agent(mock_container, mock_llm_client, mock_yaml_config): + """Create a MiniSWECodeAgent instance with mocked dependencies.""" + with mock.patch('pathlib.Path.read_text', return_value=""): + agent = mini_swe_agent.MiniSWECodeAgent( + container=mock_container, + llm_client=mock_llm_client, + ) + return agent + + +class TestParseAction: + """Tests for the parse_action method.""" + + def test_single_block_success(self, agent): + """Test parsing a single bash block successfully.""" + response_text = "Let me run this command:\n```bash\necho 'hello world'\n```\nThis should work." + + action = agent.parse_action(response_text) + + assert action == "echo 'hello world'" + + def test_single_block_with_whitespace(self, agent): + """Test parsing a single bash block with extra whitespace.""" + response_text = "```bash\n ls -la \n```" + + action = agent.parse_action(response_text) + + assert action == "ls -la" + + def test_multiple_blocks_error(self, agent): + """Test that multiple bash blocks raise a FormatError.""" + response_text = """ + First command: + ```bash + echo 'first' + ``` + Second command: + ```bash + echo 'second' + ``` + """ + + with pytest.raises(mini_swe_agent._FormatError) as exc_info: + agent.parse_action(response_text) + + # Verify the error message includes information about the actions + assert "Format error" in str(exc_info.value) + + def test_no_blocks_error(self, agent): + """Test that no bash blocks raise a FormatError.""" + response_text = "I will run a command but forgot to use code blocks." + + with pytest.raises(mini_swe_agent._FormatError) as exc_info: + agent.parse_action(response_text) + + assert "Format error" in str(exc_info.value) + + def test_multiline_command(self, agent): + """Test parsing a multiline bash command.""" + response_text = """```bash +for i in 1 2 3; do + echo $i +done +```""" + + action = agent.parse_action(response_text) + + assert "for i in 1 2 3; do" in action + assert "echo $i" in action + assert "done" in action + + +class TestRaiseIfFinished: + """Tests for the _raise_if_finished method.""" + + def test_mini_swe_agent_final_output(self, agent): + """Test that MINI_SWE_AGENT_FINAL_OUTPUT raises _SubmittedError.""" + output = containers.ExecResult( + exit_code=0, + output="MINI_SWE_AGENT_FINAL_OUTPUT\nThis is the final output\nwith multiple lines" + ) + + with pytest.raises(mini_swe_agent._SubmittedError) as exc_info: + agent._raise_if_finished(output) + + # Verify the error message contains the lines after the marker + assert "This is the final output" in str(exc_info.value) + assert "with multiple lines" in str(exc_info.value) + + def test_complete_task_and_submit_final_output(self, agent): + """Test that COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT raises _SubmittedError.""" + output = containers.ExecResult( + exit_code=0, + output="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\nTask completed successfully" + ) + + with pytest.raises(mini_swe_agent._SubmittedError) as exc_info: + agent._raise_if_finished(output) + + assert "Task completed successfully" in str(exc_info.value) + + def test_final_output_with_leading_whitespace(self, agent): + """Test that markers with leading whitespace are recognized.""" + output = containers.ExecResult( + exit_code=0, + output=" \n MINI_SWE_AGENT_FINAL_OUTPUT\nFinal result" + ) + + with pytest.raises(mini_swe_agent._SubmittedError) as exc_info: + agent._raise_if_finished(output) + + assert "Final result" in str(exc_info.value) + + def test_normal_output_no_exception(self, agent): + """Test that normal output does not raise an exception.""" + output = containers.ExecResult( + exit_code=0, + output="This is normal output\nNothing special here" + ) + + # Should not raise any exception + agent._raise_if_finished(output) + + def test_marker_not_at_start(self, agent): + """Test that markers not at the start of output do not trigger.""" + output = containers.ExecResult( + exit_code=0, + output="Some output\nMINI_SWE_AGENT_FINAL_OUTPUT\nThis should not trigger" + ) + + # Should not raise any exception + agent._raise_if_finished(output) + + def test_empty_output(self, agent): + """Test that empty output does not raise an exception.""" + output = containers.ExecResult( + exit_code=0, + output="" + ) + + # Should not raise any exception + agent._raise_if_finished(output) + + +class TestExecuteAction: + """Tests for the execute_action method.""" + + @pytest.mark.anyio + async def test_execute_action_success(self, agent, mock_container): + """Test successful action execution.""" + # Mock LLM response + mock_response = mock.MagicMock() + mock_response.chat_completion_response.choices = [ + mock.MagicMock(message=mock.MagicMock(content="```bash\necho 'test'\n```")) + ] + + # Mock container execution + mock_container.exec_run.return_value = containers.ExecResult( + exit_code=0, + output="test" + ) + + # Execute action + await agent.execute_action(mock_response) + + # Verify container.exec_run was called with correct parameters + mock_container.exec_run.assert_called_once_with( + "echo 'test'", + timeout_s=30, + env={"TEST_VAR": "test_value"} + ) + + # Verify message was added to history + assert len(agent._messages) > 0 + + @pytest.mark.anyio + async def test_execute_action_timeout(self, agent, mock_container): + """Test that timeout raises _ExecutionTimeoutError.""" + # Mock LLM response + mock_response = mock.MagicMock() + mock_response.chat_completion_response.choices = [ + mock.MagicMock(message=mock.MagicMock(content="```bash\nsleep 100\n```")) + ] + + # Mock container execution to raise TimeoutError + mock_container.exec_run.side_effect = TimeoutError("Command timed out") + + # Execute action and verify exception + with pytest.raises(mini_swe_agent._ExecutionTimeoutError) as exc_info: + await agent.execute_action(mock_response) + + # Verify the error message mentions timeout + assert "timed out" in str(exc_info.value) + + @pytest.mark.anyio + async def test_execute_action_with_final_output(self, agent, mock_container): + """Test that final output marker triggers _SubmittedError.""" + # Mock LLM response + mock_response = mock.MagicMock() + mock_response.chat_completion_response.choices = [ + mock.MagicMock(message=mock.MagicMock(content="```bash\necho 'done'\n```")) + ] + + # Mock container execution with final output + mock_container.exec_run.return_value = containers.ExecResult( + exit_code=0, + output="MINI_SWE_AGENT_FINAL_OUTPUT\nTask complete" + ) + + # Execute action and verify exception + with pytest.raises(mini_swe_agent._SubmittedError) as exc_info: + await agent.execute_action(mock_response) + + assert "Task complete" in str(exc_info.value) + + @pytest.mark.anyio + async def test_execute_action_format_error(self, agent, mock_container): + """Test that invalid format raises _FormatError.""" + # Mock LLM response with invalid format + mock_response = mock.MagicMock() + mock_response.chat_completion_response.choices = [ + mock.MagicMock(message=mock.MagicMock(content="No bash blocks here")) + ] + + # Execute action and verify exception + with pytest.raises(mini_swe_agent._FormatError): + await agent.execute_action(mock_response) + + @pytest.mark.anyio + async def test_execute_action_non_zero_exit_code(self, agent, mock_container): + """Test execution with non-zero exit code.""" + # Mock LLM response + mock_response = mock.MagicMock() + mock_response.chat_completion_response.choices = [ + mock.MagicMock(message=mock.MagicMock(content="```bash\nfalse\n```")) + ] + + # Mock container execution with non-zero exit code + mock_container.exec_run.return_value = containers.ExecResult( + exit_code=1, + output="Command failed" + ) + + # Execute action - should not raise exception for non-zero exit code + await agent.execute_action(mock_response) + + # Verify execution happened + mock_container.exec_run.assert_called_once() + + +class TestHelperFunctions: + """Tests for helper/render functions.""" + + def test_render_system_template(self): + """Test rendering system template.""" + template = "System info" + result = mini_swe_agent._render_system_template(template) + assert result == "System info" + + def test_render_instance_template(self): + """Test rendering instance template with all variables.""" + template = "Task: {{ task }}, System: {{ system }}" + result = mini_swe_agent._render_instance_template( + template, + task="Fix bug", + system="Linux", + release="5.15", + version="Ubuntu", + machine="x86_64" + ) + assert "Task: Fix bug" in result + assert "System: Linux" in result + + def test_render_action_observation_template(self): + """Test rendering action observation template.""" + template = "Exit: {{ output.returncode }}, Output: {{ output.output }}" + output = mini_swe_agent._MiniSWEAgentOutput(returncode=0, output="success") + result = mini_swe_agent._render_action_observation_template(template, output) + assert "Exit: 0" in result + assert "Output: success" in result + + def test_render_format_error_template(self): + """Test rendering format error template.""" + template = "Found {{ actions|length }} actions" + actions = ["action1", "action2"] + result = mini_swe_agent._render_format_error_template(template, actions) + assert "Found 2 actions" in result + + def test_render_timeout_template(self): + """Test rendering timeout template.""" + result = mini_swe_agent._render_timeout_template("sleep 100", "partial output") + assert "sleep 100" in result + assert "timed out" in result.lower() + assert "partial output" in result + + +class TestMessageManagement: + """Tests for message management.""" + + def test_add_message(self, agent): + """Test adding messages to the message list.""" + initial_length = len(agent._messages) + + agent._add_message("user", "Test message") + + assert len(agent._messages) == initial_length + 1 + assert agent._messages[-1]["role"] == "user" + assert agent._messages[-1]["content"] == "Test message" + + def test_add_empty_message(self, agent): + """Test that empty messages are handled with a placeholder.""" + initial_length = len(agent._messages) + + agent._add_message("user", " ") + + assert len(agent._messages) == initial_length + 1 + assert agent._messages[-1]["content"] == "[Empty content]" diff --git a/src/ares/llms/accounting_test.py b/src/ares/llms/accounting_test.py new file mode 100644 index 0000000..3025d1d --- /dev/null +++ b/src/ares/llms/accounting_test.py @@ -0,0 +1,534 @@ +"""Unit tests for the accounting module.""" + +import decimal +from unittest import mock + +import frozendict +import httpx +import pytest +from openai.types.chat import chat_completion as chat_completion_types + +from ares.llms.accounting import ( + ModelCost, + ModelPricing, + ModelsResponse, + get_llm_cost, + martian_cost_list, +) + + +class TestMartianCostList: + """Tests for the martian_cost_list function.""" + + def test_martian_cost_list_success(self, monkeypatch): + """Test successful fetching and parsing of model costs.""" + # Clear the cache before testing + martian_cost_list.cache_clear() + + mock_response_data = { + "data": [ + { + "id": "gpt-4", + "pricing": { + "prompt": "0.03", + "completion": "0.06", + "image": None, + "request": None, + "web_search": None, + "internal_reasoning": None, + }, + }, + { + "id": "gpt-3.5-turbo", + "pricing": { + "prompt": "0.0015", + "completion": "0.002", + "image": None, + "request": "0.0001", + "web_search": None, + "internal_reasoning": None, + }, + }, + ] + } + + mock_response = mock.Mock(spec=httpx.Response) + mock_response.json.return_value = mock_response_data + mock_response.raise_for_status.return_value = None + + mock_client = mock.MagicMock(spec=httpx.Client) + mock_client.get.return_value = mock_response + mock_client.__enter__.return_value = mock_client + mock_client.__exit__.return_value = None + + result = martian_cost_list(client=mock_client) + + # Verify the result is a frozendict + assert isinstance(result, frozendict.frozendict) + + # Verify the models are present + assert "gpt-4" in result + assert "gpt-3.5-turbo" in result + + # Verify the pricing information + gpt4_cost = result["gpt-4"] + assert gpt4_cost.id == "gpt-4" + assert gpt4_cost.pricing.prompt == decimal.Decimal("0.03") + assert gpt4_cost.pricing.completion == decimal.Decimal("0.06") + assert gpt4_cost.pricing.image is None + assert gpt4_cost.pricing.request is None + + gpt35_cost = result["gpt-3.5-turbo"] + assert gpt35_cost.id == "gpt-3.5-turbo" + assert gpt35_cost.pricing.prompt == decimal.Decimal("0.0015") + assert gpt35_cost.pricing.completion == decimal.Decimal("0.002") + assert gpt35_cost.pricing.request == decimal.Decimal("0.0001") + + def test_martian_cost_list_caching(self, monkeypatch): + """Test that martian_cost_list results are cached.""" + # Clear the cache before testing + martian_cost_list.cache_clear() + + mock_response_data = { + "data": [ + { + "id": "test-model", + "pricing": { + "prompt": "0.01", + "completion": "0.02", + "image": None, + "request": None, + "web_search": None, + "internal_reasoning": None, + }, + } + ] + } + + mock_response = mock.Mock(spec=httpx.Response) + mock_response.json.return_value = mock_response_data + mock_response.raise_for_status.return_value = None + + mock_client = mock.MagicMock(spec=httpx.Client) + mock_client.get.return_value = mock_response + mock_client.__enter__.return_value = mock_client + mock_client.__exit__.return_value = None + + # First call + result1 = martian_cost_list(client=mock_client) + + # Second call should use cache + result2 = martian_cost_list(client=mock_client) + + # Verify client.get was only called once (cached) + assert mock_client.get.call_count == 1 + + # Verify results are the same + assert result1 is result2 + + def test_martian_cost_list_http_error(self, monkeypatch): + """Test handling of HTTP errors.""" + # Clear the cache before testing + martian_cost_list.cache_clear() + + mock_response = mock.Mock(spec=httpx.Response) + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "404 Not Found", + request=mock.Mock(), + response=mock.Mock(status_code=404), + ) + + mock_client = mock.MagicMock(spec=httpx.Client) + mock_client.get.return_value = mock_response + mock_client.__enter__.return_value = mock_client + mock_client.__exit__.return_value = None + + with pytest.raises(httpx.HTTPStatusError): + martian_cost_list(client=mock_client) + + def test_martian_cost_list_network_error(self, monkeypatch): + """Test handling of network errors.""" + # Clear the cache before testing + martian_cost_list.cache_clear() + + mock_client = mock.MagicMock(spec=httpx.Client) + mock_client.get.side_effect = httpx.NetworkError("Connection failed") + mock_client.__enter__.return_value = mock_client + mock_client.__exit__.return_value = None + + with pytest.raises(httpx.NetworkError): + martian_cost_list(client=mock_client) + + def test_martian_cost_list_invalid_json(self, monkeypatch): + """Test handling of invalid JSON responses.""" + # Clear the cache before testing + martian_cost_list.cache_clear() + + mock_response = mock.Mock(spec=httpx.Response) + mock_response.json.side_effect = ValueError("Invalid JSON") + mock_response.raise_for_status.return_value = None + + mock_client = mock.MagicMock(spec=httpx.Client) + mock_client.get.return_value = mock_response + mock_client.__enter__.return_value = mock_client + mock_client.__exit__.return_value = None + + with pytest.raises(ValueError): + martian_cost_list(client=mock_client) + + def test_martian_cost_list_creates_default_client(self, monkeypatch): + """Test that martian_cost_list creates a default client when none provided.""" + # Clear the cache before testing + martian_cost_list.cache_clear() + + mock_response_data = { + "data": [ + { + "id": "test-model", + "pricing": { + "prompt": "0.01", + "completion": "0.02", + "image": None, + "request": None, + "web_search": None, + "internal_reasoning": None, + }, + } + ] + } + + mock_response = mock.Mock(spec=httpx.Response) + mock_response.json.return_value = mock_response_data + mock_response.raise_for_status.return_value = None + + mock_client_instance = mock.MagicMock(spec=httpx.Client) + mock_client_instance.get.return_value = mock_response + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.__exit__.return_value = None + + with mock.patch("httpx.Client", return_value=mock_client_instance): + result = martian_cost_list() + + assert isinstance(result, frozendict.frozendict) + assert "test-model" in result + + +class TestGetLlmCost: + """Tests for the get_llm_cost function.""" + + def test_get_llm_cost_basic(self): + """Test basic cost calculation with prompt and completion tokens.""" + cost_mapping = frozendict.frozendict( + { + "test-model": ModelCost( + id="test-model", + pricing=ModelPricing( + prompt=decimal.Decimal("0.01"), + completion=decimal.Decimal("0.02"), + image=None, + request=None, + web_search=None, + internal_reasoning=None, + ), + ) + } + ) + + completion = chat_completion_types.ChatCompletion( + id="test-completion", + choices=[], + created=1234567890, + model="test-model", + object="chat.completion", + usage=chat_completion_types.CompletionUsage( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + ), + ) + + cost = get_llm_cost("test-model", completion, cost_mapping=cost_mapping) + + # 100 * 0.01 + 50 * 0.02 = 1.0 + 1.0 = 2.0 + assert cost == decimal.Decimal("2.0") + + def test_get_llm_cost_with_request_charge(self): + """Test cost calculation with request charge.""" + cost_mapping = frozendict.frozendict( + { + "test-model": ModelCost( + id="test-model", + pricing=ModelPricing( + prompt=decimal.Decimal("0.01"), + completion=decimal.Decimal("0.02"), + image=None, + request=decimal.Decimal("0.001"), + web_search=None, + internal_reasoning=None, + ), + ) + } + ) + + completion = chat_completion_types.ChatCompletion( + id="test-completion", + choices=[], + created=1234567890, + model="test-model", + object="chat.completion", + usage=chat_completion_types.CompletionUsage( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + ), + ) + + cost = get_llm_cost("test-model", completion, cost_mapping=cost_mapping) + + # 0.001 + 100 * 0.01 + 50 * 0.02 = 0.001 + 1.0 + 1.0 = 2.001 + assert cost == decimal.Decimal("2.001") + + def test_get_llm_cost_decimal_precision(self): + """Test that decimal precision is maintained correctly.""" + cost_mapping = frozendict.frozendict( + { + "test-model": ModelCost( + id="test-model", + pricing=ModelPricing( + prompt=decimal.Decimal("0.000001"), + completion=decimal.Decimal("0.000002"), + image=None, + request=decimal.Decimal("0.0000001"), + web_search=None, + internal_reasoning=None, + ), + ) + } + ) + + completion = chat_completion_types.ChatCompletion( + id="test-completion", + choices=[], + created=1234567890, + model="test-model", + object="chat.completion", + usage=chat_completion_types.CompletionUsage( + prompt_tokens=1000, + completion_tokens=500, + total_tokens=1500, + ), + ) + + cost = get_llm_cost("test-model", completion, cost_mapping=cost_mapping) + + # 0.0000001 + 1000 * 0.000001 + 500 * 0.000002 + # = 0.0000001 + 0.001 + 0.001 = 0.0020001 + assert cost == decimal.Decimal("0.0020001") + + def test_get_llm_cost_missing_model(self): + """Test error handling when model is not in cost mapping.""" + cost_mapping = frozendict.frozendict( + { + "other-model": ModelCost( + id="other-model", + pricing=ModelPricing( + prompt=decimal.Decimal("0.01"), + completion=decimal.Decimal("0.02"), + image=None, + request=None, + web_search=None, + internal_reasoning=None, + ), + ) + } + ) + + completion = chat_completion_types.ChatCompletion( + id="test-completion", + choices=[], + created=1234567890, + model="test-model", + object="chat.completion", + usage=chat_completion_types.CompletionUsage( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + ), + ) + + with pytest.raises(ValueError, match="Model test-model not found in cost mapping"): + get_llm_cost("test-model", completion, cost_mapping=cost_mapping) + + def test_get_llm_cost_missing_usage(self): + """Test error handling when completion has no usage information.""" + cost_mapping = frozendict.frozendict( + { + "test-model": ModelCost( + id="test-model", + pricing=ModelPricing( + prompt=decimal.Decimal("0.01"), + completion=decimal.Decimal("0.02"), + image=None, + request=None, + web_search=None, + internal_reasoning=None, + ), + ) + } + ) + + completion = chat_completion_types.ChatCompletion( + id="test-completion", + choices=[], + created=1234567890, + model="test-model", + object="chat.completion", + usage=None, + ) + + with pytest.raises(ValueError, match="Cannot compute cost of a completion with no usage"): + get_llm_cost("test-model", completion, cost_mapping=cost_mapping) + + def test_get_llm_cost_none_pricing_fields(self): + """Test that None pricing fields are treated as zero.""" + cost_mapping = frozendict.frozendict( + { + "test-model": ModelCost( + id="test-model", + pricing=ModelPricing( + prompt=None, + completion=None, + image=None, + request=None, + web_search=None, + internal_reasoning=None, + ), + ) + } + ) + + completion = chat_completion_types.ChatCompletion( + id="test-completion", + choices=[], + created=1234567890, + model="test-model", + object="chat.completion", + usage=chat_completion_types.CompletionUsage( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + ), + ) + + cost = get_llm_cost("test-model", completion, cost_mapping=cost_mapping) + + # All pricing fields are None, so cost should be 0 + assert cost == decimal.Decimal("0") + + def test_get_llm_cost_partial_none_pricing_fields(self): + """Test cost calculation with some None pricing fields.""" + cost_mapping = frozendict.frozendict( + { + "test-model": ModelCost( + id="test-model", + pricing=ModelPricing( + prompt=decimal.Decimal("0.01"), + completion=None, + image=None, + request=None, + web_search=None, + internal_reasoning=None, + ), + ) + } + ) + + completion = chat_completion_types.ChatCompletion( + id="test-completion", + choices=[], + created=1234567890, + model="test-model", + object="chat.completion", + usage=chat_completion_types.CompletionUsage( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + ), + ) + + cost = get_llm_cost("test-model", completion, cost_mapping=cost_mapping) + + # Only prompt tokens are charged: 100 * 0.01 = 1.0 + assert cost == decimal.Decimal("1.0") + + def test_get_llm_cost_zero_tokens(self): + """Test cost calculation with zero tokens.""" + cost_mapping = frozendict.frozendict( + { + "test-model": ModelCost( + id="test-model", + pricing=ModelPricing( + prompt=decimal.Decimal("0.01"), + completion=decimal.Decimal("0.02"), + image=None, + request=decimal.Decimal("0.001"), + web_search=None, + internal_reasoning=None, + ), + ) + } + ) + + completion = chat_completion_types.ChatCompletion( + id="test-completion", + choices=[], + created=1234567890, + model="test-model", + object="chat.completion", + usage=chat_completion_types.CompletionUsage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ), + ) + + cost = get_llm_cost("test-model", completion, cost_mapping=cost_mapping) + + # Only request charge: 0.001 + assert cost == decimal.Decimal("0.001") + + def test_get_llm_cost_large_token_counts(self): + """Test cost calculation with very large token counts.""" + cost_mapping = frozendict.frozendict( + { + "test-model": ModelCost( + id="test-model", + pricing=ModelPricing( + prompt=decimal.Decimal("0.00001"), + completion=decimal.Decimal("0.00002"), + image=None, + request=None, + web_search=None, + internal_reasoning=None, + ), + ) + } + ) + + completion = chat_completion_types.ChatCompletion( + id="test-completion", + choices=[], + created=1234567890, + model="test-model", + object="chat.completion", + usage=chat_completion_types.CompletionUsage( + prompt_tokens=1000000, + completion_tokens=500000, + total_tokens=1500000, + ), + ) + + cost = get_llm_cost("test-model", completion, cost_mapping=cost_mapping) + + # 1000000 * 0.00001 + 500000 * 0.00002 = 10.0 + 10.0 = 20.0 + assert cost == decimal.Decimal("20.0")