-
Notifications
You must be signed in to change notification settings - Fork 10
/
retrieve_html_proxy_agent.py
181 lines (159 loc) · 7.14 KB
/
retrieve_html_proxy_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import asyncio
import json
from typing import Any, Callable, Dict, List, Optional, Union
import autogen
from dotenv import load_dotenv
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from autogen.agentchat.agent import Agent
from token_count import num_tokens_from_string
import websockets
CHUNK_SIZE = 15_000 # use a chunk size of 15_000 tokens so that it comfortably fits in the OpenAI API limit of 16_000 tokens
def get_html_chunks(html: str):
text_splitter = CharacterTextSplitter(
separator=">",
chunk_size=CHUNK_SIZE,
chunk_overlap=10,
length_function=num_tokens_from_string # use OpenAI's tokenizer to count tokens
)
chunks = text_splitter.split_text(html)
return chunks
# TODO : avoid recomputing embeddings for html chunks that have not changes
def build_vectorstore(html_chunks: [str]):
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_texts(texts=html_chunks, embedding=embeddings)
return vectorstore
load_dotenv()
PROMPT_QA = """You're a retrieve augmented chatbot. You answer user's questions based on the HTML
context provided by the user. You must answer as concisely as possible.
User's question is: {input_question}
Context is: {input_context}
"""
class RetrieveHTMLProxyAgent(autogen.ConversableAgent):
'''
An agent that fetches the relevant HTML content from a user query based
on the HTML content of the current page on the browser.
'''
def __init__(
self,
name: str,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "ALWAYS",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Optional[Union[Dict, bool]] = None,
default_auto_reply: Optional[Union[str, Dict, None]] = "",
llm_config: Optional[Union[Dict, bool]] = False,
system_message: Optional[str] = "",
browser_console_uri: Optional[str] = "ws://localhost:3000",
):
super().__init__(
name=name,
is_termination_msg=is_termination_msg,
max_consecutive_auto_reply=max_consecutive_auto_reply,
human_input_mode=human_input_mode,
function_map=function_map,
code_execution_config=code_execution_config,
default_auto_reply=default_auto_reply,
llm_config=llm_config,
system_message=system_message,
)
self.browser_console_uri = browser_console_uri
self.connect_websocket()
self.html = ""
self.vectorstore = None
def connect_websocket(self):
self.websocket = asyncio.get_event_loop().run_until_complete(websockets.connect(self.browser_console_uri))
async def fetch_html(self, **kwargs) -> str:
'''
Fetch the HTML of the current page from the browser console
'''
if not self.websocket:
self.websocket = websockets.connect(self.browser_console_uri)
print(f"Fetching HTML of current page...")
message = json.dumps({
'action': "fetchHTML",
})
await self.websocket.send(message)
response_data = json.loads(await self.websocket.recv())
if not response_data.get('success'):
raise Exception("Failed to fetch HTML")
return response_data["result"]
def _retrieve_context(self, vectorstore, query: str) -> str:
'''
Get the most relevant chunk using the user's question as a query
'''
relevant_chunks = vectorstore.similarity_search(query, k = 1)
print("Relevant chunks retrieved")
relevant_chunks = [chunk.page_content for chunk in relevant_chunks]
return "\n\n".join(relevant_chunks)
def _build_message_with_context(self, question: str) -> str:
'''
Build a message with the context retrieved from the HTML using RAG
'''
html = ""
try:
html = asyncio.get_event_loop().run_until_complete(self.fetch_html())
except Exception as e:
raise e
if num_tokens_from_string(html) < CHUNK_SIZE :
context = html
else :
if html != self.html: # html has changed
self.html = html # update html
html_chunks = get_html_chunks(html)
print("HTML chunked")
print("n_chunks = ", len(html_chunks))
vectorstore = build_vectorstore(html_chunks)
self.vectorstore = vectorstore
print("Vectorstore built")
context = self._retrieve_context(self.vectorstore, question)
message = PROMPT_QA.format(input_question=question, input_context=context)
return message
def send(
self,
message: Union[Dict, str],
recipient: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> bool:
"""Send a message to another agent.
Args:
message (dict or str): message to be sent.
The message could contain the following fields (either content or function_call must be provided):
- content (str): the content of the message.
- function_call (str): the name of the function to be called.
- name (str): the name of the function to be called.
- role (str): the role of the message, any role that is not "function"
will be modified to "assistant".
- context (dict): the context of the message, which will be passed to
[Completion.create](../oai/Completion#create).
For example, one agent can send a message A as:
```python
{
"content": lambda context: context["use_tool_msg"],
"context": {
"use_tool_msg": "Use tool X if they are relevant."
}
}
```
Next time, one agent can send a message B with a different "use_tool_msg".
Then the content of message A will be refreshed to the new "use_tool_msg".
So effectively, this provides a way for an agent to send a "link" and modify
the content of the "link" later.
recipient (Agent): the recipient of the message.
request_reply (bool or None): whether to request a reply from the recipient.
silent (bool or None): (Experimental) whether to print the message sent.
Raises:
ValueError: if the message can't be converted into a valid ChatCompletion message.
"""
# add relevant HTML context to the message first
message = self._build_message_with_context(message)
valid = self._append_oai_message(message, "assistant", recipient)
if valid:
recipient.receive(message, self, request_reply, silent)
else:
raise ValueError(
"Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided."
)