Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ def validate_google_oauth_config(config: Dict[str, Any]) -> List[str]:
"""Validate Google OAuth configuration and return list of errors."""
errors = []

required_fields = ['client_id', 'client_secret', 'redirect_uri']
required_fields = [
'client_id',
'client_secret',
'redirect_uri',
'client_redirect_success_url',
'client_redirect_failure_url',
]
for field in required_fields:
if not config.get(field):
errors.append(f'Missing required field: {field}')
Expand All @@ -24,6 +30,11 @@ def validate_google_oauth_config(config: Dict[str, Any]) -> List[str]:
):
errors.append('redirect_uri must be a valid HTTP/HTTPS URL')

for field in ('client_redirect_success_url', 'client_redirect_failure_url'):
value = config.get(field)
if value and not (value.startswith('http://') or value.startswith('https://')):
errors.append(f'{field} must be a valid HTTP/HTTPS URL')

# Validate scopes
scopes = config.get('scopes', [])
if not isinstance(scopes, list) or len(scopes) == 0:
Expand Down Expand Up @@ -61,6 +72,32 @@ def validate_microsoft_oauth_config(config: Dict[str, Any]) -> List[str]:
return errors


def validate_microsoft_adfs_config(config: Dict[str, Any]) -> List[str]:
"""Validate Microsoft ADFS configuration and return list of errors."""
errors = []

required_fields = ['client_id', 'client_secret', 'authority', 'redirect_uri']
for field in required_fields:
if not config.get(field):
errors.append(f'Missing required field: {field}')

authority = config.get('authority', '')
if authority and not authority.startswith('https://'):
errors.append('authority must be a valid HTTPS URL')

redirect_uri = config.get('redirect_uri')
if redirect_uri and not (
redirect_uri.startswith('http://') or redirect_uri.startswith('https://')
):
errors.append('redirect_uri must be a valid HTTP/HTTPS URL')

scopes = config.get('scopes', [])
if not isinstance(scopes, list) or len(scopes) == 0:
errors.append('scopes must be a non-empty list')

return errors
Comment thread
rootflo-hardik marked this conversation as resolved.


def validate_email_password_config(config: Dict[str, Any]) -> List[str]:
"""Validate email/password configuration and return list of errors."""
errors = []
Expand Down Expand Up @@ -130,6 +167,23 @@ def get_config_template(auth_type: str) -> Dict[str, Any]:
'response_type': 'code',
'response_mode': 'query',
},
'microsoft_adfs': {
'client_id': 'YOUR_ADFS_CLIENT_ID',
'client_secret': 'YOUR_ADFS_CLIENT_SECRET',
'authority': 'https://fs.your-domain.com',
'redirect_uri': 'https://your-domain.com/v1/oauth/adfs/callback',
'client_redirect_success_url': 'https://your-domain.com/login/success',
'client_redirect_failure_url': 'https://your-domain.com/login/failed',
'scopes': ['openid', 'profile', 'email'],
'response_type': 'code',
'response_mode': 'query',
'authorize_path': '/adfs/oauth2/authorize',
'token_path': '/adfs/oauth2/token',
'jwks_path': '/adfs/discovery/keys',
'expected_issuer': 'https://fs.your-domain.com/adfs',
'clock_skew_seconds': 60,
'verify_ssl': True,
},
}

return templates.get(auth_type, {})
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
'/floware/v1/plugin-auth/authenticate',
'/floware/v1/oauth/google/callback',
'/floware/v1/oauth/microsoft/callback',
'/floware/v1/oauth/adfs/callback',
'/floware/v1/plugin-auth/oauth/init',
'/floware/v1/settings/config',
]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import logging
import secrets
from uuid import uuid4
from db_repo_module.models.resource import ResourceScope
from dependency_injector.wiring import inject, Provide
Expand Down Expand Up @@ -36,6 +38,53 @@

auth_plugin_router = APIRouter()

logger = logging.getLogger(__name__)

# Per-flow OAuth state lives in Redis with a short TTL. Each authorize ->
# callback round-trip is a one-shot: stored on init, consumed (deleted) on
# successful callback. Replays after consumption miss the cache and are
# rejected.
OAUTH_FLOW_TTL_SECONDS = 600
_OAUTH_FLOW_KEY_PREFIX = 'oauth:flow:'


def _store_oauth_flow(cache_manager: CacheManager, auth_id: str) -> tuple[str, str]:
"""Mint and persist an opaque state+nonce pair bound to auth_id.

`nx=True` prevents accidental overwrite if (astronomically) the same
32-byte token is minted twice.
"""
state = secrets.token_urlsafe(32)
nonce = secrets.token_urlsafe(32)
cache_manager.add(
f'{_OAUTH_FLOW_KEY_PREFIX}{state}',
json.dumps({'auth_id': str(auth_id), 'nonce': nonce}),
expiry=OAUTH_FLOW_TTL_SECONDS,
nx=True,
)
return state, nonce


