Skip to content

Commit 94d5cb3

Browse files
committed
graceful token fed
1 parent 30286ad commit 94d5cb3

File tree

4 files changed

+274
-35
lines changed

4 files changed

+274
-35
lines changed

src/databricks/sql/auth/auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def get_auth_provider(cfg: ClientContext, http_client):
6565
else:
6666
raise RuntimeError("No valid authentication settings!")
6767

68-
# Always wrap with token federation (falls back gracefully if not needed)
69-
if base_provider:
68+
# Wrap with token federation only if explicitly enabled via identity_federation_client_id
69+
if base_provider and cfg.identity_federation_client_id:
7070
return TokenFederationProvider(
7171
hostname=cfg.hostname,
7272
external_provider=base_provider,

src/databricks/sql/auth/token_federation.py

Lines changed: 120 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,24 @@
1515
logger = logging.getLogger(__name__)
1616

1717

18+
class TokenFederationError(Exception):
19+
"""Base exception for token federation errors."""
20+
21+
pass
22+
23+
24+
class TokenExchangeNotAvailableError(TokenFederationError):
25+
"""Raised when token exchange endpoint is not available (404)."""
26+
27+
pass
28+
29+
30+
class TokenExchangeAuthenticationError(TokenFederationError):
31+
"""Raised when token exchange fails due to authentication issues (401/403)."""
32+
33+
pass
34+
35+
1836
class Token:
1937
"""
2038
Represents an OAuth token with expiration management.
@@ -72,8 +90,20 @@ class TokenFederationProvider(AuthProvider):
7290
"""
7391
Implementation of Token Federation for Databricks SQL Python driver.
7492
75-
This provider exchanges third-party access tokens for Databricks in-house tokens
76-
when the token issuer is different from the Databricks host.
93+
This provider exchanges third-party access tokens (e.g., Azure AD, AWS IAM) for
94+
Databricks-native tokens when the token issuer differs from the Databricks host.
95+
96+
Token federation is useful for:
97+
- Cross-cloud authentication scenarios
98+
- Unity Catalog access across Azure subscriptions
99+
- Service principal authentication with external identity providers
100+
101+
The provider automatically detects when token exchange is needed by comparing the
102+
token issuer with the Databricks workspace hostname. If exchange fails, it gracefully
103+
falls back to using the external token directly.
104+
105+
Note: Token federation must be explicitly enabled by providing the
106+
identity_federation_client_id parameter during connection setup.
77107
"""
78108

79109
TOKEN_EXCHANGE_ENDPOINT = "/oidc/v1/token"
@@ -92,9 +122,17 @@ def __init__(
92122
93123
Args:
94124
hostname: The Databricks workspace hostname
95-
external_provider: The external authentication provider
125+
external_provider: The external authentication provider that provides the initial token
96126
http_client: HTTP client for making requests (required)
97-
identity_federation_client_id: Optional client ID for token federation
127+
identity_federation_client_id: Client ID for identity federation (required for token exchange).
128+
This parameter enables token federation and should be provided when:
129+
- Using Service Principal authentication across Azure subscriptions
130+
- Accessing Unity Catalog resources in different Azure tenants
131+
- Configured with your workspace administrator
132+
133+
Without this parameter, the external token will be used directly without exchange.
134+
Contact your Databricks workspace administrator to obtain the appropriate client ID
135+
for your authentication scenario.
98136
"""
99137
if not http_client:
100138
raise ValueError("http_client is required for TokenFederationProvider")
@@ -143,9 +181,33 @@ def _get_token(self) -> Token:
143181
try:
144182
token = self._exchange_token(access_token)
145183
self._cached_token = token
184+
logger.info(
185+
"Successfully exchanged external token for Databricks token"
186+
)
146187
return token
188+
except TokenExchangeNotAvailableError:
189+
logger.debug(
190+
"Token exchange endpoint not available. Using external token directly. "
191+
"This is expected when token federation is not configured for this workspace."
192+
)
193+
except TokenExchangeAuthenticationError as e:
194+
logger.warning(
195+
"Token exchange failed due to authentication error. Using external token directly. "
196+
"Error: %s. If this persists, verify your identity_federation_client_id configuration.",
197+
e,
198+
)
199+
except TokenFederationError as e:
200+
logger.info(
201+
"Token exchange not performed, using external token directly. "
202+
"Error: %s",
203+
e,
204+
)
147205
except Exception as e:
148-
logger.warning("Token exchange failed, using external token: %s", e)
206+
logger.debug(
207+
"Token exchange failed with unexpected error, using external token directly. "
208+
"Error: %s",
209+
e,
210+
)
149211

150212
# Use external token directly
151213
token = Token(access_token, token_type)
@@ -163,7 +225,20 @@ def _should_exchange_token(self, access_token: str) -> bool:
163225
return not is_same_host(issuer, self.hostname)
164226

165227
def _exchange_token(self, access_token: str) -> Token:
166-
"""Exchange the external token for a Databricks token."""
228+
"""
229+
Exchange the external token for a Databricks token.
230+
231+
Args:
232+
access_token: The external access token to exchange
233+
234+
Returns:
235+
Token: The exchanged Databricks token
236+
237+
Raises:
238+
TokenExchangeNotAvailableError: If the endpoint is not available (404)
239+
TokenExchangeAuthenticationError: If authentication fails (401/403)
240+
TokenFederationError: For other token exchange errors
241+
"""
167242
token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}"
168243

