Skip to content

Commit

Permalink
Fix issue with tracing in threadpool executors (#639)
Browse files Browse the repository at this point in the history
* When using AsyncContextVars + ThreadPoolExecutors, things can get
quite hairy. In BAML we treat threadpool executors as top level threads
with no prior context (i.e. things like tags and such are reset and must
be re-configured)
* See:
https://kobybass.medium.com/python-contextvars-and-multithreading-faa33dbe953d
  • Loading branch information
hellovai authored Jun 3, 2024
1 parent a6aed34 commit 171cbfe
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/vscode_ext.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
working-directory: engine/baml-schema-wasm

- name: Install Bindgen
run: cargo install -f [email protected].87
run: cargo install -f [email protected].92
working-directory: engine/baml-schema-wasm

- name: Install Rust
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,59 @@
import typing
from .baml_py import RuntimeContextManager, BamlRuntime, BamlSpan
import atexit
import threading

F = typing.TypeVar("F", bound=typing.Callable[..., typing.Any])


# See this article about why we need to track for every thread:
# https://kobybass.medium.com/python-contextvars-and-multithreading-faa33dbe953d
RTContextVar = contextvars.ContextVar[typing.Dict[int, RuntimeContextManager]]


def current_thread_id() -> int:
current_thread = threading.current_thread()
return current_thread.native_id or 0


class CtxManager:
def __init__(self, rt: BamlRuntime):
self.rt = rt
self.ctx = contextvars.ContextVar[RuntimeContextManager](
"baml_ctx", default=rt.create_context_manager()
self.ctx = contextvars.ContextVar[typing.Dict[int, RuntimeContextManager]](
"baml_ctx", default={current_thread_id(): rt.create_context_manager()}
)
atexit.register(self.rt.flush)

def __ctx(self) -> RuntimeContextManager:
ctx = self.ctx.get()
thread_id = current_thread_id()
if thread_id not in ctx:
ctx[thread_id] = self.rt.create_context_manager()
return ctx[thread_id]

def upsert_tags(self, **tags: str) -> None:
mngr = self.ctx.get()
mngr = self.__ctx()
mngr.upsert_tags(tags)

def get(self) -> RuntimeContextManager:
return self.ctx.get()
return self.__ctx()

def start_trace_sync(
self, name: str, args: typing.Dict[str, typing.Any]
) -> BamlSpan:
mng = self.ctx.get()
mng = self.__ctx()
return BamlSpan.new(self.rt, name, args, mng)

def start_trace_async(
self, name: str, args: typing.Dict[str, typing.Any]
) -> BamlSpan:
mng = self.ctx.get()
mng = self.__ctx()
cln = mng.deep_clone()
self.ctx.set(cln)
self.ctx.set({current_thread_id(): cln})
return BamlSpan.new(self.rt, name, args, cln)

def end_trace(self, span: BamlSpan, response: typing.Any) -> None:
span.finish(response, self.ctx.get())
span.finish(response, self.__ctx())

def flush(self) -> None:
self.rt.flush()
Expand Down
51 changes: 51 additions & 0 deletions integ-tests/python/app/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,48 @@ def test_tracing_sync():
res2 = sync_dummy_func("second-dummycall-arg")


def test_tracing_thread_pool():
trace_thread_pool()


@pytest.mark.asyncio
async def test_tracing_thread_pool_async():
await trace_thread_pool_async()


@pytest.mark.asyncio
async def test_tracing_async_gather():
await trace_async_gather()


import concurrent.futures


@trace
def trace_thread_pool():
with concurrent.futures.ThreadPoolExecutor() as executor:
# Create 10 tasks and execute them
futures = [
executor.submit(parent_sync, "second-dummycall-arg") for _ in range(10)
]
for future in concurrent.futures.as_completed(futures):
future.result()


@trace
async def trace_thread_pool_async():
with concurrent.futures.ThreadPoolExecutor() as executor:
# Create 10 tasks and execute them
futures = [executor.submit(trace_async_gather) for _ in range(10)]
for future in concurrent.futures.as_completed(futures):
res = await future.result()


@trace
async def trace_async_gather():
await asyncio.gather(*[async_dummy_func("second-dummycall-arg") for _ in range(10)])


@trace
async def parent_async(myStr: str):
set_tags(myKey="myVal")
Expand All @@ -203,12 +245,21 @@ async def parent_async2(myStr: str):

@trace
def parent_sync(myStr: str):
import time
import random

time.sleep(0.5 + random.random())
sync_dummy_func(myStr)
return "hello world parentsync"


import asyncio
import random


@trace
async def async_dummy_func(myArgggg: str):
await asyncio.sleep(0.5 + random.random())
return "asyncDummyFuncOutput"


Expand Down
2 changes: 1 addition & 1 deletion typescript/fiddle-frontend/vercel-build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ cd ../../engine/baml-schema-wasm
# cargo install
rustup target add wasm32-unknown-unknown
cargo update -p wasm-bindgen
cargo install -f [email protected].87
cargo install -f [email protected].92

# cargo build
cd ../../typescript/fiddle-frontend
Expand Down

0 comments on commit 171cbfe

Please sign in to comment.