Skip to content

refactor: Extract util method from OAuth2 credential fetcher for reuse #1424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 10 additions & 49 deletions src/google/adk/auth/oauth2_credential_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,15 @@
from __future__ import annotations

import logging
from typing import Optional
from typing import Tuple

from fastapi.openapi.models import OAuth2

from ..utils.feature_decorator import experimental
from .auth_credential import AuthCredential
from .auth_schemes import AuthScheme
from .auth_schemes import OAuthGrantType
from .auth_schemes import OpenIdConnectWithConfig
from .oauth2_credential_util import create_oauth2_session
from .oauth2_credential_util import update_credential_with_tokens

try:
from authlib.integrations.requests_client import OAuth2Session
from authlib.oauth2.rfc6749 import OAuth2Token

AUTHLIB_AVIALABLE = True
Expand All @@ -50,45 +46,6 @@ def __init__(
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential

def _oauth2_session(self) -> Tuple[Optional[OAuth2Session], Optional[str]]:
auth_scheme = self._auth_scheme
auth_credential = self._auth_credential

if isinstance(auth_scheme, OpenIdConnectWithConfig):
if not hasattr(auth_scheme, "token_endpoint"):
return None, None
token_endpoint = auth_scheme.token_endpoint
scopes = auth_scheme.scopes
elif isinstance(auth_scheme, OAuth2):
if (
not auth_scheme.flows.authorizationCode
or not auth_scheme.flows.authorizationCode.tokenUrl
):
return None, None
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
else:
return None, None

if (
not auth_credential
or not auth_credential.oauth2
or not auth_credential.oauth2.client_id
or not auth_credential.oauth2.client_secret
):
return None, None

return (
OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
state=auth_credential.oauth2.state,
),
token_endpoint,
)

def _update_credential(self, tokens: OAuth2Token) -> None:
self._auth_credential.oauth2.access_token = tokens.get("access_token")
self._auth_credential.oauth2.refresh_token = tokens.get("refresh_token")
Expand All @@ -114,7 +71,9 @@ def exchange(self) -> AuthCredential:
):
return self._auth_credential

client, token_endpoint = self._oauth2_session()
client, token_endpoint = create_oauth2_session(
self._auth_scheme, self._auth_credential
)
if not client:
logger.warning("Could not create OAuth2 session for token exchange")
return self._auth_credential
Expand All @@ -126,7 +85,7 @@ def exchange(self) -> AuthCredential:
code=self._auth_credential.oauth2.auth_code,
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
)
self._update_credential(tokens)
update_credential_with_tokens(self._auth_credential, tokens)
logger.info("Successfully exchanged OAuth2 tokens")
except Exception as e:
logger.error("Failed to exchange OAuth2 tokens: %s", e)
Expand All @@ -151,7 +110,9 @@ def refresh(self) -> AuthCredential:
"expires_at": credential.oauth2.expires_at,
"expires_in": credential.oauth2.expires_in,
}).is_expired():
client, token_endpoint = self._oauth2_session()
client, token_endpoint = create_oauth2_session(
self._auth_scheme, self._auth_credential
)
if not client:
logger.warning("Could not create OAuth2 session for token refresh")
return credential
Expand All @@ -161,7 +122,7 @@ def refresh(self) -> AuthCredential:
url=token_endpoint,
refresh_token=credential.oauth2.refresh_token,
)
self._update_credential(tokens)
update_credential_with_tokens(self._auth_credential, tokens)
logger.info("Successfully refreshed OAuth2 tokens")
except Exception as e:
logger.error("Failed to refresh OAuth2 tokens: %s", e)
Expand Down
107 changes: 107 additions & 0 deletions src/google/adk/auth/oauth2_credential_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
from typing import Optional
from typing import Tuple

from fastapi.openapi.models import OAuth2

from ..utils.feature_decorator import experimental
from .auth_credential import AuthCredential
from .auth_schemes import AuthScheme
from .auth_schemes import OpenIdConnectWithConfig

try:
from authlib.integrations.requests_client import OAuth2Session
from authlib.oauth2.rfc6749 import OAuth2Token

AUTHLIB_AVIALABLE = True
except ImportError:
AUTHLIB_AVIALABLE = False


logger = logging.getLogger("google_adk." + __name__)


@experimental
def create_oauth2_session(
auth_scheme: AuthScheme,
auth_credential: AuthCredential,
) -> Tuple[Optional[OAuth2Session], Optional[str]]:
"""Create an OAuth2 session for token operations.

Args:
auth_scheme: The authentication scheme configuration.
auth_credential: The authentication credential.

Returns:
Tuple of (OAuth2Session, token_endpoint) or (None, None) if cannot create session.
"""
if isinstance(auth_scheme, OpenIdConnectWithConfig):
if not hasattr(auth_scheme, "token_endpoint"):
return None, None
token_endpoint = auth_scheme.token_endpoint
scopes = auth_scheme.scopes
elif isinstance(auth_scheme, OAuth2):
if (
not auth_scheme.flows.authorizationCode
or not auth_scheme.flows.authorizationCode.tokenUrl
):
return None, None
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
else:
return None, None

if (
not auth_credential
or not auth_credential.oauth2
or not auth_credential.oauth2.client_id
or not auth_credential.oauth2.client_secret
):
return None, None

return (
OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
state=auth_credential.oauth2.state,
),
token_endpoint,
)


@experimental
def update_credential_with_tokens(
auth_credential: AuthCredential, tokens: OAuth2Token
) -> None:
"""Update the credential with new tokens.

Args:
auth_credential: The authentication credential to update.
tokens: The OAuth2Token object containing new token information.
"""
auth_credential.oauth2.access_token = tokens.get("access_token")
auth_credential.oauth2.refresh_token = tokens.get("refresh_token")
auth_credential.oauth2.expires_at = (
int(tokens.get("expires_at")) if tokens.get("expires_at") else None
)
auth_credential.oauth2.expires_in = (
int(tokens.get("expires_in")) if tokens.get("expires_in") else None
)
2 changes: 1 addition & 1 deletion tests/unittests/auth/test_auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def test_credentials_with_token(
assert result == oauth2_credentials_with_token

@patch(
"google.adk.auth.oauth2_credential_fetcher.OAuth2Session",
"google.adk.auth.oauth2_credential_util.OAuth2Session",
MockOAuth2Session,
)
def test_successful_token_exchange(self, auth_config_with_auth_code):
Expand Down
Loading