Skip to content

SEA: add support for Hybrid disposition #631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: ext-links-sea
Choose a base branch
from
1 change: 1 addition & 0 deletions examples/experimental/tests/test_sea_async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_sea_async_query_with_cloud_fetch():
use_sea=True,
user_agent_entry="SEA-Test-Client",
use_cloud_fetch=True,
enable_query_result_lz4_compression=False,
)

logger.info(
Expand Down
1 change: 1 addition & 0 deletions examples/experimental/tests/test_sea_sync_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_sea_sync_query_with_cloud_fetch():
use_sea=True,
user_agent_entry="SEA-Test-Client",
use_cloud_fetch=True,
enable_query_result_lz4_compression=False,
)

logger.info(
Expand Down
26 changes: 12 additions & 14 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def __init__(
"_use_arrow_native_complex_types", True
)

self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True)

# Extract warehouse ID from http_path
self.warehouse_id = self._extract_warehouse_id(http_path)

Expand Down Expand Up @@ -456,7 +458,11 @@ def execute_command(
ResultFormat.ARROW_STREAM if use_cloud_fetch else ResultFormat.JSON_ARRAY
).value
disposition = (
ResultDisposition.EXTERNAL_LINKS
(
ResultDisposition.HYBRID
if self.use_hybrid_disposition
else ResultDisposition.EXTERNAL_LINKS
)
if use_cloud_fetch
else ResultDisposition.INLINE
).value
Expand Down Expand Up @@ -637,7 +643,9 @@ def get_execution_result(
arraysize=cursor.arraysize,
)

def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink:
def get_chunk_links(
self, statement_id: str, chunk_index: int
) -> List[ExternalLink]:
"""
Get links for chunks starting from the specified index.
Args:
Expand All @@ -653,18 +661,8 @@ def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink:
)
response = GetChunksResponse.from_dict(response_data)

links = response.external_links or []
link = next((l for l in links if l.chunk_index == chunk_index), None)
if not link:
raise ServerOperationError(
f"No link found for chunk index {chunk_index}",
{
"operation-id": statement_id,
"diagnostic-info": None,
},
)

return link
links = response.external_links
return links

# == Metadata Operations ==

Expand Down
9 changes: 8 additions & 1 deletion src/databricks/sql/backend/sea/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
These models define the structures used in SEA API responses.
"""

import base64
from typing import Dict, Any, List, Optional
from dataclasses import dataclass

Expand Down Expand Up @@ -91,6 +92,12 @@ def _parse_result(data: Dict[str, Any]) -> ResultData:
)
)

# Handle attachment field - decode from base64 if present
attachment = result_data.get("attachment")
if attachment is not None and isinstance(attachment, str):
# Decode base64 string to bytes
attachment = base64.b64decode(attachment)

return ResultData(
data=result_data.get("data_array"),
external_links=external_links,
Expand All @@ -100,7 +107,7 @@ def _parse_result(data: Dict[str, Any]) -> ResultData:
next_chunk_internal_link=result_data.get("next_chunk_internal_link"),
row_count=result_data.get("row_count"),
row_offset=result_data.get("row_offset"),
attachment=result_data.get("attachment"),
attachment=attachment,
)


Expand Down
66 changes: 51 additions & 15 deletions src/databricks/sql/backend/sea/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager

from databricks.sql.cloudfetch.downloader import ResultSetDownloadHandler

try:
import pyarrow
except ImportError:
Expand All @@ -22,7 +24,12 @@
from databricks.sql.exc import ProgrammingError, ServerOperationError
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
from databricks.sql.types import SSLOptions
from databricks.sql.utils import CloudFetchQueue, ResultSetQueue
from databricks.sql.utils import (
ArrowQueue,
CloudFetchQueue,
ResultSetQueue,
create_arrow_table_from_arrow_file,
)

import logging

