@@ -145,7 +145,9 @@ def test_get_python_sql_connector_auth_provider_access_token(self):
145145 hostname = "moderakh-test.cloud.databricks.com"
146146 kwargs = {"access_token" : "dpi123" }
147147 mock_http_client = MagicMock ()
148- auth_provider = get_python_sql_connector_auth_provider (hostname , mock_http_client , ** kwargs )
148+ auth_provider = get_python_sql_connector_auth_provider (
149+ hostname , mock_http_client , ** kwargs
150+ )
149151 self .assertTrue (type (auth_provider ).__name__ , "AccessTokenAuthProvider" )
150152
151153 headers = {}
@@ -163,10 +165,41 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
163165 hostname = "moderakh-test.cloud.databricks.com"
164166 kwargs = {"credentials_provider" : MyProvider ()}
165167 mock_http_client = MagicMock ()
166- auth_provider = get_python_sql_connector_auth_provider (hostname , mock_http_client , ** kwargs )
168+ auth_provider = get_python_sql_connector_auth_provider (
169+ hostname , mock_http_client , ** kwargs
170+ )
171+
172+ # Without identity_federation_client_id, should return ExternalAuthProvider directly
173+ self .assertEqual (type (auth_provider ).__name__ , "ExternalAuthProvider" )
174+
175+ headers = {}
176+ auth_provider .add_headers (headers )
177+ self .assertEqual (headers ["foo" ], "bar" )
178+
179+ def test_get_python_sql_connector_auth_provider_with_token_federation (self ):
180+ class MyProvider (CredentialsProvider ):
181+ def auth_type (self ) -> str :
182+ return "mine"
183+
184+ def __call__ (self , * args , ** kwargs ) -> HeaderFactory :
185+ return lambda : {"foo" : "bar" }
186+
187+ hostname = "moderakh-test.cloud.databricks.com"
188+ kwargs = {
189+ "credentials_provider" : MyProvider (),
190+ "identity_federation_client_id" : "test-client-id" ,
191+ }
192+ mock_http_client = MagicMock ()
193+ auth_provider = get_python_sql_connector_auth_provider (
194+ hostname , mock_http_client , ** kwargs
195+ )
167196
197+ # With identity_federation_client_id, should wrap with TokenFederationProvider
168198 self .assertEqual (type (auth_provider ).__name__ , "TokenFederationProvider" )
169- self .assertEqual (type (auth_provider .external_provider ).__name__ , "ExternalAuthProvider" )
199+ self .assertEqual (
200+ type (auth_provider .external_provider ).__name__ , "ExternalAuthProvider"
201+ )
202+ self .assertEqual (auth_provider .identity_federation_client_id , "test-client-id" )
170203
171204 headers = {}
172205 auth_provider .add_headers (headers )
@@ -181,7 +214,9 @@ def test_get_python_sql_connector_auth_provider_noop(self):
181214 "_use_cert_as_auth" : use_cert_as_auth ,
182215 }
183216 mock_http_client = MagicMock ()
184- auth_provider = get_python_sql_connector_auth_provider (hostname , mock_http_client , ** kwargs )
217+ auth_provider = get_python_sql_connector_auth_provider (
218+ hostname , mock_http_client , ** kwargs
219+ )
185220 self .assertTrue (type (auth_provider ).__name__ , "CredentialProvider" )
186221
187222 def test_get_python_sql_connector_basic_auth (self ):
@@ -191,7 +226,9 @@ def test_get_python_sql_connector_basic_auth(self):
191226 }
192227 mock_http_client = MagicMock ()
193228 with self .assertRaises (ValueError ) as e :
194- get_python_sql_connector_auth_provider ("foo.cloud.databricks.com" , mock_http_client , ** kwargs )
229+ get_python_sql_connector_auth_provider (
230+ "foo.cloud.databricks.com" , mock_http_client , ** kwargs
231+ )
195232 self .assertIn (
196233 "Username/password authentication is no longer supported" , str (e .exception )
197234 )
@@ -200,12 +237,13 @@ def test_get_python_sql_connector_basic_auth(self):
200237 def test_get_python_sql_connector_default_auth (self , mock__initial_get_token ):
201238 hostname = "foo.cloud.databricks.com"
202239 mock_http_client = MagicMock ()
203- auth_provider = get_python_sql_connector_auth_provider (hostname , mock_http_client )
204-
205- self .assertEqual (type (auth_provider ).__name__ , "TokenFederationProvider" )
206- self .assertEqual (type (auth_provider .external_provider ).__name__ , "DatabricksOAuthProvider" )
240+ auth_provider = get_python_sql_connector_auth_provider (
241+ hostname , mock_http_client
242+ )
207243
208- self .assertEqual (auth_provider .external_provider ._client_id , PYSQL_OAUTH_CLIENT_ID )
244+ # Without identity_federation_client_id, should return DatabricksOAuthProvider directly
245+ self .assertEqual (type (auth_provider ).__name__ , "DatabricksOAuthProvider" )
246+ self .assertEqual (auth_provider ._client_id , PYSQL_OAUTH_CLIENT_ID )
209247
210248
211249class TestClientCredentialsTokenSource :
@@ -264,16 +302,16 @@ def test_no_token_refresh__when_token_is_not_expired(
264302
265303 def test_get_token_success (self , token_source , http_response ):
266304 mock_http_client = MagicMock ()
267-
305+
268306 with patch .object (token_source , "_http_client" , mock_http_client ):
269307 # Create a mock response with the expected format
270308 mock_response = MagicMock ()
271309 mock_response .status = 200
272310 mock_response .data .decode .return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}'
273-
311+
274312 # Mock the request method to return the response directly
275313 mock_http_client .request .return_value = mock_response
276-
314+
277315 token = token_source .get_token ()
278316
279317 # Assert
@@ -284,16 +322,16 @@ def test_get_token_success(self, token_source, http_response):
284322
285323 def test_get_token_failure (self , token_source , http_response ):
286324 mock_http_client = MagicMock ()
287-
325+
288326 with patch .object (token_source , "_http_client" , mock_http_client ):
289327 # Create a mock response with error
290328 mock_response = MagicMock ()
291329 mock_response .status = 400
292330 mock_response .data .decode .return_value = "Bad Request"
293-
331+
294332 # Mock the request method to return the response directly
295333 mock_http_client .request .return_value = mock_response
296-
334+
297335 with pytest .raises (Exception ) as e :
298336 token_source .get_token ()
299337 assert "Failed to get token: 400" in str (e .value )
0 commit comments