Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FunctionTools Serializable (Declarative) #5052

Merged
merged 14 commits into from
Jan 24, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,20 @@
"source": [
"# Serializing Components \n",
"\n",
"AutoGen provides a {py:class}`~autogen_core.Component` configuration class that defines behaviours for to serialize/deserialize component into declarative specifications. This is useful for debugging, visualizing, and even for sharing your work with others. In this notebook, we will demonstrate how to serialize multiple components to a declarative specification like a JSON file. \n",
"AutoGen provides a {py:class}`~autogen_core.Component` configuration class that defines behaviours to serialize/deserialize component into declarative specifications. We can accomplish this by calling `.dump_component()` and `.load_component()` respectively. This is useful for debugging, visualizing, and even for sharing your work with others. In this notebook, we will demonstrate how to serialize multiple components to a declarative specification like a JSON file. \n",
"\n",
"\n",
"```{note}\n",
"This is work in progress\n",
"``` \n",
"```{warning}\n",
"\n",
"We will be implementing declarative support for the following components:\n",
"ONLY LOAD COMPONENTS FROM TRUSTED SOURCES.\n",
"\n",
"- Termination conditions ✔️\n",
"- Tools \n",
"- Agents \n",
"- Teams \n",
"With serilized components, each component implements the logic for how it is serialized and deserialized - i.e., how the declarative specification is generated and how it is converted back to an object. \n",
"\n",
"In some cases, creating an object may include executing code (e.g., a serialized function). ONLY LOAD COMPONENTS FROM TRUSTED SOURCES. \n",
" \n",
"```\n",
"\n",
" \n",
"### Termination Condition Example \n",
"\n",
"In the example below, we will define termination conditions (a part of an agent team) in python, export this to a dictionary/json and also demonstrate how the termination condition object can be loaded from the dictionary/json. \n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class ImportFromModule:
module: str
imports: Tuple[Union[str, Alias], ...]

## backward compatibility
# backward compatibility
def __init__(
self,
module: str,
Expand Down Expand Up @@ -214,3 +214,11 @@ def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str

content += " ..."
return content


def to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
return _to_code(func)


def import_to_str(im: Import) -> str:
return _import_to_str(im)
9 changes: 7 additions & 2 deletions python/packages/autogen-core/src/autogen_core/tools/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing_extensions import NotRequired

from .. import CancellationToken
from .._component_config import ComponentBase
from .._function_utils import normalize_annotated_type

T = TypeVar("T", bound=BaseModel, contravariant=True)
Expand Down Expand Up @@ -56,7 +57,9 @@ def load_state_json(self, state: Mapping[str, Any]) -> None: ...
StateT = TypeVar("StateT", bound=BaseModel)


class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT], ComponentBase[BaseModel]):
component_type = "tool"

def __init__(
self,
args_type: Type[ArgsT],
Expand Down Expand Up @@ -132,7 +135,7 @@ def load_state_json(self, state: Mapping[str, Any]) -> None:
pass


class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]):
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT], ComponentBase[BaseModel]):
def __init__(
self,
args_type: Type[ArgsT],
Expand All @@ -144,6 +147,8 @@ def __init__(
super().__init__(args_type, return_type, name, description)
self._state_type = state_type

component_type = "tool"

@abstractmethod
def save_state(self) -> StateT: ...

Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
import asyncio
import functools
from typing import Any, Callable
from textwrap import dedent
from typing import Any, Callable, Sequence
import warnings

from pydantic import BaseModel
from typing_extensions import Self

from .. import CancellationToken
from .._component_config import Component
from .._function_utils import (
args_base_model_from_signature,
get_typed_signature,
)
from ..code_executor._func_with_reqs import Import, import_to_str, to_code
from ._base import BaseTool


class FunctionTool(BaseTool[BaseModel, BaseModel]):
class FunctionToolConfig(BaseModel):
"""Configuration for a function tool."""

source_code: str
name: str
description: str
global_imports: Sequence[Import]
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
has_cancellation_support: bool


class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]):
"""
Create custom tools by wrapping standard Python functions.

Expand Down Expand Up @@ -64,8 +79,14 @@ async def example():
asyncio.run(example())
"""

def __init__(self, func: Callable[..., Any], description: str, name: str | None = None) -> None:
component_provider_override = "autogen_core.tools.FunctionTool"
component_config_schema = FunctionToolConfig

def __init__(
self, func: Callable[..., Any], description: str, name: str | None = None, global_imports: Sequence[Import] = []
) -> None:
self._func = func
self._global_imports = global_imports
signature = get_typed_signature(func)
func_name = name or func.__name__
args_model = args_base_model_from_signature(func_name + "args", signature)
Expand Down Expand Up @@ -98,3 +119,52 @@ async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> A
result = await future

