-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_assistant.py
322 lines (300 loc) · 11.6 KB
/
test_assistant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import json
from collections import Counter
import openai
from openai import OpenAI
from odds.common.config import config
from odds.common.vectordb import indexer
from odds.common.store import store
from odds.common.embedder import embedder
from odds.common.catalog_repo import catalog_repo
from odds.common.cost_collector import CostCollector
from odds.common.llm.openai.openai_llm_runner import OpenAILLMRunner
import sqlite3
import asyncio
import time
ASSISTANT_NAME = 'Open data fact checker'
ASSISTANT_INSTRUCTIONS = """
You are a fact checker for a news organization.
Your main focus is to verify the accuracy of claims made in news articles, by using _only_ public data from governmental open data portals.
You have these tools you might use:
- search_datasets: Search for relevant datasets using semantic search
- fetch_dataset: Retrieve full information about a dataset (based on the dataset's id), including its metadata and the names and ids of the resources it contains.
- fetch_resource: Retrieve full information about a resource (based on the resource's id), including its metadata and its DB schema (so you can query it)
- query_resource_database: Query a resource using an SQL query (you need to fetch the DB schema first in order to do a query)
Your goal is to provide an assessment of the accuracy of each relevant claim in the text, based on the data you find.
You should also provide a confidence score for your assessment.
You avoid asking for any additional data, and you avoid using any data that is not publicly available.
All your responses should be based on the data you find in the open data portals, and include references to the datasets you used.
"""
TOOLS = [
dict(
type='function',
function=dict(
name='search_datasets',
description='Find the metadata of relevant datasets using semantic search',
parameters=dict(
type='object',
properties=dict(
query=dict(
type='string',
description='The query string to use to search for datasets. Multiple query terms are allowed, separated by a comma.'
),
),
required=['query']
)
)
),
dict(
type='function',
function=dict(
name='fetch_dataset',
description='Get the full metadata for a single dataset, including the list of its resources',
parameters=dict(
type='object',
properties=dict(
dataset_id=dict(
type='string',
description='The dataset ID to fetch.'
),
),
required=['dataset_id']
)
)
),
dict(
type='function',
function=dict(
name='fetch_resource',
description='Get the full metadata for a single resource in a single dataset ',
parameters=dict(
type='object',
properties=dict(
dataset_id=dict(
type='string',
description='The dataset id containing this resource'
),
resource_id=dict(
type='string',
description='The resource ID to fetch.'
),
),
required=['dataset_id', 'resource_id']
)
)
),
dict(
type='function',
function=dict(
name='query_resource_database',
description='Perform an SQL query on a resource',
parameters=dict(
type='object',
properties=dict(
resource_id=dict(
type='string',
description='The resource ID to query.'
),
query=dict(
type='string',
description='SQLite compatible query to perform on the resource'
),
),
required=['id', 'query']
)
)
),
]
async def search_datasets(query: str):
print('SEARCH DATASETS:', query)
embedding = await embedder.embed(query)
datasets = await indexer.findDatasets(embedding)
catalogs = [catalog_repo.get_catalog(dataset.catalogId) for dataset in datasets]
response = [
dict(
id=dataset.storeId(),
title=dataset.better_title or dataset.title,
description=dataset.better_description or dataset.description,
publisher=dataset.publisher,
catalog=catalog.title,
)
for dataset, catalog in zip(datasets, catalogs)
]
print('RESPONSE:', response)
return response
async def fetch_dataset(id):
print('FETCH DATASET:', id)
dataset = await store.getDataset(id)
response = None
if dataset:
response = dict(
title=dataset.better_title or dataset.title,
description=dataset.better_description or dataset.description,
publisher=dataset.publisher,
publisher_description=dataset.publisher_description,
resources=[
dict(
id=id + f'##{i}',
name=resource.title,
num_rows=resource.row_count,
)
for i, resource in enumerate(dataset.resources)
],
)
print('RESPONSE:', response)
return response
async def fetch_sesource(id):
print('FETCH RESOURCE:', id)
datasetId, resourceIdx = id.split('##')
resourceIdx = int(resourceIdx)
dataset = await store.getDataset(datasetId)
response = None
if dataset:
resource = dataset.resources[resourceIdx]
if resource:
response = dict(
name=resource.title,
fields=[
dict(
name=field.name,
type=field.data_type,
max=field.max_value,
min=field.min_value,
sample_values=field.sample_values,
)
for field in resource.fields
],
db_schema=resource.db_schema
)
print('RESPONSE:', response)
return response
async def query_db(resource_id, query):
print('QUERY DB:', resource_id, query)
datasetId, resourceIdx = resource_id.split('##')
resourceIdx = int(resourceIdx)
dataset = await store.getDataset(datasetId)
if dataset:
resource = dataset.resources[resourceIdx]
if resource:
dbFile = await store.getDB(resource, dataset)
if dbFile is None:
print('FAILED TO FIND DB', dbFile)
return None
try:
con = sqlite3.connect(dbFile)
cur = con.cursor()
cur.execute(query)
# Fetch data as a list of dicts:
data = cur.fetchall()
headers = [x[0] for x in cur.description]
data = [dict(zip(headers, row)) for row in data]
print('GOT DATA', data)
return dict(success=True, data=data)
except Exception as e:
print('FAILED TO QUERY DB', dbFile, repr(e))
return dict(success=False, error=str(e))
return None
async def loop(client, thread, run, usage):
while True:
print('RUN:', run.status)
if run.status == 'completed':
return True
elif run.status == 'requires_action':
tool_outputs = []
for tool in run.required_action.submit_tool_outputs.tool_calls:
if not tool.type == 'function':
continue
print(f'TOOL - {tool.type}: {tool.function.name}({tool.function.arguments})')
arguments = json.loads(tool.function.arguments)
function_name = tool.function.name
if function_name == 'search_datasets':
query = arguments['query']
output = await search_datasets(query)
elif function_name == 'fetch_dataset':
id = arguments['dataset_id']
output = await fetch_dataset(id)
elif function_name == 'fetch_resource':
id = arguments['resource_id']
output = await fetch_sesource(id)
elif function_name == 'query_resource_database':
id = arguments['resource_id']
query = arguments['query']
output = await query_db(id, query)
output = json.dumps(output, ensure_ascii=False)
tool_outputs.append(dict(
tool_call_id=tool.id,
output=output
))
# Submit all tool outputs at once after collecting them in a list
if tool_outputs:
try:
run = client.beta.threads.runs.submit_tool_outputs_and_poll(
thread_id=thread.id,
run_id=run.id,
tool_outputs=tool_outputs,
)
if run.usage:
usage.update_cost('expensive', 'prompt', run.usage.prompt_tokens)
usage.update_cost('expensive', 'completion', run.usage.completion_tokens)
print("Tool outputs submitted successfully.")
continue
except Exception as e:
print("Failed to submit tool outputs:", e)
else:
return False
else:
print(run.status)
return False
def main():
client = OpenAI(
api_key=config.credentials.openai.key,
organization=config.credentials.openai.org,
project=config.credentials.openai.proj,
)
assistant = client.beta.assistants.create(
name=ASSISTANT_NAME,
instructions=ASSISTANT_INSTRUCTIONS,
model="gpt-4o",
tools=TOOLS,
temperature=0,
)
thread = client.beta.threads.create()
client.beta.threads.messages.create(
thread_id=thread.id,
role='user',
content='Please verify the claims in this article:\n\n' + open('article2.txt').read(),
)
usage = CostCollector('assistant', OpenAILLMRunner.COSTS)
run = client.beta.threads.runs.create_and_poll(
thread_id=thread.id,
assistant_id=assistant.id,
temperature=0,
)
if run.usage:
usage.update_cost('expensive', 'prompt', run.usage.prompt_tokens)
usage.update_cost('expensive', 'completion', run.usage.completion_tokens)
success = False
try:
success = asyncio.run(loop(client, thread, run, usage))
finally:
print('SUCCESS:', success)
messages = client.beta.threads.messages.list(
thread_id=thread.id, order='asc'
)
for message in messages:
print(f'{message.role}:')
for block in message.content:
if block.type == 'text':
print(block.text.value)
else:
print(block.type)
print(block)
print('CLEANUP')
# try:
# time.sleep(600)
# except KeyboardInterrupt:
# pass
client.beta.threads.delete(thread.id)
client.beta.assistants.delete(assistant.id)
usage.print_total_usage()
if __name__ == '__main__':
main()