Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions backend/lib/github/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import typing
import warnings

import aiostream.stream as aiostream_stream
import pydantic

import lib.github.clients as github_clients
import lib.github.models as github_models
import lib.task.base as task_base
import lib.task.protocols
import lib.utils.asyncio as asyncio_utils
import lib.utils.pydantic as pydantic_utils

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -307,7 +307,7 @@ async def produce_events(self) -> typing.AsyncGenerator[task_base.Event, None]:
repositories = await self._get_repositories()

async with self._acquire_state() as state:
iterators = (
event_iterators = (
self._process_subtrigger_factory(
config=subtrigger_config,
state=state,
Expand All @@ -316,7 +316,7 @@ async def produce_events(self) -> typing.AsyncGenerator[task_base.Event, None]:
for subtrigger_config in self.config.subtriggers
)

async for event in asyncio_utils.GatherIterators(iterators):
async for event in aiostream_stream.merge(*event_iterators):
yield event

def _process_subtrigger_factory(
Expand Down Expand Up @@ -357,14 +357,16 @@ async def _process_all_repository_issue_created(
config.exclude_author |= await self._resolve_author_groups(config.exclude_author_group)
config.exclude_author_group = set()

async for event in asyncio_utils.GatherIterators(
event_iterators = (
self._process_repository_issue_created(
state=state,
config=config,
repository=repository.name,
)
for repository in repositories
):
)

async for event in aiostream_stream.merge(*event_iterators):
yield event

async def _process_repository_issue_created(
Expand Down Expand Up @@ -413,14 +415,16 @@ async def _process_all_repository_pr_created(
config.exclude_author |= await self._resolve_author_groups(config.exclude_author_group)
config.exclude_author_group = set()

async for event in asyncio_utils.GatherIterators(
event_iterators = (
self._process_repository_pr_created(
state=state,
config=config,
repository=repository.name,
)
for repository in repositories
):
)

async for event in aiostream_stream.merge(*event_iterators):
yield event

async def _process_repository_pr_created(
Expand Down Expand Up @@ -464,14 +468,16 @@ async def _process_all_repository_failed_workflow_run(
config: RepositoryFailedWorkflowRunSubtriggerConfig,
repositories: list[github_models.Repository],
) -> typing.AsyncGenerator[task_base.Event, None]:
async for event in asyncio_utils.GatherIterators(
event_iterators = (
self._process_repository_failed_workflow_run(
state=state,
config=config,
repository=repository.name,
)
for repository in repositories
):
)

async for event in aiostream_stream.merge(*event_iterators):
yield event

async def _process_repository_failed_workflow_run(
Expand Down
36 changes: 0 additions & 36 deletions backend/lib/utils/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,6 @@
import typing


class GatherIterators[T]:
def __init__(self, iterators: typing.Iterable[typing.AsyncIterator[T]]) -> None:
self._iterators: dict[asyncio.Task[T], typing.AsyncIterator[T]] = {} # {task: iterator}

for iterator in iterators:
self._add_iterator(iterator)

def _add_iterator(self, iterator: typing.AsyncIterator[T]) -> None:
coroutine = typing.cast(typing.Coroutine[None, None, T], iterator.__anext__())
task = asyncio.create_task(coroutine)
self._iterators[task] = iterator

def _delete_iterator_by_task(self, task: asyncio.Task[T]) -> None:
del self._iterators[task]

def _create_next_task(self, task: asyncio.Task[T]) -> None:
iterator = self._iterators.pop(task)
self._add_iterator(iterator)

@property
def _tasks(self) -> typing.Collection[asyncio.Task[T]]:
return self._iterators.keys()

async def __aiter__(self) -> typing.AsyncIterator[T]:
while self._iterators:
done, _ = await asyncio.wait(self._tasks, return_when=asyncio.FIRST_COMPLETED)
for task in done:
try:
yield task.result()
except StopAsyncIteration:
self._delete_iterator_by_task(task)
else:
self._create_next_task(task)


class TimeoutTimer:
def __init__(self, timeout: float = 0):
self._timeout = timeout
Expand Down Expand Up @@ -73,7 +38,6 @@ async def acquire_file_lock(path: str) -> typing.AsyncIterator[None]:


__all__ = [
"GatherIterators",
"TimeoutTimer",
"acquire_file_lock",
]
19 changes: 18 additions & 1 deletion backend/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ version = "0.0.1"
aiofile = "^3.8.8"
aiohttp = "^3.9.3"
aiojobs = "^1.2.1"
aiostream = "^0.7.0"
cron-converter = "^1.1.0"
gql = {extras = ["aiohttp"], version = "^3.5.0"}
graphql-core = "^3.2.3"
Expand Down