Skip to content

Commit 96b9de6

Browse files
seanzhougooglecopybara-github
authored andcommitted
refactor: Extract util method from OAuth2 credential fetcher for reuse
PiperOrigin-RevId: 771635844
1 parent a4d432a commit 96b9de6

25 files changed

+1689
-485
lines changed

contributing/samples/oauth_calendar_agent/agent.py

Lines changed: 79 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
from google.adk.auth import AuthCredentialTypes
2828
from google.adk.auth import OAuth2Auth
2929
from google.adk.tools import ToolContext
30-
from google.adk.tools.authenticated_tool.base_authenticated_tool import AuthenticatedFunctionTool
31-
from google.adk.tools.authenticated_tool.credentials_store import ToolContextCredentialsStore
3230
from google.adk.tools.google_api_tool import CalendarToolset
3331
from google.auth.transport.requests import Request
3432
from google.oauth2.credentials import Credentials
@@ -58,7 +56,6 @@ def list_calendar_events(
5856
end_time: str,
5957
limit: int,
6058
tool_context: ToolContext,
61-
credential: AuthCredential,
6259
) -> list[dict]:
6360
"""Search for calendar events.
6461
@@ -83,11 +80,84 @@ def list_calendar_events(
8380
Returns:
8481
list[dict]: A list of events that match the search criteria.
8582
"""
86-
87-
creds = Credentials(
88-
token=credential.oauth2.access_token,
89-
refresh_token=credential.oauth2.refresh_token,
90-
)
83+
creds = None
84+
85+
# Check if the tokes were already in the session state, which means the user
86+
# has already gone through the OAuth flow and successfully authenticated and
87+
# authorized the tool to access their calendar.
88+
if "calendar_tool_tokens" in tool_context.state:
89+
creds = Credentials.from_authorized_user_info(
90+
tool_context.state["calendar_tool_tokens"], SCOPES
91+
)
92+
if not creds or not creds.valid:
93+
# If the access token is expired, refresh it with the refresh token.
94+
if creds and creds.expired and creds.refresh_token:
95+
creds.refresh(Request())
96+
else:
97+
auth_scheme = OAuth2(
98+
flows=OAuthFlows(
99+
authorizationCode=OAuthFlowAuthorizationCode(
100+
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
101+
tokenUrl="https://oauth2.googleapis.com/token",
102+
scopes={
103+
"https://www.googleapis.com/auth/calendar": (
104+
"See, edit, share, and permanently delete all the"
105+
" calendars you can access using Google Calendar"
106+
)
107+
},
108+
)
109+
)
110+
)
111+
auth_credential = AuthCredential(
112+
auth_type=AuthCredentialTypes.OAUTH2,
113+
oauth2=OAuth2Auth(
114+
client_id=oauth_client_id, client_secret=oauth_client_secret
115+
),
116+
)
117+
# If the user has not gone through the OAuth flow before, or the refresh
118+
# token also expired, we need to ask users to go through the OAuth flow.
119+
# First we check whether the user has just gone through the OAuth flow and
120+
# Oauth response is just passed back.
121+
auth_response = tool_context.get_auth_response(
122+
AuthConfig(
123+
auth_scheme=auth_scheme, raw_auth_credential=auth_credential
124+
)
125+
)
126+
if auth_response:
127+
# ADK exchanged the access token already for us
128+
access_token = auth_response.oauth2.access_token
129+
refresh_token = auth_response.oauth2.refresh_token
130+
131+
creds = Credentials(
132+
token=access_token,
133+
refresh_token=refresh_token,
134+
token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
135+
client_id=oauth_client_id,
136+
client_secret=oauth_client_secret,
137+
scopes=list(auth_scheme.flows.authorizationCode.scopes.keys()),
138+
)
139+
else:
140+
# If there are no auth response which means the user has not gone
141+
# through the OAuth flow yet, we need to ask users to go through the
142+
# OAuth flow.
143+
tool_context.request_credential(
144+
AuthConfig(
145+
auth_scheme=auth_scheme,
146+
raw_auth_credential=auth_credential,
147+
)
148+
)
149+
# The return value is optional and could be any dict object. It will be
150+
# wrapped in a dict with key as 'result' and value as the return value
151+
# if the object returned is not a dict. This response will be passed
152+
# to LLM to generate a user friendly message. e.g. LLM will tell user:
153+
# "I need your authorization to access your calendar. Please authorize
154+
# me so I can check your meetings for today."
155+
return "Need User Authorization to access their calendar."
156+
# We store the access token and refresh token in the session state for the
157+
# next runs. This is just an example. On production, a tool should store
158+
# those credentials in some secure store or properly encrypt it before store
159+
# it in the session state.
160+
tool_context.state["calendar_tool_tokens"] = json.loads(creds.to_json())
91161

