Skip to content

Commit 3cae6af

Browse files
authored
new functions added to conversations (#126)
* doing something! * fix mypy * fix * fix * onward * wip * killing it * nicely failing test! * better * passing tests * onward * try this * next * use standardized types * getting there! * next * ownard * onward * killing it
1 parent 26e7a60 commit 3cae6af

File tree

13 files changed

+390
-111
lines changed

13 files changed

+390
-111
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,4 @@ lerna-debug.log*
4141

4242
# pycache
4343
__pycache__
44+
.mypy_cache/
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
-- CreateTable
2+
CREATE TABLE "FunctionDefined" (
3+
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
4+
"message_id" TEXT NOT NULL,
5+
"function_id" INTEGER NOT NULL,
6+
CONSTRAINT "FunctionDefined_message_id_fkey" FOREIGN KEY ("message_id") REFERENCES "conversation_message" ("id") ON DELETE RESTRICT ON UPDATE CASCADE
7+
);
8+
9+
-- CreateTable
10+
CREATE TABLE "WebhookDefined" (
11+
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
12+
"message_id" TEXT NOT NULL,
13+
"webhook_id" TEXT NOT NULL,
14+
CONSTRAINT "WebhookDefined_message_id_fkey" FOREIGN KEY ("message_id") REFERENCES "conversation_message" ("id") ON DELETE RESTRICT ON UPDATE CASCADE
15+
);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
Warnings:
3+
4+
- You are about to drop the column `function_id` on the `FunctionDefined` table. All the data in the column will be lost.
5+
- You are about to drop the column `webhook_id` on the `WebhookDefined` table. All the data in the column will be lost.
6+
- Added the required column `functionPublicId` to the `FunctionDefined` table without a default value. This is not possible if the table is not empty.
7+
- Added the required column `webhookPublicId` to the `WebhookDefined` table without a default value. This is not possible if the table is not empty.
8+
9+
*/
10+
-- RedefineTables
11+
PRAGMA foreign_keys=OFF;
12+
CREATE TABLE "new_FunctionDefined" (
13+
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
14+
"message_id" TEXT NOT NULL,
15+
"functionPublicId" TEXT NOT NULL,
16+
CONSTRAINT "FunctionDefined_message_id_fkey" FOREIGN KEY ("message_id") REFERENCES "conversation_message" ("id") ON DELETE RESTRICT ON UPDATE CASCADE
17+
);
18+
INSERT INTO "new_FunctionDefined" ("id", "message_id") SELECT "id", "message_id" FROM "FunctionDefined";
19+
DROP TABLE "FunctionDefined";
20+
ALTER TABLE "new_FunctionDefined" RENAME TO "FunctionDefined";
21+
CREATE TABLE "new_WebhookDefined" (
22+
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
23+
"message_id" TEXT NOT NULL,
24+
"webhookPublicId" TEXT NOT NULL,
25+
CONSTRAINT "WebhookDefined_message_id_fkey" FOREIGN KEY ("message_id") REFERENCES "conversation_message" ("id") ON DELETE RESTRICT ON UPDATE CASCADE
26+
);
27+
INSERT INTO "new_WebhookDefined" ("id", "message_id") SELECT "id", "message_id" FROM "WebhookDefined";
28+
DROP TABLE "WebhookDefined";
29+
ALTER TABLE "new_WebhookDefined" RENAME TO "WebhookDefined";
30+
PRAGMA foreign_key_check;
31+
PRAGMA foreign_keys=ON;
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
Warnings:
3+
4+
- You are about to drop the `FunctionDefined` table. If the table is not empty, all the data it contains will be lost.
5+
- You are about to drop the `WebhookDefined` table. If the table is not empty, all the data it contains will be lost.
6+
7+
*/
8+
-- DropTable
9+
PRAGMA foreign_keys=off;
10+
DROP TABLE "FunctionDefined";
11+
PRAGMA foreign_keys=on;
12+
13+
-- DropTable
14+
PRAGMA foreign_keys=off;
15+
DROP TABLE "WebhookDefined";
16+
PRAGMA foreign_keys=on;
17+
18+
-- CreateTable
19+
CREATE TABLE "function_defined" (
20+
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
21+
"message_id" TEXT NOT NULL,
22+
"functionPublicId" TEXT NOT NULL,
23+
CONSTRAINT "function_defined_message_id_fkey" FOREIGN KEY ("message_id") REFERENCES "conversation_message" ("id") ON DELETE RESTRICT ON UPDATE CASCADE
24+
);
25+
26+
-- CreateTable
27+
CREATE TABLE "webhook_defined" (
28+
"id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
29+
"message_id" TEXT NOT NULL,
30+
"webhookPublicId" TEXT NOT NULL,
31+
CONSTRAINT "webhook_defined_message_id_fkey" FOREIGN KEY ("message_id") REFERENCES "conversation_message" ("id") ON DELETE RESTRICT ON UPDATE CASCADE
32+
);

prisma/schema.prisma

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,36 @@ model WebhookHandle {
8080
model ConversationMessage {
8181
id String @id @default(uuid())
8282
createdAt DateTime @default(now())
83-
user User @relation(fields: [userId], references: [id])
84-
userId Int @map("user_id")
85-
name String @default("") // not used for now. in future, will allow the user to have multiple conversations
86-
role String // assistant or user
87-
content String
83+
user User @relation(fields: [userId], references: [id])
84+
userId Int @map("user_id")
85+
name String @default("") // not used for now. in future, will allow the user to have multiple conversations
86+
role String // assistant or user
87+
content String
88+
functions FunctionDefined[]
89+
webhooks WebhookDefined[]
8890
8991
@@map("conversation_message")
9092
}
9193

94+
model FunctionDefined {
95+
id Int @id @default(autoincrement())
96+
message ConversationMessage @relation(fields: [messageId], references: [id])
97+
messageId String @map("message_id")
98+
functionPublicId String // could be a relation but isn't currently
99+
100+
@@map("function_defined")
101+
}
102+
103+
104+
model WebhookDefined {
105+
id Int @id @default(autoincrement())
106+
message ConversationMessage @relation(fields: [messageId], references: [id])
107+
messageId String @map("message_id")
108+
webhookPublicId String // could be a relation but isn't currently
109+
110+
@@map("webhook_defined")
111+
}
112+
92113
// customizable prompt per user
93114
model SystemPrompt {
94115
id String @id @default(uuid())

science/completion.py

Lines changed: 131 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import copy
12
import requests
23
import openai
3-
from typing import List, Dict, Optional
4+
import flask
5+
from typing import List, Dict, Optional, Set
46
from prisma import get_client
57
from prisma.models import ConversationMessage, SystemPrompt
68

@@ -9,7 +11,9 @@
911

1012
from utils import (
1113
FunctionDto,
14+
MessageDict,
1215
WebhookDto,
16+
clear_conversation,
1317
func_path_with_args,
1418
func_path,
1519
store_message,
@@ -37,7 +41,9 @@
3741
"the poly api library does not provide",
3842
"the poly api library doesn't provide",
3943
}
40-
NO_FUNCTION_ANSWER = "Poly doesn't know any functions to do that yet. But Poly would love to be taught!"
44+
NO_FUNCTION_ANSWER = (
45+
"Poly doesn't know any functions to do that yet. But Poly would love to be taught!"
46+
)
4147

4248

4349
ADVERBS = {
@@ -50,7 +56,7 @@ def question_processing(question: str) -> str:
5056
# return question
5157

5258

53-
def answer_processing(from_openai: str):
59+
def answer_processing(from_openai: str) -> str:
5460
lowered = from_openai.strip().lower()
5561

5662
# first strip off any common adverbs
@@ -77,9 +83,7 @@ def get_function_completion_answer(user_id: Optional[int], question: str) -> str
7783
if messages:
7884
return get_conversation_answer(user_id, messages, question)
7985
else:
80-
functions = get_function_prompt()
81-
webhooks = get_webhook_prompt()
82-
return get_completion_answer(user_id, functions, webhooks, question)
86+
return get_completion_answer(user_id, question)
8387

8488

8589
def get_conversations_for_user(user_id: Optional[int]) -> List[ConversationMessage]:
@@ -94,10 +98,16 @@ def get_conversations_for_user(user_id: Optional[int]) -> List[ConversationMessa
9498
)
9599

96100

97-
def get_function_prompt() -> str:
98-
how_to_import = "To import the Poly API library, use `import poly from 'polyapi';`"
99-
preface = "Here are the functions in the Poly API library,"
100-
parts: List[str] = [how_to_import, preface]
101+
def get_function_message_dict(
102+
already_defined: Optional[Set[str]] = None,
103+
) -> Optional[MessageDict]:
104+
"""get all functions (if any) that are not yet defined in the prompt
105+
:param already_defined: a list of function public ids that are already defined
106+
"""
107+
already_defined = already_defined or set()
108+
109+
preface = "Here are some functions in the Poly API library,"
110+
parts: List[str] = [preface]
101111

102112
db = get_client()
103113
user = db.user.find_first(where={"role": "ADMIN"})
@@ -108,14 +118,28 @@ def get_function_prompt() -> str:
108118
resp = requests.get(f"{NODE_API_URL}/functions", headers=headers)
109119
assert resp.status_code == 200, resp.content
110120
funcs: List[FunctionDto] = resp.json()
121+
122+
public_ids = []
111123
for func in funcs:
124+
if func["id"] in already_defined:
125+
continue
126+
112127
parts.append(f"// {func['description']}\n{func_path_with_args(func)}")
128+
public_ids.append(func["id"])
129+
130+
content = "\n\n".join(parts)
131+
return {"role": "assistant", "content": content, "function_ids": public_ids}
113132

114-
return "\n\n".join(parts)
115133

134+
def get_webhook_message_dict(
135+
already_defined: Optional[Set[str]] = None,
136+
) -> Optional[MessageDict]:
137+
"""get all webhooks (if any) that are not yet defined in the prompt
138+
:param already_defined: a list of webhook public ids that are already defined
139+
"""
140+
already_defined = already_defined or set()
116141

117-
def get_webhook_prompt() -> str:
118-
preface = "Here are the event handlers in the Poly API library,"
142+
preface = "Here are some event handlers in the Poly API library,"
119143
parts: List[str] = [preface]
120144

121145
db = get_client()
@@ -126,11 +150,27 @@ def get_webhook_prompt() -> str:
126150
headers = {"Content-Type": "application/json", "X-PolyApiKey": user.apiKey}
127151
resp = requests.get(f"{NODE_API_URL}/webhooks/", headers=headers)
128152
assert resp.status_code == 200, resp.content
153+
154+
public_ids = []
129155
webhooks: List[WebhookDto] = resp.json()
156+
130157
for webhook in webhooks:
158+
if webhook["id"] in already_defined:
159+
continue
160+
131161
parts.append(webhook_prompt(webhook))
162+
public_ids.append(webhook["id"])
163+
164+
if not public_ids:
165+
# all the webhooks are already defined!
166+
# let's go ahead and skip
167+
return None
132168

133-
return "\n\n".join(parts)
169+
return {
170+
"role": "assistant",
171+
"content": "\n\n".join(parts),
172+
"webhook_ids": public_ids,
173+
}
134174

135175

136176
def get_fine_tune_answer(question: str):
@@ -159,33 +199,84 @@ def webhook_prompt(hook: WebhookDto) -> str:
159199
def get_conversation_answer(
160200
user_id: Optional[int], messages: List[ConversationMessage], question: str
161201
):
162-
priors: List[Dict[str, str]] = []
202+
# prepare payload
203+
priors: List[MessageDict] = []
163204
for message in messages:
164205
priors.append({"role": message.role, "content": message.content})
165206

166-
question_message = {"role": "user", "content": question}
167-
resp = openai.ChatCompletion.create(
168-
model="gpt-3.5-turbo",
169-
messages=priors + [question_message],
170-
)
207+
new_messages = get_new_conversation_messages(messages, question)
208+
209+
# get
210+
try:
211+
resp = get_chat_completion(priors + new_messages)
212+
except openai.InvalidRequestError as e:
213+
# our conversation is probably too long! let's transparently nuke it and start again
214+
flask.current_app.log_exception(e) # type: ignore
215+
clear_conversation(user_id)
216+
return get_completion_answer(user_id, question)
171217

172218
answer = answer_processing(resp["choices"][0]["message"]["content"])
173-
store_message(user_id, question_message)
219+
220+
# store
221+
for message in new_messages:
222+
store_message(user_id, message)
174223
store_message(user_id, {"role": "assistant", "content": answer})
224+
175225
return answer
176226

177227

178-
def get_completion_prompt_messages(
179-
functions: str, webhooks: str, question: str
180-
) -> List[Dict]:
228+
def get_new_conversation_messages(old_messages: List[ConversationMessage], question: str) -> List[MessageDict]:
229+
""" get all the new messages that should be added to an existing conversation
230+
"""
231+
rv = []
232+
233+
old_msg_ids = [m.id for m in old_messages]
234+
235+
db = get_client()
236+
old_function_ids = {f.functionPublicId for f in db.functiondefined.find_many(where={"messageId": {"in": old_msg_ids}})}
237+
old_webhook_ids = {w.webhookPublicId for w in db.webhookdefined.find_many(where={"messageId": {"in": old_msg_ids}})}
238+
239+
new_functions = get_webhook_message_dict(old_function_ids)
240+
if new_functions:
241+
rv.append(new_functions)
242+
243+
new_webhooks = get_webhook_message_dict(old_webhook_ids)
244+
if new_webhooks:
245+
rv.append(new_webhooks)
246+
247+
question_msg = MessageDict(role="user", content=question)
248+
rv.append(question_msg)
249+
return rv
250+
251+
252+
def get_chat_completion(messages: List[MessageDict]) -> Dict:
253+
""" send the messages to OpenAI and get a response
254+
"""
255+
stripped = copy.deepcopy(messages)
256+
for s in stripped:
257+
# pop off all the data we use internally before sending the messages to OpenAI
258+
s.pop("function_ids", None)
259+
s.pop("webhook_ids", None)
260+
261+
return openai.ChatCompletion.create(
262+
model="gpt-3.5-turbo",
263+
messages=stripped,
264+
)
265+
266+
267+
def get_completion_prompt_messages(question: str) -> List[MessageDict]:
268+
function_message = get_function_message_dict()
269+
webhook_message = get_webhook_message_dict()
270+
181271
rv = [
182-
{"role": "system", "content": "Include argument types. Be concise."},
183-
{"role": "assistant", "content": functions},
184-
{"role": "assistant", "content": webhooks},
185-
# HACK To try to prevent Poly hallucinating functions we don't have
186-
# https://github.com/polyapi/poly-alpha/issues/96
187-
# {"role": "assistant", "content": "Only respond with functions and event handlers explicitly listed as part of the Poly API library. Do not use external APIs."},
188-
{"role": "user", "content": question},
272+
MessageDict(role="system", content="Include argument types. Be concise."),
273+
MessageDict(
274+
role="assistant",
275+
content="To import the Poly API library, use `import poly from 'polyapi';`",
276+
),
277+
function_message,
278+
webhook_message,
279+
MessageDict(role="user", content=question),
189280
]
190281

191282
system_prompt = get_system_prompt()
@@ -202,18 +293,18 @@ def get_system_prompt() -> Optional[SystemPrompt]:
202293
return system_prompt
203294

204295

205-
def get_completion_answer(
206-
user_id: Optional[int], functions: str, webhooks: str, question: str
207-
) -> str:
208-
messages = get_completion_prompt_messages(functions, webhooks, question)
209-
210-
model = "gpt-3.5-turbo"
211-
# print(f"Using model: {model}")
212-
resp = openai.ChatCompletion.create(model=model, messages=messages)
296+
def get_completion_answer(user_id: Optional[int], question: str) -> str:
297+
messages = get_completion_prompt_messages(question)
298+
resp = get_chat_completion(messages)
213299
answer = answer_processing(resp["choices"][0]["message"]["content"])
214300

215301
for message in messages:
216-
store_message(user_id, message)
302+
store_message(
303+
user_id,
304+
message,
305+
function_ids=message.get("function_ids", []),
306+
webhook_ids=message.get("webhook_ids", []),
307+
)
217308
store_message(user_id, {"role": "assistant", "content": answer})
218309

219310
return answer

0 commit comments

Comments
 (0)