3
3
4
4
import os
5
5
import sys
6
+
6
7
sys .path .append (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
7
8
8
- from utils .search_utils import exa_search , extract_search_results , search_and_extract_results
9
9
from utils .pydantic_models import ResearchState
10
- from utils .tracing_metadata_utils import get_costs_by_prompt_type
10
+ from utils .search_utils import (
11
+ exa_search ,
12
+ extract_search_results ,
13
+ search_and_extract_results ,
14
+ )
11
15
12
16
13
17
def test_exa_cost_extraction ():
14
18
"""Test that Exa costs are properly extracted from API responses."""
15
19
print ("=== Testing Exa Cost Extraction ===" )
16
-
20
+
17
21
# Test with a simple query
18
22
query = "What is quantum computing?"
19
23
print (f"\n Searching for: { query } " )
20
-
24
+
21
25
# Test direct exa_search
22
26
results = exa_search (query , max_results = 2 )
23
- print (f"Direct exa_search returned exa_cost: ${ results .get ('exa_cost' , 0.0 ):.4f} " )
24
-
27
+ print (
28
+ f"Direct exa_search returned exa_cost: ${ results .get ('exa_cost' , 0.0 ):.4f} "
29
+ )
30
+
25
31
# Test extract_search_results
26
32
extracted , cost = extract_search_results (results , provider = "exa" )
27
33
print (f"extract_search_results returned cost: ${ cost :.4f} " )
28
34
print (f"Number of results extracted: { len (extracted )} " )
29
-
35
+
30
36
# Test search_and_extract_results
31
- results2 , cost2 = search_and_extract_results (query , max_results = 2 , provider = "exa" )
37
+ results2 , cost2 = search_and_extract_results (
38
+ query , max_results = 2 , provider = "exa"
39
+ )
32
40
print (f"search_and_extract_results returned cost: ${ cost2 :.4f} " )
33
41
print (f"Number of results: { len (results2 )} " )
34
-
42
+
35
43
return cost2 > 0
36
44
37
45
38
46
def test_research_state_cost_tracking ():
39
47
"""Test that ResearchState properly tracks costs."""
40
48
print ("\n === Testing ResearchState Cost Tracking ===" )
41
-
49
+
42
50
state = ResearchState (main_query = "Test query" )
43
-
51
+
44
52
# Simulate adding search costs
45
53
state .search_costs ["exa" ] = 0.05
46
- state .search_cost_details .append ({
47
- "provider" : "exa" ,
48
- "query" : "test query 1" ,
49
- "cost" : 0.02 ,
50
- "timestamp" : 1234567890.0 ,
51
- "step" : "test_step"
52
- })
53
- state .search_cost_details .append ({
54
- "provider" : "exa" ,
55
- "query" : "test query 2" ,
56
- "cost" : 0.03 ,
57
- "timestamp" : 1234567891.0 ,
58
- "step" : "test_step"
59
- })
60
-
54
+ state .search_cost_details .append (
55
+ {
56
+ "provider" : "exa" ,
57
+ "query" : "test query 1" ,
58
+ "cost" : 0.02 ,
59
+ "timestamp" : 1234567890.0 ,
60
+ "step" : "test_step" ,
61
+ }
62
+ )
63
+ state .search_cost_details .append (
64
+ {
65
+ "provider" : "exa" ,
66
+ "query" : "test query 2" ,
67
+ "cost" : 0.03 ,
68
+ "timestamp" : 1234567891.0 ,
69
+ "step" : "test_step" ,
70
+ }
71
+ )
72
+
61
73
print (f"Total Exa cost: ${ state .search_costs .get ('exa' , 0.0 ):.4f} " )
62
74
print (f"Number of search details: { len (state .search_cost_details )} " )
63
-
75
+
64
76
return True
65
77
66
78
67
79
def test_cost_aggregation ():
68
80
"""Test cost aggregation from multiple states."""
69
81
print ("\n === Testing Cost Aggregation ===" )
70
-
82
+
71
83
# Create multiple sub-states
72
84
state1 = ResearchState (main_query = "Test" )
73
85
state1 .search_costs ["exa" ] = 0.02
74
- state1 .search_cost_details .append ({
75
- "provider" : "exa" ,
76
- "query" : "query1" ,
77
- "cost" : 0.02 ,
78
- "timestamp" : 1234567890.0 ,
79
- "step" : "sub_step_1"
80
- })
81
-
86
+ state1 .search_cost_details .append (
87
+ {
88
+ "provider" : "exa" ,
89
+ "query" : "query1" ,
90
+ "cost" : 0.02 ,
91
+ "timestamp" : 1234567890.0 ,
92
+ "step" : "sub_step_1" ,
93
+ }
94
+ )
95
+
82
96
state2 = ResearchState (main_query = "Test" )
83
97
state2 .search_costs ["exa" ] = 0.03
84
- state2 .search_cost_details .append ({
85
- "provider" : "exa" ,
86
- "query" : "query2" ,
87
- "cost" : 0.03 ,
88
- "timestamp" : 1234567891.0 ,
89
- "step" : "sub_step_2"
90
- })
91
-
98
+ state2 .search_cost_details .append (
99
+ {
100
+ "provider" : "exa" ,
101
+ "query" : "query2" ,
102
+ "cost" : 0.03 ,
103
+ "timestamp" : 1234567891.0 ,
104
+ "step" : "sub_step_2" ,
105
+ }
106
+ )
107
+
92
108
# Simulate merge
93
109
merged_state = ResearchState (main_query = "Test" )
94
110
merged_state .search_costs = {}
95
111
merged_state .search_cost_details = []
96
-
112
+
97
113
for state in [state1 , state2 ]:
98
114
for provider , cost in state .search_costs .items ():
99
- merged_state .search_costs [provider ] = merged_state .search_costs .get (provider , 0.0 ) + cost
115
+ merged_state .search_costs [provider ] = (
116
+ merged_state .search_costs .get (provider , 0.0 ) + cost
117
+ )
100
118
merged_state .search_cost_details .extend (state .search_cost_details )
101
-
102
- print (f"Merged total cost: ${ merged_state .search_costs .get ('exa' , 0.0 ):.4f} " )
103
- print (f"Merged search details count: { len (merged_state .search_cost_details )} " )
104
-
105
- return merged_state .search_costs .get ('exa' , 0.0 ) == 0.05
119
+
120
+ print (
121
+ f"Merged total cost: ${ merged_state .search_costs .get ('exa' , 0.0 ):.4f} "
122
+ )
123
+ print (
124
+ f"Merged search details count: { len (merged_state .search_cost_details )} "
125
+ )
126
+
127
+ return merged_state .search_costs .get ("exa" , 0.0 ) == 0.05
106
128
107
129
108
130
def main ():
109
131
"""Run all tests."""
110
132
print ("Testing Exa Cost Tracking Implementation\n " )
111
-
133
+
112
134
# Check if Exa API key is set
113
135
if not os .getenv ("EXA_API_KEY" ):
114
136
print ("WARNING: EXA_API_KEY not set. Skipping real API tests." )
115
137
test_api = False
116
138
else :
117
139
test_api = True
118
-
140
+
119
141
tests_passed = 0
120
142
tests_total = 0
121
-
143
+
122
144
# Test 1: Exa cost extraction (only if API key is available)
123
145
if test_api :
124
146
tests_total += 1
@@ -130,7 +152,7 @@ def main():
130
152
print ("✗ Exa cost extraction test failed" )
131
153
except Exception as e :
132
154
print (f"✗ Exa cost extraction test failed with error: { e } " )
133
-
155
+
134
156
# Test 2: ResearchState cost tracking
135
157
tests_total += 1
136
158
try :
@@ -141,7 +163,7 @@ def main():
141
163
print ("✗ ResearchState cost tracking test failed" )
142
164
except Exception as e :
143
165
print (f"✗ ResearchState cost tracking test failed with error: { e } " )
144
-
166
+
145
167
# Test 3: Cost aggregation
146
168
tests_total += 1
147
169
try :
@@ -152,9 +174,9 @@ def main():
152
174
print ("✗ Cost aggregation test failed" )
153
175
except Exception as e :
154
176
print (f"✗ Cost aggregation test failed with error: { e } " )
155
-
177
+
156
178
print (f"\n Tests passed: { tests_passed } /{ tests_total } " )
157
-
179
+
158
180
if tests_passed == tests_total :
159
181
print ("\n ✅ All tests passed!" )
160
182
return 0
@@ -164,4 +186,4 @@ def main():
164
186
165
187
166
188
if __name__ == "__main__" :
167
- sys .exit (main ())
189
+ sys .exit (main ())
0 commit comments