Skip to content

Commit

Permalink
refactor LLMagent
Browse files Browse the repository at this point in the history
  • Loading branch information
elisalimli committed Apr 29, 2024
1 parent 325b00a commit 24a0904
Showing 1 changed file with 35 additions and 33 deletions.
68 changes: 35 additions & 33 deletions libs/superagent/app/agents/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,22 @@ async def call_tool(


class LLMAgent(AgentBase):
_streaming_callback: CustomAsyncIteratorCallbackHandler

@property
def streaming_callback(self):
return self._streaming_callback

def _set_streaming_callback(
self, callbacks: list[CustomAsyncIteratorCallbackHandler]
):
for callback in callbacks:
if isinstance(callback, CustomAsyncIteratorCallbackHandler):
return callback

# if not still found, raise error
raise Exception("Streaming Callback not found")

@property
def tools(self):
tools = get_tools(
Expand Down Expand Up @@ -104,7 +120,7 @@ def prompt(self):
return prompt

@property
def _is_tool_calling_supported(self):
def _tool_calling_supported(self):
(model, custom_llm_provider, _, _) = get_llm_provider(self.llm_data.model)
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
Expand All @@ -122,7 +138,7 @@ async def _stream_by_lines(self, output: str):
await self.streaming_callback.on_llm_new_token(output_by_lines[0])

async def get_agent(self):
if self._is_tool_calling_supported:
if self._tool_calling_supported:
logger.info("Using native function calling")
return AgentExecutor(**self.__dict__)

Expand Down Expand Up @@ -193,20 +209,22 @@ def _transform_completion_to_streaming(self, res):
choice.delta = choice.message
return [res]

@property
def _stream_directly(self):
return (
self.enable_streaming
and self.llm_data.llm.provider
not in self.NOT_TOOLS_STREAMING_SUPPORTED_PROVIDERS
and self.llm_data.llm.provider != LLMProvider.ANTHROPIC
)

async def _completion(self, **kwargs) -> Any:
logger.info(f"Calling LLM with kwargs: {kwargs}")
new_messages = self.messages

if kwargs.get("stream"):
await self.streaming_callback.on_llm_start()

should_stream_directly = (
self.enable_streaming
and self.llm_data.llm.provider
not in self.NOT_TOOLS_STREAMING_SUPPORTED_PROVIDERS
and self.llm_data.llm.provider != LLMProvider.ANTHROPIC
)

# TODO: Remove this when Groq and Bedrock supports streaming with tools
if self.llm_data.llm.provider in self.NOT_TOOLS_STREAMING_SUPPORTED_PROVIDERS:
logger.info(
Expand Down Expand Up @@ -239,7 +257,7 @@ async def _completion(self, **kwargs) -> Any:

if content:
output += content
if should_stream_directly:
if self._stream_directly:
await self.streaming_callback.on_llm_new_token(content)

self.messages = new_messages
Expand All @@ -249,7 +267,7 @@ async def _completion(self, **kwargs) -> Any:

output = self._cleanup_output(output)

if not should_stream_directly:
if not self._stream_directly:
await self._stream_by_lines(output)

if self.enable_streaming:
Expand All @@ -271,12 +289,7 @@ async def ainvoke(self, input, *_, **kwargs):
]

if self.enable_streaming:
for callback in kwargs["config"]["callbacks"]:
if isinstance(callback, CustomAsyncIteratorCallbackHandler):
self.streaming_callback = callback

if not self.streaming_callback:
raise Exception("Streaming Callback not found")
self._set_streaming_callback()

output = await self._completion(
model=self.llm_data.model,
Expand Down Expand Up @@ -327,13 +340,9 @@ def messages(self):
async def ainvoke(self, input, *_, **kwargs):
self.input = input
tool_results = []
if self.enable_streaming:
for callback in kwargs["config"]["callbacks"]:
if isinstance(callback, CustomAsyncIteratorCallbackHandler):
self.streaming_callback = callback

if not self.streaming_callback:
raise Exception("Streaming Callback not found")
if self.enable_streaming:
self._set_streaming_callback()

if len(self.tools) > 0:
openai_llm = await prisma.llm.find_first(
Expand Down Expand Up @@ -397,22 +406,15 @@ async def ainvoke(self, input, *_, **kwargs):

output = ""
if self.enable_streaming:
streaming_callback = None
for callback in kwargs["config"]["callbacks"]:
if isinstance(callback, CustomAsyncIteratorCallbackHandler):
streaming_callback = callback

if not streaming_callback:
raise Exception("Streaming Callback not found")
await streaming_callback.on_llm_start()
await self.streaming_callback.on_llm_start()

for chunk in res:
token = chunk.choices[0].delta.content
if token:
output += token
await streaming_callback.on_llm_new_token(token)
await self.streaming_callback.on_llm_new_token(token)

streaming_callback.done.set()
self.streaming_callback.done.set()
else:
output = res.choices[0].message.content

Expand Down

0 comments on commit 24a0904

Please sign in to comment.