-
Notifications
You must be signed in to change notification settings - Fork 6
DOG-6551 : Add throttling helper class #505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b72f169
2e9d54d
42b7f38
3bfe4e1
ddcd830
9b14edb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| """ | ||
| Module containing the helper class for Throttling. | ||
| """ | ||
|
|
||
| from collections.abc import Callable, Generator | ||
| from contextlib import contextmanager | ||
| from functools import wraps | ||
| from threading import Semaphore | ||
| from typing import ParamSpec, TypeVar | ||
|
|
||
| P = ParamSpec("P") | ||
| T = TypeVar("T") | ||
|
|
||
|
|
||
| class TaskThrottle: | ||
| """ | ||
| A throttle to limit the number of concurrent tasks using semaphores. | ||
|
|
||
| Usage: | ||
| As a decorator: | ||
| >>> throttle = TaskThrottle(max_concurrent=5) | ||
| >>> @throttle.limit | ||
| ... def my_task(data): | ||
| ... # Process data | ||
| ... pass | ||
|
|
||
| As a context manager: | ||
| >>> throttle = TaskThrottle(max_concurrent=5) | ||
| >>> with throttle.lease(): | ||
| ... # Protected code block | ||
| ... pass | ||
| """ | ||
|
|
||
| def __init__(self, max_concurrent: int) -> None: | ||
| """ | ||
| Create a throttle with specified concurrency limit. | ||
|
|
||
| Args: | ||
| max_concurrent: Maximum number of tasks that can run concurrently | ||
| """ | ||
| if max_concurrent < 1: | ||
| raise ValueError("max_concurrent must be at least 1") | ||
| self._semaphore: Semaphore = Semaphore(max_concurrent) | ||
| self._max_concurrent: int = max_concurrent | ||
|
|
||
| def limit(self, func: Callable[P, T]) -> Callable[P, T]: | ||
| """ | ||
| Decorator to throttle a task function. | ||
| """ | ||
|
|
||
| @wraps(func) | ||
| def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: | ||
| with self.lease(): | ||
| return func(*args, **kwargs) | ||
|
|
||
|
Comment on lines
+46
to
+55
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In light of the previous comment, I would suggest you only implement this def limit_concurrency(max_concurrent):
semaphore = Semaphore(max_concurrent)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
with semaphore:
return func(*args, **kwargs)
return wrapper
return decoratorWhich would allow both single-use: @limit_concurrency(3)
def foo(data):
......and shared-pool limit: throttle = limit_concurrency(5)
@throttle
def foo(user_id):
...
@throttle
def bar(entry):
... |
||
| return wrapper | ||
|
|
||
| @contextmanager | ||
| def lease(self) -> Generator[None, None, None]: | ||
| """ | ||
| Context manager that acquires/releases a throttle slot. | ||
| """ | ||
| self._semaphore.acquire() | ||
| try: | ||
| yield | ||
| finally: | ||
| self._semaphore.release() | ||
|
Comment on lines
+58
to
+67
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By using the @contextmanager decorator, you can get the effect of a context manager without writing a full class. You, however, have written a full class, then I see no point in not implementing enter and exit dunder methods: def __enter__(self):
self._semaphore.acquire()
def __exit__(self, exc_type, exc_val, exc_tb):
self._semaphore.release()However, taking a step back, this is exactly the interface the semaphore already provides you, leading me to question why you need this in the first place? throttle = TaskThrottle(5)
with throttle.lease():
...
throttle = Semaphore(5)
with throttle:
... |
||
|
|
||
| @property | ||
| def max_concurrent(self) -> int: | ||
| """Get the configured concurrency limit.""" | ||
| return self._max_concurrent | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| import time | ||
| from concurrent.futures import ThreadPoolExecutor, as_completed | ||
| from threading import Lock | ||
|
|
||
| import pytest | ||
|
|
||
| from cognite.extractorutils.unstable.core.throttle import TaskThrottle | ||
|
|
||
|
|
||
| def test_throttle_initialization() -> None: | ||
| """Test throttle initialization with valid and invalid parameters.""" | ||
|
|
||
| throttle = TaskThrottle(max_concurrent=5) | ||
| assert throttle.max_concurrent == 5 | ||
|
|
||
| with pytest.raises(ValueError, match="max_concurrent must be at least 1"): | ||
| TaskThrottle(max_concurrent=0) | ||
|
|
||
| with pytest.raises(ValueError, match="max_concurrent must be at least 1"): | ||
| TaskThrottle(max_concurrent=-1) | ||
|
|
||
|
|
||
| def test_throttle_concurrency_limits() -> None: | ||
| max_concurrent = 3 | ||
| throttle = TaskThrottle(max_concurrent=max_concurrent) | ||
|
|
||
| concurrent_count = 0 | ||
| max_observed = 0 | ||
| lock = Lock() | ||
|
|
||
| def task(task_id: int) -> int: | ||
| nonlocal concurrent_count, max_observed | ||
|
|
||
| with throttle.lease(): | ||
| with lock: | ||
| concurrent_count += 1 | ||
| max_observed = max(max_observed, concurrent_count) | ||
|
|
||
| time.sleep(0.1) | ||
|
|
||
| with lock: | ||
| concurrent_count -= 1 | ||
|
|
||
| return task_id | ||
|
|
||
| with ThreadPoolExecutor(max_workers=10) as executor: | ||
| futures = [executor.submit(task, i) for i in range(10)] | ||
| results = [f.result() for f in as_completed(futures)] | ||
|
|
||
| assert len(results) == 10 | ||
| assert max_observed <= max_concurrent | ||
|
|
||
|
|
||
| def test_throttle_serial_execution() -> None: | ||
| lock = Lock() | ||
| throttle_serial = TaskThrottle(max_concurrent=1) | ||
| execution_order = [] | ||
|
|
||
| def serial_task(task_id: int) -> None: | ||
| with throttle_serial.lease(): | ||
| with lock: | ||
| execution_order.append(task_id) | ||
| time.sleep(0.05) | ||
| with lock: | ||
| execution_order.append(task_id) | ||
|
|
||
| with ThreadPoolExecutor(max_workers=3) as executor: | ||
| futures = [executor.submit(serial_task, i) for i in range(3)] | ||
| for f in as_completed(futures): | ||
| f.result() | ||
|
|
||
| for i in range(0, len(execution_order) - 1, 2): | ||
| task_id = execution_order[i] | ||
| assert execution_order[i + 1] == task_id | ||
|
|
||
|
|
||
| def test_throttle_high_concurrency() -> None: | ||
| lock = Lock() | ||
| throttle_high = TaskThrottle(max_concurrent=50) | ||
| completed = [] | ||
|
|
||
| def fast_task(task_id: int) -> int: | ||
| with throttle_high.lease(): | ||
| time.sleep(0.01) | ||
| with lock: | ||
| completed.append(task_id) | ||
| return task_id | ||
|
|
||
| num_tasks = 100 | ||
| with ThreadPoolExecutor(max_workers=num_tasks) as executor: | ||
| futures = [executor.submit(fast_task, i) for i in range(num_tasks)] | ||
| for f in as_completed(futures): | ||
| f.result() | ||
|
|
||
| assert len(completed) == num_tasks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where was this added?