Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions orchestrator/graphql/resolvers/helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from collections.abc import Sequence
from functools import wraps
from typing import Callable, Coroutine

import structlog
from sqlalchemy import CompoundSelect, Select, select
from sqlalchemy.orm.strategy_options import _AbstractLoad
from starlette.concurrency import run_in_threadpool

from orchestrator.db import db
from orchestrator.db.database import BaseModel

logger = structlog.get_logger(__name__)


def rows_from_statement(
stmt: Select | CompoundSelect,
Expand All @@ -19,3 +25,12 @@ def rows_from_statement(
result = db.session.scalars(from_stmt)
uresult = result.unique() if unique else result
return uresult.all()


def make_async(f: Callable): # type: ignore
@wraps(f)
async def wrapper(*args, **kwargs) -> Coroutine: # type: ignore
logger.debug(f"**async, calling fn {f.__name__}")
return await run_in_threadpool(f, *args, **kwargs)

return wrapper
8 changes: 5 additions & 3 deletions orchestrator/graphql/resolvers/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from orchestrator.db.sorting import Sort
from orchestrator.db.sorting.process import process_sort_fields, sort_processes
from orchestrator.graphql.pagination import Connection
from orchestrator.graphql.resolvers.helpers import rows_from_statement
from orchestrator.graphql.resolvers.helpers import make_async, rows_from_statement
from orchestrator.graphql.schemas.process import ProcessType
from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
from orchestrator.graphql.utils import (
Expand Down Expand Up @@ -55,7 +55,8 @@ def _enrich_process(process: ProcessTable, with_details: bool = False) -> Proces
return ProcessSchema(**process_data)


async def resolve_process(info: OrchestratorInfo, process_id: UUID) -> ProcessType | None:
@make_async
def resolve_process(info: OrchestratorInfo, process_id: UUID) -> ProcessType | None:
query_loaders = get_query_loaders_for_gql_fields(ProcessTable, info)
stmt = select(ProcessTable).options(*query_loaders).where(ProcessTable.process_id == process_id)
if process := db.session.scalar(stmt):
Expand All @@ -64,7 +65,8 @@ async def resolve_process(info: OrchestratorInfo, process_id: UUID) -> ProcessTy
return None


async def resolve_processes(
@make_async
def resolve_processes(
info: OrchestratorInfo,
filter_by: list[GraphqlFilter] | None = None,
sort_by: list[GraphqlSort] | None = None,
Expand Down
5 changes: 3 additions & 2 deletions orchestrator/graphql/resolvers/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from orchestrator.db.sorting import Sort
from orchestrator.db.sorting.product import product_sort_fields, sort_products
from orchestrator.graphql.pagination import Connection
from orchestrator.graphql.resolvers.helpers import rows_from_statement
from orchestrator.graphql.resolvers.helpers import make_async, rows_from_statement
from orchestrator.graphql.schemas.product import ProductType
from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
Expand All @@ -19,7 +19,8 @@
logger = structlog.get_logger(__name__)


async def resolve_products(
@make_async
def resolve_products(
info: OrchestratorInfo,
filter_by: list[GraphqlFilter] | None = None,
sort_by: list[GraphqlSort] | None = None,
Expand Down
5 changes: 3 additions & 2 deletions orchestrator/graphql/resolvers/product_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from orchestrator.db.sorting import Sort
from orchestrator.db.sorting.product_block import product_block_sort_fields, sort_product_blocks
from orchestrator.graphql.pagination import Connection
from orchestrator.graphql.resolvers.helpers import rows_from_statement
from orchestrator.graphql.resolvers.helpers import make_async, rows_from_statement
from orchestrator.graphql.schemas.product_block import ProductBlock
from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
Expand All @@ -23,7 +23,8 @@
logger = structlog.get_logger(__name__)


async def resolve_product_blocks(
@make_async
def resolve_product_blocks(
info: OrchestratorInfo,
filter_by: list[GraphqlFilter] | None = None,
sort_by: list[GraphqlSort] | None = None,
Expand Down
5 changes: 3 additions & 2 deletions orchestrator/graphql/resolvers/resource_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from orchestrator.db.sorting import Sort
from orchestrator.db.sorting.resource_type import resource_type_sort_fields, sort_resource_types
from orchestrator.graphql.pagination import Connection
from orchestrator.graphql.resolvers.helpers import rows_from_statement
from orchestrator.graphql.resolvers.helpers import make_async, rows_from_statement
from orchestrator.graphql.schemas.resource_type import ResourceType
from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
Expand All @@ -23,7 +23,8 @@
logger = structlog.get_logger(__name__)


async def resolve_resource_types(
@make_async
def resolve_resource_types(
info: OrchestratorInfo,
filter_by: list[GraphqlFilter] | None = None,
sort_by: list[GraphqlSort] | None = None,
Expand Down
4 changes: 3 additions & 1 deletion orchestrator/graphql/resolvers/scheduled_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from orchestrator.db.filters import Filter
from orchestrator.db.sorting import Sort
from orchestrator.graphql.pagination import Connection
from orchestrator.graphql.resolvers.helpers import make_async
from orchestrator.graphql.schemas.scheduled_task import ScheduledTaskGraphql
from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
from orchestrator.graphql.utils import create_resolver_error_handler, to_graphql_result_page
Expand All @@ -12,7 +13,8 @@
logger = structlog.get_logger(__name__)


async def resolve_scheduled_tasks(
@make_async
def resolve_scheduled_tasks(
info: OrchestratorInfo,
filter_by: list[GraphqlFilter] | None = None,
sort_by: list[GraphqlSort] | None = None,
Expand Down
2 changes: 2 additions & 0 deletions orchestrator/graphql/resolvers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from oauth2_lib.strawberry import authenticated_mutation_field
from orchestrator.api.api_v1.endpoints.settings import generate_engine_status_response
from orchestrator.graphql.resolvers.helpers import make_async
from orchestrator.graphql.schemas.errors import Error
from orchestrator.graphql.schemas.settings import (
CACHE_FLUSH_OPTIONS,
Expand All @@ -27,6 +28,7 @@


# Queries
@make_async
def resolve_settings(info: OrchestratorInfo) -> StatusType:
selected_fields = get_selected_fields(info)

Expand Down
8 changes: 5 additions & 3 deletions orchestrator/graphql/resolvers/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pydantic.alias_generators import to_camel as to_lower_camel
from sqlalchemy import Select, func, select
from sqlalchemy.orm import contains_eager
from starlette.concurrency import run_in_threadpool
from strawberry.experimental.pydantic.conversion_types import StrawberryTypeFromPydantic

from nwastdlib.asyncio import gather_nice
Expand Down Expand Up @@ -101,7 +102,7 @@ async def format_subscription(info: OrchestratorInfo, subscription: Subscription
async def resolve_subscription(info: OrchestratorInfo, id: UUID) -> SubscriptionInterface | None:
stmt = select(SubscriptionTable).where(SubscriptionTable.subscription_id == id)

if subscription := db.session.scalar(stmt):
if subscription := await run_in_threadpool(db.session.scalar, stmt):
return await format_subscription(info, subscription)
return None

Expand Down Expand Up @@ -141,12 +142,13 @@ async def resolve_subscriptions(
stmt = filter_by_query_string(stmt, query)

stmt = cast(Select, sort_subscriptions(stmt, pydantic_sort_by, _error_handler))
total = db.session.scalar(select(func.count()).select_from(stmt.subquery()))
total = await run_in_threadpool(db.session.scalar, select(func.count()).select_from(stmt.subquery()))
stmt = apply_range_to_statement(stmt, after, after + first + 1)

graphql_subscriptions: list[SubscriptionInterface] = []
if is_querying_page_data(info):
subscriptions = db.session.scalars(stmt).all()
scalars = await run_in_threadpool(db.session.scalars, stmt)
subscriptions = scalars.all()
graphql_subscriptions = list(await gather_nice((format_subscription(info, p) for p in subscriptions))) # type: ignore
logger.info("Resolve subscriptions", filter_by=filter_by, total=total)

Expand Down
2 changes: 2 additions & 0 deletions orchestrator/graphql/resolvers/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from structlog import get_logger

from orchestrator import __version__
from orchestrator.graphql.resolvers.helpers import make_async
from orchestrator.graphql.schemas.version import VersionType
from orchestrator.graphql.types import OrchestratorInfo
from orchestrator.graphql.utils import create_resolver_error_handler
Expand All @@ -11,6 +12,7 @@
VERSIONS = [f"orchestrator-core: {__version__}"]


@make_async
def resolve_version(info: OrchestratorInfo) -> VersionType | None:
logger.debug("resolve_version() called")
_error_handler = create_resolver_error_handler(info)
Expand Down
5 changes: 3 additions & 2 deletions orchestrator/graphql/resolvers/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from orchestrator.db.sorting import Sort
from orchestrator.db.sorting.workflow import sort_workflows, workflow_sort_fields
from orchestrator.graphql.pagination import Connection
from orchestrator.graphql.resolvers.helpers import rows_from_statement
from orchestrator.graphql.resolvers.helpers import make_async, rows_from_statement
from orchestrator.graphql.schemas.workflow import Workflow
from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
from orchestrator.graphql.utils import create_resolver_error_handler, is_querying_page_data, to_graphql_result_page
Expand All @@ -19,7 +19,8 @@
logger = structlog.get_logger(__name__)


async def resolve_workflows(
@make_async
def resolve_workflows(
info: OrchestratorInfo,
filter_by: list[GraphqlFilter] | None = None,
sort_by: list[GraphqlSort] | None = None,
Expand Down