Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify local's implementation #481

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 28 additions & 83 deletions asgiref/local.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,46 @@
import asyncio
import contextlib
import contextvars
import threading
from typing import Any, Dict, Union
from typing import Any, Union


class _CVar:
"""Storage utility for Local."""

def __init__(self) -> None:
self._data: "contextvars.ContextVar[Dict[str, Any]]" = contextvars.ContextVar(
"asgiref.local"
)
self._thread_lock = threading.RLock()
self._data: dict[str, contextvars.ContextVar[Any]] = {}

def __getattr__(self, key: str) -> Any:
with self._thread_lock:
try:
var = self._data[key]
except KeyError:
raise AttributeError(f"{self!r} object has no attribute {key!r}")

def __getattr__(self, key):
storage_object = self._data.get({})
try:
return storage_object[key]
except KeyError:
return var.get()
except LookupError:
raise AttributeError(f"{self!r} object has no attribute {key!r}")

def __setattr__(self, key: str, value: Any) -> None:
if key == "_data":
if key in ("_data", "_thread_lock"):
return super().__setattr__(key, value)

storage_object = self._data.get({})
storage_object[key] = value
self._data.set(storage_object)
with self._thread_lock:
var = self._data.get(key)
if var is None:
self._data[key] = var = contextvars.ContextVar(key)
var.set(value)

def __delattr__(self, key: str) -> None:
storage_object = self._data.get({})
if key in storage_object:
del storage_object[key]
self._data.set(storage_object)
else:
raise AttributeError(f"{self!r} object has no attribute {key!r}")
with self._thread_lock:
if key in self._data:
del self._data[key]
else:
raise AttributeError(f"{self!r} object has no attribute {key!r}")


class Local:
def Local(thread_critical: bool = False) -> Union[threading.local, _CVar]:
"""Local storage for async tasks.

This is a namespace object (similar to `threading.local`) where data is
Expand All @@ -64,65 +67,7 @@ class Local:

Unlike plain `contextvars` objects, this utility is threadsafe.
"""

def __init__(self, thread_critical: bool = False) -> None:
self._thread_critical = thread_critical
self._thread_lock = threading.RLock()

self._storage: "Union[threading.local, _CVar]"

if thread_critical:
# Thread-local storage
self._storage = threading.local()
else:
# Contextvar storage
self._storage = _CVar()

@contextlib.contextmanager
def _lock_storage(self):
# Thread safe access to storage
if self._thread_critical:
try:
# this is a test for are we in a async or sync
# thread - will raise RuntimeError if there is
# no current loop
asyncio.get_running_loop()
except RuntimeError:
# We are in a sync thread, the storage is
# just the plain thread local (i.e, "global within
# this thread" - it doesn't matter where you are
# in a call stack you see the same storage)
yield self._storage
else:
# We are in an async thread - storage is still
# local to this thread, but additionally should
# behave like a context var (is only visible with
# the same async call stack)

# Ensure context exists in the current thread
if not hasattr(self._storage, "cvar"):
self._storage.cvar = _CVar()

# self._storage is a thread local, so the members
# can't be accessed in another thread (we don't
# need any locks)
yield self._storage.cvar
else:
# Lock for thread_critical=False as other threads
# can access the exact same storage object
with self._thread_lock:
yield self._storage

def __getattr__(self, key):
with self._lock_storage() as storage:
return getattr(storage, key)

def __setattr__(self, key, value):
if key in ("_local", "_storage", "_thread_critical", "_thread_lock"):
return super().__setattr__(key, value)
with self._lock_storage() as storage:
setattr(storage, key, value)

def __delattr__(self, key):
with self._lock_storage() as storage:
delattr(storage, key)
if thread_critical:
return threading.local()
else:
return _CVar()
37 changes: 37 additions & 0 deletions tests/test_local.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import gc
import threading
from threading import Thread

import pytest

Expand Down Expand Up @@ -338,3 +339,39 @@ async def async_function():
# inner value was set inside a new async context, meaning that
# we do not see it, as context vars don't propagate up the stack
assert not hasattr(test_local_not_tc, "test_value")


def test_visibility_thread_asgiref() -> None:
"""Check visibility with subthreads."""
test_local = Local()
test_local.value = 0

def _test() -> None:
# Local() is cleared when changing thread
assert not hasattr(test_local, "value")
setattr(test_local, "value", 1)
assert test_local.value == 1

thread = Thread(target=_test)
thread.start()
thread.join()

assert test_local.value == 0


@pytest.mark.asyncio
async def test_visibility_task() -> None:
"""Check visibility with asyncio tasks."""
test_local = Local()
test_local.value = 0

async def _test() -> None:
# Local is inherited when changing task
assert test_local.value == 0
test_local.value = 1
assert test_local.value == 1

await asyncio.create_task(_test())

# Changes should not leak to the caller
assert test_local.value == 0
Loading