diff --git a/backend/lib/github/triggers.py b/backend/lib/github/triggers.py index dbcc687..452f199 100644 --- a/backend/lib/github/triggers.py +++ b/backend/lib/github/triggers.py @@ -6,13 +6,13 @@ import typing import warnings +import aiostream.stream as aiostream_stream import pydantic import lib.github.clients as github_clients import lib.github.models as github_models import lib.task.base as task_base import lib.task.protocols -import lib.utils.asyncio as asyncio_utils import lib.utils.pydantic as pydantic_utils logger = logging.getLogger(__name__) @@ -307,7 +307,7 @@ async def produce_events(self) -> typing.AsyncGenerator[task_base.Event, None]: repositories = await self._get_repositories() async with self._acquire_state() as state: - iterators = ( + event_iterators = ( self._process_subtrigger_factory( config=subtrigger_config, state=state, @@ -316,7 +316,7 @@ async def produce_events(self) -> typing.AsyncGenerator[task_base.Event, None]: for subtrigger_config in self.config.subtriggers ) - async for event in asyncio_utils.GatherIterators(iterators): + async for event in aiostream_stream.merge(*event_iterators): yield event def _process_subtrigger_factory( @@ -357,14 +357,16 @@ async def _process_all_repository_issue_created( config.exclude_author |= await self._resolve_author_groups(config.exclude_author_group) config.exclude_author_group = set() - async for event in asyncio_utils.GatherIterators( + event_iterators = ( self._process_repository_issue_created( state=state, config=config, repository=repository.name, ) for repository in repositories - ): + ) + + async for event in aiostream_stream.merge(*event_iterators): yield event async def _process_repository_issue_created( @@ -413,14 +415,16 @@ async def _process_all_repository_pr_created( config.exclude_author |= await self._resolve_author_groups(config.exclude_author_group) config.exclude_author_group = set() - async for event in asyncio_utils.GatherIterators( + event_iterators = ( self._process_repository_pr_created( state=state, config=config, repository=repository.name, ) for repository in repositories - ): + ) + + async for event in aiostream_stream.merge(*event_iterators): yield event async def _process_repository_pr_created( @@ -464,14 +468,16 @@ async def _process_all_repository_failed_workflow_run( config: RepositoryFailedWorkflowRunSubtriggerConfig, repositories: list[github_models.Repository], ) -> typing.AsyncGenerator[task_base.Event, None]: - async for event in asyncio_utils.GatherIterators( + event_iterators = ( self._process_repository_failed_workflow_run( state=state, config=config, repository=repository.name, ) for repository in repositories - ): + ) + + async for event in aiostream_stream.merge(*event_iterators): yield event async def _process_repository_failed_workflow_run( diff --git a/backend/lib/utils/asyncio.py b/backend/lib/utils/asyncio.py index 1db35be..f227af6 100644 --- a/backend/lib/utils/asyncio.py +++ b/backend/lib/utils/asyncio.py @@ -6,41 +6,6 @@ import typing -class GatherIterators[T]: - def __init__(self, iterators: typing.Iterable[typing.AsyncIterator[T]]) -> None: - self._iterators: dict[asyncio.Task[T], typing.AsyncIterator[T]] = {} # {task: iterator} - - for iterator in iterators: - self._add_iterator(iterator) - - def _add_iterator(self, iterator: typing.AsyncIterator[T]) -> None: - coroutine = typing.cast(typing.Coroutine[None, None, T], iterator.__anext__()) - task = asyncio.create_task(coroutine) - self._iterators[task] = iterator - - def _delete_iterator_by_task(self, task: asyncio.Task[T]) -> None: - del self._iterators[task] - - def _create_next_task(self, task: asyncio.Task[T]) -> None: - iterator = self._iterators.pop(task) - self._add_iterator(iterator) - - @property - def _tasks(self) -> typing.Collection[asyncio.Task[T]]: - return self._iterators.keys() - - async def __aiter__(self) -> typing.AsyncIterator[T]: - while self._iterators: - done, _ = await asyncio.wait(self._tasks, return_when=asyncio.FIRST_COMPLETED) - for task in done: - try: - yield task.result() - except StopAsyncIteration: - self._delete_iterator_by_task(task) - else: - self._create_next_task(task) - - class TimeoutTimer: def __init__(self, timeout: float = 0): self._timeout = timeout @@ -73,7 +38,6 @@ async def acquire_file_lock(path: str) -> typing.AsyncIterator[None]: __all__ = [ - "GatherIterators", "TimeoutTimer", "acquire_file_lock", ] diff --git a/backend/poetry.lock b/backend/poetry.lock index c035cb3..9f255ed 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -161,6 +161,23 @@ files = [ frozenlist = ">=1.1.0" typing-extensions = {version = ">=4.2", markers = "python_version < \"3.13\""} +[[package]] +name = "aiostream" +version = "0.7.0" +description = "Generator-based operators for asynchronous iteration" +optional = false +python-versions = ">=3.9" +files = [ + {file = "aiostream-0.7.0-py3-none-any.whl", hash = "sha256:17e52dc10fdf98c4b5296c7f3569511c13bfaf12da31c072bad580817e69d705"}, + {file = "aiostream-0.7.0.tar.gz", hash = "sha256:5ab4acd44ef5f583b6488c32ade465f43c3d7b0df039f1ee49dfb1fd1e255e02"}, +] + +[package.dependencies] +typing-extensions = "*" + +[package.extras] +dev = ["pytest", "pytest-asyncio", "pytest-cov"] + [[package]] name = "annotated-types" version = "0.7.0" @@ -1604,4 +1621,4 @@ propcache = ">=0.2.1" [metadata] lock-version = "2.0" python-versions = "~3.12" -content-hash = "c86250d710ded1033f154d07ae11da3ebe13d658582f0c023c86b94dbb314e30" +content-hash = "869ddedfdbfdcb5d4fe1c7e8a25add9dafd00369335d4b12d2331d830d451986" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 2609742..2740612 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -36,6 +36,7 @@ version = "0.0.1" aiofile = "^3.8.8" aiohttp = "^3.9.3" aiojobs = "^1.2.1" +aiostream = "^0.7.0" cron-converter = "^1.1.0" gql = {extras = ["aiohttp"], version = "^3.5.0"} graphql-core = "^3.2.3"