diff --git a/src/google/adk/tools/authenticated_function_tool.py b/src/google/adk/tools/authenticated_function_tool.py new file mode 100644 index 000000000..67cc5885f --- /dev/null +++ b/src/google/adk/tools/authenticated_function_tool.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import logging +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Union + +from typing_extensions import override + +from ..auth.auth_credential import AuthCredential +from ..auth.auth_tool import AuthConfig +from ..auth.credential_manager import CredentialManager +from ..utils.feature_decorator import experimental +from .function_tool import FunctionTool +from .tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class AuthenticatedFunctionTool(FunctionTool): + """A FunctionTool that handles authentication before the actual tool logic + gets called. Functions can accept a special `credential` argument which is the + credential ready for use.(Experimental) + """ + + def __init__( + self, + *, + func: Callable[..., Any], + auth_config: AuthConfig = None, + response_for_auth_required: Optional[Union[dict[str, Any], str]] = None, + ): + """Initializes the AuthenticatedFunctionTool. + + Args: + func: The function to be called. + auth_config: The authentication configuration. + response_for_auth_required: The response to return when the tool is + requesting auth credential from the client. There could be two case, + the tool doesn't configure any credentials + (auth_config.raw_auth_credential is missing) or the credentials + configured is not enough to authenticate the tool (e.g. an OAuth + client id and client secrect is configured.) and needs client input + (e.g. client need to involve the end user in an oauth flow and get + back the oauth response.) + """ + super().__init__(func=func) + self._ignore_params.append("credential") + + if auth_config and auth_config.auth_scheme: + self._credentials_manager = CredentialManager(auth_config=auth_config) + else: + logger.warning( + "auth_config or auth_config.auth_scheme is missing. Will skip" + " authentication.Using FunctionTool instead if authentication is not" + " required." + ) + self._credentials_manager = None + self._response_for_auth_required = response_for_auth_required + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + credential = None + if self._credentials_manager: + credential = await self._credentials_manager.get_auth_credential( + tool_context + ) + if not credential: + await self._credentials_manager.request_credential(tool_context) + return self._response_for_auth_required or "Pending User Authorization." + + return await self._run_async_impl( + args=args, tool_context=tool_context, credential=credential + ) + + async def _run_async_impl( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + credential: AuthCredential, + ) -> Any: + args_to_call = args.copy() + signature = inspect.signature(self.func) + if "credential" in signature.parameters: + args_to_call["credential"] = credential + return await super().run_async(args=args_to_call, tool_context=tool_context) diff --git a/src/google/adk/tools/base_authenticated_tool.py b/src/google/adk/tools/base_authenticated_tool.py new file mode 100644 index 000000000..4858e4953 --- /dev/null +++ b/src/google/adk/tools/base_authenticated_tool.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +import logging +from typing import Any +from typing import Optional +from typing import Union + +from typing_extensions import override + +from ..auth.auth_credential import AuthCredential +from ..auth.auth_tool import AuthConfig +from ..auth.credential_manager import CredentialManager +from ..utils.feature_decorator import experimental +from .base_tool import BaseTool +from .tool_context import ToolContext + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class BaseAuthenticatedTool(BaseTool): + """A base tool class that handles authentication before the actual tool logic + gets called. Functions can accept a special `credential` argument which is the + credential ready for use.(Experimental) + """ + + def __init__( + self, + *, + name, + description, + auth_config: AuthConfig = None, + response_for_auth_required: Optional[Union[dict[str, Any], str]] = None, + ): + """ + Args: + name: The name of the tool. + description: The description of the tool. + auth_config: The auth configuration of the tool. + response_for_auth_required: The response to return when the tool is + requesting auth credential from the client. There could be two case, + the tool doesn't configure any credentials + (auth_config.raw_auth_credential is missing) or the credentials + configured is not enough to authenticate the tool (e.g. an OAuth + client id and client secrect is configured.) and needs client input + (e.g. client need to involve the end user in an oauth flow and get + back the oauth response.) + """ + super().__init__( + name=name, + description=description, + ) + + if auth_config and auth_config.auth_scheme: + self._credentials_manager = CredentialManager(auth_config=auth_config) + else: + logger.warning( + "auth_config or auth_config.auth_scheme is missing. Will skip" + " authentication.Using FunctionTool instead if authentication is not" + " required." + ) + self._credentials_manager = None + self._response_for_auth_required = response_for_auth_required + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + credential = None + if self._credentials_manager: + credential = await self._credentials_manager.get_auth_credential( + tool_context + ) + if not credential: + await self._credentials_manager.request_credential(tool_context) + return self._response_for_auth_required or "Pending User Authorization." + + return await self._run_async_impl( + args=args, + tool_context=tool_context, + credential=credential, + ) + + @abstractmethod + async def _run_async_impl( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + credential: AuthCredential, + ) -> Any: + pass diff --git a/tests/unittests/tools/test_authenticated_function_tool.py b/tests/unittests/tools/test_authenticated_function_tool.py new file mode 100644 index 000000000..88454032a --- /dev/null +++ b/tests/unittests/tools/test_authenticated_function_tool.py @@ -0,0 +1,541 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_tool import AuthConfig +from google.adk.tools.authenticated_function_tool import AuthenticatedFunctionTool +from google.adk.tools.tool_context import ToolContext +import pytest + +# Test functions for different scenarios + + +def sync_function_no_credential(arg1: str, arg2: int) -> str: + """Test sync function without credential parameter.""" + return f"sync_result_{arg1}_{arg2}" + + +async def async_function_no_credential(arg1: str, arg2: int) -> str: + """Test async function without credential parameter.""" + return f"async_result_{arg1}_{arg2}" + + +def sync_function_with_credential(arg1: str, credential: AuthCredential) -> str: + """Test sync function with credential parameter.""" + return f"sync_cred_result_{arg1}_{credential.auth_type.value}" + + +async def async_function_with_credential( + arg1: str, credential: AuthCredential +) -> str: + """Test async function with credential parameter.""" + return f"async_cred_result_{arg1}_{credential.auth_type.value}" + + +def sync_function_with_tool_context( + arg1: str, tool_context: ToolContext +) -> str: + """Test sync function with tool_context parameter.""" + return f"sync_context_result_{arg1}" + + +async def async_function_with_both( + arg1: str, tool_context: ToolContext, credential: AuthCredential +) -> str: + """Test async function with both tool_context and credential parameters.""" + return f"async_both_result_{arg1}_{credential.auth_type.value}" + + +def function_with_optional_args( + arg1: str, arg2: str = "default", credential: AuthCredential = None +) -> str: + """Test function with optional arguments.""" + cred_type = credential.auth_type.value if credential else "none" + return f"optional_result_{arg1}_{arg2}_{cred_type}" + + +class MockCallable: + """Test callable class for testing.""" + + def __init__(self): + self.__name__ = "MockCallable" + self.__doc__ = "Test callable documentation" + + def __call__(self, arg1: str, credential: AuthCredential) -> str: + return f"callable_result_{arg1}_{credential.auth_type.value}" + + +def _create_mock_auth_config(): + """Creates a mock AuthConfig with proper structure.""" + auth_scheme = Mock(spec=AuthScheme) + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = auth_scheme + + return auth_config + + +def _create_mock_auth_credential(): + """Creates a mock AuthCredential.""" + credential = Mock(spec=AuthCredential) + # Create a mock auth_type that returns the expected value + mock_auth_type = Mock() + mock_auth_type.value = "oauth2" + credential.auth_type = mock_auth_type + return credential + + +class TestAuthenticatedFunctionTool: + """Test suite for AuthenticatedFunctionTool.""" + + def test_init_with_sync_function(self): + """Test initialization with synchronous function.""" + auth_config = _create_mock_auth_config() + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, + auth_config=auth_config, + response_for_auth_required="Please authenticate", + ) + + assert tool.name == "sync_function_no_credential" + assert ( + tool.description == "Test sync function without credential parameter." + ) + assert tool.func == sync_function_no_credential + assert tool._credentials_manager is not None + assert tool._response_for_auth_required == "Please authenticate" + assert "credential" in tool._ignore_params + + def test_init_with_async_function(self): + """Test initialization with asynchronous function.""" + auth_config = _create_mock_auth_config() + + tool = AuthenticatedFunctionTool( + func=async_function_no_credential, auth_config=auth_config + ) + + assert tool.name == "async_function_no_credential" + assert ( + tool.description == "Test async function without credential parameter." + ) + assert tool.func == async_function_no_credential + assert tool._response_for_auth_required is None + + def test_init_with_callable(self): + """Test initialization with callable object.""" + auth_config = _create_mock_auth_config() + test_callable = MockCallable() + + tool = AuthenticatedFunctionTool( + func=test_callable, auth_config=auth_config + ) + + assert tool.name == "MockCallable" + assert tool.description == "Test callable documentation" + assert tool.func == test_callable + + def test_init_no_auth_config(self): + """Test initialization without auth_config.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + + assert tool._credentials_manager is None + assert tool._response_for_auth_required is None + + def test_init_with_empty_auth_scheme(self): + """Test initialization with auth_config but no auth_scheme.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = None + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, auth_config=auth_config + ) + + assert tool._credentials_manager is None + + @pytest.mark.asyncio + async def test_run_async_sync_function_no_credential_manager(self): + """Test run_async with sync function when no credential manager is configured.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "sync_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_async_function_no_credential_manager(self): + """Test run_async with async function when no credential manager is configured.""" + tool = AuthenticatedFunctionTool(func=async_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "async_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_with_valid_credential(self): + """Test run_async when valid credential is available.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"sync_cred_result_test_{credential.auth_type.value}" + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_async_function_with_credential(self): + """Test run_async with async function that expects credential.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=async_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"async_cred_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_no_credential_available(self): + """Test run_async when no credential is available.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, + auth_config=auth_config, + response_for_auth_required="Custom auth required", + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Custom auth required" + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_default_message(self): + """Test run_async when no credential is available with default message.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Pending User Authorization." + + @pytest.mark.asyncio + async def test_run_async_function_without_credential_param(self): + """Test run_async with function that doesn't have credential parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_no_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test", "arg2": 42} + + result = await tool.run_async(args=args, tool_context=tool_context) + + # Credential should not be passed to function since it doesn't have the parameter + assert result == "sync_result_test_42" + + @pytest.mark.asyncio + async def test_run_async_function_with_tool_context(self): + """Test run_async with function that has tool_context parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_tool_context, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "sync_context_result_test" + + @pytest.mark.asyncio + async def test_run_async_function_with_both_params(self): + """Test run_async with function that has both tool_context and credential parameters.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=async_function_with_both, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"async_both_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_function_with_optional_credential(self): + """Test run_async with function that has optional credential parameter.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=function_with_optional_args, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert ( + result == f"optional_result_test_default_{credential.auth_type.value}" + ) + + @pytest.mark.asyncio + async def test_run_async_callable_object(self): + """Test run_async with callable object.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + test_callable = MockCallable() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=test_callable, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == f"callable_result_test_{credential.auth_type.value}" + + @pytest.mark.asyncio + async def test_run_async_propagates_function_exception(self): + """Test that run_async propagates exceptions from the wrapped function.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + def failing_function(arg1: str, credential: AuthCredential) -> str: + raise ValueError("Function failed") + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = AuthenticatedFunctionTool( + func=failing_function, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + with pytest.raises(ValueError, match="Function failed"): + await tool.run_async(args=args, tool_context=tool_context) + + @pytest.mark.asyncio + async def test_run_async_missing_required_args(self): + """Test run_async with missing required arguments.""" + tool = AuthenticatedFunctionTool(func=sync_function_no_credential) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} # Missing arg2 + + result = await tool.run_async(args=args, tool_context=tool_context) + + # Should return error dict indicating missing parameters + assert isinstance(result, dict) + assert "error" in result + assert "arg2" in result["error"] + + @pytest.mark.asyncio + async def test_run_async_credentials_manager_exception(self): + """Test run_async when credentials manager raises an exception.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to raise an exception + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + side_effect=RuntimeError("Credential service error") + ) + + tool = AuthenticatedFunctionTool( + func=sync_function_with_credential, auth_config=auth_config + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + with pytest.raises(RuntimeError, match="Credential service error"): + await tool.run_async(args=args, tool_context=tool_context) + + def test_credential_in_ignore_params(self): + """Test that 'credential' is added to ignore_params during initialization.""" + tool = AuthenticatedFunctionTool(func=sync_function_with_credential) + + assert "credential" in tool._ignore_params + + @pytest.mark.asyncio + async def test_run_async_with_none_credential(self): + """Test run_async when credential is None but function expects it.""" + tool = AuthenticatedFunctionTool(func=function_with_optional_args) + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "optional_result_test_default_none" + + def test_signature_inspection(self): + """Test that the tool correctly inspects function signatures.""" + tool = AuthenticatedFunctionTool(func=sync_function_with_credential) + + signature = inspect.signature(tool.func) + assert "credential" in signature.parameters + assert "arg1" in signature.parameters + + @pytest.mark.asyncio + async def test_args_to_call_modification(self): + """Test that args_to_call is properly modified with credential.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + # Create a spy function to check what arguments are passed + original_args = {} + + def spy_function(arg1: str, credential: AuthCredential) -> str: + nonlocal original_args + original_args = {"arg1": arg1, "credential": credential} + return "spy_result" + + tool = AuthenticatedFunctionTool(func=spy_function, auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"arg1": "test"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "spy_result" + assert original_args is not None + assert original_args["arg1"] == "test" + assert original_args["credential"] == credential diff --git a/tests/unittests/tools/test_base_authenticated_tool.py b/tests/unittests/tools/test_base_authenticated_tool.py new file mode 100644 index 000000000..55454224d --- /dev/null +++ b/tests/unittests/tools/test_base_authenticated_tool.py @@ -0,0 +1,343 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_schemes import AuthScheme +from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_tool import AuthConfig +from google.adk.tools.base_authenticated_tool import BaseAuthenticatedTool +from google.adk.tools.tool_context import ToolContext +import pytest + + +class _TestAuthenticatedTool(BaseAuthenticatedTool): + """Test implementation of BaseAuthenticatedTool for testing purposes.""" + + def __init__( + self, + name="test_auth_tool", + description="Test authenticated tool", + auth_config=None, + unauthenticated_response=None, + ): + super().__init__( + name=name, + description=description, + auth_config=auth_config, + response_for_auth_required=unauthenticated_response, + ) + self.run_impl_called = False + self.run_impl_result = "test_result" + + async def _run_async_impl(self, *, args, tool_context, credential): + """Test implementation of the abstract method.""" + self.run_impl_called = True + self.last_args = args + self.last_tool_context = tool_context + self.last_credential = credential + return self.run_impl_result + + +def _create_mock_auth_config(): + """Creates a mock AuthConfig with proper structure.""" + auth_scheme = Mock(spec=AuthScheme) + auth_scheme.type_ = AuthSchemeType.oauth2 + + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = auth_scheme + + return auth_config + + +def _create_mock_auth_credential(): + """Creates a mock AuthCredential.""" + credential = Mock(spec=AuthCredential) + credential.auth_type = AuthCredentialTypes.OAUTH2 + return credential + + +class TestBaseAuthenticatedTool: + """Test suite for BaseAuthenticatedTool.""" + + def test_init_with_auth_config(self): + """Test initialization with auth_config.""" + auth_config = _create_mock_auth_config() + unauthenticated_response = {"error": "Not authenticated"} + + tool = _TestAuthenticatedTool( + name="test_tool", + description="Test description", + auth_config=auth_config, + unauthenticated_response=unauthenticated_response, + ) + + assert tool.name == "test_tool" + assert tool.description == "Test description" + assert tool._credentials_manager is not None + assert tool._response_for_auth_required == unauthenticated_response + + def test_init_with_no_auth_config(self): + """Test initialization without auth_config.""" + tool = _TestAuthenticatedTool() + + assert tool.name == "test_auth_tool" + assert tool.description == "Test authenticated tool" + assert tool._credentials_manager is None + assert tool._response_for_auth_required is None + + def test_init_with_empty_auth_scheme(self): + """Test initialization with auth_config but no auth_scheme.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = None + + tool = _TestAuthenticatedTool(auth_config=auth_config) + + assert tool._credentials_manager is None + + def test_init_with_default_unauthenticated_response(self): + """Test initialization with default unauthenticated response.""" + auth_config = _create_mock_auth_config() + + tool = _TestAuthenticatedTool(auth_config=auth_config) + + assert tool._response_for_auth_required is None + + @pytest.mark.asyncio + async def test_run_async_no_credentials_manager(self): + """Test run_async when no credentials manager is configured.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "test_result" + assert tool.run_impl_called + assert tool.last_args == args + assert tool.last_tool_context == tool_context + assert tool.last_credential is None + + @pytest.mark.asyncio + async def test_run_async_with_valid_credential(self): + """Test run_async when valid credential is available.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "test_result" + assert tool.run_impl_called + assert tool.last_args == args + assert tool.last_tool_context == tool_context + assert tool.last_credential == credential + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_available(self): + """Test run_async when no credential is available.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == "Pending User Authorization." + assert not tool.run_impl_called + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_with_custom_response(self): + """Test run_async when no credential is available with custom response.""" + auth_config = _create_mock_auth_config() + custom_response = { + "status": "authentication_required", + "message": "Please login", + } + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool( + auth_config=auth_config, unauthenticated_response=custom_response + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == custom_response + assert not tool.run_impl_called + mock_credentials_manager.get_auth_credential.assert_called_once_with( + tool_context + ) + mock_credentials_manager.request_credential.assert_called_once_with( + tool_context + ) + + @pytest.mark.asyncio + async def test_run_async_no_credential_with_string_response(self): + """Test run_async when no credential is available with string response.""" + auth_config = _create_mock_auth_config() + custom_response = "Custom authentication required message" + + # Mock the credentials manager to return None + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock(return_value=None) + mock_credentials_manager.request_credential = AsyncMock() + + tool = _TestAuthenticatedTool( + auth_config=auth_config, unauthenticated_response=custom_response + ) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == custom_response + assert not tool.run_impl_called + + @pytest.mark.asyncio + async def test_run_async_propagates_impl_exception(self): + """Test that run_async propagates exceptions from _run_async_impl.""" + auth_config = _create_mock_auth_config() + credential = _create_mock_auth_credential() + + # Mock the credentials manager + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + return_value=credential + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + # Make the implementation raise an exception + async def failing_impl(*, args, tool_context, credential): + raise ValueError("Implementation failed") + + tool._run_async_impl = failing_impl + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + with pytest.raises(ValueError, match="Implementation failed"): + await tool.run_async(args=args, tool_context=tool_context) + + @pytest.mark.asyncio + async def test_run_async_with_different_args_types(self): + """Test run_async with different argument types.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + + # Test with empty args + result = await tool.run_async(args={}, tool_context=tool_context) + assert result == "test_result" + assert tool.last_args == {} + + # Test with complex args + complex_args = { + "string_param": "test", + "number_param": 42, + "list_param": [1, 2, 3], + "dict_param": {"nested": "value"}, + } + result = await tool.run_async(args=complex_args, tool_context=tool_context) + assert result == "test_result" + assert tool.last_args == complex_args + + @pytest.mark.asyncio + async def test_run_async_credentials_manager_exception(self): + """Test run_async when credentials manager raises an exception.""" + auth_config = _create_mock_auth_config() + + # Mock the credentials manager to raise an exception + mock_credentials_manager = AsyncMock() + mock_credentials_manager.get_auth_credential = AsyncMock( + side_effect=RuntimeError("Credential service error") + ) + + tool = _TestAuthenticatedTool(auth_config=auth_config) + tool._credentials_manager = mock_credentials_manager + + tool_context = Mock(spec=ToolContext) + args = {"param1": "value1"} + + with pytest.raises(RuntimeError, match="Credential service error"): + await tool.run_async(args=args, tool_context=tool_context) + + def test_abstract_nature(self): + """Test that BaseAuthenticatedTool cannot be instantiated directly.""" + with pytest.raises(TypeError): + # This should fail because _run_async_impl is abstract + BaseAuthenticatedTool(name="test", description="test") + + @pytest.mark.asyncio + async def test_run_async_return_values(self): + """Test run_async with different return value types.""" + tool = _TestAuthenticatedTool() + tool_context = Mock(spec=ToolContext) + args = {} + + # Test with None return + tool.run_impl_result = None + result = await tool.run_async(args=args, tool_context=tool_context) + assert result is None + + # Test with dict return + tool.run_impl_result = {"key": "value"} + result = await tool.run_async(args=args, tool_context=tool_context) + assert result == {"key": "value"} + + # Test with list return + tool.run_impl_result = [1, 2, 3] + result = await tool.run_async(args=args, tool_context=tool_context) + assert result == [1, 2, 3]