Skip to content

Commit 1df5327

Browse files
seanzhougooglecopybara-github
authored andcommitted
chore: Add in memory credential service (Experimental)
PiperOrigin-RevId: 771468462
1 parent 1ae176a commit 1df5327

File tree

2 files changed

+387
-0
lines changed

2 files changed

+387
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Optional
18+
19+
from typing_extensions import override
20+
21+
from ...tools.tool_context import ToolContext
22+
from ...utils.feature_decorator import experimental
23+
from ..auth_credential import AuthCredential
24+
from ..auth_tool import AuthConfig
25+
from .base_credential_service import BaseCredentialService
26+
27+
28+
@experimental
29+
class InMemoryCredentialService(BaseCredentialService):
30+
"""Class for in memory implementation of credential service(Experimental)"""
31+
32+
def __init__(self):
33+
super().__init__()
34+
self._credentials = {}
35+
36+
@override
37+
async def load_credential(
38+
self,
39+
auth_config: AuthConfig,
40+
tool_context: ToolContext,
41+
) -> Optional[AuthCredential]:
42+
credential_bucket = self._get_bucket_for_current_context(tool_context)
43+
return credential_bucket.get(auth_config.credential_key)
44+
45+
@override
46+
async def save_credential(
47+
self,
48+
auth_config: AuthConfig,
49+
tool_context: ToolContext,
50+
) -> None:
51+
credential_bucket = self._get_bucket_for_current_context(tool_context)
52+
credential_bucket[auth_config.credential_key] = (
53+
auth_config.exchanged_auth_credential
54+
)
55+
56+
def _get_bucket_for_current_context(self, tool_context: ToolContext) -> str:
57+
app_name = tool_context._invocation_context.app_name
58+
user_id = tool_context._invocation_context.user_id
59+
60+
if app_name not in self._credentials:
61+
self._credentials[app_name] = {}
62+
if user_id not in self._credentials[app_name]:
63+
self._credentials[app_name][user_id] = {}
64+
return self._credentials[app_name][user_id]
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import Mock
16+
17+
from fastapi.openapi.models import OAuth2
18+
from fastapi.openapi.models import OAuthFlowAuthorizationCode
19+
from fastapi.openapi.models import OAuthFlows
20+
from google.adk.auth.auth_credential import AuthCredential
21+
from google.adk.auth.auth_credential import AuthCredentialTypes
22+
from google.adk.auth.auth_credential import OAuth2Auth
23+
from google.adk.auth.auth_tool import AuthConfig
24+
from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
25+
from google.adk.tools.tool_context import ToolContext
26+
import pytest
27+
28+
29+
class TestInMemoryCredentialService:
30+
"""Tests for the InMemoryCredentialService class."""
31+
32+
@pytest.fixture
33+
def credential_service(self):
34+
"""Create an InMemoryCredentialService instance for testing."""
35+
return InMemoryCredentialService()
36+
37+
@pytest.fixture
38+
def oauth2_auth_scheme(self):
39+
"""Create an OAuth2 auth scheme for testing."""
40+
flows = OAuthFlows(
41+
authorizationCode=OAuthFlowAuthorizationCode(
42+
authorizationUrl="https://example.com/oauth2/authorize",
43+
tokenUrl="https://example.com/oauth2/token",
44+
scopes={"read": "Read access", "write": "Write access"},
45+
)
46+
)
47+
return OAuth2(flows=flows)
48+
49+
@pytest.fixture
50+
def oauth2_credentials(self):
51+
"""Create OAuth2 credentials for testing."""
52+
return AuthCredential(
53+
auth_type=AuthCredentialTypes.OAUTH2,
54+
oauth2=OAuth2Auth(
55+
client_id="mock_client_id",
56+
client_secret="mock_client_secret",
57+
redirect_uri="https://example.com/callback",
58+
),
59+
)
60+
61+
@pytest.fixture
62+
def auth_config(self, oauth2_auth_scheme, oauth2_credentials):
63+
"""Create an AuthConfig for testing."""
64+
exchanged_credential = oauth2_credentials.model_copy(deep=True)
65+
return AuthConfig(
66+
auth_scheme=oauth2_auth_scheme,
67+
raw_auth_credential=oauth2_credentials,
68+
exchanged_auth_credential=exchanged_credential,
69+
)
70+
71+
@pytest.fixture
72+
def tool_context(self):
73+
"""Create a mock ToolContext for testing."""
74+
mock_context = Mock(spec=ToolContext)
75+
mock_invocation_context = Mock()
76+
mock_invocation_context.app_name = "test_app"
77+
mock_invocation_context.user_id = "test_user"
78+
mock_context._invocation_context = mock_invocation_context
79+
return mock_context
80+
81+
@pytest.fixture
82+
def another_tool_context(self):
83+
"""Create another mock ToolContext with different app/user for testing isolation."""
84+
mock_context = Mock(spec=ToolContext)
85+
mock_invocation_context = Mock()
86+
mock_invocation_context.app_name = "another_app"
87+
mock_invocation_context.user_id = "another_user"
88+
mock_context._invocation_context = mock_invocation_context
89+
return mock_context
90+
91+
def test_init(self, credential_service):
92+
"""Test that the service initializes with an empty store."""
93+
assert isinstance(credential_service._credentials, dict)
94+
assert len(credential_service._credentials) == 0
95+
96+
@pytest.mark.asyncio
97+
async def test_load_credential_not_found(
98+
self, credential_service, auth_config, tool_context
99+
):
100+
"""Test loading a credential that doesn't exist returns None."""
101+
result = await credential_service.load_credential(auth_config, tool_context)
102+
assert result is None
103+
104+
@pytest.mark.asyncio
105+
async def test_save_and_load_credential(
106+
self, credential_service, auth_config, tool_context
107+
):
108+
"""Test saving and then loading a credential."""
109+
# Save the credential
110+
await credential_service.save_credential(auth_config, tool_context)
111+
112+
# Load the credential
113+
result = await credential_service.load_credential(auth_config, tool_context)
114+
115+
# Verify the credential was saved and loaded correctly
116+
assert result is not None
117+
assert result == auth_config.exchanged_auth_credential
118+
assert result.auth_type == AuthCredentialTypes.OAUTH2
119+
assert result.oauth2.client_id == "mock_client_id"
120+
121+
@pytest.mark.asyncio
122+
async def test_save_credential_updates_existing(
123+
self, credential_service, auth_config, tool_context, oauth2_credentials
124+
):
125+
"""Test that saving a credential updates an existing one."""
126+
# Save initial credential
127+
await credential_service.save_credential(auth_config, tool_context)
128+
129+
# Create a new credential and update the auth_config
130+
new_credential = AuthCredential(
131+
auth_type=AuthCredentialTypes.OAUTH2,
132+
oauth2=OAuth2Auth(
133+
client_id="updated_client_id",
134+
client_secret="updated_client_secret",
135+
redirect_uri="https://updated.com/callback",
136+
),
137+
)
138+
auth_config.exchanged_auth_credential = new_credential
139+
140+
# Save the updated credential
141+
await credential_service.save_credential(auth_config, tool_context)
142+
143+
# Load and verify the credential was updated
144+
result = await credential_service.load_credential(auth_config, tool_context)
145+
assert result is not None
146+
assert result.oauth2.client_id == "updated_client_id"
147+
assert result.oauth2.client_secret == "updated_client_secret"
148+
149+
@pytest.mark.asyncio
150+
async def test_credentials_isolated_by_context(
151+
self, credential_service, auth_config, tool_context, another_tool_context
152+
):
153+
"""Test that credentials are isolated between different app/user contexts."""
154+
# Save credential in first context
155+
await credential_service.save_credential(auth_config, tool_context)
156+
157+
# Try to load from another context
158+
result = await credential_service.load_credential(
159+
auth_config, another_tool_context
160+
)
161+
assert result is None
162+
163+
# Verify original context still has the credential
164+
result = await credential_service.load_credential(auth_config, tool_context)
165+
assert result is not None
166+
167+
@pytest.mark.asyncio
168+
async def test_multiple_credentials_same_context(
169+
self, credential_service, tool_context, oauth2_auth_scheme
170+
):
171+
"""Test storing multiple credentials in the same context with different keys."""
172+
# Create two different auth configs with different credential keys
173+
cred1 = AuthCredential(
174+
auth_type=AuthCredentialTypes.OAUTH2,
175+
oauth2=OAuth2Auth(
176+
client_id="client1",
177+
client_secret="secret1",
178+
redirect_uri="https://example1.com/callback",
179+
),
180+
)
181+
182+
cred2 = AuthCredential(
183+
auth_type=AuthCredentialTypes.OAUTH2,
184+
oauth2=OAuth2Auth(
185+
client_id="client2",
186+
client_secret="secret2",
187+
redirect_uri="https://example2.com/callback",
188+
),
189+
)
190+
191+
auth_config1 = AuthConfig(
192+
auth_scheme=oauth2_auth_scheme,
193+
raw_auth_credential=cred1,
194+
exchanged_auth_credential=cred1,
195+
credential_key="key1",
196+
)
197+
198+
auth_config2 = AuthConfig(
199+
auth_scheme=oauth2_auth_scheme,
200+
raw_auth_credential=cred2,
201+
exchanged_auth_credential=cred2,
202+
credential_key="key2",
203+
)
204+
205+
# Save both credentials
206+
await credential_service.save_credential(auth_config1, tool_context)
207+
await credential_service.save_credential(auth_config2, tool_context)
208+
209+
# Load and verify both credentials
210+
result1 = await credential_service.load_credential(
211+
auth_config1, tool_context
212+
)
213+
result2 = await credential_service.load_credential(
214+
auth_config2, tool_context
215+
)
216+
217+
assert result1 is not None
218+
assert result2 is not None
219+
assert result1.oauth2.client_id == "client1"
220+
assert result2.oauth2.client_id == "client2"
221+
222+
def test_get_bucket_for_current_context_creates_nested_structure(
223+
self, credential_service, tool_context
224+
):
225+
"""Test that _get_bucket_for_current_context creates the proper nested structure."""
226+
storage = credential_service._get_bucket_for_current_context(tool_context)
227+
228+
# Verify the nested structure was created
229+
assert "test_app" in credential_service._credentials
230+
assert "test_user" in credential_service._credentials["test_app"]
231+
assert isinstance(storage, dict)
232+
assert storage is credential_service._credentials["test_app"]["test_user"]
233+
234+
def test_get_bucket_for_current_context_reuses_existing(
235+
self, credential_service, tool_context
236+
):
237+
"""Test that _get_bucket_for_current_context reuses existing structure."""
238+
# Create initial structure
239+
storage1 = credential_service._get_bucket_for_current_context(tool_context)
240+
storage1["test_key"] = "test_value"
241+
242+
# Get storage again
243+
storage2 = credential_service._get_bucket_for_current_context(tool_context)
244+
245+
# Verify it's the same storage instance
246+
assert storage1 is storage2
247+
assert storage2["test_key"] == "test_value"
248+
249+
def test_get_storage_different_apps(
250+
self, credential_service, tool_context, another_tool_context
251+
):
252+
"""Test that different apps get different storage instances."""
253+
storage1 = credential_service._get_bucket_for_current_context(tool_context)
254+
storage2 = credential_service._get_bucket_for_current_context(
255+
another_tool_context
256+
)
257+
258+
# Verify they are different storage instances
259+
assert storage1 is not storage2
260+
261+
# Verify the structure
262+
assert "test_app" in credential_service._credentials
263+
assert "another_app" in credential_service._credentials
264+
assert "test_user" in credential_service._credentials["test_app"]
265+
assert "another_user" in credential_service._credentials["another_app"]
266+
267+
@pytest.mark.asyncio
268+
async def test_same_user_different_apps(
269+
self, credential_service, auth_config
270+
):
271+
"""Test that the same user in different apps get isolated storage."""
272+
# Create two contexts with same user but different apps
273+
context1 = Mock(spec=ToolContext)
274+
mock_invocation_context1 = Mock()
275+
mock_invocation_context1.app_name = "app1"
276+
mock_invocation_context1.user_id = "same_user"
277+
context1._invocation_context = mock_invocation_context1
278+
279+
context2 = Mock(spec=ToolContext)
280+
mock_invocation_context2 = Mock()
281+
mock_invocation_context2.app_name = "app2"
282+
mock_invocation_context2.user_id = "same_user"
283+
context2._invocation_context = mock_invocation_context2
284+
285+
# Save credential in app1
286+
await credential_service.save_credential(auth_config, context1)
287+
288+
# Try to load from app2 (should not find it)
289+
result = await credential_service.load_credential(auth_config, context2)
290+
assert result is None
291+
292+
# Verify app1 still has the credential
293+
result = await credential_service.load_credential(auth_config, context1)
294+
assert result is not None
295+
296+
@pytest.mark.asyncio
297+
async def test_same_app_different_users(
298+
self, credential_service, auth_config
299+
):
300+
"""Test that different users in the same app get isolated storage."""
301+
# Create two contexts with same app but different users
302+
context1 = Mock(spec=ToolContext)
303+
mock_invocation_context1 = Mock()
304+
mock_invocation_context1.app_name = "same_app"
305+
mock_invocation_context1.user_id = "user1"
306+
context1._invocation_context = mock_invocation_context1
307+
308+
context2 = Mock(spec=ToolContext)
309+
mock_invocation_context2 = Mock()
310+
mock_invocation_context2.app_name = "same_app"
311+
mock_invocation_context2.user_id = "user2"
312+
context2._invocation_context = mock_invocation_context2
313+
314+
# Save credential for user1
315+
await credential_service.save_credential(auth_config, context1)
316+
317+
# Try to load for user2 (should not find it)
318+
result = await credential_service.load_credential(auth_config, context2)
319+
assert result is None
320+
321+
# Verify user1 still has the credential
322+
result = await credential_service.load_credential(auth_config, context1)
323+
assert result is not None

0 commit comments

Comments
 (0)