Skip to content

Commit fcacbf5

Browse files
committed
Formatting
1 parent 6a1eb77 commit fcacbf5

File tree

3 files changed

+155
-113
lines changed

3 files changed

+155
-113
lines changed

deep_research/design/test_exa_cost_tracking.py

Lines changed: 81 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,122 +3,144 @@
33

44
import os
55
import sys
6+
67
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
78

8-
from utils.search_utils import exa_search, extract_search_results, search_and_extract_results
99
from utils.pydantic_models import ResearchState
10-
from utils.tracing_metadata_utils import get_costs_by_prompt_type
10+
from utils.search_utils import (
11+
exa_search,
12+
extract_search_results,
13+
search_and_extract_results,
14+
)
1115

1216

1317
def test_exa_cost_extraction():
1418
"""Test that Exa costs are properly extracted from API responses."""
1519
print("=== Testing Exa Cost Extraction ===")
16-
20+
1721
# Test with a simple query
1822
query = "What is quantum computing?"
1923
print(f"\nSearching for: {query}")
20-
24+
2125
# Test direct exa_search
2226
results = exa_search(query, max_results=2)
23-
print(f"Direct exa_search returned exa_cost: ${results.get('exa_cost', 0.0):.4f}")
24-
27+
print(
28+
f"Direct exa_search returned exa_cost: ${results.get('exa_cost', 0.0):.4f}"
29+
)
30+
2531
# Test extract_search_results
2632
extracted, cost = extract_search_results(results, provider="exa")
2733
print(f"extract_search_results returned cost: ${cost:.4f}")
2834
print(f"Number of results extracted: {len(extracted)}")
29-
35+
3036
# Test search_and_extract_results
31-
results2, cost2 = search_and_extract_results(query, max_results=2, provider="exa")
37+
results2, cost2 = search_and_extract_results(
38+
query, max_results=2, provider="exa"
39+
)
3240
print(f"search_and_extract_results returned cost: ${cost2:.4f}")
3341
print(f"Number of results: {len(results2)}")
34-
42+
3543
return cost2 > 0
3644

3745

3846
def test_research_state_cost_tracking():
3947
"""Test that ResearchState properly tracks costs."""
4048
print("\n=== Testing ResearchState Cost Tracking ===")
41-
49+
4250
state = ResearchState(main_query="Test query")
43-
51+
4452
# Simulate adding search costs
4553
state.search_costs["exa"] = 0.05
46-
state.search_cost_details.append({
47-
"provider": "exa",
48-
"query": "test query 1",
49-
"cost": 0.02,
50-
"timestamp": 1234567890.0,
51-
"step": "test_step"
52-
})
53-
state.search_cost_details.append({
54-
"provider": "exa",
55-
"query": "test query 2",
56-
"cost": 0.03,
57-
"timestamp": 1234567891.0,
58-
"step": "test_step"
59-
})
60-
54+
state.search_cost_details.append(
55+
{
56+
"provider": "exa",
57+
"query": "test query 1",
58+
"cost": 0.02,
59+
"timestamp": 1234567890.0,
60+
"step": "test_step",
61+
}
62+
)
63+
state.search_cost_details.append(
64+
{
65+
"provider": "exa",
66+
"query": "test query 2",
67+
"cost": 0.03,
68+
"timestamp": 1234567891.0,
69+
"step": "test_step",
70+
}
71+
)
72+
6173
print(f"Total Exa cost: ${state.search_costs.get('exa', 0.0):.4f}")
6274
print(f"Number of search details: {len(state.search_cost_details)}")
63-
75+
6476
return True
6577

6678