return result

def _to_config(self) -> FunctionToolConfig:
return FunctionToolConfig(
source_code=dedent(to_code(self._func)),
global_imports=self._global_imports,
name=self.name,
description=self.description,
has_cancellation_support=self._has_cancellation_support,
)

@classmethod
def _from_config(cls, config: FunctionToolConfig) -> Self:
warnings.warn(
"\n⚠️ SECURITY WARNING ⚠️\n"
"Loading a FunctionTool from config will execute code to import the provided global imports and and function code.\n"
"Only load configs from TRUSTED sources to prevent arbitrary code execution.",
UserWarning,
stacklevel=2,
)

exec_globals: dict[str, Any] = {}

# Execute imports first
for import_stmt in config.global_imports:
import_code = import_to_str(import_stmt)
try:
exec(import_code, exec_globals)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Failed to import {import_code}: Module not found. Please ensure the module is installed."
) from e
except ImportError as e:
raise ImportError(f"Failed to import {import_code}: {str(e)}") from e
except Exception as e:
raise RuntimeError(f"Unexpected error while importing {import_code}: {str(e)}") from e

# Execute function code
try:
exec(config.source_code, exec_globals)
func_name = config.source_code.split("def ")[1].split("(")[0]
except Exception as e:
raise ValueError(f"Could not compile and load function: {e}") from e

# Get function and verify it's callable
func: Callable[..., Any] = exec_globals[func_name]
if not callable(func):
raise TypeError(f"Expected function but got {type(func)}")

return cls(func, "", None)
69 changes: 68 additions & 1 deletion python/packages/autogen-core/tests/test_component_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from typing import Any, Dict

import pytest
from autogen_core import Component, ComponentBase, ComponentLoader, ComponentModel
from autogen_core import CancellationToken, Component, ComponentBase, ComponentLoader, ComponentModel
from autogen_core._component_config import _type_to_provider_str # type: ignore
from autogen_core.code_executor import ImportFromModule
from autogen_core.models import ChatCompletionClient
from autogen_core.tools import FunctionTool
from autogen_test_utils import MyInnerComponent, MyOuterComponent
from pydantic import BaseModel, ValidationError
from typing_extensions import Self
Expand Down Expand Up @@ -283,3 +285,68 @@ def test_component_version_from_dict() -> None:
assert comp.info == "test"
assert comp.__class__ == ComponentNonOneVersionWithUpgrade
assert comp.dump_component().version == 2


@pytest.mark.asyncio
async def test_function_tool() -> None:
"""Test FunctionTool with different function types and features."""

# Test sync and async functions
def sync_func(x: int, y: str) -> str:
return y * x

async def async_func(x: float, y: float, cancellation_token: CancellationToken) -> float:
if cancellation_token.is_cancelled():
raise Exception("Cancelled")
return x + y

# Create tools with different configurations
sync_tool = FunctionTool(
func=sync_func, description="Multiply string", global_imports=[ImportFromModule("typing", ("Dict",))]
)
invalid_import_sync_tool = FunctionTool(
func=sync_func, description="Multiply string", global_imports=[ImportFromModule("invalid_module (", ("Dict",))]
)

invalid_import_config = invalid_import_sync_tool.dump_component()
# check that invalid import raises an error
with pytest.raises(RuntimeError):
_ = FunctionTool.load_component(invalid_import_config, FunctionTool)

async_tool = FunctionTool(
func=async_func,
description="Add numbers",
name="custom_adder",
global_imports=[ImportFromModule("autogen_core", ("CancellationToken",))],
)

# Test serialization and config

sync_config = sync_tool.dump_component()
assert isinstance(sync_config, ComponentModel)
assert sync_config.config["name"] == "sync_func"
assert len(sync_config.config["global_imports"]) == 1
assert not sync_config.config["has_cancellation_support"]

async_config = async_tool.dump_component()
assert async_config.config["name"] == "custom_adder"
assert async_config.config["has_cancellation_support"]

# Test deserialization and execution
loaded_sync = FunctionTool.load_component(sync_config, FunctionTool)
loaded_async = FunctionTool.load_component(async_config, FunctionTool)

# Test execution and validation
token = CancellationToken()
assert await loaded_sync.run_json({"x": 2, "y": "test"}, token) == "testtest"
assert await loaded_async.run_json({"x": 1.5, "y": 2.5}, token) == 4.0

# Test error cases
with pytest.raises(ValueError):
# Type error
await loaded_sync.run_json({"x": "invalid", "y": "test"}, token)

cancelled_token = CancellationToken()
cancelled_token.cancel()
with pytest.raises(Exception, match="Cancelled"):
await loaded_async.run_json({"x": 1.0, "y": 2.0}, cancelled_token)
Loading