Skip to content

Add RBAC info to process and subscription detail endpoints #990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions orchestrator/api/api_v1/endpoints/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@
subscription_workflows,
)
from orchestrator.settings import app_settings
from orchestrator.targets import Target
from orchestrator.types import SubscriptionLifecycle
from orchestrator.utils.deprecation_logger import deprecated_endpoint
from orchestrator.utils.get_subscription_dict import get_subscription_dict
from orchestrator.websocket import sync_invalidate_subscription_cache
from orchestrator.workflows import get_workflow

router = APIRouter()

Expand Down Expand Up @@ -100,6 +102,25 @@
return statuses


def _authorized_subscription_workflows(
subscription: SubscriptionTable, current_user: OIDCUserModel | None
) -> dict[str, list[dict[str, list[Any] | str]]]:
subscription_workflows_dict = subscription_workflows(subscription)

for workflow_target in Target.values():
for workflow_dict in subscription_workflows_dict[workflow_target.lower()]:
workflow = get_workflow(workflow_dict["name"])
if not workflow:
continue

Check warning on line 114 in orchestrator/api/api_v1/endpoints/subscriptions.py

View check run for this annotation

Codecov / codecov/patch

orchestrator/api/api_v1/endpoints/subscriptions.py#L114

Added line #L114 was not covered by tests
if (
not workflow.authorize_callback(current_user) # The current user isn't allowed to run this workflow
and "reason" not in workflow_dict # and there isn't already a reason why this workflow cannot run
):
workflow_dict["reason"] = "subscription.insufficient_workflow_permissions"

return subscription_workflows_dict


@router.get(
"/domain-model/{subscription_id}",
response_model=SubscriptionDomainModelSchema | None,
Expand Down Expand Up @@ -169,7 +190,9 @@
description="This endpoint is deprecated and will be removed in a future release. Please use the GraphQL query",
dependencies=[Depends(deprecated_endpoint)],
)
def subscription_workflows_by_id(subscription_id: UUID) -> dict[str, list[dict[str, list[Any] | str]]]:
def subscription_workflows_by_id(
subscription_id: UUID, current_user: OIDCUserModel | None = Depends(authenticate)
) -> dict[str, list[dict[str, list[Any] | str]]]:
subscription = db.session.get(
SubscriptionTable,
subscription_id,
Expand All @@ -181,7 +204,7 @@
if not subscription:
raise_status(HTTPStatus.NOT_FOUND)

return subscription_workflows(subscription)
return _authorized_subscription_workflows(subscription, current_user)


@router.put("/{subscription_id}/set_in_sync", response_model=None, status_code=HTTPStatus.OK)
Expand Down
17 changes: 16 additions & 1 deletion orchestrator/graphql/schemas/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from strawberry.scalars import JSON

from oauth2_lib.strawberry import authenticated_field
from orchestrator.api.api_v1.endpoints.processes import get_auth_callbacks, get_current_steps
from orchestrator.db import ProcessTable, ProductTable, db
from orchestrator.graphql.pagination import EMPTY_PAGE, Connection
from orchestrator.graphql.schemas.customer import CustomerType
from orchestrator.graphql.schemas.helpers import get_original_model
from orchestrator.graphql.schemas.product import ProductType
from orchestrator.graphql.types import GraphqlFilter, GraphqlSort, OrchestratorInfo
from orchestrator.graphql.types import FormUserPermissionsType, GraphqlFilter, GraphqlSort, OrchestratorInfo
from orchestrator.schemas.process import ProcessSchema, ProcessStepSchema
from orchestrator.services.processes import load_process
from orchestrator.settings import app_settings
from orchestrator.workflows import get_workflow

if TYPE_CHECKING:
from orchestrator.graphql.schemas.subscription import SubscriptionInterface
Expand Down Expand Up @@ -74,6 +77,18 @@
shortcode=app_settings.DEFAULT_CUSTOMER_SHORTCODE,
)

@strawberry.field(description="Returns user permissions for operations on this process") # type: ignore
def user_permissions(self, info: OrchestratorInfo) -> FormUserPermissionsType:
oidc_user = info.context.get_current_user
workflow = get_workflow(self.workflow_name)
process = load_process(db.session.get(ProcessTable, self.process_id)) # type: ignore[arg-type]
auth_resume, auth_retry = get_auth_callbacks(get_current_steps(process), workflow) # type: ignore[arg-type]

Check warning on line 85 in orchestrator/graphql/schemas/process.py

View check run for this annotation

Codecov / codecov/patch

orchestrator/graphql/schemas/process.py#L82-L85

Added lines #L82 - L85 were not covered by tests