6779
def test_cost_aggregation():
6880
"""Test cost aggregation from multiple states."""
6981
print("\n=== Testing Cost Aggregation ===")
70-
82+
7183
# Create multiple sub-states
7284
state1 = ResearchState(main_query="Test")
7385
state1.search_costs["exa"] = 0.02
74-
state1.search_cost_details.append({
75-
"provider": "exa",
76-
"query": "query1",
77-
"cost": 0.02,
78-
"timestamp": 1234567890.0,
79-
"step": "sub_step_1"
80-
})
81-
86+
state1.search_cost_details.append(
87+
{
88+
"provider": "exa",
89+
"query": "query1",
90+
"cost": 0.02,
91+
"timestamp": 1234567890.0,
92+
"step": "sub_step_1",
93+
}
94+
)
95+
8296
state2 = ResearchState(main_query="Test")
8397
state2.search_costs["exa"] = 0.03
84-
state2.search_cost_details.append({
85-
"provider": "exa",
86-
"query": "query2",
87-
"cost": 0.03,
88-
"timestamp": 1234567891.0,
89-
"step": "sub_step_2"
90-
})
91-
98+
state2.search_cost_details.append(
99+
{
100+
"provider": "exa",
101+
"query": "query2",
102+
"cost": 0.03,
103+
"timestamp": 1234567891.0,
104+
"step": "sub_step_2",
105+
}
106+
)
107+
92108
# Simulate merge
93109
merged_state = ResearchState(main_query="Test")
94110
merged_state.search_costs = {}
95111
merged_state.search_cost_details = []
96-
112+
97113
for state in [state1, state2]:
98114
for provider, cost in state.search_costs.items():
99-
merged_state.search_costs[provider] = merged_state.search_costs.get(provider, 0.0) + cost
115+
merged_state.search_costs[provider] = (
116+
merged_state.search_costs.get(provider, 0.0) + cost
117+
)
100118
merged_state.search_cost_details.extend(state.search_cost_details)
101-
102-
print(f"Merged total cost: ${merged_state.search_costs.get('exa', 0.0):.4f}")
103-
print(f"Merged search details count: {len(merged_state.search_cost_details)}")
104-
105-
return merged_state.search_costs.get('exa', 0.0) == 0.05
119+
120+
print(
121+
f"Merged total cost: ${merged_state.search_costs.get('exa', 0.0):.4f}"
122+
)
123+
print(
124+
f"Merged search details count: {len(merged_state.search_cost_details)}"
125+
)
126+
127+
return merged_state.search_costs.get("exa", 0.0) == 0.05
106128

107129

108130
def main():
109131
"""Run all tests."""
110132
print("Testing Exa Cost Tracking Implementation\n")
111-
133+
112134
# Check if Exa API key is set
113135
if not os.getenv("EXA_API_KEY"):
114136
print("WARNING: EXA_API_KEY not set. Skipping real API tests.")
115137
test_api = False
116138
else:
117139
test_api = True
118-
140+
119141
tests_passed = 0
120142
tests_total = 0
121-
143+
122144
# Test 1: Exa cost extraction (only if API key is available)
123145
if test_api:
124146
tests_total += 1
@@ -130,7 +152,7 @@ def main():
130152
print("✗ Exa cost extraction test failed")
131153
except Exception as e:
132154
print(f"✗ Exa cost extraction test failed with error: {e}")
133-
155+
134156
# Test 2: ResearchState cost tracking
135157
tests_total += 1
136158
try:
@@ -141,7 +163,7 @@ def main():
141163
print("✗ ResearchState cost tracking test failed")
142164
except Exception as e:
143165
print(f"✗ ResearchState cost tracking test failed with error: {e}")
144-
166+
145167
# Test 3: Cost aggregation
146168
tests_total += 1
147169
try:
@@ -152,9 +174,9 @@ def main():
152174
print("✗ Cost aggregation test failed")
153175
except Exception as e:
154176
print(f"✗ Cost aggregation test failed with error: {e}")
155-
177+
156178
print(f"\nTests passed: {tests_passed}/{tests_total}")
157-
179+
158180
if tests_passed == tests_total:
159181
print("\n✅ All tests passed!")
160182
return 0
@@ -164,4 +186,4 @@ def main():
164186

165187

166188
if __name__ == "__main__":
167-
sys.exit(main())
189+
sys.exit(main())

0 commit comments

Comments
 (0)