Skip to content

Commit fd5be9e

Browse files
iatkinsonIan Atkinson
andauthored
Make ExecutorFactory multi-process safe (#2546)
# Rationale for this change `ExecutorFactor` can be left in an invalid state, which causes new `submit` calls to hang indefinitely in multiprocessing. Closes #2545 This PR detects when the executor is used in a new process and creates a new instance ## Are these changes tested? Yes, new tests added to `test_concurrent.py` ## Are there any user-facing changes? No --------- Co-authored-by: Ian Atkinson <[email protected]>
1 parent 2ad3280 commit fd5be9e

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

pyiceberg/utils/concurrent.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
"""Concurrency concepts that support efficient multi-threading."""
1818

19+
import os
1920
from concurrent.futures import Executor, ThreadPoolExecutor
2021
from typing import Optional
2122

@@ -24,6 +25,7 @@
2425

2526
class ExecutorFactory:
2627
_instance: Optional[Executor] = None
28+
_instance_pid: Optional[int] = None
2729

2830
@staticmethod
2931
def max_workers() -> Optional[int]:
@@ -33,6 +35,13 @@ def max_workers() -> Optional[int]:
3335
@staticmethod
3436
def get_or_create() -> Executor:
3537
"""Return the same executor in each call."""
38+
# ThreadPoolExecutor cannot be shared across processes. If a new pid is found it means
39+
# there is a new process so a new executor is needed. Otherwise, the executor may be in
40+
# an invalid state and tasks submitted will not be started.
41+
if ExecutorFactory._instance_pid != os.getpid():
42+
ExecutorFactory._instance_pid = os.getpid()
43+
ExecutorFactory._instance = None
44+
3645
if ExecutorFactory._instance is None:
3746
max_workers = ExecutorFactory.max_workers()
3847
ExecutorFactory._instance = ThreadPoolExecutor(max_workers=max_workers)

tests/utils/test_concurrent.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import multiprocessing
1819
import os
19-
from concurrent.futures import ThreadPoolExecutor
20-
from typing import Dict, Optional
20+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
21+
from typing import Dict, Generator, Optional
2122
from unittest import mock
2223

2324
import pytest
@@ -29,6 +30,41 @@
2930
INVALID_ENV = {"PYICEBERG_MAX_WORKERS": "invalid"}
3031

3132

33+
@pytest.fixture
34+
def fork_process() -> Generator[None, None, None]:
35+
original = multiprocessing.get_start_method()
36+
allowed = multiprocessing.get_all_start_methods()
37+
38+
assert "fork" in allowed
39+
40+
multiprocessing.set_start_method("fork", force=True)
41+
42+
yield
43+
44+
multiprocessing.set_start_method(original, force=True)
45+
46+
47+
@pytest.fixture
48+
def spawn_process() -> Generator[None, None, None]:
49+
original = multiprocessing.get_start_method()
50+
allowed = multiprocessing.get_all_start_methods()
51+
52+
assert "spawn" in allowed
53+
54+
multiprocessing.set_start_method("spawn", force=True)
55+
56+
yield
57+
58+
multiprocessing.set_start_method(original, force=True)
59+
60+
61+
def _use_executor_to_return(value: int) -> int:
62+
# Module level function to enabling pickling for use with ProcessPoolExecutor.
63+
executor = ExecutorFactory.get_or_create()
64+
future = executor.submit(lambda: value)
65+
return future.result()
66+
67+
3268
def test_create_reused() -> None:
3369
first = ExecutorFactory.get_or_create()
3470
second = ExecutorFactory.get_or_create()
@@ -50,3 +86,38 @@ def test_max_workers() -> None:
5086
def test_max_workers_invalid() -> None:
5187
with pytest.raises(ValueError):
5288
ExecutorFactory.max_workers()
89+
90+
91+
@pytest.mark.parametrize(
92+
"fixture_name",
93+
[
94+
pytest.param(
95+
"fork_process",
96+
marks=pytest.mark.skipif(
97+
"fork" not in multiprocessing.get_all_start_methods(), reason="Fork start method is not available"
98+
),
99+
),
100+
pytest.param(
101+
"spawn_process",
102+
marks=pytest.mark.skipif(
103+
"spawn" not in multiprocessing.get_all_start_methods(), reason="Spawn start method is not available"
104+
),
105+
),
106+
],
107+
)
108+
def test_use_executor_in_different_process(fixture_name: str, request: pytest.FixtureRequest) -> None:
109+
# Use the fixture, which sets up fork or spawn process start method.
110+
request.getfixturevalue(fixture_name)
111+
112+
# Use executor in main process to ensure the singleton is initialized.
113+
main_value = _use_executor_to_return(10)
114+
115+
# Use two separate ProcessPoolExecutors to ensure different processes are used.
116+
with ProcessPoolExecutor() as process_executor:
117+
future1 = process_executor.submit(_use_executor_to_return, 20)
118+
with ProcessPoolExecutor() as process_executor:
119+
future2 = process_executor.submit(_use_executor_to_return, 30)
120+
121+
assert main_value == 10
122+
assert future1.result() == 20
123+
assert future2.result() == 30

0 commit comments

Comments
 (0)