diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml index bf247db613dba..cfcd6a2a5ec89 100644 --- a/.github/workflows/build_python_connect.yml +++ b/.github/workflows/build_python_connect.yml @@ -79,6 +79,8 @@ jobs: env: SPARK_TESTING: 1 SPARK_CONNECT_TESTING_REMOTE: sc://localhost + # Increase socket timeout for CI environment reliability + SPARK_AUTH_SOCKET_TIMEOUT: 60 run: | # Make less noisy cp conf/log4j2.properties.template conf/log4j2.properties diff --git a/.github/workflows/build_python_connect35.yml b/.github/workflows/build_python_connect35.yml index 6c37091afcb4f..6c0bf433505ba 100644 --- a/.github/workflows/build_python_connect35.yml +++ b/.github/workflows/build_python_connect35.yml @@ -22,6 +22,8 @@ name: Build / Python-only, Connect-only (master-server, branch-3.5-client, Pytho on: schedule: - cron: '0 21 * * *' + pull_request: + branches: [ master, main ] workflow_dispatch: jobs: @@ -82,6 +84,8 @@ jobs: SPARK_TESTING: 1 SPARK_SKIP_CONNECT_COMPAT_TESTS: 1 SPARK_CONNECT_TESTING_REMOTE: sc://localhost + # Increase socket timeout for CI environment reliability + SPARK_AUTH_SOCKET_TIMEOUT: 60 run: | # Make less noisy cp conf/log4j2.properties.template conf/log4j2.properties diff --git a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py index b819634adb5a6..13aafe17a1c2e 100644 --- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py @@ -55,6 +55,28 @@ def process(df_id, batch_id): # type: ignore[no-untyped-def] f"{log_name} Completed batch {batch_id} with DF id {df_id} and session id {session_id}" ) + def create_spark_session_with_retry(connect_url, session_id, max_retries=3): + """Create Spark Connect session with retry logic for better reliability in CI environments.""" + import time + + for attempt in range(max_retries): + try: + print(f"{log_name} Attempting to connect (attempt {attempt + 1}/{max_retries})") + spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() + assert spark_connect_session.session_id == session_id + print(f"{log_name} Successfully connected to Spark Connect server") + return spark_connect_session + except Exception as e: + print(f"{log_name} Connection attempt {attempt + 1} failed: {str(e)}") + if attempt < max_retries - 1: + # Exponential backoff: 1s, 2s, 4s + wait_time = 2 ** attempt + print(f"{log_name} Retrying in {wait_time} seconds...") + time.sleep(wait_time) + else: + print(f"{log_name} All connection attempts failed") + raise + try: check_python_version(infile) @@ -68,8 +90,9 @@ def process(df_id, batch_id): # type: ignore[no-untyped-def] # To attach to the existing SparkSession, we're setting the session_id in the URL. connect_url = connect_url + ";session_id=" + session_id - spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() - assert spark_connect_session.session_id == session_id + + # Use retry logic for better reliability in CI environments + spark_connect_session = create_spark_session_with_retry(connect_url, session_id) spark = spark_connect_session func = worker.read_command(pickle_ser, infile)