Expand Down Expand Up @@ -61,6 +68,18 @@ def build_queue(
# INLINE disposition with JSON_ARRAY format
return JsonQueue(result_data.data)
elif manifest.format == ResultFormat.ARROW_STREAM.value:
if result_data.attachment is not None:
arrow_file = (
ResultSetDownloadHandler._decompress_data(result_data.attachment)
if lz4_compressed
else result_data.attachment
)
arrow_table = create_arrow_table_from_arrow_file(
arrow_file, description
)
logger.debug(f"Created arrow table with {arrow_table.num_rows} rows")
return ArrowQueue(arrow_table, manifest.total_row_count)

# EXTERNAL_LINKS disposition
return SeaCloudFetchQueue(
result_data=result_data,
Expand Down Expand Up @@ -140,18 +159,28 @@ def __init__(

self._sea_client = sea_client
self._statement_id = statement_id
self._total_chunk_count = total_chunk_count

logger.debug(
"SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}".format(
statement_id, total_chunk_count
)
)

initial_links = result_data.external_links or []
first_link = next((l for l in initial_links if l.chunk_index == 0), None)
initial_links = result_data.external_links
self._chunk_index_to_link = {link.chunk_index: link for link in initial_links}

first_link = self._chunk_index_to_link.get(0, None)
if not first_link:
# possibly an empty response
return None
return

self.download_manager = ResultFileDownloadManager(
links=[],
max_download_threads=max_download_threads,
lz4_compressed=lz4_compressed,
ssl_options=ssl_options,
)

# Track the current chunk we're processing
self._current_chunk_link = first_link
Expand All @@ -172,6 +201,17 @@ def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink
httpHeaders=link.http_headers or {},
)

def _get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]:
if chunk_index >= self._total_chunk_count:
raise ValueError(
f"Chunk index {chunk_index} is greater than total chunk count {self._total_chunk_count}"
)

if chunk_index not in self._chunk_index_to_link:
links = self._sea_client.get_chunk_links(self._statement_id, chunk_index)
self._chunk_index_to_link.update({link.chunk_index: link for link in links})
return self._chunk_index_to_link.get(chunk_index, None)

def _progress_chunk_link(self):
"""Progress to the next chunk link."""
if not self._current_chunk_link:
Expand All @@ -183,18 +223,14 @@ def _progress_chunk_link(self):
self._current_chunk_link = None
return None

try:
self._current_chunk_link = self._sea_client.get_chunk_link(
self._statement_id, next_chunk_index
)
except Exception as e:
raise ServerOperationError(
f"Error fetching link for chunk {next_chunk_index}: {e}",
{
"operation-id": self._statement_id,
"diagnostic-info": None,
},
self._current_chunk_link = self._get_chunk_link(next_chunk_index)
if not self._current_chunk_link:
logger.error(
"SeaCloudFetchQueue: unable to retrieve link for chunk {}".format(
next_chunk_index
)
)
return None

logger.debug(
f"SeaCloudFetchQueue: Progressed to link for chunk {next_chunk_index}: {self._current_chunk_link}"
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/backend/sea/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ResultFormat(Enum):
class ResultDisposition(Enum):
"""Enum for result disposition values."""

# TODO: add support for hybrid disposition
HYBRID = "INLINE_OR_EXTERNAL_LINKS"
EXTERNAL_LINKS = "EXTERNAL_LINKS"
INLINE = "INLINE"

Expand Down
4 changes: 4 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def __init__(
Connect to a Databricks SQL endpoint or a Databricks cluster.

Parameters:
:param use_sea: `bool`, optional (default is False)
Use the SEA backend instead of the Thrift backend.
:param use_hybrid_disposition: `bool`, optional (default is False)
Use the hybrid disposition instead of the inline disposition.
:param server_hostname: Databricks instance host name.
:param http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef)
or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123)
Expand Down
73 changes: 0 additions & 73 deletions tests/unit/test_sea_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,76 +893,3 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor):
cursor=mock_cursor,
)
assert "Catalog name is required for get_columns" in str(excinfo.value)

def test_get_chunk_link(self, sea_client, mock_http_client, sea_command_id):
"""Test get_chunk_link method."""
# Setup mock response
mock_response = {
"external_links": [
{
"external_link": "https://example.com/data/chunk0",
"expiration": "2025-07-03T05:51:18.118009",
"row_count": 100,
"byte_count": 1024,
"row_offset": 0,
"chunk_index": 0,
"next_chunk_index": 1,
"http_headers": {"Authorization": "Bearer token123"},
}
]
}
mock_http_client._make_request.return_value = mock_response

# Call the method
result = sea_client.get_chunk_link("test-statement-123", 0)

# Verify the HTTP client was called correctly
mock_http_client._make_request.assert_called_once_with(
method="GET",
path=sea_client.CHUNK_PATH_WITH_ID_AND_INDEX.format(
"test-statement-123", 0
),
)

