diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 6f39e264..42677b90 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -18,7 +18,8 @@ if TYPE_CHECKING: from databricks.sql.client import Cursor - from databricks.sql.backend.sea.result_set import SeaResultSet + +from databricks.sql.backend.sea.result_set import SeaResultSet from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import ( @@ -332,7 +333,7 @@ def _extract_description_from_manifest( return columns def _results_message_to_execute_response( - self, response: GetStatementResponse + self, response: Union[ExecuteStatementResponse, GetStatementResponse] ) -> ExecuteResponse: """ Convert a SEA response to an ExecuteResponse and extract result data. @@ -366,6 +367,27 @@ def _results_message_to_execute_response( return execute_response + def _response_to_result_set( + self, + response: Union[ExecuteStatementResponse, GetStatementResponse], + cursor: Cursor, + ) -> SeaResultSet: + """ + Convert a SEA response to a SeaResultSet. + """ + + execute_response = self._results_message_to_execute_response(response) + + return SeaResultSet( + connection=cursor.connection, + execute_response=execute_response, + sea_client=self, + result_data=response.result, + manifest=response.manifest, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + ) + def _check_command_not_in_failed_or_closed_state( self, state: CommandState, command_id: CommandId ) -> None: @@ -386,21 +408,24 @@ def _check_command_not_in_failed_or_closed_state( def _wait_until_command_done( self, response: ExecuteStatementResponse - ) -> CommandState: + ) -> Union[ExecuteStatementResponse, GetStatementResponse]: """ Wait until a command is done. """ - state = response.status.state - command_id = CommandId.from_sea_statement_id(response.statement_id) + final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response + + state = final_response.status.state + command_id = CommandId.from_sea_statement_id(final_response.statement_id) while state in [CommandState.PENDING, CommandState.RUNNING]: time.sleep(self.POLL_INTERVAL_SECONDS) - state = self.get_query_state(command_id) + final_response = self._poll_query(command_id) + state = final_response.status.state self._check_command_not_in_failed_or_closed_state(state, command_id) - return state + return final_response def execute_command( self, @@ -506,8 +531,11 @@ def execute_command( if async_op: return None - self._wait_until_command_done(response) - return self.get_execution_result(command_id, cursor) + final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response + if response.status.state != CommandState.SUCCEEDED: + final_response = self._wait_until_command_done(response) + + return self._response_to_result_set(final_response, cursor) def cancel_command(self, command_id: CommandId) -> None: """ @@ -559,18 +587,9 @@ def close_command(self, command_id: CommandId) -> None: data=request.to_dict(), ) - def get_query_state(self, command_id: CommandId) -> CommandState: + def _poll_query(self, command_id: CommandId) -> GetStatementResponse: """ - Get the state of a running query. - - Args: - command_id: Command identifier - - Returns: - CommandState: The current state of the command - - Raises: - ValueError: If the command ID is invalid + Poll for the current command info. """ if command_id.backend_type != BackendType.SEA: @@ -586,9 +605,25 @@ def get_query_state(self, command_id: CommandId) -> CommandState: path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), data=request.to_dict(), ) - - # Parse the response response = GetStatementResponse.from_dict(response_data) + + return response + + def get_query_state(self, command_id: CommandId) -> CommandState: + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ProgrammingError: If the command ID is invalid + """ + + response = self._poll_query(command_id) return response.status.state def get_execution_result( @@ -610,38 +645,8 @@ def get_execution_result( ValueError: If the command ID is invalid """ - if command_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA command ID") - - sea_statement_id = command_id.to_sea_statement_id() - if sea_statement_id is None: - raise ValueError("Not a valid SEA command ID") - - # Create the request model - request = GetStatementRequest(statement_id=sea_statement_id) - - # Get the statement result - response_data = self._http_client._make_request( - method="GET", - path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), - data=request.to_dict(), - ) - response = GetStatementResponse.from_dict(response_data) - - # Create and return a SeaResultSet - from databricks.sql.backend.sea.result_set import SeaResultSet - - execute_response = self._results_message_to_execute_response(response) - - return SeaResultSet( - connection=cursor.connection, - execute_response=execute_response, - sea_client=self, - result_data=response.result, - manifest=response.manifest, - buffer_size_bytes=cursor.buffer_size_bytes, - arraysize=cursor.arraysize, - ) + response = self._poll_query(command_id) + return self._response_to_result_set(response, cursor) def get_chunk_links( self, statement_id: str, chunk_index: int diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index b67fc74d..a6a0a298 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -4,7 +4,6 @@ import logging -from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.sea.models.base import ResultData, ResultManifest from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter @@ -15,6 +14,7 @@ if TYPE_CHECKING: from databricks.sql.client import Connection + from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.types import Row from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory from databricks.sql.backend.types import ExecuteResponse diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 877136cf..5f920e24 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -227,7 +227,7 @@ def test_command_execution_sync( mock_http_client._make_request.return_value = execute_response with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" + sea_client, "_response_to_result_set", return_value="mock_result_set" ) as mock_get_result: result = sea_client.execute_command( operation="SELECT 1", @@ -242,9 +242,6 @@ def test_command_execution_sync( enforce_embedded_schema_correctness=False, ) assert result == "mock_result_set" - cmd_id_arg = mock_get_result.call_args[0][0] - assert isinstance(cmd_id_arg, CommandId) - assert cmd_id_arg.guid == "test-statement-123" # Test with invalid session ID with pytest.raises(ValueError) as excinfo: @@ -332,7 +329,7 @@ def test_command_execution_advanced( mock_http_client._make_request.side_effect = [initial_response, poll_response] with patch.object( - sea_client, "get_execution_result", return_value="mock_result_set" + sea_client, "_response_to_result_set", return_value="mock_result_set" ) as mock_get_result: with patch("time.sleep"): result = sea_client.execute_command( @@ -360,7 +357,7 @@ def test_command_execution_advanced( dbsql_param = IntegerParameter(name="param1", value=1) param = dbsql_param.as_tspark_param(named=True) - with patch.object(sea_client, "get_execution_result"): + with patch.object(sea_client, "_response_to_result_set"): sea_client.execute_command( operation="SELECT * FROM table WHERE col = :param1", session_id=sea_session_id,