return FormUserPermissionsType(

Check warning on line 87 in orchestrator/graphql/schemas/process.py

View check run for this annotation

Codecov / codecov/patch

orchestrator/graphql/schemas/process.py#L87

Added line #L87 was not covered by tests
retryAllowed=auth_retry and auth_retry(oidc_user), # type: ignore[arg-type]
resumeAllowed=auth_resume and auth_resume(oidc_user), # type: ignore[arg-type]
)

@authenticated_field(description="Returns list of subscriptions of the process") # type: ignore
async def subscriptions(
self,
Expand Down
8 changes: 7 additions & 1 deletion orchestrator/graphql/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022-2023 SURF, GÉANT.
# Copyright 2022-2025 SURF, GÉANT.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -132,6 +132,12 @@ class GraphqlFilter:
}


@strawberry.type(description="User permissions on a specific process")
class FormUserPermissionsType:
retryAllowed: bool
resumeAllowed: bool


@strawberry.type(description="Generic class to capture errors")
class MutationError:
message: str = strawberry.field(description="Error message")
Expand Down
43 changes: 39 additions & 4 deletions test/unit_tests/api/test_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from nwastdlib.url import URL
from oauth2_lib.fastapi import OIDCUserModel
from orchestrator.api.helpers import product_block_paths
from orchestrator.db import (
FixedInputTable,
Expand All @@ -29,7 +30,7 @@
unsync,
)
from orchestrator.targets import Target
from orchestrator.workflow import ProcessStatus
from orchestrator.workflow import ProcessStatus, done, init, workflow
from test.unit_tests.config import (
IMS_CIRCUIT_ID,
INTERNETPINNEN_PREFIX_SUBSCRIPTION_ID,
Expand All @@ -39,6 +40,7 @@
PORT_SUBSCRIPTION_ID,
)
from test.unit_tests.conftest import do_refresh_subscriptions_search_view
from test.unit_tests.workflows import WorkflowInstanceForTests

SERVICE_SUBSCRIPTION_ID = str(uuid4())
PORT_A_SUBSCRIPTION_ID = str(uuid4())
Expand All @@ -50,7 +52,8 @@
INVALID_SUBSCRIPTION_ID = str(uuid4())
INVALID_PORT_SUBSCRIPTION_ID = str(uuid4())

PRODUCT_ID = str(uuid4())
PORT_A_PRODUCT_ID = str(uuid4())
PORT_B_PRODUCT_ID = str(uuid4())
CUSTOMER_ID = str(uuid4())


Expand All @@ -77,7 +80,7 @@ def seed():
fixed_inputs=fixed_inputs,
)
port_a_product = ProductTable(
product_id=PRODUCT_ID,
product_id=PORT_A_PRODUCT_ID,
name="PortAProduct",
description="Port A description",
product_type="Port",
Expand All @@ -87,6 +90,7 @@ def seed():
fixed_inputs=fixed_inputs,
)
port_b_product = ProductTable(
product_id=PORT_B_PRODUCT_ID,
name="PortBProduct",
description="Port B description",
product_type="Port",
Expand Down Expand Up @@ -280,7 +284,7 @@ def seed_with_direct_relations():
fixed_inputs=fixed_inputs,
)
port_a_product = ProductTable(
product_id=PRODUCT_ID,
product_id=PORT_A_PRODUCT_ID,
name="PortAProduct",
description="Port A description",
product_type="Port",
Expand Down Expand Up @@ -844,3 +848,34 @@ def test_subscription_detail_with_in_use_by_ids_not_filtered_self(test_client, p
)
assert response.status_code == HTTPStatus.OK
assert response.json()["block"]["sub_block"]["in_use_by_ids"]


@pytest.mark.parametrize(
"test_input",
[
(PORT_A_PRODUCT_ID, PORT_A_SUBSCRIPTION_ID, "subscription.no_modify_invalid_status"),
(PORT_B_PRODUCT_ID, SSP_SUBSCRIPTION_ID, "subscription.insufficient_workflow_permissions"),
],
)
def test_subscription_detail_with_forbidden_workflow_without_override(seed, test_client, test_input):
product_id, subscription_id, expected_error = test_input

def disallow(_: OIDCUserModel | None = None) -> bool:
return False

@workflow("unauthorized_workflow", target=Target.MODIFY, authorize_callback=disallow)
def unauthorized_workflow():
return init >> done

with WorkflowInstanceForTests(unauthorized_workflow, "unauthorized_workflow") as wf:
product = db.session.get(ProductTable, product_id)
product.workflows.append(wf)
db.session.commit()

response = test_client.get(f"/api/subscriptions/workflows/{subscription_id}")
assert response.status_code == HTTPStatus.OK

subscription_workflows = response.json()
assert len(subscription_workflows["modify"]) == 1
assert "reason" in subscription_workflows["modify"][0]
assert subscription_workflows["modify"][0]["reason"] == expected_error
Loading