diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a3f8510..8f8098cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,12 @@ Changes are grouped as follows - `Fixed` for any bug fixes. - `Security` in case of vulnerabilities. +## 7.12.0 + +### Added +* In the `unstable` package: Add TaskThrottle helper class for limiting concurrent task execution with decorator and context manager support +* Add support for environment variable interpolation in keyvault config + ## 7.11.5 ### Fixed diff --git a/cognite/extractorutils/__init__.py b/cognite/extractorutils/__init__.py index d2c2cd8c..a8cec2e1 100644 --- a/cognite/extractorutils/__init__.py +++ b/cognite/extractorutils/__init__.py @@ -16,7 +16,7 @@ Cognite extractor utils is a Python package that simplifies the development of new extractors. """ -__version__ = "7.11.5" +__version__ = "7.12.0" from .base import Extractor __all__ = ["Extractor"] diff --git a/cognite/extractorutils/unstable/core/throttle.py b/cognite/extractorutils/unstable/core/throttle.py new file mode 100644 index 00000000..f35d85b1 --- /dev/null +++ b/cognite/extractorutils/unstable/core/throttle.py @@ -0,0 +1,72 @@ +""" +Module containing the helper class for Throttling. +""" + +from collections.abc import Callable, Generator +from contextlib import contextmanager +from functools import wraps +from threading import Semaphore +from typing import ParamSpec, TypeVar + +P = ParamSpec("P") +T = TypeVar("T") + + +class TaskThrottle: + """ + A throttle to limit the number of concurrent tasks using semaphores. + + Usage: + As a decorator: + >>> throttle = TaskThrottle(max_concurrent=5) + >>> @throttle.limit + ... def my_task(data): + ... # Process data + ... pass + + As a context manager: + >>> throttle = TaskThrottle(max_concurrent=5) + >>> with throttle.lease(): + ... # Protected code block + ... pass + """ + + def __init__(self, max_concurrent: int) -> None: + """ + Create a throttle with specified concurrency limit. + + Args: + max_concurrent: Maximum number of tasks that can run concurrently + """ + if max_concurrent < 1: + raise ValueError("max_concurrent must be at least 1") + self._semaphore: Semaphore = Semaphore(max_concurrent) + self._max_concurrent: int = max_concurrent + + def limit(self, func: Callable[P, T]) -> Callable[P, T]: + """ + Decorator to throttle a task function. + """ + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + with self.lease(): + return func(*args, **kwargs) + + return wrapper + + @contextmanager + def lease(self) -> Generator[None, None, None]: + """ + Context manager that acquires/releases a throttle slot. + """ + self._semaphore.acquire() + try: + yield + finally: + self._semaphore.release() + + @property + def max_concurrent(self) -> int: + """Get the configured concurrency limit.""" + return self._max_concurrent diff --git a/pyproject.toml b/pyproject.toml index 7566ee58..b9b821bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cognite-extractor-utils" -version = "7.11.5" +version = "7.12.0" description = "Utilities for easier development of extractors for CDF" authors = [ {name = "Mathias Lohne", email = "mathias.lohne@cognite.com"} diff --git a/tests/test_unstable/test_throttle.py b/tests/test_unstable/test_throttle.py new file mode 100644 index 00000000..18fd69a8 --- /dev/null +++ b/tests/test_unstable/test_throttle.py @@ -0,0 +1,95 @@ +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from threading import Lock + +import pytest + +from cognite.extractorutils.unstable.core.throttle import TaskThrottle + + +def test_throttle_initialization() -> None: + """Test throttle initialization with valid and invalid parameters.""" + + throttle = TaskThrottle(max_concurrent=5) + assert throttle.max_concurrent == 5 + + with pytest.raises(ValueError, match="max_concurrent must be at least 1"): + TaskThrottle(max_concurrent=0) + + with pytest.raises(ValueError, match="max_concurrent must be at least 1"): + TaskThrottle(max_concurrent=-1) + + +def test_throttle_concurrency_limits() -> None: + max_concurrent = 3 + throttle = TaskThrottle(max_concurrent=max_concurrent) + + concurrent_count = 0 + max_observed = 0 + lock = Lock() + + def task(task_id: int) -> int: + nonlocal concurrent_count, max_observed + + with throttle.lease(): + with lock: + concurrent_count += 1 + max_observed = max(max_observed, concurrent_count) + + time.sleep(0.1) + + with lock: + concurrent_count -= 1 + + return task_id + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(task, i) for i in range(10)] + results = [f.result() for f in as_completed(futures)] + + assert len(results) == 10 + assert max_observed <= max_concurrent + + +def test_throttle_serial_execution() -> None: + lock = Lock() + throttle_serial = TaskThrottle(max_concurrent=1) + execution_order = [] + + def serial_task(task_id: int) -> None: + with throttle_serial.lease(): + with lock: + execution_order.append(task_id) + time.sleep(0.05) + with lock: + execution_order.append(task_id) + + with ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(serial_task, i) for i in range(3)] + for f in as_completed(futures): + f.result() + + for i in range(0, len(execution_order) - 1, 2): + task_id = execution_order[i] + assert execution_order[i + 1] == task_id + + +def test_throttle_high_concurrency() -> None: + lock = Lock() + throttle_high = TaskThrottle(max_concurrent=50) + completed = [] + + def fast_task(task_id: int) -> int: + with throttle_high.lease(): + time.sleep(0.01) + with lock: + completed.append(task_id) + return task_id + + num_tasks = 100 + with ThreadPoolExecutor(max_workers=num_tasks) as executor: + futures = [executor.submit(fast_task, i) for i in range(num_tasks)] + for f in as_completed(futures): + f.result() + + assert len(completed) == num_tasks