-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest_chatbot.py
130 lines (97 loc) · 3.95 KB
/
test_chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import yaml
# import mocks.openai as openai
import openai
import backend
from openai_chatbot import OpenAIChatbot
from grace_chatbot import GRACEChatbot
from datetime import datetime
from dotenv import load_dotenv
import pytest
load_dotenv()
# Suppress the tokenizers parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
openai.api_key = os.environ["OPENAI_API_KEY"]
with open("config.yaml", "r") as stream:
config = yaml.safe_load(stream)
with open("domain.yaml", "r") as stream:
domain = yaml.safe_load(stream)
@pytest.fixture
def customer_prompt_template() -> str:
return """You are a customer of {business_name}, {business_description}. You are chatting to the restaurant's AI assistant. {{task_description}}
A transcript of your chat session with the AI assistant follows.
""".format(**domain)
def test_book_table(customer_prompt_template):
backend.bookings = {}
task_description = f"You are looking to book a table on the name of Jeremiah Biggs, for 3 people at 8 pm on June 23, 2023. You don't provide all of this information at once but rather respond to the AI assistant's prompts."
customer_prompt = customer_prompt_template.format(
task_description=task_description)
_run_session(customer_prompt)
assert list(backend.bookings.values()) == [{
"full_name": "Jeremiah Biggs",
"num_people": 3,
"time": datetime(2023, 6, 23, 20, 0, 0)
}]
def test_change_booking(customer_prompt_template):
reference = "S8W308"
backend.bookings = {
reference: {
"full_name": "Ann Hicks",
"num_people": 4,
"time": datetime(2023, 7, 14, 18, 0, 0)
}
}
task_description = f"You'd like to change a table booking with reference {reference} that you made earlier. You're looking to change it from 4 people to 3 people and from 6 PM to 5:30 PM. You don't provide all of this information at once but rather respond to the AI assistant's prompts."
customer_prompt = customer_prompt_template.format(
task_description=task_description)
_run_session(customer_prompt)
assert list(backend.bookings.values()) == [{
"full_name": "Ann Hicks",
"num_people": 3,
"time": datetime(2023, 7, 14, 17, 30, 0)
}]
def test_cancel_booking(customer_prompt_template):
reference = "ZBA4HB"
backend.bookings = {
reference: {
"full_name": "Mary Ashcroft",
"num_people": 5,
"time": datetime(2023, 6, 2, 18, 15, 0)
}
}
task_description = f"You are looking to cancel your booking with reference {reference}. The reference is all information you have about the booking."
customer_prompt = customer_prompt_template.format(
task_description=task_description)
_run_session(customer_prompt)
assert reference not in backend.bookings
def _run_session(customer_prompt: str):
ai_utterances = []
customer_utterances = []
ai_chatbot = GRACEChatbot(
openai=openai,
backend=backend.backend,
domain=domain,
output_callback=lambda u: ai_utterances.append(u),
openai_model=config["openai"]["model"],
openai_endpoint=config["openai"]["endpoint"]
)
ai_chatbot.start_session()
customer_prompt += "".join(["\nAI: " + u for u in ai_utterances])
print(ai_utterances)
ai_utterances = []
customer_chatbot = OpenAIChatbot(
openai=openai,
initial_prompt=customer_prompt,
output_callback=lambda u: customer_utterances.append(u),
names=("Customer", "AI"),
openai_model=config["openai"]["model"],
openai_endpoint=config["openai"]["endpoint"]
)
customer_chatbot.start_session()
while not ai_chatbot.session_ended():
ai_chatbot.send_responses(customer_utterances)
print(customer_utterances)
customer_utterances = []
customer_chatbot.send_responses(ai_utterances)
print(ai_utterances)
ai_utterances = []