Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions src/authsome/server/credential_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import importlib.resources
import json
import os
from datetime import timedelta
from typing import Any
from urllib.parse import urlparse
Expand Down Expand Up @@ -62,6 +63,15 @@
}


def is_admin_principal(principal_id: str | None) -> bool:
"""Return whether a principal is listed in AUTHSOME_ADMIN_PRINCIPALS."""
if not principal_id:
return False
raw = os.environ.get("AUTHSOME_ADMIN_PRINCIPALS", "")
principals = {item.strip() for item in raw.split(",") if item.strip()}
return principal_id in principals


class AuthService:
"""
Authentication and credential lifecycle service.
Expand Down Expand Up @@ -203,6 +213,8 @@ async def remove_provider(self, name: str) -> bool:
return await self._vault.delete(name, collection="providers")

def _ensure_local_provider_admin_operation_allowed(self, operation: str, provider: str) -> None:
if is_admin_principal(self._principal_id):
return
if self._deployment_mode == "hosted":
raise OperationNotAllowedError(
operation,
Expand All @@ -211,6 +223,8 @@ def _ensure_local_provider_admin_operation_allowed(self, operation: str, provide
)

def _ensure_provider_client_mutation_allowed(self, provider: str) -> None:
if is_admin_principal(self._principal_id):
return
if self._deployment_mode == "hosted":
raise OperationNotAllowedError(
"login",
Expand Down Expand Up @@ -377,7 +391,13 @@ async def update_provider_configuration(
updated.base_url = existing.base_url if existing else None
updated.api_url = existing.api_url if existing else None

updated.scopes = list(existing.scopes) if existing and existing.scopes is not None else None
if "scopes" in inputs:
scopes_input = inputs["scopes"].strip()
updated.scopes = [s.strip() for s in scopes_input.split(",") if s.strip()] if scopes_input else []
elif existing and existing.scopes is not None:
updated.scopes = list(existing.scopes)
else:
updated.scopes = list(definition.oauth.scopes or []) if definition.oauth else []
updated.metadata = dict(existing.metadata) if existing else {}

changed = existing is None or any(
Expand All @@ -386,6 +406,7 @@ async def update_provider_configuration(
existing.client_secret != updated.client_secret,
existing.base_url != updated.base_url,
existing.api_url != updated.api_url,
existing.scopes != updated.scopes,
)
)
if not changed:
Expand Down Expand Up @@ -468,7 +489,7 @@ async def get_required_inputs(
default=flow_client_id or "",
)
)
fields.append(InputField(name="client_secret", label="Client Secret (Optional)", secret=True, default=""))
fields.append(InputField(name="client_secret", label="Client Secret", secret=True, default=""))
elif flow_type == FlowType.DEVICE_CODE and (provider_config_only or not flow_client_id):
fields.append(
InputField(
Expand Down
43 changes: 30 additions & 13 deletions src/authsome/server/routes/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from authsome.auth.models.enums import AuthType, FlowType
from authsome.auth.models.provider import ProviderDefinition
from authsome.auth.sessions import AuthSession, AuthSessionStore
from authsome.server.credential_service import AuthService
from authsome.server.credential_service import AuthService, is_admin_principal
from authsome.server.dependencies import (
create_principal_vault_binding_registry,
create_vault_registry,
Expand Down Expand Up @@ -79,12 +79,18 @@ def _ui_cookie_secure(server_base_url: str) -> bool:
return server_base_url.startswith("https://")


def _ui_policy() -> dict[str, Any]:
def _ui_policy(request: Request, auth: AuthService | None = None) -> dict[str, Any]:
hosted = _is_hosted_ui()
principal_id = auth.principal_id if auth is not None else getattr(request.state, "ui_principal_id", None)
show_provider_client_details = not hosted or is_admin_principal(principal_id)
return {
"ui_mode": "hosted" if hosted else "local",
"show_provider_client_details": not hosted,
"provider_management_label": "OAuth application managed by Authsome" if hosted else "OAuth Application",
"show_provider_client_details": show_provider_client_details,
"provider_management_label": (
"OAuth application managed by Authsome"
if hosted and not show_provider_client_details
else "OAuth Application"
),
"show_hosted_identity": hosted,
}

Expand Down Expand Up @@ -188,13 +194,13 @@ def _hosted_auth_page_response(
return HTMLResponse(page, status_code=400 if error else 200)


def _page_context(request: Request, page: str, **kwargs: Any) -> dict[str, Any]:
def _page_context(request: Request, page: str, *, auth: AuthService | None = None, **kwargs: Any) -> dict[str, Any]:
return {
"page": page,
"version": __version__,
"ui_identity": getattr(request.state, "ui_identity", None),
"ui_email": getattr(request.state, "ui_email", None),
**_ui_policy(),
**_ui_policy(request, auth),
**kwargs,
}

Expand Down Expand Up @@ -259,7 +265,7 @@ def _build_provider_view(
async def _provider_connection_groups(
request: Request,
*,
identity: str,
identity: str | None,
principal_id: str | None,
provider_name: str,
) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -310,6 +316,7 @@ async def _provider_connection_groups(

def _provider_page_context(
request: Request,
auth: AuthService,
provider: ProviderDefinition,
api_url: str,
*,
Expand All @@ -319,10 +326,11 @@ def _provider_page_context(
auth_url: str | None,
token_url: str | None,
) -> dict[str, Any]:
policy = _ui_policy()
policy = _ui_policy(request, auth)
return _page_context(
request,
"applications",
auth=auth,
provider=provider,
connection=None,
grouped_connections=grouped_connections,
Expand All @@ -346,13 +354,15 @@ def _provider_page_context(

def _connection_detail_context(
request: Request,
auth: AuthService,
provider: ProviderDefinition,
connection_record: ConnectionRecord,
api_url: str,
) -> dict[str, Any]:
return _page_context(
request,
"connections",
auth=auth,
provider=provider,
connection=connection_record,
logo_initial=_logo_initial(provider.display_name or provider.name),
Expand Down Expand Up @@ -425,6 +435,7 @@ async def overview(
_page_context(
request,
"overview",
auth=auth,
stats={
"connected": len(connected),
"available": available_count,
Expand Down Expand Up @@ -454,7 +465,7 @@ async def applications(
return templates.TemplateResponse(
request,
"applications.html",
_page_context(request, "applications", providers=providers),
_page_context(request, "applications", auth=auth, providers=providers),
)


Expand All @@ -471,6 +482,7 @@ async def connections(
_page_context(
request,
"connections",
auth=auth,
connection_rows=rows,
total_connections=len(rows),
),
Expand All @@ -493,6 +505,7 @@ async def identity_page(
_page_context(
request,
"identity",
auth=auth,
identities=identities,
principal_id=auth.principal_id,
),
Expand All @@ -509,13 +522,15 @@ async def app_detail(
provider = await auth.get_provider(provider_name)
redirect_uri = build_callback_url(server_base_url)
api_url = provider.api_url or (provider.oauth.base_url if provider.oauth else None) or provider.name
if _is_hosted_ui():
policy = _ui_policy(request, auth)
if not policy["show_provider_client_details"] and _is_hosted_ui():
return templates.TemplateResponse(
request,
"app_detail_managed.html",
_page_context(
request,
"applications",
auth=auth,
provider=provider,
logo_initial=_logo_initial(provider.display_name or provider.name),
),
Expand All @@ -524,7 +539,7 @@ async def app_detail(
client_record = await auth.get_provider_client(provider_name)
grouped_connections = await _provider_connection_groups(
request,
identity=auth.require_identity(),
identity=auth.identity,
principal_id=auth.principal_id,
provider_name=provider_name,
)
Expand All @@ -533,6 +548,7 @@ async def app_detail(
"app_provider.html",
_provider_page_context(
request,
auth,
provider,
api_url,
grouped_connections=grouped_connections,
Expand All @@ -554,7 +570,7 @@ async def connection_detail(
provider = await auth.get_provider(provider_name)
connection_record = await auth.get_connection(provider_name, connection_name)
api_url = provider.api_url or (provider.oauth.base_url if provider.oauth else None) or provider.name
common = _connection_detail_context(request, provider, connection_record, api_url)
common = _connection_detail_context(request, auth, provider, connection_record, api_url)

if provider.auth_type == AuthType.OAUTH2:
return templates.TemplateResponse(
Expand Down Expand Up @@ -664,7 +680,8 @@ async def configure_provider(
) -> Response:
"""Open the provider configuration flow for deployment-scoped credentials."""
provider = await auth.get_provider(provider_name)
if provider.auth_type != AuthType.OAUTH2 or _is_hosted_ui():
policy = _ui_policy(request, auth)
if provider.auth_type != AuthType.OAUTH2 or (not policy["show_provider_client_details"] and _is_hosted_ui()):
return _redirect(request, f"/ui/apps/{provider_name}")

session = await sessions.create(
Expand Down
2 changes: 1 addition & 1 deletion src/authsome/server/ui/pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def input_page(
optional_rows = []
for field in fields:
row = _field_row(field)
if field.get("default") is None or field.get("name") == "client_secret":
if field.get("default") is None or field.get("name") in {"client_id", "client_secret"}:
required_rows.append(row)
else:
optional_rows.append(row)
Expand Down
93 changes: 93 additions & 0 deletions tests/auth/test_service_provider_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,99 @@ async def test_get_required_inputs_skips_scope_prompt_when_server_scopes_exist()
assert all(field.name != "scopes" for field in fields)


@pytest.mark.asyncio
async def test_pkce_client_credentials_prompt_id_then_secret() -> None:
vault = mock.AsyncMock()
vault.get.return_value = None
service = AuthService(vault, identity="steady-wisely-boldly-0042")
session = _make_session(flow_type=FlowType.PKCE)

with mock.patch.object(service, "get_provider", new=mock.AsyncMock(return_value=_make_provider())):
fields = await service.get_required_inputs(session)

credential_fields = [field for field in fields if field.name in {"client_id", "client_secret"}]
assert [field.name for field in credential_fields] == ["client_id", "client_secret"]
assert credential_fields[1].label == "Client Secret"


@pytest.mark.asyncio
async def test_update_provider_configuration_persists_default_scopes_when_omitted() -> None:
vault = mock.AsyncMock()
vault.get.return_value = None
service = AuthService(vault, identity="steady-wisely-boldly-0042")

with mock.patch.object(service, "get_provider", new=mock.AsyncMock(return_value=_make_provider())):
changed = await service.update_provider_configuration(
"github",
{"client_id": "cid", "client_secret": "secret"},
)

assert changed is True
vault.put.assert_awaited_once()
saved = ProviderClientRecord.model_validate_json(vault.put.await_args.args[1])
assert saved.client_id == "cid"
assert saved.client_secret == "secret"
assert saved.scopes == ["repo"]


@pytest.mark.asyncio
async def test_update_provider_configuration_persists_submitted_scopes() -> None:
vault = mock.AsyncMock()
vault.get.return_value = None
service = AuthService(vault, identity="steady-wisely-boldly-0042")

with mock.patch.object(service, "get_provider", new=mock.AsyncMock(return_value=_make_provider())):
changed = await service.update_provider_configuration(
"github",
{"client_id": "cid", "client_secret": "secret", "scopes": "repo,read:user"},
)

assert changed is True
saved = ProviderClientRecord.model_validate_json(vault.put.await_args.args[1])
assert saved.scopes == ["repo", "read:user"]


@pytest.mark.asyncio
async def test_hosted_admin_provider_config_satisfies_next_identity_login(monkeypatch: pytest.MonkeyPatch) -> None:
store: dict[tuple[str, str], str] = {}
vault = mock.AsyncMock()

async def get_value(key: str, *, collection: str) -> str | None:
return store.get((collection, key))

async def put_value(key: str, value: str, *, collection: str) -> None:
store[(collection, key)] = value

vault.get.side_effect = get_value
vault.put.side_effect = put_value
monkeypatch.setenv("AUTHSOME_ADMIN_PRINCIPALS", "principal_admin")

admin_service = AuthService(
vault,
identity=None,
principal_id="principal_admin",
deployment_mode="hosted",
)
identity_service = AuthService(
vault,
identity="steady-wisely-boldly-0042",
principal_id="principal_user",
deployment_mode="hosted",
)
provider = _make_provider()

with mock.patch.object(admin_service, "get_provider", new=mock.AsyncMock(return_value=provider)):
await admin_service.update_provider_configuration(
"github",
{"client_id": "cid", "client_secret": "secret"},
)

with mock.patch.object(identity_service, "get_provider", new=mock.AsyncMock(return_value=provider)):
fields = await identity_service.get_required_inputs(_make_session(flow_type=FlowType.PKCE))

assert fields == []


@pytest.mark.asyncio
async def test_begin_login_flow_reuses_server_scopes() -> None:
vault = mock.AsyncMock()
Expand Down
2 changes: 2 additions & 0 deletions tests/server/test_pop_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_health_and_ready_report_encryption_details(monkeypatch, tmp_path: Path)

def test_rekey_rotates_local_vault(monkeypatch, tmp_path: Path) -> None:
monkeypatch.setenv("AUTHSOME_HOME", str(tmp_path))
monkeypatch.delenv("AUTHSOME_DEPLOYMENT_MODE", raising=False)
identity = create_identity(tmp_path, "steady-wisely-boldly-0042")

with TestClient(create_app()) as client:
Expand All @@ -110,6 +111,7 @@ def test_rekey_rotates_local_vault(monkeypatch, tmp_path: Path) -> None:
def test_rekey_rejects_env_master_key(monkeypatch, tmp_path: Path) -> None:
monkeypatch.setenv("AUTHSOME_HOME", str(tmp_path))
monkeypatch.setenv("AUTHSOME_MASTER_KEY", base64.b64encode(b"\x03" * 32).decode("ascii"))
monkeypatch.delenv("AUTHSOME_DEPLOYMENT_MODE", raising=False)
identity = create_identity(tmp_path, "steady-wisely-boldly-0042")

with TestClient(create_app()) as client:
Expand Down
Loading
Loading