Skip to content

SEA: Reduce network calls for synchronous commands #633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: sea-migration
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 55 additions & 52 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
DeleteSessionRequest,
StatementParameter,
ExecuteStatementResponse,
GetStatementResponse,
CreateSessionResponse,
)

Expand Down Expand Up @@ -324,7 +323,7 @@ def _extract_description_from_manifest(
return columns

def _results_message_to_execute_response(
self, response: GetStatementResponse
self, response: ExecuteStatementResponse
) -> ExecuteResponse:
"""
Convert a SEA response to an ExecuteResponse and extract result data.
Expand Down Expand Up @@ -358,6 +357,28 @@ def _results_message_to_execute_response(

return execute_response

def _response_to_result_set(
self, response: ExecuteStatementResponse, cursor: Cursor
) -> SeaResultSet:
"""
Convert a SEA response to a SeaResultSet.
"""

# Create and return a SeaResultSet
from databricks.sql.backend.sea.result_set import SeaResultSet

execute_response = self._results_message_to_execute_response(response)

return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
)

def _check_command_not_in_failed_or_closed_state(
self, state: CommandState, command_id: CommandId
) -> None:
Expand All @@ -378,7 +399,7 @@ def _check_command_not_in_failed_or_closed_state(

def _wait_until_command_done(
self, response: ExecuteStatementResponse
) -> CommandState:
) -> ExecuteStatementResponse:
"""
Wait until a command is done.
"""
Expand All @@ -388,11 +409,12 @@ def _wait_until_command_done(

while state in [CommandState.PENDING, CommandState.RUNNING]:
time.sleep(self.POLL_INTERVAL_SECONDS)
state = self.get_query_state(command_id)
response = self._poll_query(command_id)
state = response.status.state

self._check_command_not_in_failed_or_closed_state(state, command_id)

return state
return response

def execute_command(
self,
Expand Down Expand Up @@ -494,8 +516,12 @@ def execute_command(
if async_op:
return None

self._wait_until_command_done(response)
return self.get_execution_result(command_id, cursor)
if response.status.state == CommandState.SUCCEEDED:
# if the response succeeded within the wait_timeout, return the results immediately
return self._response_to_result_set(response, cursor)

response = self._wait_until_command_done(response)
return self._response_to_result_set(response, cursor)

def cancel_command(self, command_id: CommandId) -> None:
"""
Expand Down Expand Up @@ -547,18 +573,9 @@ def close_command(self, command_id: CommandId) -> None:
data=request.to_dict(),
)

def get_query_state(self, command_id: CommandId) -> CommandState:
def _poll_query(self, command_id: CommandId) -> ExecuteStatementResponse:
"""
Get the state of a running query.

Args:
command_id: Command identifier

Returns:
CommandState: The current state of the command

Raises:
ProgrammingError: If the command ID is invalid
Poll for the current command info.
"""

if command_id.backend_type != BackendType.SEA:
Expand All @@ -574,9 +591,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
data=request.to_dict(),
)
response = ExecuteStatementResponse.from_dict(response_data)

# Parse the response
response = GetStatementResponse.from_dict(response_data)
return response

def get_query_state(self, command_id: CommandId) -> CommandState:
"""
Get the state of a running query.

Args:
command_id: Command identifier

Returns:
CommandState: The current state of the command

Raises:
ProgrammingError: If the command ID is invalid
"""

response = self._poll_query(command_id)
return response.status.state

def get_execution_result(
Expand All @@ -598,38 +631,8 @@ def get_execution_result(
ValueError: If the command ID is invalid
"""

if command_id.backend_type != BackendType.SEA:
raise ValueError("Not a valid SEA command ID")

sea_statement_id = command_id.to_sea_statement_id()
if sea_statement_id is None:
raise ValueError("Not a valid SEA command ID")

# Create the request model
request = GetStatementRequest(statement_id=sea_statement_id)

# Get the statement result
response_data = self.http_client._make_request(
method="GET",
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
data=request.to_dict(),
)
response = GetStatementResponse.from_dict(response_data)

# Create and return a SeaResultSet
from databricks.sql.backend.sea.result_set import SeaResultSet

execute_response = self._results_message_to_execute_response(response)

return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
)
response = self._poll_query(command_id)
return self._response_to_result_set(response, cursor)

# == Metadata Operations ==

Expand Down
2 changes: 0 additions & 2 deletions src/databricks/sql/backend/sea/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from databricks.sql.backend.sea.models.responses import (
ExecuteStatementResponse,
GetStatementResponse,
CreateSessionResponse,
)

Expand All @@ -47,6 +46,5 @@
"DeleteSessionRequest",
# Response models
"ExecuteStatementResponse",
"GetStatementResponse",
"CreateSessionResponse",
]
20 changes: 0 additions & 20 deletions src/databricks/sql/backend/sea/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,26 +124,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
)


@dataclass
class GetStatementResponse:
"""Representation of the response from getting information about a statement."""

statement_id: str
status: StatementStatus
manifest: ResultManifest
result: ResultData

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
"""Create a GetStatementResponse from a dictionary."""
return cls(
statement_id=data.get("statement_id", ""),
status=_parse_status(data),
manifest=_parse_manifest(data),
result=_parse_result(data),
)


@dataclass
class CreateSessionResponse:
"""Representation of the response from creating a new session."""
Expand Down
9 changes: 3 additions & 6 deletions tests/unit/test_sea_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_command_execution_sync(
mock_http_client._make_request.return_value = execute_response

with patch.object(
sea_client, "get_execution_result", return_value="mock_result_set"
sea_client, "_response_to_result_set", return_value="mock_result_set"
) as mock_get_result:
result = sea_client.execute_command(
operation="SELECT 1",
Expand All @@ -242,9 +242,6 @@ def test_command_execution_sync(
enforce_embedded_schema_correctness=False,
)
assert result == "mock_result_set"
cmd_id_arg = mock_get_result.call_args[0][0]
assert isinstance(cmd_id_arg, CommandId)
assert cmd_id_arg.guid == "test-statement-123"

# Test with invalid session ID
with pytest.raises(ValueError) as excinfo:
Expand Down Expand Up @@ -332,7 +329,7 @@ def test_command_execution_advanced(
mock_http_client._make_request.side_effect = [initial_response, poll_response]

with patch.object(
sea_client, "get_execution_result", return_value="mock_result_set"
sea_client, "_response_to_result_set", return_value="mock_result_set"
) as mock_get_result:
with patch("time.sleep"):
result = sea_client.execute_command(
Expand Down Expand Up @@ -360,7 +357,7 @@ def test_command_execution_advanced(
dbsql_param = IntegerParameter(name="param1", value=1)
param = dbsql_param.as_tspark_param(named=True)

with patch.object(sea_client, "get_execution_result"):
with patch.object(sea_client, "_response_to_result_set"):
sea_client.execute_command(
operation="SELECT * FROM table WHERE col = :param1",
session_id=sea_session_id,
Expand Down
Loading