169244
data = {
@@ -184,15 +259,45 @@ def _exchange_token(self, access_token: str) -> Token:
184259

185260
body = urlencode(data)
186261

187-
response = self.http_client.request(
188-
HttpMethod.POST, url=token_url, body=body, headers=headers
189-
)
190-
191-
token_response = json.loads(response.data.decode())
192-
193-
return Token(
194-
token_response["access_token"], token_response.get("token_type", "Bearer")
195-
)
262+
try:
263+
response = self.http_client.request(
264+
HttpMethod.POST, url=token_url, body=body, headers=headers
265+
)
266+
267+
# Check response status code
268+
if response.status == 404:
269+
raise TokenExchangeNotAvailableError(
270+
"Token exchange endpoint not found. Token federation may not be enabled for this workspace."
271+
)
272+
elif response.status in (401, 403):
273+
error_detail = (
274+
response.data.decode() if response.data else "No error details"
275+
)
276+
raise TokenExchangeAuthenticationError(
277+
f"Authentication failed during token exchange (HTTP {response.status}): {error_detail}"
278+
)
279+
elif response.status != 200:
280+
error_detail = (
281+
response.data.decode() if response.data else "No error details"
282+
)
283+
raise TokenFederationError(
284+
f"Token exchange failed with HTTP {response.status}: {error_detail}"
285+
)
286+
287+
token_response = json.loads(response.data.decode())
288+
289+
return Token(
290+
token_response["access_token"],
291+
token_response.get("token_type", "Bearer"),
292+
)
293+
except TokenFederationError:
294+
# Re-raise our custom exceptions
295+
raise
296+
except Exception as e:
297+
# Handle unexpected errors (network errors, JSON parsing errors, etc.)
298+
raise TokenFederationError(
299+
f"Unexpected error during token exchange: {str(e)}"
300+
) from e
196301

197302
def _extract_token_from_header(self, auth_header: str) -> Tuple[str, str]:
198303
"""Extract token type and access token from Authorization header."""

tests/unit/test_auth.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ def test_get_python_sql_connector_auth_provider_access_token(self):
145145
hostname = "moderakh-test.cloud.databricks.com"
146146
kwargs = {"access_token": "dpi123"}
147147
mock_http_client = MagicMock()
148-
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
148+
auth_provider = get_python_sql_connector_auth_provider(
149+
hostname, mock_http_client, **kwargs
150+
)
149151
self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider")
150152

151153
headers = {}
@@ -163,10 +165,41 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
163165
hostname = "moderakh-test.cloud.databricks.com"
164166
kwargs = {"credentials_provider": MyProvider()}
165167
mock_http_client = MagicMock()
166-
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
168+
auth_provider = get_python_sql_connector_auth_provider(
169+
hostname, mock_http_client, **kwargs
170+
)
171+
172+
# Without identity_federation_client_id, should return ExternalAuthProvider directly
173+
self.assertEqual(type(auth_provider).__name__, "ExternalAuthProvider")
174+
175+
headers = {}
176+
auth_provider.add_headers(headers)
177+
self.assertEqual(headers["foo"], "bar")
178+
179+
def test_get_python_sql_connector_auth_provider_with_token_federation(self):
180+
class MyProvider(CredentialsProvider):
181+
def auth_type(self) -> str:
182+
return "mine"
183+
184+
def __call__(self, *args, **kwargs) -> HeaderFactory:
185+
return lambda: {"foo": "bar"}
186+
187+
hostname = "moderakh-test.cloud.databricks.com"
188+
kwargs = {
189+
"credentials_provider": MyProvider(),
190+
"identity_federation_client_id": "test-client-id",
191+
}
192+
mock_http_client = MagicMock()
193+
auth_provider = get_python_sql_connector_auth_provider(
194+
hostname, mock_http_client, **kwargs
195+
)
167196

197+
# With identity_federation_client_id, should wrap with TokenFederationProvider
168198
self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider")
169-
self.assertEqual(type(auth_provider.external_provider).__name__, "ExternalAuthProvider")
199+
self.assertEqual(
200+
type(auth_provider.external_provider).__name__, "ExternalAuthProvider"
201+
)
202+
self.assertEqual(auth_provider.identity_federation_client_id, "test-client-id")
170203

171204
headers = {}
172205
auth_provider.add_headers(headers)
@@ -181,7 +214,9 @@ def test_get_python_sql_connector_auth_provider_noop(self):
181214
"_use_cert_as_auth": use_cert_as_auth,
182215
}
183216
mock_http_client = MagicMock()
184-
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
217+
auth_provider = get_python_sql_connector_auth_provider(
218+
hostname, mock_http_client, **kwargs
219+
)
185220
self.assertTrue(type(auth_provider).__name__, "CredentialProvider")
186221

