Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,5 @@ cython_debug/

# Node stuff
node_modules/

.DS_Store
54 changes: 39 additions & 15 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import time
from rich.console import Console
from rich.table import Table
from pathlib import Path

from computers import EnvState, Computer
from function_registry import FunctionRegistry

MAX_RECENT_TURN_WITH_SCREENSHOTS = 3
PREDEFINED_COMPUTER_USE_FUNCTIONS = [
Expand All @@ -55,11 +57,6 @@
FunctionResponseT = Union[EnvState, dict]


def multiply_numbers(x: float, y: float) -> dict:
"""Multiplies two numbers."""
return {"result": x * y}


class BrowserAgent:
def __init__(
self,
Expand All @@ -79,6 +76,14 @@ def __init__(
project=os.environ.get("VERTEXAI_PROJECT"),
location=os.environ.get("VERTEXAI_LOCATION"),
)
config_path = os.environ.get(
"FUNCTION_CONFIG_PATH",
str(Path(__file__).parent / "config" / "functions.json"),
)
self._function_registry = FunctionRegistry(
config_path=config_path,
client=self._client,
)
self._contents: list[Content] = [
Content(
role="user",
Expand All @@ -91,13 +96,7 @@ def __init__(
# Exclude any predefined functions here.
excluded_predefined_functions = []

# Add your own custom functions here.
custom_functions = [
# For example:
types.FunctionDeclaration.from_callable(
client=self._client, callable=multiply_numbers
)
]
custom_functions = self._function_registry.function_declarations()

self._generate_content_config = GenerateContentConfig(
temperature=1,
Expand Down Expand Up @@ -190,9 +189,15 @@ def handle_action(self, action: types.FunctionCall) -> FunctionResponseT:
destination_x=destination_x,
destination_y=destination_y,
)
# Handle the custom function declarations here.
elif action.name == multiply_numbers.__name__:
return multiply_numbers(x=action.args["x"], y=action.args["y"])
elif self._function_registry.has_function(action.name):
if not self._function_registry.is_whitelisted(action.name):
if not self._confirm_custom_function(action):
termcolor.cprint(
f"Custom function {action.name} denied by user.",
color="yellow",
)
return {"status": "rejected", "reason": "user_denied"}
return self._function_registry.execute(action.name, action.args)
else:
raise ValueError(f"Unsupported function: {action}")

Expand Down Expand Up @@ -389,6 +394,25 @@ def run_one_iteration(self) -> Literal["COMPLETE", "CONTINUE"]:

return "CONTINUE"

def _confirm_custom_function(
self, action: types.FunctionCall
) -> bool:
"""Prompt user before executing non-whitelisted custom functions."""
termcolor.cprint(
"Custom function requires confirmation!",
color="yellow",
attrs=["bold"],
)
print(f"Function: {action.name}")
print(f"Args: {action.args}")
risk_note = self._function_registry.risk_note(action.name)
if risk_note:
print(f"Risk: {risk_note}")
decision = ""
while decision.lower() not in ("y", "n", "ye", "yes", "no"):
decision = input("Do you wish to execute? [Yes]/[No]\n")
return decision.lower() in ("y", "ye", "yes")

def _get_safety_confirmation(
self, safety: dict[str, Any]
) -> Literal["CONTINUE", "TERMINATE"]:
Expand Down
13 changes: 13 additions & 0 deletions config/functions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"functions": [
{
"name": "multiply_numbers",
"module": "custom_functions.math",
"attribute": "multiply_numbers",
"description": "Multiply two numbers and return the product.",
"whitelist": true,
"risk_note": "Safe arithmetic operation."
}
]
}

2 changes: 2 additions & 0 deletions custom_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Placeholder module for custom function plugins.

4 changes: 4 additions & 0 deletions custom_functions/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def multiply_numbers(x: float, y: float) -> dict:
"""Multiplies two numbers."""
return {"result": x * y}

152 changes: 152 additions & 0 deletions function_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import importlib
import inspect
import json
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional

import termcolor
from google.genai import types


@dataclass
class FunctionSpec:
"""Configuration for a single custom function."""

name: str
module: str
attribute: str
description: Optional[str]
whitelist: bool
risk_note: Optional[str]


class FunctionRegistry:
"""Loads custom functions from configuration and exposes declarations/execution."""

def __init__(self, config_path: str, client: Any):
self._config_path = config_path
self._client = client
self._specs: Dict[str, FunctionSpec] = {}
self._callables: Dict[str, Callable] = {}
self._load_config()

def _load_config(self) -> None:
"""Load function specs and resolve callables."""
if not os.path.exists(self._config_path):
termcolor.cprint(
f"Function config not found at {self._config_path}; no custom functions loaded.",
color="yellow",
)
return

try:
with open(self._config_path, "r", encoding="utf-8") as config_file:
config = json.load(config_file) or {}
except Exception as exc:
termcolor.cprint(
f"Failed to read function config {self._config_path}: {exc}",
color="red",
)
return

for entry in config.get("functions", []):
try:
spec = FunctionSpec(
name=entry["name"],
module=entry["module"],
attribute=entry.get("attribute", entry["name"]),
description=entry.get("description"),
whitelist=bool(entry.get("whitelist", False)),
risk_note=entry.get("risk_note"),
)
except KeyError as exc:
termcolor.cprint(
f"Invalid function config entry missing required key {exc}: {entry}",
color="red",
)
continue

resolved = self._import_callable(spec)
if resolved:
self._specs[spec.name] = spec
self._callables[spec.name] = resolved

def _import_callable(self, spec: FunctionSpec) -> Optional[Callable]:
"""Import callable from module according to spec."""
try:
module = importlib.import_module(spec.module)
except Exception as exc:
termcolor.cprint(
f"Failed to import module {spec.module} for {spec.name}: {exc}",
color="red",
)
return None

try:
func = getattr(module, spec.attribute)
except AttributeError:
termcolor.cprint(
f"Attribute {spec.attribute} not found in module {spec.module}",
color="red",
)
return None

if not callable(func):
termcolor.cprint(
f"{spec.attribute} in module {spec.module} is not callable",
color="red",
)
return None

if spec.description and not (func.__doc__ and func.__doc__.strip()):
func.__doc__ = spec.description
return func

def function_declarations(self) -> List[types.FunctionDeclaration]:
"""Create function declarations for all loaded functions."""
declarations: List[types.FunctionDeclaration] = []
for name, func in self._callables.items():
try:
declarations.append(
types.FunctionDeclaration.from_callable(
client=self._client,
callable=func,
)
)
except Exception as exc:
termcolor.cprint(
f"Failed to build declaration for {name}: {exc}",
color="red",
)
return declarations

def has_function(self, name: str) -> bool:
return name in self._callables

def is_whitelisted(self, name: str) -> bool:
return bool(self._specs.get(name) and self._specs[name].whitelist)

def risk_note(self, name: str) -> Optional[str]:
spec = self._specs.get(name)
if not spec:
return None
return spec.risk_note

def execute(self, name: str, args: dict) -> dict:
if name not in self._callables:
raise ValueError(f"Function {name} is not registered.")

func = self._callables[name]
signature = inspect.signature(func)
try:
bound_args = signature.bind(**args)
bound_args.apply_defaults()
except TypeError as exc:
termcolor.cprint(
f"Invalid arguments for {name}: {exc}",
color="red",
)
raise
return func(*bound_args.args, **bound_args.kwargs)

44 changes: 43 additions & 1 deletion test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tempfile
import unittest
from unittest.mock import MagicMock, patch
from google.genai import types
from agent import BrowserAgent, multiply_numbers
from agent import BrowserAgent
from computers import EnvState
from function_registry import FunctionRegistry
from custom_functions.math import multiply_numbers

class TestBrowserAgent(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -62,6 +66,44 @@ def test_handle_action_navigate(self):
self.agent.handle_action(action)
self.mock_browser_computer.navigate.assert_called_once_with("https://example.com")

def test_function_registry_load_and_execute(self):
config_payload = {
"functions": [
{
"name": "multiply_numbers",
"module": "custom_functions.math",
"attribute": "multiply_numbers",
"description": "Multiply two numbers.",
"whitelist": True,
}
]
}
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as temp_config:
json.dump(config_payload, temp_config)
temp_path = temp_config.name
registry = FunctionRegistry(config_path=temp_path, client=MagicMock())
self.assertTrue(registry.has_function("multiply_numbers"))
self.assertTrue(registry.is_whitelisted("multiply_numbers"))
self.assertEqual(
registry.execute("multiply_numbers", {"x": 2, "y": 3}),
{"result": 6},
)
os.remove(temp_path)

@patch("agent.input", return_value="yes")
def test_handle_action_custom_function_requires_confirmation(self, mock_input):
mock_registry = MagicMock()
mock_registry.has_function.return_value = True
mock_registry.is_whitelisted.return_value = False
mock_registry.risk_note.return_value = "Risky operation"
mock_registry.execute.return_value = {"status": "ok"}
self.agent._function_registry = mock_registry
action = types.FunctionCall(name="custom_fn", args={"x": 1})
result = self.agent.handle_action(action)
self.assertEqual(result, {"status": "ok"})
mock_registry.execute.assert_called_once_with("custom_fn", {"x": 1})
mock_input.assert_called()

def test_handle_action_unknown_function(self):
action = types.FunctionCall(name="unknown_function", args={})
with self.assertRaises(ValueError):
Expand Down