Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ Changes are grouped as follows
- `Fixed` for any bug fixes.
- `Security` in case of vulnerabilities.

## 7.12.0

### Added
* In the `unstable` package: Add TaskThrottle helper class for limiting concurrent task execution with decorator and context manager support
* Add support for environment variable interpolation in keyvault config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where was this added?

* Add support for environment variable interpolation in keyvault config


## 7.11.5

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion cognite/extractorutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Cognite extractor utils is a Python package that simplifies the development of new extractors.
"""

__version__ = "7.11.5"
__version__ = "7.12.0"
from .base import Extractor

__all__ = ["Extractor"]
72 changes: 72 additions & 0 deletions cognite/extractorutils/unstable/core/throttle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Module containing the helper class for Throttling.
"""

from collections.abc import Callable, Generator
from contextlib import contextmanager
from functools import wraps
from threading import Semaphore
from typing import ParamSpec, TypeVar

P = ParamSpec("P")
T = TypeVar("T")


class TaskThrottle:
"""
A throttle to limit the number of concurrent tasks using semaphores.

Usage:
As a decorator:
>>> throttle = TaskThrottle(max_concurrent=5)
>>> @throttle.limit
... def my_task(data):
... # Process data
... pass

As a context manager:
>>> throttle = TaskThrottle(max_concurrent=5)
>>> with throttle.lease():
... # Protected code block
... pass
"""

def __init__(self, max_concurrent: int) -> None:
"""
Create a throttle with specified concurrency limit.

Args:
max_concurrent: Maximum number of tasks that can run concurrently
"""
if max_concurrent < 1:
raise ValueError("max_concurrent must be at least 1")
self._semaphore: Semaphore = Semaphore(max_concurrent)
self._max_concurrent: int = max_concurrent

def limit(self, func: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to throttle a task function.
"""

@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
with self.lease():
return func(*args, **kwargs)

Comment on lines +46 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In light of the previous comment, I would suggest you only implement this limit functionality:

def limit_concurrency(max_concurrent):
    semaphore = Semaphore(max_concurrent)

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            with semaphore:
                return func(*args, **kwargs)
        return wrapper

    return decorator

Which would allow both single-use:

@limit_concurrency(3)
def foo(data):
    ...

...and shared-pool limit:

throttle = limit_concurrency(5)

@throttle
def foo(user_id):
    ...

@throttle
def bar(entry):
    ...

return wrapper

@contextmanager
def lease(self) -> Generator[None, None, None]:
"""
Context manager that acquires/releases a throttle slot.
"""
self._semaphore.acquire()
try:
yield
finally:
self._semaphore.release()
Comment on lines +58 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By using the @contextmanager decorator, you can get the effect of a context manager without writing a full class. You, however, have written a full class, then I see no point in not implementing enter and exit dunder methods:

def __enter__(self):
    self._semaphore.acquire()

def __exit__(self, exc_type, exc_val, exc_tb):
    self._semaphore.release()

However, taking a step back, this is exactly the interface the semaphore already provides you, leading me to question why you need this in the first place?

throttle = TaskThrottle(5)
with throttle.lease():
    ...

throttle = Semaphore(5)
with throttle:
    ...


@property
def max_concurrent(self) -> int:
"""Get the configured concurrency limit."""
return self._max_concurrent
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "cognite-extractor-utils"
version = "7.11.5"
version = "7.12.0"
description = "Utilities for easier development of extractors for CDF"
authors = [
{name = "Mathias Lohne", email = "mathias.lohne@cognite.com"}
Expand Down
95 changes: 95 additions & 0 deletions tests/test_unstable/test_throttle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock

import pytest

from cognite.extractorutils.unstable.core.throttle import TaskThrottle


def test_throttle_initialization() -> None:
"""Test throttle initialization with valid and invalid parameters."""

throttle = TaskThrottle(max_concurrent=5)
assert throttle.max_concurrent == 5

with pytest.raises(ValueError, match="max_concurrent must be at least 1"):
TaskThrottle(max_concurrent=0)

with pytest.raises(ValueError, match="max_concurrent must be at least 1"):
TaskThrottle(max_concurrent=-1)


def test_throttle_concurrency_limits() -> None:
max_concurrent = 3
throttle = TaskThrottle(max_concurrent=max_concurrent)

concurrent_count = 0
max_observed = 0
lock = Lock()

def task(task_id: int) -> int:
nonlocal concurrent_count, max_observed

with throttle.lease():
with lock:
concurrent_count += 1
max_observed = max(max_observed, concurrent_count)

time.sleep(0.1)

with lock:
concurrent_count -= 1

return task_id

with ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(task, i) for i in range(10)]
results = [f.result() for f in as_completed(futures)]

assert len(results) == 10
assert max_observed <= max_concurrent


def test_throttle_serial_execution() -> None:
lock = Lock()
throttle_serial = TaskThrottle(max_concurrent=1)
execution_order = []

def serial_task(task_id: int) -> None:
with throttle_serial.lease():
with lock:
execution_order.append(task_id)
time.sleep(0.05)
with lock:
execution_order.append(task_id)

with ThreadPoolExecutor(max_workers=3) as executor:
futures = [executor.submit(serial_task, i) for i in range(3)]
for f in as_completed(futures):
f.result()

for i in range(0, len(execution_order) - 1, 2):
task_id = execution_order[i]
assert execution_order[i + 1] == task_id


def test_throttle_high_concurrency() -> None:
lock = Lock()
throttle_high = TaskThrottle(max_concurrent=50)
completed = []

def fast_task(task_id: int) -> int:
with throttle_high.lease():
time.sleep(0.01)
with lock:
completed.append(task_id)
return task_id

num_tasks = 100
with ThreadPoolExecutor(max_workers=num_tasks) as executor:
futures = [executor.submit(fast_task, i) for i in range(num_tasks)]
for f in as_completed(futures):
f.result()

assert len(completed) == num_tasks
Loading