From 171cbfe96c2f8cadf96987cad4d349d9b0273448 Mon Sep 17 00:00:00 2001 From: hellovai Date: Mon, 3 Jun 2024 16:34:45 -0400 Subject: [PATCH] Fix issue with tracing in threadpool executors (#639) * When using AsyncContextVars + ThreadPoolExecutors, things can get quite hairy. In BAML we treat threadpool executors as top level threads with no prior context (i.e. things like tags and such are reset and must be re-configured) * See: https://kobybass.medium.com/python-contextvars-and-multithreading-faa33dbe953d --- .github/workflows/vscode_ext.yml | 2 +- .../python_src/baml_py/async_context_vars.py | 34 ++++++++++--- integ-tests/python/app/test_functions.py | 51 +++++++++++++++++++ typescript/fiddle-frontend/vercel-build.sh | 2 +- 4 files changed, 79 insertions(+), 10 deletions(-) diff --git a/.github/workflows/vscode_ext.yml b/.github/workflows/vscode_ext.yml index 6a8d4a35b..6fdc24cd2 100644 --- a/.github/workflows/vscode_ext.yml +++ b/.github/workflows/vscode_ext.yml @@ -36,7 +36,7 @@ jobs: working-directory: engine/baml-schema-wasm - name: Install Bindgen - run: cargo install -f wasm-bindgen-cli@0.2.87 + run: cargo install -f wasm-bindgen-cli@0.2.92 working-directory: engine/baml-schema-wasm - name: Install Rust diff --git a/engine/language_client_python/python_src/baml_py/async_context_vars.py b/engine/language_client_python/python_src/baml_py/async_context_vars.py index 34b1c9253..2230c23e9 100644 --- a/engine/language_client_python/python_src/baml_py/async_context_vars.py +++ b/engine/language_client_python/python_src/baml_py/async_context_vars.py @@ -8,41 +8,59 @@ import typing from .baml_py import RuntimeContextManager, BamlRuntime, BamlSpan import atexit +import threading F = typing.TypeVar("F", bound=typing.Callable[..., typing.Any]) +# See this article about why we need to track for every thread: +# https://kobybass.medium.com/python-contextvars-and-multithreading-faa33dbe953d +RTContextVar = contextvars.ContextVar[typing.Dict[int, RuntimeContextManager]] + + +def current_thread_id() -> int: + current_thread = threading.current_thread() + return current_thread.native_id or 0 + + class CtxManager: def __init__(self, rt: BamlRuntime): self.rt = rt - self.ctx = contextvars.ContextVar[RuntimeContextManager]( - "baml_ctx", default=rt.create_context_manager() + self.ctx = contextvars.ContextVar[typing.Dict[int, RuntimeContextManager]]( + "baml_ctx", default={current_thread_id(): rt.create_context_manager()} ) atexit.register(self.rt.flush) + def __ctx(self) -> RuntimeContextManager: + ctx = self.ctx.get() + thread_id = current_thread_id() + if thread_id not in ctx: + ctx[thread_id] = self.rt.create_context_manager() + return ctx[thread_id] + def upsert_tags(self, **tags: str) -> None: - mngr = self.ctx.get() + mngr = self.__ctx() mngr.upsert_tags(tags) def get(self) -> RuntimeContextManager: - return self.ctx.get() + return self.__ctx() def start_trace_sync( self, name: str, args: typing.Dict[str, typing.Any] ) -> BamlSpan: - mng = self.ctx.get() + mng = self.__ctx() return BamlSpan.new(self.rt, name, args, mng) def start_trace_async( self, name: str, args: typing.Dict[str, typing.Any] ) -> BamlSpan: - mng = self.ctx.get() + mng = self.__ctx() cln = mng.deep_clone() - self.ctx.set(cln) + self.ctx.set({current_thread_id(): cln}) return BamlSpan.new(self.rt, name, args, cln) def end_trace(self, span: BamlSpan, response: typing.Any) -> None: - span.finish(response, self.ctx.get()) + span.finish(response, self.__ctx()) def flush(self) -> None: self.rt.flush() diff --git a/integ-tests/python/app/test_functions.py b/integ-tests/python/app/test_functions.py index 8f87f0c17..54bc58d16 100644 --- a/integ-tests/python/app/test_functions.py +++ b/integ-tests/python/app/test_functions.py @@ -187,6 +187,48 @@ def test_tracing_sync(): res2 = sync_dummy_func("second-dummycall-arg") +def test_tracing_thread_pool(): + trace_thread_pool() + + +@pytest.mark.asyncio +async def test_tracing_thread_pool_async(): + await trace_thread_pool_async() + + +@pytest.mark.asyncio +async def test_tracing_async_gather(): + await trace_async_gather() + + +import concurrent.futures + + +@trace +def trace_thread_pool(): + with concurrent.futures.ThreadPoolExecutor() as executor: + # Create 10 tasks and execute them + futures = [ + executor.submit(parent_sync, "second-dummycall-arg") for _ in range(10) + ] + for future in concurrent.futures.as_completed(futures): + future.result() + + +@trace +async def trace_thread_pool_async(): + with concurrent.futures.ThreadPoolExecutor() as executor: + # Create 10 tasks and execute them + futures = [executor.submit(trace_async_gather) for _ in range(10)] + for future in concurrent.futures.as_completed(futures): + res = await future.result() + + +@trace +async def trace_async_gather(): + await asyncio.gather(*[async_dummy_func("second-dummycall-arg") for _ in range(10)]) + + @trace async def parent_async(myStr: str): set_tags(myKey="myVal") @@ -203,12 +245,21 @@ async def parent_async2(myStr: str): @trace def parent_sync(myStr: str): + import time + import random + + time.sleep(0.5 + random.random()) sync_dummy_func(myStr) return "hello world parentsync" +import asyncio +import random + + @trace async def async_dummy_func(myArgggg: str): + await asyncio.sleep(0.5 + random.random()) return "asyncDummyFuncOutput" diff --git a/typescript/fiddle-frontend/vercel-build.sh b/typescript/fiddle-frontend/vercel-build.sh index 594929dbf..cb95f4251 100644 --- a/typescript/fiddle-frontend/vercel-build.sh +++ b/typescript/fiddle-frontend/vercel-build.sh @@ -11,7 +11,7 @@ cd ../../engine/baml-schema-wasm # cargo install rustup target add wasm32-unknown-unknown cargo update -p wasm-bindgen -cargo install -f wasm-bindgen-cli@0.2.87 +cargo install -f wasm-bindgen-cli@0.2.92 # cargo build cd ../../typescript/fiddle-frontend