# Verify the result
assert result.external_link == "https://example.com/data/chunk0"
assert result.expiration == "2025-07-03T05:51:18.118009"
assert result.row_count == 100
assert result.byte_count == 1024
assert result.row_offset == 0
assert result.chunk_index == 0
assert result.next_chunk_index == 1
assert result.http_headers == {"Authorization": "Bearer token123"}

def test_get_chunk_link_not_found(self, sea_client, mock_http_client):
"""Test get_chunk_link when the requested chunk is not found."""
# Setup mock response with no matching chunk
mock_response = {
"external_links": [
{
"external_link": "https://example.com/data/chunk1",
"expiration": "2025-07-03T05:51:18.118009",
"row_count": 100,
"byte_count": 1024,
"row_offset": 100,
"chunk_index": 1, # Different chunk index
"next_chunk_index": 2,
"http_headers": {"Authorization": "Bearer token123"},
}
]
}
mock_http_client._make_request.return_value = mock_response

# Call the method and expect an exception
with pytest.raises(
ServerOperationError, match="No link found for chunk index 0"
):
sea_client.get_chunk_link("test-statement-123", 0)

# Verify the HTTP client was called correctly
mock_http_client._make_request.assert_called_once_with(
method="GET",
path=sea_client.CHUNK_PATH_WITH_ID_AND_INDEX.format(
"test-statement-123", 0
),
)
71 changes: 0 additions & 71 deletions tests/unit/test_sea_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,77 +426,6 @@ def test_progress_chunk_link_no_next_chunk(self, mock_logger):
assert result is None
assert queue._current_chunk_link is None

@patch("databricks.sql.backend.sea.queue.logger")
def test_progress_chunk_link_success(self, mock_logger, mock_sea_client):
"""Test _progress_chunk_link with successful progression."""
# Create a queue instance without initializing
queue = Mock(spec=SeaCloudFetchQueue)
queue._current_chunk_link = ExternalLink(
external_link="https://example.com/data/chunk0",
expiration="2025-07-03T05:51:18.118009",
row_count=100,
byte_count=1024,
row_offset=0,
chunk_index=0,
next_chunk_index=1,
http_headers={"Authorization": "Bearer token123"},
)
queue._sea_client = mock_sea_client
queue._statement_id = "test-statement-123"

# Setup the mock client to return a new link
next_link = ExternalLink(
external_link="https://example.com/data/chunk1",
expiration="2025-07-03T05:51:18.235843",
row_count=50,
byte_count=512,
row_offset=100,
chunk_index=1,
next_chunk_index=None,
http_headers={"Authorization": "Bearer token123"},
)
mock_sea_client.get_chunk_link.return_value = next_link

# Call the method directly
SeaCloudFetchQueue._progress_chunk_link(queue)

# Verify the client was called
mock_sea_client.get_chunk_link.assert_called_once_with("test-statement-123", 1)

# Verify debug message was logged
mock_logger.debug.assert_called_with(
f"SeaCloudFetchQueue: Progressed to link for chunk 1: {next_link}"
)

@patch("databricks.sql.backend.sea.queue.logger")
def test_progress_chunk_link_error(self, mock_logger, mock_sea_client):
"""Test _progress_chunk_link with error during chunk fetch."""
# Create a queue instance without initializing
queue = Mock(spec=SeaCloudFetchQueue)
queue._current_chunk_link = ExternalLink(
external_link="https://example.com/data/chunk0",
expiration="2025-07-03T05:51:18.118009",
row_count=100,
byte_count=1024,
row_offset=0,
chunk_index=0,
next_chunk_index=1,
http_headers={"Authorization": "Bearer token123"},
)
queue._sea_client = mock_sea_client
queue._statement_id = "test-statement-123"

# Setup the mock client to raise an error
error_message = "Network error"
mock_sea_client.get_chunk_link.side_effect = Exception(error_message)

# Call the method directly
with pytest.raises(ServerOperationError, match=error_message):
SeaCloudFetchQueue._progress_chunk_link(queue)

# Verify the client was called
mock_sea_client.get_chunk_link.assert_called_once_with("test-statement-123", 1)

@patch("databricks.sql.backend.sea.queue.logger")
def test_create_next_table_no_current_link(self, mock_logger):
"""Test _create_next_table with no current link."""
Expand Down
Loading