diff --git a/src/databricks/sql/common/http.py b/src/databricks/sql/common/http.py index ec4e3341..0cd2919c 100644 --- a/src/databricks/sql/common/http.py +++ b/src/databricks/sql/common/http.py @@ -5,8 +5,10 @@ import threading from dataclasses import dataclass from contextlib import contextmanager -from typing import Generator +from typing import Generator, Optional import logging +from requests.adapters import HTTPAdapter +from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType logger = logging.getLogger(__name__) @@ -81,3 +83,70 @@ def execute( def close(self): self.session.close() + + +class TelemetryHTTPAdapter(HTTPAdapter): + """ + Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request. + This ensures the retry timer is started and the command type is set correctly, + allowing the policy to manage its state for the duration of the request retries. + """ + + def send(self, request, **kwargs): + self.max_retries.command_type = CommandType.OTHER + self.max_retries.start_retry_timer() + return super().send(request, **kwargs) + + +class TelemetryHttpClient: # TODO: Unify all the http clients in the PySQL Connector + """Singleton HTTP client for sending telemetry data.""" + + _instance: Optional["TelemetryHttpClient"] = None + _lock = threading.Lock() + + TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3 + TELEMETRY_RETRY_DELAY_MIN = 1.0 + TELEMETRY_RETRY_DELAY_MAX = 10.0 + TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0 + + def __init__(self): + """Initializes the session and mounts the custom retry adapter.""" + retry_policy = DatabricksRetryPolicy( + delay_min=self.TELEMETRY_RETRY_DELAY_MIN, + delay_max=self.TELEMETRY_RETRY_DELAY_MAX, + stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT, + stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION, + delay_default=1.0, + force_dangerous_codes=[], + ) + adapter = TelemetryHTTPAdapter(max_retries=retry_policy) + self.session = requests.Session() + self.session.mount("https://", adapter) + self.session.mount("http://", adapter) + + @classmethod + def get_instance(cls) -> "TelemetryHttpClient": + """Get the singleton instance of the TelemetryHttpClient.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + logger.debug("Initializing singleton TelemetryHttpClient") + cls._instance = TelemetryHttpClient() + return cls._instance + + def post(self, url: str, **kwargs) -> requests.Response: + """ + Executes a POST request using the configured session. + + This is a blocking call intended to be run in a background thread. + """ + logger.debug("Executing telemetry POST request to: %s", url) + return self.session.post(url, **kwargs) + + def close(self): + """Closes the underlying requests.Session.""" + logger.debug("Closing TelemetryHttpClient session.") + self.session.close() + # Clear the instance to allow for re-initialization if needed + with TelemetryHttpClient._lock: + TelemetryHttpClient._instance = None diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 65235f63..4a772c49 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -2,8 +2,6 @@ import logging logger = logging.getLogger(__name__) -from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory - ### PEP-249 Mandated ### # https://peps.python.org/pep-0249/#exceptions @@ -22,6 +20,8 @@ def __init__( error_name = self.__class__.__name__ if session_id_hex: + from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory + telemetry_client = TelemetryClientFactory.get_telemetry_client( session_id_hex ) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 5eb8c6ed..a5884c01 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -4,6 +4,7 @@ import logging from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional +from databricks.sql.common.http import TelemetryHttpClient from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, @@ -153,6 +154,7 @@ def __init__( self._driver_connection_params = None self._host_url = host_url self._executor = executor + self._http_client = TelemetryHttpClient.get_instance() def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" @@ -201,7 +203,7 @@ def _send_telemetry(self, events): try: logger.debug("Submitting telemetry request to thread pool") future = self._executor.submit( - requests.post, + self._http_client.post, url, data=request.to_json(), headers=headers, @@ -427,6 +429,7 @@ def close(session_id_hex): ) try: TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryHttpClient.get_instance().close() except Exception as e: logger.debug("Failed to shutdown thread pool executor: %s", e) TelemetryClientFactory._executor = None diff --git a/tests/e2e/test_telemetry_retry.py b/tests/e2e/test_telemetry_retry.py new file mode 100644 index 00000000..70089b7d --- /dev/null +++ b/tests/e2e/test_telemetry_retry.py @@ -0,0 +1,107 @@ +import pytest +from unittest.mock import patch, MagicMock +import io +import time + +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +from databricks.sql.auth.retry import DatabricksRetryPolicy + +PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' + +def create_mock_conn(responses): + """Creates a mock connection object whose getresponse() method yields a series of responses.""" + mock_conn = MagicMock() + mock_http_responses = [] + for resp in responses: + mock_http_response = MagicMock() + mock_http_response.status = resp.get("status") + mock_http_response.headers = resp.get("headers", {}) + body = resp.get("body", b'{}') + mock_http_response.fp = io.BytesIO(body) + def release(): + mock_http_response.fp.close() + mock_http_response.release_conn = release + mock_http_responses.append(mock_http_response) + mock_conn.getresponse.side_effect = mock_http_responses + return mock_conn + +class TestTelemetryClientRetries: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + TelemetryClientFactory._initialized = False + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + yield + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._initialized = False + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + + def get_client(self, session_id, num_retries=3): + """ + Configures a client with a specific number of retries. + """ + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=None, + host_url="test.databricks.com", + ) + client = TelemetryClientFactory.get_telemetry_client(session_id) + + retry_policy = DatabricksRetryPolicy( + delay_min=0.01, + delay_max=0.02, + stop_after_attempts_duration=2.0, + stop_after_attempts_count=num_retries, + delay_default=0.1, + force_dangerous_codes=[], + urllib3_kwargs={'total': num_retries} + ) + adapter = client._session.adapters.get("https://") + adapter.max_retries = retry_policy + return client, adapter + + @pytest.mark.parametrize( + "status_code, description", + [ + (401, "Unauthorized"), + (403, "Forbidden"), + (501, "Not Implemented"), + (200, "Success"), + ], + ) + def test_non_retryable_status_codes_are_not_retried(self, status_code, description): + """ + Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried. + """ + # Use the status code in the session ID for easier debugging if it fails + client, _ = self.get_client(f"session-{status_code}") + mock_responses = [{"status": status_code}] + + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + client.export_failure_log("TestError", "Test message") + TelemetryClientFactory.close(client._session_id_hex) + + mock_get_conn.return_value.getresponse.assert_called_once() + + def test_exceeds_retry_count_limit(self): + """ + Verifies that the client retries up to the specified number of times before giving up. + Verifies that the client respects the Retry-After header and retries on 429, 502, 503. + """ + num_retries = 3 + expected_total_calls = num_retries + 1 + retry_after = 1 + client, _ = self.get_client("session-exceed-limit", num_retries=num_retries) + mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}] + + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + start_time = time.time() + client.export_failure_log("TestError", "Test message") + TelemetryClientFactory.close(client._session_id_hex) + end_time = time.time() + + assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls + assert end_time - start_time > retry_after \ No newline at end of file diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index dc1c7d63..acf15969 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -91,7 +91,7 @@ def test_network_request_flow(self, mock_post, mock_telemetry_client): args, kwargs = client._executor.submit.call_args # Verify correct function and URL - assert args[0] == requests.post + assert args[0] == client._http_client.post assert args[1] == "https://test-host.com/telemetry-ext" assert kwargs["headers"]["Authorization"] == "Bearer test-token"