diff --git a/examples/experimental/tests/test_sea_async_query.py b/examples/experimental/tests/test_sea_async_query.py index 3c0e325f..5bc6c679 100644 --- a/examples/experimental/tests/test_sea_async_query.py +++ b/examples/experimental/tests/test_sea_async_query.py @@ -45,6 +45,7 @@ def test_sea_async_query_with_cloud_fetch(): use_sea=True, user_agent_entry="SEA-Test-Client", use_cloud_fetch=True, + enable_query_result_lz4_compression=False, ) logger.info( diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 76941e2d..16ee80a7 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -43,6 +43,7 @@ def test_sea_sync_query_with_cloud_fetch(): use_sea=True, user_agent_entry="SEA-Test-Client", use_cloud_fetch=True, + enable_query_result_lz4_compression=False, ) logger.info( diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index dd3ace9e..01f97992 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -130,6 +130,8 @@ def __init__( "_use_arrow_native_complex_types", True ) + self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True) + # Extract warehouse ID from http_path self.warehouse_id = self._extract_warehouse_id(http_path) @@ -456,7 +458,11 @@ def execute_command( ResultFormat.ARROW_STREAM if use_cloud_fetch else ResultFormat.JSON_ARRAY ).value disposition = ( - ResultDisposition.EXTERNAL_LINKS + ( + ResultDisposition.HYBRID + if self.use_hybrid_disposition + else ResultDisposition.EXTERNAL_LINKS + ) if use_cloud_fetch else ResultDisposition.INLINE ).value @@ -637,7 +643,9 @@ def get_execution_result( arraysize=cursor.arraysize, ) - def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink: + def get_chunk_links( + self, statement_id: str, chunk_index: int + ) -> List[ExternalLink]: """ Get links for chunks starting from the specified index. Args: @@ -653,18 +661,8 @@ def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink: ) response = GetChunksResponse.from_dict(response_data) - links = response.external_links or [] - link = next((l for l in links if l.chunk_index == chunk_index), None) - if not link: - raise ServerOperationError( - f"No link found for chunk index {chunk_index}", - { - "operation-id": statement_id, - "diagnostic-info": None, - }, - ) - - return link + links = response.external_links + return links # == Metadata Operations == diff --git a/src/databricks/sql/backend/sea/models/responses.py b/src/databricks/sql/backend/sea/models/responses.py index 6bd28c9b..111c1d7f 100644 --- a/src/databricks/sql/backend/sea/models/responses.py +++ b/src/databricks/sql/backend/sea/models/responses.py @@ -4,6 +4,7 @@ These models define the structures used in SEA API responses. """ +import base64 from typing import Dict, Any, List, Optional from dataclasses import dataclass @@ -91,6 +92,12 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: ) ) + # Handle attachment field - decode from base64 if present + attachment = result_data.get("attachment") + if attachment is not None and isinstance(attachment, str): + # Decode base64 string to bytes + attachment = base64.b64decode(attachment) + return ResultData( data=result_data.get("data_array"), external_links=external_links, @@ -100,7 +107,7 @@ def _parse_result(data: Dict[str, Any]) -> ResultData: next_chunk_internal_link=result_data.get("next_chunk_internal_link"), row_count=result_data.get("row_count"), row_offset=result_data.get("row_offset"), - attachment=result_data.get("attachment"), + attachment=attachment, ) diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 1e1c41c3..ac9c8499 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -5,6 +5,8 @@ from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager +from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler + try: import pyarrow except ImportError: @@ -22,7 +24,12 @@ from databricks.sql.exc import ProgrammingError, ServerOperationError from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.types import SSLOptions -from databricks.sql.utils import CloudFetchQueue, ResultSetQueue +from databricks.sql.utils import ( + ArrowQueue, + CloudFetchQueue, + ResultSetQueue, + create_arrow_table_from_arrow_file, +) import logging @@ -61,6 +68,18 @@ def build_queue( # INLINE disposition with JSON_ARRAY format return JsonQueue(result_data.data) elif manifest.format == ResultFormat.ARROW_STREAM.value: + if result_data.attachment is not None: + arrow_file = ( + ResultSetDownloadHandler._decompress_data(result_data.attachment) + if lz4_compressed + else result_data.attachment + ) + arrow_table = create_arrow_table_from_arrow_file( + arrow_file, description + ) + logger.debug(f"Created arrow table with {arrow_table.num_rows} rows") + return ArrowQueue(arrow_table, manifest.total_row_count) + # EXTERNAL_LINKS disposition return SeaCloudFetchQueue( result_data=result_data, @@ -140,6 +159,7 @@ def __init__( self._sea_client = sea_client self._statement_id = statement_id + self._total_chunk_count = total_chunk_count logger.debug( "SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}".format( @@ -147,11 +167,20 @@ def __init__( ) ) - initial_links = result_data.external_links or [] - first_link = next((l for l in initial_links if l.chunk_index == 0), None) + initial_links = result_data.external_links + self._chunk_index_to_link = {link.chunk_index: link for link in initial_links} + + first_link = self._chunk_index_to_link.get(0, None) if not first_link: # possibly an empty response - return None + return + + self.download_manager = ResultFileDownloadManager( + links=[], + max_download_threads=max_download_threads, + lz4_compressed=lz4_compressed, + ssl_options=ssl_options, + ) # Track the current chunk we're processing self._current_chunk_link = first_link @@ -172,6 +201,17 @@ def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink httpHeaders=link.http_headers or {}, ) + def _get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]: + if chunk_index >= self._total_chunk_count: + raise ValueError( + f"Chunk index {chunk_index} is greater than total chunk count {self._total_chunk_count}" + ) + + if chunk_index not in self._chunk_index_to_link: + links = self._sea_client.get_chunk_links(self._statement_id, chunk_index) + self._chunk_index_to_link.update({link.chunk_index: link for link in links}) + return self._chunk_index_to_link.get(chunk_index, None) + def _progress_chunk_link(self): """Progress to the next chunk link.""" if not self._current_chunk_link: @@ -183,18 +223,14 @@ def _progress_chunk_link(self): self._current_chunk_link = None return None - try: - self._current_chunk_link = self._sea_client.get_chunk_link( - self._statement_id, next_chunk_index - ) - except Exception as e: - raise ServerOperationError( - f"Error fetching link for chunk {next_chunk_index}: {e}", - { - "operation-id": self._statement_id, - "diagnostic-info": None, - }, + self._current_chunk_link = self._get_chunk_link(next_chunk_index) + if not self._current_chunk_link: + logger.error( + "SeaCloudFetchQueue: unable to retrieve link for chunk {}".format( + next_chunk_index + ) ) + return None logger.debug( f"SeaCloudFetchQueue: Progressed to link for chunk {next_chunk_index}: {self._current_chunk_link}" diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 402da0de..46ce8c98 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -28,7 +28,7 @@ class ResultFormat(Enum): class ResultDisposition(Enum): """Enum for result disposition values.""" - # TODO: add support for hybrid disposition + HYBRID = "INLINE_OR_EXTERNAL_LINKS" EXTERNAL_LINKS = "EXTERNAL_LINKS" INLINE = "INLINE" diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 75e89d92..dfa732c2 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -99,6 +99,10 @@ def __init__( Connect to a Databricks SQL endpoint or a Databricks cluster. Parameters: + :param use_sea: `bool`, optional (default is False) + Use the SEA backend instead of the Thrift backend. + :param use_hybrid_disposition: `bool`, optional (default is False) + Use the hybrid disposition instead of the inline disposition. :param server_hostname: Databricks instance host name. :param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 493b8dc1..f771ec98 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -893,76 +893,3 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor): cursor=mock_cursor, ) assert "Catalog name is required for get_columns" in str(excinfo.value) - - def test_get_chunk_link(self, sea_client, mock_http_client, sea_command_id): - """Test get_chunk_link method.""" - # Setup mock response - mock_response = { - "external_links": [ - { - "external_link": "https://example.com/data/chunk0", - "expiration": "2025-07-03T05:51:18.118009", - "row_count": 100, - "byte_count": 1024, - "row_offset": 0, - "chunk_index": 0, - "next_chunk_index": 1, - "http_headers": {"Authorization": "Bearer token123"}, - } - ] - } - mock_http_client._make_request.return_value = mock_response - - # Call the method - result = sea_client.get_chunk_link("test-statement-123", 0) - - # Verify the HTTP client was called correctly - mock_http_client._make_request.assert_called_once_with( - method="GET", - path=sea_client.CHUNK_PATH_WITH_ID_AND_INDEX.format( - "test-statement-123", 0 - ), - ) - - # Verify the result - assert result.external_link == "https://example.com/data/chunk0" - assert result.expiration == "2025-07-03T05:51:18.118009" - assert result.row_count == 100 - assert result.byte_count == 1024 - assert result.row_offset == 0 - assert result.chunk_index == 0 - assert result.next_chunk_index == 1 - assert result.http_headers == {"Authorization": "Bearer token123"} - - def test_get_chunk_link_not_found(self, sea_client, mock_http_client): - """Test get_chunk_link when the requested chunk is not found.""" - # Setup mock response with no matching chunk - mock_response = { - "external_links": [ - { - "external_link": "https://example.com/data/chunk1", - "expiration": "2025-07-03T05:51:18.118009", - "row_count": 100, - "byte_count": 1024, - "row_offset": 100, - "chunk_index": 1, # Different chunk index - "next_chunk_index": 2, - "http_headers": {"Authorization": "Bearer token123"}, - } - ] - } - mock_http_client._make_request.return_value = mock_response - - # Call the method and expect an exception - with pytest.raises( - ServerOperationError, match="No link found for chunk index 0" - ): - sea_client.get_chunk_link("test-statement-123", 0) - - # Verify the HTTP client was called correctly - mock_http_client._make_request.assert_called_once_with( - method="GET", - path=sea_client.CHUNK_PATH_WITH_ID_AND_INDEX.format( - "test-statement-123", 0 - ), - ) diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index 748f5d26..12c9448b 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -426,77 +426,6 @@ def test_progress_chunk_link_no_next_chunk(self, mock_logger): assert result is None assert queue._current_chunk_link is None - @patch("databricks.sql.backend.sea.queue.logger") - def test_progress_chunk_link_success(self, mock_logger, mock_sea_client): - """Test _progress_chunk_link with successful progression.""" - # Create a queue instance without initializing - queue = Mock(spec=SeaCloudFetchQueue) - queue._current_chunk_link = ExternalLink( - external_link="https://example.com/data/chunk0", - expiration="2025-07-03T05:51:18.118009", - row_count=100, - byte_count=1024, - row_offset=0, - chunk_index=0, - next_chunk_index=1, - http_headers={"Authorization": "Bearer token123"}, - ) - queue._sea_client = mock_sea_client - queue._statement_id = "test-statement-123" - - # Setup the mock client to return a new link - next_link = ExternalLink( - external_link="https://example.com/data/chunk1", - expiration="2025-07-03T05:51:18.235843", - row_count=50, - byte_count=512, - row_offset=100, - chunk_index=1, - next_chunk_index=None, - http_headers={"Authorization": "Bearer token123"}, - ) - mock_sea_client.get_chunk_link.return_value = next_link - - # Call the method directly - SeaCloudFetchQueue._progress_chunk_link(queue) - - # Verify the client was called - mock_sea_client.get_chunk_link.assert_called_once_with("test-statement-123", 1) - - # Verify debug message was logged - mock_logger.debug.assert_called_with( - f"SeaCloudFetchQueue: Progressed to link for chunk 1: {next_link}" - ) - - @patch("databricks.sql.backend.sea.queue.logger") - def test_progress_chunk_link_error(self, mock_logger, mock_sea_client): - """Test _progress_chunk_link with error during chunk fetch.""" - # Create a queue instance without initializing - queue = Mock(spec=SeaCloudFetchQueue) - queue._current_chunk_link = ExternalLink( - external_link="https://example.com/data/chunk0", - expiration="2025-07-03T05:51:18.118009", - row_count=100, - byte_count=1024, - row_offset=0, - chunk_index=0, - next_chunk_index=1, - http_headers={"Authorization": "Bearer token123"}, - ) - queue._sea_client = mock_sea_client - queue._statement_id = "test-statement-123" - - # Setup the mock client to raise an error - error_message = "Network error" - mock_sea_client.get_chunk_link.side_effect = Exception(error_message) - - # Call the method directly - with pytest.raises(ServerOperationError, match=error_message): - SeaCloudFetchQueue._progress_chunk_link(queue) - - # Verify the client was called - mock_sea_client.get_chunk_link.assert_called_once_with("test-statement-123", 1) - @patch("databricks.sql.backend.sea.queue.logger") def test_create_next_table_no_current_link(self, mock_logger): """Test _create_next_table with no current link."""