def _consume_oauth_flow(
cache_manager: CacheManager, state: Optional[str]
) -> Optional[Dict[str, str]]:
"""Single-use lookup of the flow record. Returns None on miss/parse error."""
if not state:
return None
key = f'{_OAUTH_FLOW_KEY_PREFIX}{state}'
raw = cache_manager.get_str(key)
if not raw:
return None
try:
flow = json.loads(raw)
except (TypeError, ValueError):
cache_manager.remove(key)
return None
cache_manager.remove(key)
if not isinstance(flow, dict) or 'auth_id' not in flow:
return None
return flow


class UnifiedAuthRequest(BaseModel):
auth_id: str
Expand Down Expand Up @@ -204,6 +253,7 @@ async def init_oauth_flow(
authenticator_repository: SQLAlchemyRepository[Authenticator] = Depends(
Provide[PluginsContainer.authenticator_repository]
),
cache_manager: CacheManager = Depends(Provide[UserContainer.cache_manager]),
):
"""Initialize OAuth flow and return authorization URL."""

Expand All @@ -222,9 +272,10 @@ async def init_oauth_flow(
),
)

# Generate state and get authorization URL
state = json.dumps({'auth_id': oauth_request.auth_id})
auth_url = authenticator.get_authorization_url(state)
# Mint opaque CSRF state + OIDC nonce, persist server-side, and pass
# both into the provider so they end up in the authorize URL.
state, nonce = _store_oauth_flow(cache_manager, oauth_request.auth_id)
auth_url = authenticator.get_authorization_url(state, nonce=nonce)

