@@ -666,6 +666,96 @@ def test_sparksql_magic_with_dataproc_session(connect_session):
666666 assert row ["joined_string" ] == "Dataproc-Spark"
667667
668668
669+ def test_stop_named_session_with_terminate_true (
670+ auth_type ,
671+ test_project ,
672+ test_region ,
673+ session_controller_client ,
674+ os_environment ,
675+ ):
676+ """Test that stop(terminate=True) terminates a named session on the server."""
677+ # Use a randomized session ID to avoid conflicts
678+ custom_session_id = f"test-terminate-true-{ uuid .uuid4 ().hex [:8 ]} "
679+
680+ # Create a session with custom ID
681+ spark = (
682+ DataprocSparkSession .builder .dataprocSessionId (custom_session_id )
683+ .projectId (test_project )
684+ .location (test_region )
685+ .getOrCreate ()
686+ )
687+
688+ # Verify session is created
689+ assert DataprocSparkSession ._active_s8s_session_id == custom_session_id
690+ session_name = f"projects/{ test_project } /locations/{ test_region } /sessions/{ custom_session_id } "
691+
692+ # Test basic functionality
693+ df = spark .createDataFrame ([(1 , "test" )], ["id" , "value" ])
694+ assert df .count () == 1
695+
696+ # Stop with terminate=True
697+ spark .stop (terminate = True )
698+
699+ # Verify client-side cleanup
700+ assert DataprocSparkSession ._active_s8s_session_id is None
701+
702+ # Verify server-side session is terminating or terminated
703+ get_session_request = GetSessionRequest ()
704+ get_session_request .name = session_name
705+ session = session_controller_client .get_session (get_session_request )
706+
707+ assert session .state in [
708+ Session .State .TERMINATING ,
709+ Session .State .TERMINATED ,
710+ ]
711+
712+
713+ def test_stop_managed_session_with_terminate_false (
714+ auth_type ,
715+ test_project ,
716+ test_region ,
717+ session_controller_client ,
718+ os_environment ,
719+ ):
720+ """Test that stop(terminate=False) does NOT terminate a managed session on the server."""
721+ # Create a managed session (auto-generated ID)
722+ spark = (
723+ DataprocSparkSession .builder .projectId (test_project )
724+ .location (test_region )
725+ .getOrCreate ()
726+ )
727+
728+ # Verify it's a managed session (auto-generated ID)
729+ assert DataprocSparkSession ._active_s8s_session_id is not None
730+ assert DataprocSparkSession ._active_session_uses_custom_id is False
731+ session_id = DataprocSparkSession ._active_s8s_session_id
732+ session_name = (
733+ f"projects/{ test_project } /locations/{ test_region } /sessions/{ session_id } "
734+ )
735+
736+ # Test basic functionality
737+ df = spark .createDataFrame ([(1 , "test" )], ["id" , "value" ])
738+ assert df .count () == 1
739+
740+ # Stop with terminate=False (prevent server-side termination)
741+ spark .stop (terminate = False )
742+
743+ # Verify client-side cleanup
744+ assert DataprocSparkSession ._active_s8s_session_id is None
745+
746+ # Verify server-side session is still ACTIVE (not terminated)
747+ get_session_request = GetSessionRequest ()
748+ get_session_request .name = session_name
749+ session = session_controller_client .get_session (get_session_request )
750+
751+ assert session .state == Session .State .ACTIVE
752+
753+ # Clean up: terminate the session manually
754+ terminate_session_request = TerminateSessionRequest ()
755+ terminate_session_request .name = session_name
756+ session_controller_client .terminate_session (terminate_session_request )
757+
758+
669759@pytest .fixture
670760def batch_workload_env (monkeypatch ):
671761 """Sets DATAPROC_WORKLOAD_TYPE to 'batch' for a test."""
0 commit comments