diff --git a/src/google/adk/auth/oauth2_credential_fetcher.py b/src/google/adk/auth/oauth2_credential_fetcher.py index cbed70762..c9e838b25 100644 --- a/src/google/adk/auth/oauth2_credential_fetcher.py +++ b/src/google/adk/auth/oauth2_credential_fetcher.py @@ -15,19 +15,15 @@ from __future__ import annotations import logging -from typing import Optional -from typing import Tuple - -from fastapi.openapi.models import OAuth2 from ..utils.feature_decorator import experimental from .auth_credential import AuthCredential from .auth_schemes import AuthScheme from .auth_schemes import OAuthGrantType -from .auth_schemes import OpenIdConnectWithConfig +from .oauth2_credential_util import create_oauth2_session +from .oauth2_credential_util import update_credential_with_tokens try: - from authlib.integrations.requests_client import OAuth2Session from authlib.oauth2.rfc6749 import OAuth2Token AUTHLIB_AVIALABLE = True @@ -50,45 +46,6 @@ def __init__( self._auth_scheme = auth_scheme self._auth_credential = auth_credential - def _oauth2_session(self) -> Tuple[Optional[OAuth2Session], Optional[str]]: - auth_scheme = self._auth_scheme - auth_credential = self._auth_credential - - if isinstance(auth_scheme, OpenIdConnectWithConfig): - if not hasattr(auth_scheme, "token_endpoint"): - return None, None - token_endpoint = auth_scheme.token_endpoint - scopes = auth_scheme.scopes - elif isinstance(auth_scheme, OAuth2): - if ( - not auth_scheme.flows.authorizationCode - or not auth_scheme.flows.authorizationCode.tokenUrl - ): - return None, None - token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl - scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) - else: - return None, None - - if ( - not auth_credential - or not auth_credential.oauth2 - or not auth_credential.oauth2.client_id - or not auth_credential.oauth2.client_secret - ): - return None, None - - return ( - OAuth2Session( - auth_credential.oauth2.client_id, - auth_credential.oauth2.client_secret, - scope=" ".join(scopes), - redirect_uri=auth_credential.oauth2.redirect_uri, - state=auth_credential.oauth2.state, - ), - token_endpoint, - ) - def _update_credential(self, tokens: OAuth2Token) -> None: self._auth_credential.oauth2.access_token = tokens.get("access_token") self._auth_credential.oauth2.refresh_token = tokens.get("refresh_token") @@ -114,7 +71,9 @@ def exchange(self) -> AuthCredential: ): return self._auth_credential - client, token_endpoint = self._oauth2_session() + client, token_endpoint = create_oauth2_session( + self._auth_scheme, self._auth_credential + ) if not client: logger.warning("Could not create OAuth2 session for token exchange") return self._auth_credential @@ -126,7 +85,7 @@ def exchange(self) -> AuthCredential: code=self._auth_credential.oauth2.auth_code, grant_type=OAuthGrantType.AUTHORIZATION_CODE, ) - self._update_credential(tokens) + update_credential_with_tokens(self._auth_credential, tokens) logger.info("Successfully exchanged OAuth2 tokens") except Exception as e: logger.error("Failed to exchange OAuth2 tokens: %s", e) @@ -151,7 +110,9 @@ def refresh(self) -> AuthCredential: "expires_at": credential.oauth2.expires_at, "expires_in": credential.oauth2.expires_in, }).is_expired(): - client, token_endpoint = self._oauth2_session() + client, token_endpoint = create_oauth2_session( + self._auth_scheme, self._auth_credential + ) if not client: logger.warning("Could not create OAuth2 session for token refresh") return credential @@ -161,7 +122,7 @@ def refresh(self) -> AuthCredential: url=token_endpoint, refresh_token=credential.oauth2.refresh_token, ) - self._update_credential(tokens) + update_credential_with_tokens(self._auth_credential, tokens) logger.info("Successfully refreshed OAuth2 tokens") except Exception as e: logger.error("Failed to refresh OAuth2 tokens: %s", e) diff --git a/src/google/adk/auth/oauth2_credential_util.py b/src/google/adk/auth/oauth2_credential_util.py new file mode 100644 index 000000000..51ed4d29f --- /dev/null +++ b/src/google/adk/auth/oauth2_credential_util.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Optional +from typing import Tuple + +from fastapi.openapi.models import OAuth2 + +from ..utils.feature_decorator import experimental +from .auth_credential import AuthCredential +from .auth_schemes import AuthScheme +from .auth_schemes import OpenIdConnectWithConfig + +try: + from authlib.integrations.requests_client import OAuth2Session + from authlib.oauth2.rfc6749 import OAuth2Token + + AUTHLIB_AVIALABLE = True +except ImportError: + AUTHLIB_AVIALABLE = False + + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +def create_oauth2_session( + auth_scheme: AuthScheme, + auth_credential: AuthCredential, +) -> Tuple[Optional[OAuth2Session], Optional[str]]: + """Create an OAuth2 session for token operations. + + Args: + auth_scheme: The authentication scheme configuration. + auth_credential: The authentication credential. + + Returns: + Tuple of (OAuth2Session, token_endpoint) or (None, None) if cannot create session. + """ + if isinstance(auth_scheme, OpenIdConnectWithConfig): + if not hasattr(auth_scheme, "token_endpoint"): + return None, None + token_endpoint = auth_scheme.token_endpoint + scopes = auth_scheme.scopes + elif isinstance(auth_scheme, OAuth2): + if ( + not auth_scheme.flows.authorizationCode + or not auth_scheme.flows.authorizationCode.tokenUrl + ): + return None, None + token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl + scopes = list(auth_scheme.flows.authorizationCode.scopes.keys()) + else: + return None, None + + if ( + not auth_credential + or not auth_credential.oauth2 + or not auth_credential.oauth2.client_id + or not auth_credential.oauth2.client_secret + ): + return None, None + + return ( + OAuth2Session( + auth_credential.oauth2.client_id, + auth_credential.oauth2.client_secret, + scope=" ".join(scopes), + redirect_uri=auth_credential.oauth2.redirect_uri, + state=auth_credential.oauth2.state, + ), + token_endpoint, + ) + + +@experimental +def update_credential_with_tokens( + auth_credential: AuthCredential, tokens: OAuth2Token +) -> None: + """Update the credential with new tokens. + + Args: + auth_credential: The authentication credential to update. + tokens: The OAuth2Token object containing new token information. + """ + auth_credential.oauth2.access_token = tokens.get("access_token") + auth_credential.oauth2.refresh_token = tokens.get("refresh_token") + auth_credential.oauth2.expires_at = ( + int(tokens.get("expires_at")) if tokens.get("expires_at") else None + ) + auth_credential.oauth2.expires_in = ( + int(tokens.get("expires_in")) if tokens.get("expires_in") else None + ) diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index aaed35a19..2bfc7d4c9 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -538,7 +538,7 @@ def test_credentials_with_token( assert result == oauth2_credentials_with_token @patch( - "google.adk.auth.oauth2_credential_fetcher.OAuth2Session", + "google.adk.auth.oauth2_credential_util.OAuth2Session", MockOAuth2Session, ) def test_successful_token_exchange(self, auth_config_with_auth_code): diff --git a/tests/unittests/auth/test_oauth2_credential_fetcher.py b/tests/unittests/auth/test_oauth2_credential_fetcher.py index 0b9b5a3c1..aba6a9923 100644 --- a/tests/unittests/auth/test_oauth2_credential_fetcher.py +++ b/tests/unittests/auth/test_oauth2_credential_fetcher.py @@ -14,7 +14,6 @@ import time from unittest.mock import Mock -from unittest.mock import patch from authlib.oauth2.rfc6749 import OAuth2Token from fastapi.openapi.models import OAuth2 @@ -24,38 +23,15 @@ from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_schemes import OpenIdConnectWithConfig -from google.adk.auth.oauth2_credential_fetcher import OAuth2CredentialFetcher +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens -class TestOAuth2CredentialFetcher: - """Test suite for OAuth2CredentialFetcher.""" +class TestOAuth2CredentialUtil: + """Test suite for OAuth2 credential utility functions.""" - def test_init(self): - """Test OAuth2CredentialFetcher initialization.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid", "profile"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uri="https://example.com/callback", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - assert fetcher._auth_scheme == scheme - assert fetcher._auth_credential == credential - - def test_oauth2_session_openid_connect(self): - """Test _oauth2_session with OpenID Connect scheme.""" + def test_create_oauth2_session_openid_connect(self): + """Test create_oauth2_session with OpenID Connect scheme.""" scheme = OpenIdConnectWithConfig( type_="openIdConnect", openId_connect_url=( @@ -75,16 +51,15 @@ def test_oauth2_session_openid_connect(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() + client, token_endpoint = create_oauth2_session(scheme, credential) assert client is not None assert token_endpoint == "https://example.com/token" assert client.client_id == "test_client_id" assert client.client_secret == "test_client_secret" - def test_oauth2_session_oauth2_scheme(self): - """Test _oauth2_session with OAuth2 scheme.""" + def test_create_oauth2_session_oauth2_scheme(self): + """Test create_oauth2_session with OAuth2 scheme.""" flows = OAuthFlows( authorizationCode=OAuthFlowAuthorizationCode( authorizationUrl="https://example.com/auth", @@ -102,14 +77,13 @@ def test_oauth2_session_oauth2_scheme(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() + client, token_endpoint = create_oauth2_session(scheme, credential) assert client is not None assert token_endpoint == "https://example.com/token" - def test_oauth2_session_invalid_scheme(self): - """Test _oauth2_session with invalid scheme.""" + def test_create_oauth2_session_invalid_scheme(self): + """Test create_oauth2_session with invalid scheme.""" scheme = Mock() # Invalid scheme type credential = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2, @@ -119,14 +93,13 @@ def test_oauth2_session_invalid_scheme(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() + client, token_endpoint = create_oauth2_session(scheme, credential) assert client is None assert token_endpoint is None - def test_oauth2_session_missing_credentials(self): - """Test _oauth2_session with missing credentials.""" + def test_create_oauth2_session_missing_credentials(self): + """Test create_oauth2_session with missing credentials.""" scheme = OpenIdConnectWithConfig( type_="openIdConnect", openId_connect_url=( @@ -144,23 +117,13 @@ def test_oauth2_session_missing_credentials(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) - client, token_endpoint = fetcher._oauth2_session() + client, token_endpoint = create_oauth2_session(scheme, credential) assert client is None assert token_endpoint is None - def test_update_credential(self): - """Test _update_credential method.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) + def test_update_credential_with_tokens(self): + """Test update_credential_with_tokens function.""" credential = AuthCredential( auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, oauth2=OAuth2Auth( @@ -169,7 +132,6 @@ def test_update_credential(self): ), ) - fetcher = OAuth2CredentialFetcher(scheme, credential) tokens = OAuth2Token({ "access_token": "new_access_token", "refresh_token": "new_refresh_token", @@ -177,265 +139,9 @@ def test_update_credential(self): "expires_in": 3600, }) - fetcher._update_credential(tokens) + update_credential_with_tokens(credential, tokens) assert credential.oauth2.access_token == "new_access_token" assert credential.oauth2.refresh_token == "new_refresh_token" assert credential.oauth2.expires_at == int(time.time()) + 3600 assert credential.oauth2.expires_in == 3600 - - def test_exchange_with_existing_token(self): - """Test exchange method when access token already exists.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="existing_token", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result == credential - assert result.oauth2.access_token == "existing_token" - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_exchange_success(self, mock_oauth2_session): - """Test successful token exchange.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - auth_response_uri=( - "https://example.com/callback?code=auth_code&state=test_state" - ), - ), - ) - - # Mock the OAuth2Session - mock_client = Mock() - mock_oauth2_session.return_value = mock_client - mock_tokens = { - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - "expires_at": int(time.time()) + 3600, - "expires_in": 3600, - } - mock_client.fetch_token.return_value = mock_tokens - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result.oauth2.access_token == "new_access_token" - assert result.oauth2.refresh_token == "new_refresh_token" - mock_client.fetch_token.assert_called_once() - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_exchange_with_auth_code(self, mock_oauth2_session): - """Test token exchange with auth code.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - auth_code="test_auth_code", - ), - ) - - mock_client = Mock() - mock_oauth2_session.return_value = mock_client - mock_tokens = { - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - } - mock_client.fetch_token.return_value = mock_tokens - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result.oauth2.access_token == "new_access_token" - mock_client.fetch_token.assert_called_once() - - def test_exchange_no_session(self): - """Test exchange when OAuth2Session cannot be created.""" - scheme = Mock() # Invalid scheme - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - auth_response_uri="https://example.com/callback?code=auth_code", - ), - ) - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.exchange() - - assert result == credential - assert result.oauth2.access_token is None - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_refresh_token_not_expired( - self, mock_oauth2_session, mock_oauth2_token - ): - """Test refresh when token is not expired.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="current_token", - refresh_token="refresh_token", - expires_at=int(time.time()) + 3600, - expires_in=3600, - ), - ) - - # Mock token not expired - mock_token_instance = Mock() - mock_token_instance.is_expired.return_value = False - mock_oauth2_token.return_value = mock_token_instance - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result == credential - assert result.oauth2.access_token == "current_token" - mock_oauth2_session.assert_not_called() - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Session") - def test_refresh_token_expired_success( - self, mock_oauth2_session, mock_oauth2_token - ): - """Test successful token refresh when token is expired.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="expired_token", - refresh_token="refresh_token", - expires_at=int(time.time()) - 3600, # Expired - expires_in=3600, - ), - ) - - # Mock token expired - mock_token_instance = Mock() - mock_token_instance.is_expired.return_value = True - mock_oauth2_token.return_value = mock_token_instance - - # Mock refresh token response - mock_client = Mock() - mock_oauth2_session.return_value = mock_client - mock_tokens = { - "access_token": "refreshed_access_token", - "refresh_token": "new_refresh_token", - "expires_at": int(time.time()) + 3600, - "expires_in": 3600, - } - mock_client.refresh_token.return_value = mock_tokens - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result.oauth2.access_token == "refreshed_access_token" - assert result.oauth2.refresh_token == "new_refresh_token" - mock_client.refresh_token.assert_called_once_with( - url="https://example.com/token", - refresh_token="refresh_token", - ) - - def test_refresh_no_oauth2_credential(self): - """Test refresh when oauth2 credential is missing.""" - scheme = OpenIdConnectWithConfig( - type_="openIdConnect", - openId_connect_url=( - "https://example.com/.well-known/openid_configuration" - ), - authorization_endpoint="https://example.com/auth", - token_endpoint="https://example.com/token", - scopes=["openid"], - ) - credential = AuthCredential(auth_type=AuthCredentialTypes.HTTP) # No oauth2 - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result == credential - - @patch("google.adk.auth.oauth2_credential_fetcher.OAuth2Token") - def test_refresh_no_session(self, mock_oauth2_token): - """Test refresh when OAuth2Session cannot be created.""" - scheme = Mock() # Invalid scheme - credential = AuthCredential( - auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, - oauth2=OAuth2Auth( - client_id="test_client_id", - client_secret="test_client_secret", - access_token="expired_token", - refresh_token="refresh_token", - expires_at=int(time.time()) - 3600, - ), - ) - - # Mock token expired - mock_token_instance = Mock() - mock_token_instance.is_expired.return_value = True - mock_oauth2_token.return_value = mock_token_instance - - fetcher = OAuth2CredentialFetcher(scheme, credential) - result = fetcher.refresh() - - assert result == credential - assert result.oauth2.access_token == "expired_token" # Unchanged diff --git a/tests/unittests/auth/test_oauth2_credential_util.py b/tests/unittests/auth/test_oauth2_credential_util.py new file mode 100644 index 000000000..aba6a9923 --- /dev/null +++ b/tests/unittests/auth/test_oauth2_credential_util.py @@ -0,0 +1,147 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import Mock + +from authlib.oauth2.rfc6749 import OAuth2Token +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_schemes import OpenIdConnectWithConfig +from google.adk.auth.oauth2_credential_util import create_oauth2_session +from google.adk.auth.oauth2_credential_util import update_credential_with_tokens + + +class TestOAuth2CredentialUtil: + """Test suite for OAuth2 credential utility functions.""" + + def test_create_oauth2_session_openid_connect(self): + """Test create_oauth2_session with OpenID Connect scheme.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid", "profile"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + state="test_state", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + assert client.client_id == "test_client_id" + assert client.client_secret == "test_client_secret" + + def test_create_oauth2_session_oauth2_scheme(self): + """Test create_oauth2_session with OAuth2 scheme.""" + flows = OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://example.com/auth", + tokenUrl="https://example.com/token", + scopes={"read": "Read access", "write": "Write access"}, + ) + ) + scheme = OAuth2(type_="oauth2", flows=flows) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uri="https://example.com/callback", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is not None + assert token_endpoint == "https://example.com/token" + + def test_create_oauth2_session_invalid_scheme(self): + """Test create_oauth2_session with invalid scheme.""" + scheme = Mock() # Invalid scheme type + credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is None + assert token_endpoint is None + + def test_create_oauth2_session_missing_credentials(self): + """Test create_oauth2_session with missing credentials.""" + scheme = OpenIdConnectWithConfig( + type_="openIdConnect", + openId_connect_url=( + "https://example.com/.well-known/openid_configuration" + ), + authorization_endpoint="https://example.com/auth", + token_endpoint="https://example.com/token", + scopes=["openid"], + ) + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + # Missing client_secret + ), + ) + + client, token_endpoint = create_oauth2_session(scheme, credential) + + assert client is None + assert token_endpoint is None + + def test_update_credential_with_tokens(self): + """Test update_credential_with_tokens function.""" + credential = AuthCredential( + auth_type=AuthCredentialTypes.OPEN_ID_CONNECT, + oauth2=OAuth2Auth( + client_id="test_client_id", + client_secret="test_client_secret", + ), + ) + + tokens = OAuth2Token({ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_at": int(time.time()) + 3600, + "expires_in": 3600, + }) + + update_credential_with_tokens(credential, tokens) + + assert credential.oauth2.access_token == "new_access_token" + assert credential.oauth2.refresh_token == "new_refresh_token" + assert credential.oauth2.expires_at == int(time.time()) + 3600 + assert credential.oauth2.expires_in == 3600