if not auth_url:
return JSONResponse(
Expand Down Expand Up @@ -274,11 +325,13 @@ async def google_oauth_callback(
token_service: TokenService = Depends(Provide[AuthContainer.token_service]),
):
"""Handle Google OAuth callback."""
state_obj = json.loads(state)
auth_id = state_obj['auth_id']
flow = _consume_oauth_flow(cache_manager, state)
if flow is None:
logger.warning('Google OAuth callback received unknown/expired state')
return RedirectResponse(url='about:blank')

return await _handle_oauth_callback(
auth_id,
flow['auth_id'],
{'authorization_code': code, 'state': state, 'error': error},
request,
response_formatter,
Expand All @@ -288,6 +341,7 @@ async def google_oauth_callback(
session_repository,
cache_manager,
token_service,
expected_nonce=flow.get('nonce'),
)


Expand Down Expand Up @@ -317,11 +371,57 @@ async def microsoft_oauth_callback(
token_service: TokenService = Depends(Provide[AuthContainer.token_service]),
):
"""Handle Microsoft OAuth callback."""
state_obj = json.loads(state)
auth_id = state_obj['auth_id']
flow = _consume_oauth_flow(cache_manager, state)
if flow is None:
logger.warning('Microsoft OAuth callback received unknown/expired state')
return RedirectResponse(url='about:blank')

return await _handle_oauth_callback(
flow['auth_id'],
{'authorization_code': code, 'state': state, 'error': error},
request,
response_formatter,
authenticator_repository,
user_service,
user_repository,
session_repository,
cache_manager,
token_service,
expected_nonce=flow.get('nonce'),
)


@auth_plugin_router.get('/v1/oauth/adfs/callback')
@inject
async def microsoft_adfs_oauth_callback(
request: Request,
state: str = Query(...),
code: Optional[str] = Query(None),
error: Optional[str] = Query(None),
response_formatter: ResponseFormatter = Depends(
Provide[CommonContainer.response_formatter]
),
authenticator_repository: SQLAlchemyRepository[Authenticator] = Depends(
Provide[PluginsContainer.authenticator_repository]
),
user_repository: SQLAlchemyRepository[User] = Depends(
Provide[UserContainer.user_repository]
),
user_service: UserService = Depends(Provide[UserContainer.user_service]),
session_repository: SQLAlchemyRepository[Session] = Depends(
Provide[UserContainer.session_repository]
),
cache_manager: CacheManager = Depends(Provide[UserContainer.cache_manager]),
token_service: TokenService = Depends(Provide[AuthContainer.token_service]),
):
"""Handle Microsoft ADFS OAuth callback."""
flow = _consume_oauth_flow(cache_manager, state)
if flow is None:
logger.warning('Microsoft ADFS callback received unknown/expired state')
return RedirectResponse(url='about:blank')

return await _handle_oauth_callback(
auth_id,
flow['auth_id'],
{'authorization_code': code, 'state': state, 'error': error},
request,
response_formatter,
Expand All @@ -331,6 +431,7 @@ async def microsoft_oauth_callback(
session_repository,
cache_manager,
token_service,
expected_nonce=flow.get('nonce'),
)


Expand All @@ -345,6 +446,7 @@ async def _handle_oauth_callback(
session_repository: SQLAlchemyRepository[Session],
cache_manager: CacheManager,
token_service: TokenService,
expected_nonce: Optional[str] = None,
) -> RedirectResponse:
"""Common OAuth callback handler."""

Expand Down Expand Up @@ -393,7 +495,9 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse:
return RedirectResponse(url='about:blank')

# Handle OAuth callback
auth_result = authenticator.handle_callback(callback_data)
auth_result = authenticator.handle_callback(
callback_data, expected_nonce=expected_nonce
)

if not auth_result.success:
if failure_url:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .email_password.config import EmailPasswordConfig
from .google_oauth.config import GoogleOAuthConfig
from .microsoft_oauth.config import MicrosoftOAuthConfig
from .microsoft_adfs.config import MicrosoftADFSConfig

__all__ = [
'AuthenticatorFactory',
Expand All @@ -24,4 +25,5 @@
'EmailPasswordConfig',
'GoogleOAuthConfig',
'MicrosoftOAuthConfig',
'MicrosoftADFSConfig',
]
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,15 @@ def validate_config(self) -> bool:
except Exception:
return False

def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]:
def get_authorization_url(
self, state: Optional[str] = None, nonce: Optional[str] = None
) -> Optional[str]:
"""Email/password doesn't need authorization URL."""
return None

def handle_callback(self, callback_data: Dict[str, Any]) -> AuthResult:
def handle_callback(
self, callback_data: Dict[str, Any], expected_nonce: Optional[str] = None
) -> AuthResult:
"""Email/password doesn't use OAuth callbacks."""
return AuthResult(
success=False,
Expand Down
13 changes: 13 additions & 0 deletions wavefront/server/plugins/authenticator/authenticator/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from .email_password import EmailPasswordAuthenticator
from .google_oauth import GoogleOAuthAuthenticator
from .microsoft_oauth import MicrosoftOAuthAuthenticator
from .microsoft_adfs import MicrosoftADFSAuthenticator
from .email_password.config import EmailPasswordConfig
from .google_oauth.config import GoogleOAuthConfig
from .microsoft_oauth.config import MicrosoftOAuthConfig
from .microsoft_adfs.config import MicrosoftADFSConfig


class AuthenticatorFactory:
Expand All @@ -28,6 +30,7 @@ def __init__(self):
if not hasattr(self, '_initialized'):
self._google_instances: Dict[str, GoogleOAuthAuthenticator] = {}
self._microsoft_instances: Dict[str, MicrosoftOAuthAuthenticator] = {}
self._adfs_instances: Dict[str, MicrosoftADFSAuthenticator] = {}
self._email_instances: Dict[str, EmailPasswordAuthenticator] = {}
self._instances_lock = threading.Lock()
self._initialized = True
Expand Down Expand Up @@ -85,6 +88,8 @@ def validate_config(
return GoogleOAuthAuthenticator.validate_config_static(config)
elif auth_type == AuthenticatorType.MICROSOFT_OAUTH:
return MicrosoftOAuthAuthenticator.validate_config_static(config)
elif auth_type == AuthenticatorType.MICROSOFT_ADFS:
return MicrosoftADFSAuthenticator.validate_config_static(config)
else:
raise ValueError(f'Unsupported authenticator type: {auth_type}')

Expand Down Expand Up @@ -157,6 +162,7 @@ def get_cached_instance_count(
return (
len(self._google_instances)
+ len(self._microsoft_instances)
+ len(self._adfs_instances)
+ len(self._email_instances)
)

Expand All @@ -165,6 +171,7 @@ def clear_all_instances(self) -> None:
with self._instances_lock:
self._google_instances.clear()
self._microsoft_instances.clear()
self._adfs_instances.clear()
self._email_instances.clear()

def _get_cache_for_type(
Expand All @@ -175,6 +182,8 @@ def _get_cache_for_type(
return self._google_instances
elif auth_type == AuthenticatorType.MICROSOFT_OAUTH:
return self._microsoft_instances
elif auth_type == AuthenticatorType.MICROSOFT_ADFS:
return self._adfs_instances
elif auth_type == AuthenticatorType.EMAIL_PASSWORD:
return self._email_instances
else:
Expand All @@ -196,6 +205,10 @@ def _create_authenticator(
typed_config = MicrosoftOAuthConfig(**config)
return MicrosoftOAuthAuthenticator(typed_config)

elif auth_type == AuthenticatorType.MICROSOFT_ADFS:
typed_config = MicrosoftADFSConfig(**config)
return MicrosoftADFSAuthenticator(typed_config)

else:
raise ValueError(f'Unsupported authenticator type: {auth_type}')

Expand Down
Loading
Loading