Skip to content

Commit 3beb0fc

Browse files
authored
feat: support sharable session termination by adding terminate option to stop() function (#162)
1 parent c693b63 commit 3beb0fc

File tree

2 files changed

+136
-9
lines changed

2 files changed

+136
-9
lines changed

google/cloud/dataproc_spark_connect/session.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,27 +1179,64 @@ def addArtifacts(
11791179
def _get_active_session_file_path():
11801180
return os.getenv("DATAPROC_SPARK_CONNECT_ACTIVE_SESSION_FILE_PATH")
11811181

1182-
def stop(self) -> None:
1182+
def stop(self, terminate: Optional[bool] = None) -> None:
1183+
"""
1184+
Stop the Spark session and optionally terminate the server-side session.
1185+
1186+
Parameters
1187+
----------
1188+
terminate : bool, optional
1189+
Control server-side termination behavior.
1190+
1191+
- None (default): Auto-detect based on session type
1192+
1193+
- Managed sessions (auto-generated ID): terminate server
1194+
- Named sessions (custom ID): client-side cleanup only
1195+
1196+
- True: Always terminate the server-side session
1197+
- False: Never terminate the server-side session (client cleanup only)
1198+
1199+
Examples
1200+
--------
1201+
Auto-detect termination behavior (existing behavior):
1202+
1203+
>>> spark.stop()
1204+
1205+
Force terminate a named session:
1206+
1207+
>>> spark.stop(terminate=True)
1208+
1209+
Prevent termination of a managed session:
1210+
1211+
>>> spark.stop(terminate=False)
1212+
"""
11831213
with DataprocSparkSession._lock:
11841214
if DataprocSparkSession._active_s8s_session_id is not None:
1185-
# Check if this is a managed session (auto-generated ID) or unmanaged session (custom ID)
1186-
if DataprocSparkSession._active_session_uses_custom_id:
1187-
# Unmanaged session (custom ID): Only clean up client-side state
1188-
# Don't terminate as it might be in use by other notebooks or clients
1189-
logger.debug(
1190-
f"Stopping unmanaged session {DataprocSparkSession._active_s8s_session_id} without termination"
1215+
# Determine if we should terminate the server-side session
1216+
if terminate is None:
1217+
# Auto-detect: managed sessions terminate, named sessions don't
1218+
should_terminate = (
1219+
not DataprocSparkSession._active_session_uses_custom_id
11911220
)
11921221
else:
1193-
# Managed session (auto-generated ID): Use original behavior and terminate
1222+
should_terminate = terminate
1223+
1224+
if should_terminate:
1225+
# Terminate the server-side session
11941226
logger.debug(
1195-
f"Terminating managed session {DataprocSparkSession._active_s8s_session_id}"
1227+
f"Terminating session {DataprocSparkSession._active_s8s_session_id}"
11961228
)
11971229
terminate_s8s_session(
11981230
DataprocSparkSession._project_id,
11991231
DataprocSparkSession._region,
12001232
DataprocSparkSession._active_s8s_session_id,
12011233
self._client_options,
12021234
)
1235+
else:
1236+
# Client-side cleanup only
1237+
logger.debug(
1238+
f"Stopping session {DataprocSparkSession._active_s8s_session_id} without termination"
1239+
)
12031240

12041241
self._remove_stopped_session_from_file()
12051242

tests/integration/test_session.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,96 @@ def test_sparksql_magic_with_dataproc_session(connect_session):
666666
assert row["joined_string"] == "Dataproc-Spark"
667667

668668

669+
def test_stop_named_session_with_terminate_true(
670+
auth_type,
671+
test_project,
672+
test_region,
673+
session_controller_client,
674+
os_environment,
675+
):
676+
"""Test that stop(terminate=True) terminates a named session on the server."""
677+
# Use a randomized session ID to avoid conflicts
678+
custom_session_id = f"test-terminate-true-{uuid.uuid4().hex[:8]}"
679+
680+
# Create a session with custom ID
681+
spark = (
682+
DataprocSparkSession.builder.dataprocSessionId(custom_session_id)
683+
.projectId(test_project)
684+
.location(test_region)
685+
.getOrCreate()
686+
)
687+
688+
# Verify session is created
689+
assert DataprocSparkSession._active_s8s_session_id == custom_session_id
690+
session_name = f"projects/{test_project}/locations/{test_region}/sessions/{custom_session_id}"
691+
692+
# Test basic functionality
693+
df = spark.createDataFrame([(1, "test")], ["id", "value"])
694+
assert df.count() == 1
695+
696+
# Stop with terminate=True
697+
spark.stop(terminate=True)
698+
699+
# Verify client-side cleanup
700+
assert DataprocSparkSession._active_s8s_session_id is None
701+
702+
# Verify server-side session is terminating or terminated
703+
get_session_request = GetSessionRequest()
704+
get_session_request.name = session_name
705+
session = session_controller_client.get_session(get_session_request)
706+
707+
assert session.state in [
708+
Session.State.TERMINATING,
709+
Session.State.TERMINATED,
710+
]
711+
712+
713+
def test_stop_managed_session_with_terminate_false(
714+
auth_type,
715+
test_project,
716+
test_region,
717+
session_controller_client,
718+
os_environment,
719+
):
720+
"""Test that stop(terminate=False) does NOT terminate a managed session on the server."""
721+
# Create a managed session (auto-generated ID)
722+
spark = (
723+
DataprocSparkSession.builder.projectId(test_project)
724+
.location(test_region)
725+
.getOrCreate()
726+
)
727+
728+
# Verify it's a managed session (auto-generated ID)
729+
assert DataprocSparkSession._active_s8s_session_id is not None
730+
assert DataprocSparkSession._active_session_uses_custom_id is False
731+
session_id = DataprocSparkSession._active_s8s_session_id
732+
session_name = (
733+
f"projects/{test_project}/locations/{test_region}/sessions/{session_id}"
734+
)
735+
736+
# Test basic functionality
737+
df = spark.createDataFrame([(1, "test")], ["id", "value"])
738+
assert df.count() == 1
739+
740+
# Stop with terminate=False (prevent server-side termination)
741+
spark.stop(terminate=False)
742+
743+
# Verify client-side cleanup
744+
assert DataprocSparkSession._active_s8s_session_id is None
745+
746+
# Verify server-side session is still ACTIVE (not terminated)
747+
get_session_request = GetSessionRequest()
748+
get_session_request.name = session_name
749+
session = session_controller_client.get_session(get_session_request)
750+
751+
assert session.state == Session.State.ACTIVE
752+
753+
# Clean up: terminate the session manually
754+
terminate_session_request = TerminateSessionRequest()
755+
terminate_session_request.name = session_name
756+
session_controller_client.terminate_session(terminate_session_request)
757+
758+
669759
@pytest.fixture
670760
def batch_workload_env(monkeypatch):
671761
"""Sets DATAPROC_WORKLOAD_TYPE to 'batch' for a test."""

0 commit comments

Comments
 (0)