Skip to content

Commit d8a8b8d

Browse files
seanzhougooglecopybara-github
authored andcommitted
chore: Add a base credential refresher interface
PiperOrigin-RevId: 771625930
1 parent a4d432a commit d8a8b8d

17 files changed

+1093
-81
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pydantic import ConfigDict
2323

2424
from ..artifacts.base_artifact_service import BaseArtifactService
25+
from ..auth.credential_service.base_credential_service import BaseCredentialService
2526
from ..memory.base_memory_service import BaseMemoryService
2627
from ..sessions.base_session_service import BaseSessionService
2728
from ..sessions.session import Session
@@ -115,6 +116,7 @@ class InvocationContext(BaseModel):
115116
artifact_service: Optional[BaseArtifactService] = None
116117
session_service: BaseSessionService
117118
memory_service: Optional[BaseMemoryService] = None
119+
credential_service: Optional[BaseCredentialService] = None
118120

119121
invocation_id: str
120122
"""The id of this invocation context. Readonly."""

src/google/adk/auth/credential_service/base_credential_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
from typing import Optional
2020

2121
from ...tools.tool_context import ToolContext
22-
from ...utils.feature_decorator import working_in_progress
22+
from ...utils.feature_decorator import experimental
2323
from ..auth_credential import AuthCredential
2424
from ..auth_tool import AuthConfig
2525

2626