92162
service = build("calendar", "v3", credentials=creds)
93163
events_result = (
@@ -138,38 +208,6 @@ def update_time(callback_context: CallbackContext):
138208
139209
Currnet time: {_time}
140210
""",
141-
tools=[
142-
AuthenticatedFunctionTool(
143-
func=list_calendar_events,
144-
auth_config=AuthConfig(
145-
auth_scheme=OAuth2(
146-
flows=OAuthFlows(
147-
authorizationCode=OAuthFlowAuthorizationCode(
148-
authorizationUrl=(
149-
"https://accounts.google.com/o/oauth2/auth"
150-
),
151-
tokenUrl="https://oauth2.googleapis.com/token",
152-
scopes={
153-
"https://www.googleapis.com/auth/calendar": (
154-
"See, edit, share, and permanently delete"
155-
" all the calendars you can access using"
156-
" Google Calendar"
157-
)
158-
},
159-
)
160-
)
161-
),
162-
raw_auth_credential=AuthCredential(
163-
auth_type=AuthCredentialTypes.OAUTH2,
164-
oauth2=OAuth2Auth(
165-
client_id=oauth_client_id,
166-
client_secret=oauth_client_secret,
167-
),
168-
),
169-
),
170-
credential_store=ToolContextCredentialsStore(),
171-
),
172-
calendar_toolset,
173-
],
211+
tools=[list_calendar_events, calendar_toolset],
174212
before_agent_callback=update_time,
175213
)

src/google/adk/agents/invocation_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pydantic import ConfigDict
2323

2424
from ..artifacts.base_artifact_service import BaseArtifactService
25+
from ..auth.credential_service.base_credential_service import BaseCredentialService
2526
from ..memory.base_memory_service import BaseMemoryService
2627
from ..sessions.base_session_service import BaseSessionService
2728
from ..sessions.session import Session
@@ -115,6 +116,7 @@ class InvocationContext(BaseModel):
115116
artifact_service: Optional[BaseArtifactService] = None
116117
session_service: BaseSessionService
117118
memory_service: Optional[BaseMemoryService] = None
119+
credential_service: Optional[BaseCredentialService] = None
118120

119121
invocation_id: str
120122
"""The id of this invocation context. Readonly."""

src/google/adk/auth/credential_service/base_credential_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
from typing import Optional
2020

2121
from ...tools.tool_context import ToolContext
22-
from ...utils.feature_decorator import working_in_progress
22+
from ...utils.feature_decorator import experimental
2323
from ..auth_credential import AuthCredential
2424
from ..auth_tool import AuthConfig
2525

2626

27-
@working_in_progress("Implementation are in progress. Don't use it for now.")
27+
@experimental
2828
class BaseCredentialService(ABC):
2929
"""Abstract class for Service that loads / saves tool credentials from / to
3030
the backend credential store."""
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Optional
18+
19+
from typing_extensions import override
20+
21+
from ...tools.tool_context import ToolContext
22+
from ...utils.feature_decorator import experimental
23+
from ..auth_credential import AuthCredential
24+
from ..auth_tool import AuthConfig
25+
from .base_credential_service import BaseCredentialService
26+
27+
28+
@experimental
29+
class InMemoryCredentialService(BaseCredentialService):
30+
"""Class for in memory implementation of credential service(Experimental)"""
31+
32+
def __init__(self):
33+
super().__init__()
34+
self._store: dict[str, AuthCredential] = {}
35+
36+
@override
37+
async def load_credential(
38+
self,
39+
auth_config: AuthConfig,
40+
tool_context: ToolContext,
41+
) -> Optional[AuthCredential]:
42+
"""
43+
Loads the credential by auth config and current tool context from the
44+
backend credential store.
45+
46+
Args:
47+
auth_config: The auth config which contains the auth scheme and auth
48+
credential information. auth_config.get_credential_key will be used to
49+
build the key to load the credential.
50+
51+
tool_context: The context of the current invocation when the tool is
52+
trying to load the credential.
53+
54+
Returns:
55+
Optional[AuthCredential]: the credential saved in the store.
56+
57+
"""
58+
storage = self._get_storage_for_current_context(tool_context)
59+
return storage.get(auth_config.credential_key)
60+
61+
@override
62+
async def save_credential(
63+
self,
64+
auth_config: AuthConfig,
65+
tool_context: ToolContext,
66+
) -> None:
67+
"""
68+
Saves the exchanged_auth_credential in auth config to the backend credential
69+
store.
70+
71+
Args:
72+
auth_config: The auth config which contains the auth scheme and auth
73+
credential information. auth_config.get_credential_key will be used to
74+
build the key to save the credential.
75+
76+
tool_context: The context of the current invocation when the tool is
77+
trying to save the credential.
78+
79+
Returns:
80+
None
81+
"""
82+
storage = self._get_storage_for_current_context(tool_context)
83+
storage[auth_config.credential_key] = auth_config.exchanged_auth_credential
84+
85+
def _get_storage_for_current_context(self, tool_context: ToolContext) -> str:
86+
app_name = tool_context._invocation_context.app_name
87+
user_id = tool_context._invocation_context.user_id
88+
89+
if app_name not in self._store:
90+
self._store[app_name] = {}
91+
if user_id not in self._store[app_name]:
92+
self._store[app_name][user_id] = {}
93+
return self._store[app_name][user_id]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Credential exchanger module."""
16+
17+
from .base_credential_exchanger import BaseCredentialExchanger
18+
from .service_account_credential_exchanger import ServiceAccountCredentialExchanger
19+
20+
__all__ = [
21+
"BaseCredentialExchanger",
22+
"ServiceAccountCredentialExchanger",
23+
]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Base credential exchanger interface."""
16+
17+
from __future__ import annotations
18+
19+
import abc
20+
from typing import Optional
21+
22+
from ...utils.feature_decorator import experimental
23+
from ..auth_credential import AuthCredential
24+
from ..auth_schemes import AuthScheme
25+
26+
27+
@experimental
28+
class BaseCredentialExchanger(abc.ABC):
29+
"""Base interface for credential exchangers."""
30+
31+
@abc.abstractmethod
32+
def exchange(
33+
self,
34+
auth_credential: AuthCredential,
35+
auth_scheme: Optional[AuthScheme] = None,
36+
) -> AuthCredential:
37+
"""Exchange credential if needed.
38+
39+
Args:
40+
auth_credential: The credential to exchange.
41+
auth_scheme: The authentication scheme (optional, some exchangers don't need it).
42+
43+
Returns:
44+
The exchanged credential.
45+
46+
Raises:
47+
ValueError: If credential exchange fails.
48+
"""
49+
pass

0 commit comments

Comments
 (0)