187222
def test_get_python_sql_connector_basic_auth(self):
@@ -191,7 +226,9 @@ def test_get_python_sql_connector_basic_auth(self):
191226
}
192227
mock_http_client = MagicMock()
193228
with self.assertRaises(ValueError) as e:
194-
get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs)
229+
get_python_sql_connector_auth_provider(
230+
"foo.cloud.databricks.com", mock_http_client, **kwargs
231+
)
195232
self.assertIn(
196233
"Username/password authentication is no longer supported", str(e.exception)
197234
)
@@ -200,12 +237,13 @@ def test_get_python_sql_connector_basic_auth(self):
200237
def test_get_python_sql_connector_default_auth(self, mock__initial_get_token):
201238
hostname = "foo.cloud.databricks.com"
202239
mock_http_client = MagicMock()
203-
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client)
204-
205-
self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider")
206-
self.assertEqual(type(auth_provider.external_provider).__name__, "DatabricksOAuthProvider")
240+
auth_provider = get_python_sql_connector_auth_provider(
241+
hostname, mock_http_client
242+
)
207243

208-
self.assertEqual(auth_provider.external_provider._client_id, PYSQL_OAUTH_CLIENT_ID)
244+
# Without identity_federation_client_id, should return DatabricksOAuthProvider directly
245+
self.assertEqual(type(auth_provider).__name__, "DatabricksOAuthProvider")
246+
self.assertEqual(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID)
209247

210248

211249
class TestClientCredentialsTokenSource:
@@ -264,16 +302,16 @@ def test_no_token_refresh__when_token_is_not_expired(
264302

265303
def test_get_token_success(self, token_source, http_response):
266304
mock_http_client = MagicMock()
267-
305+
268306
with patch.object(token_source, "_http_client", mock_http_client):
269307
# Create a mock response with the expected format
270308
mock_response = MagicMock()
271309
mock_response.status = 200
272310
mock_response.data.decode.return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}'
273-
311+
274312
# Mock the request method to return the response directly
275313
mock_http_client.request.return_value = mock_response
276-
314+
277315
token = token_source.get_token()
278316

279317
# Assert
@@ -284,16 +322,16 @@ def test_get_token_success(self, token_source, http_response):
284322

285323
def test_get_token_failure(self, token_source, http_response):
286324
mock_http_client = MagicMock()
287-
325+
288326
with patch.object(token_source, "_http_client", mock_http_client):
289327
# Create a mock response with error
290328
mock_response = MagicMock()
291329
mock_response.status = 400
292330
mock_response.data.decode.return_value = "Bad Request"
293-
331+
294332
# Mock the request method to return the response directly
295333
mock_http_client.request.return_value = mock_response
296-
334+
297335
with pytest.raises(Exception) as e:
298336
token_source.get_token()
299337
assert "Failed to get token: 400" in str(e.value)

0 commit comments

Comments
 (0)