27-
@working_in_progress("Implementation are in progress. Don't use it for now.")
27+
@experimental
2828
class BaseCredentialService(ABC):
2929
"""Abstract class for Service that loads / saves tool credentials from / to
3030
the backend credential store."""
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Optional
18+
19+
from typing_extensions import override
20+
21+
from ...tools.tool_context import ToolContext
22+
from ...utils.feature_decorator import experimental
23+
from ..auth_credential import AuthCredential
24+
from ..auth_tool import AuthConfig
25+
from .base_credential_service import BaseCredentialService
26+
27+
28+
@experimental
29+
class InMemoryCredentialService(BaseCredentialService):
30+
"""Class for in memory implementation of credential service(Experimental)"""
31+
32+
def __init__(self):
33+
super().__init__()
34+
self._store: dict[str, AuthCredential] = {}
35+
36+
@override
37+
async def load_credential(
38+
self,
39+
auth_config: AuthConfig,
40+
tool_context: ToolContext,
41+
) -> Optional[AuthCredential]:
42+
"""
43+
Loads the credential by auth config and current tool context from the
44+
backend credential store.
45+
46+
Args:
47+
auth_config: The auth config which contains the auth scheme and auth
48+
credential information. auth_config.get_credential_key will be used to
49+
build the key to load the credential.
50+
51+
tool_context: The context of the current invocation when the tool is
52+
trying to load the credential.
53+
54+
Returns:
55+
Optional[AuthCredential]: the credential saved in the store.
56+
57+
"""
58+
storage = self._get_storage_for_current_context(tool_context)
59+
return storage.get(auth_config.credential_key)
60+
61+
@override
62+
async def save_credential(
63+
self,
64+
auth_config: AuthConfig,
65+
tool_context: ToolContext,
66+
) -> None:
67+
"""
68+
Saves the exchanged_auth_credential in auth config to the backend credential
69+
store.
70+
71+
Args:
72+
auth_config: The auth config which contains the auth scheme and auth
73+
credential information. auth_config.get_credential_key will be used to
74+
build the key to save the credential.
75+
76+
tool_context: The context of the current invocation when the tool is
77+
trying to save the credential.
78+
79+
Returns:
80+
None
81+
"""
82+
storage = self._get_storage_for_current_context(tool_context)
83+
storage[auth_config.credential_key] = auth_config.exchanged_auth_credential
84+
85+
def _get_storage_for_current_context(self, tool_context: ToolContext) -> str:
86+
app_name = tool_context._invocation_context.app_name
87+
user_id = tool_context._invocation_context.user_id
88+
89+
if app_name not in self._store:
90+
self._store[app_name] = {}
91+
if user_id not in self._store[app_name]:
92+
self._store[app_name][user_id] = {}
93+
return self._store[app_name][user_id]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Credential exchanger module."""
16+
17+
from .base_credential_exchanger import BaseCredentialExchanger
18+
from .service_account_credential_exchanger import ServiceAccountCredentialExchanger
19+
20+
__all__ = [
21+
"BaseCredentialExchanger",
22+
"ServiceAccountCredentialExchanger",
23+
]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Base credential exchanger interface."""
16+
17+
from __future__ import annotations
18+
19+
import abc
20+
from typing import Optional
21+
22+
from ...utils.feature_decorator import experimental
23+
from ..auth_credential import AuthCredential
24+
from ..auth_schemes import AuthScheme
25+
26+
27+
@experimental
28+
class BaseCredentialExchanger(abc.ABC):
29+
"""Base interface for credential exchangers."""
30+
31+
@abc.abstractmethod
32+
def exchange(
33+
self,
34+
auth_credential: AuthCredential,
35+
auth_scheme: Optional[AuthScheme] = None,
36+
) -> AuthCredential:
37+
"""Exchange credential if needed.
38+
39+
Args:
40+
auth_credential: The credential to exchange.
41+
auth_scheme: The authentication scheme (optional, some exchangers don't need it).
42+
43+
Returns:
44+
The exchanged credential.
45+
46+
Raises:
47+
ValueError: If credential exchange fails.
48+
"""
49+
pass
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Credential exchanger registry."""
16+
17+
from __future__ import annotations
18+
19+
from typing import Dict
20+
from typing import Optional
21+
22+
from ...utils.feature_decorator import experimental
23+
from ..auth_credential import AuthCredentialTypes
24+
from .base_credential_exchanger import BaseCredentialExchanger
25+
26+
27+
@experimental
28+
class CredentialExchangerRegistry:
29+
"""Registry for credential exchanger instances."""
30+
31+
def __init__(self):
32+
self._exchangers: Dict[AuthCredentialTypes, BaseCredentialExchanger] = {}
33+
34+
def register(
35+
self,
36+
credential_type: AuthCredentialTypes,
37+
exchanger_instance: BaseCredentialExchanger,
38+
) -> None:
39+
"""Register an exchanger instance for a credential type.
40+
41+
Args:
42+
credential_type: The credential type to register for.
43+
exchanger_instance: The exchanger instance to register.
44+
"""
45+
self._exchangers[credential_type] = exchanger_instance
46+
47+
def get_exchanger(
48+
self, credential_type: AuthCredentialTypes
49+
) -> Optional[BaseCredentialExchanger]:
50+
"""Get the exchanger instance for a credential type.
51+
52+
Args:
53+
credential_type: The credential type to get exchanger for.
54+
55+
Returns:
56+
The exchanger instance if registered, None otherwise.
57+
"""
58+
return self._exchangers.get(credential_type)

src/google/adk/auth/service_account_credential_exchanger.py renamed to src/google/adk/auth/exchanger/service_account_credential_exchanger.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,64 +16,82 @@
1616

1717
from __future__ import annotations
1818

19+
from typing import Optional
20+
1921
import google.auth
2022
from google.auth.transport.requests import Request
2123
from google.oauth2 import service_account
2224

23-
from ..utils.feature_decorator import experimental
24-
from .auth_credential import AuthCredential
25-
from .auth_credential import AuthCredentialTypes
26-
from .auth_credential import HttpAuth
27-
from .auth_credential import HttpCredentials
25+
from ...utils.feature_decorator import experimental
26+
from ..auth_credential import AuthCredential
27+
from ..auth_credential import AuthCredentialTypes
28+
from ..auth_credential import ServiceAccount
29+
from ..auth_schemes import AuthScheme
30+
from .base_credential_exchanger import BaseCredentialExchanger
2831

2932

3033
@experimental
31-
class ServiceAccountCredentialExchanger:
34+
class ServiceAccountCredentialExchanger(BaseCredentialExchanger):
3235
"""Exchanges Google Service Account credentials for an access token.
3336
3437
Uses the default service credential if `use_default_credential = True`.
3538
Otherwise, uses the service account credential provided in the auth
3639
credential.
3740
"""
3841

39-
def __init__(self, credential: AuthCredential):
40-
if credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT:
41-
raise ValueError("Credential is not a service account credential.")
42-
self._credential = credential
42+
def __init__(self):
43+
"""Initialize the service account credential exchanger."""
44+
pass
4345

44-
def exchange(self) -> AuthCredential:
46+
def exchange(
47+
self,
48+
auth_credential: AuthCredential,
49+
auth_scheme: Optional[AuthScheme] = None,
50+
) -> AuthCredential:
4551
"""Exchanges the service account auth credential for an access token.
4652
4753
If the AuthCredential contains a service account credential, it will be used
4854
to exchange for an access token. Otherwise, if use_default_credential is True,
4955
the default application credential will be used for exchanging an access token.
5056
57+
Args:
58+
auth_scheme: The authentication scheme.
59+
auth_credential: The credential to exchange.
60+
5161
Returns:
52-
An AuthCredential in HTTP Bearer format, containing the access token.
62+
An AuthCredential in OAUTH2 format, containing the exchanged credential JSON.
5363
5464
Raises:
5565
ValueError: If service account credentials are missing or invalid.
5666
Exception: If credential exchange or refresh fails.
5767
"""
68+
if auth_credential is None:
69+
raise ValueError("Credential cannot be None.")
70+
71+
if auth_credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT:
72+
raise ValueError("Credential is not a service account credential.")
73+
74+
if auth_credential.service_account is None:
75+
raise ValueError(
76+
"Service account credentials are missing. Please provide them."
77+
)
78+
5879
if (
59-
self._credential is None
60-
or self._credential.service_account is None
61-
or (
62-
self._credential.service_account.service_account_credential is None
63-
and not self._credential.service_account.use_default_credential
64-
)
80+
auth_credential.service_account.service_account_credential is None
81+
and not auth_credential.service_account.use_default_credential
6582
):
6683
raise ValueError(
67-
"Service account credentials are missing. Please provide them, or set"
68-
" `use_default_credential = True` to use application default"
69-
" credential in a hosted service like Google Cloud Run."
84+
"Service account credentials are invalid. Please set the"
85+
" service_account_credential field or set `use_default_credential ="
86+
" True` to use application default credential in a hosted service"
87+
" like Google Cloud Run."
7088
)
7189

7290
try:
73-
if self._credential.service_account.use_default_credential:
91+
if auth_credential.service_account.use_default_credential:
7492
credentials, _ = google.auth.default()
7593
else:
76-
config = self._credential.service_account
94+
config = auth_credential.service_account
7795
credentials = service_account.Credentials.from_service_account_info(
7896
config.service_account_credential.model_dump(), scopes=config.scopes
7997
)
@@ -82,11 +100,8 @@ def exchange(self) -> AuthCredential:
82100
credentials.refresh(Request())
83101

84102
return AuthCredential(
85-
auth_type=AuthCredentialTypes.HTTP,
86-
http=HttpAuth(
87-
scheme="bearer",
88-
credentials=HttpCredentials(token=credentials.token),
89-
),
103+
auth_type=AuthCredentialTypes.OAUTH2,
104+
google_oauth2_json=credentials.to_json(),
90105
)
91106
except Exception as e:
92107
raise ValueError(f"Failed to exchange service account token: {e}") from e
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Credential refresher module."""
16+
17+
from .base_credential_refresher import BaseCredentialRefresher
18+
19+
__all__ = [
20+
"BaseCredentialRefresher",
21+
]

0 commit comments

Comments
 (0)