diff --git a/wavefront/server/modules/plugins_module/plugins_module/utils/authenticator_helper.py b/wavefront/server/modules/plugins_module/plugins_module/utils/authenticator_helper.py index 93e31927..1a93a647 100644 --- a/wavefront/server/modules/plugins_module/plugins_module/utils/authenticator_helper.py +++ b/wavefront/server/modules/plugins_module/plugins_module/utils/authenticator_helper.py @@ -12,7 +12,13 @@ def validate_google_oauth_config(config: Dict[str, Any]) -> List[str]: """Validate Google OAuth configuration and return list of errors.""" errors = [] - required_fields = ['client_id', 'client_secret', 'redirect_uri'] + required_fields = [ + 'client_id', + 'client_secret', + 'redirect_uri', + 'client_redirect_success_url', + 'client_redirect_failure_url', + ] for field in required_fields: if not config.get(field): errors.append(f'Missing required field: {field}') @@ -24,6 +30,11 @@ def validate_google_oauth_config(config: Dict[str, Any]) -> List[str]: ): errors.append('redirect_uri must be a valid HTTP/HTTPS URL') + for field in ('client_redirect_success_url', 'client_redirect_failure_url'): + value = config.get(field) + if value and not (value.startswith('http://') or value.startswith('https://')): + errors.append(f'{field} must be a valid HTTP/HTTPS URL') + # Validate scopes scopes = config.get('scopes', []) if not isinstance(scopes, list) or len(scopes) == 0: @@ -61,6 +72,32 @@ def validate_microsoft_oauth_config(config: Dict[str, Any]) -> List[str]: return errors +def validate_microsoft_adfs_config(config: Dict[str, Any]) -> List[str]: + """Validate Microsoft ADFS configuration and return list of errors.""" + errors = [] + + required_fields = ['client_id', 'client_secret', 'authority', 'redirect_uri'] + for field in required_fields: + if not config.get(field): + errors.append(f'Missing required field: {field}') + + authority = config.get('authority', '') + if authority and not authority.startswith('https://'): + errors.append('authority must be a valid HTTPS URL') + + redirect_uri = config.get('redirect_uri') + if redirect_uri and not ( + redirect_uri.startswith('http://') or redirect_uri.startswith('https://') + ): + errors.append('redirect_uri must be a valid HTTP/HTTPS URL') + + scopes = config.get('scopes', []) + if not isinstance(scopes, list) or len(scopes) == 0: + errors.append('scopes must be a non-empty list') + + return errors + + def validate_email_password_config(config: Dict[str, Any]) -> List[str]: """Validate email/password configuration and return list of errors.""" errors = [] @@ -130,6 +167,23 @@ def get_config_template(auth_type: str) -> Dict[str, Any]: 'response_type': 'code', 'response_mode': 'query', }, + 'microsoft_adfs': { + 'client_id': 'YOUR_ADFS_CLIENT_ID', + 'client_secret': 'YOUR_ADFS_CLIENT_SECRET', + 'authority': 'https://fs.your-domain.com', + 'redirect_uri': 'https://your-domain.com/v1/oauth/adfs/callback', + 'client_redirect_success_url': 'https://your-domain.com/login/success', + 'client_redirect_failure_url': 'https://your-domain.com/login/failed', + 'scopes': ['openid', 'profile', 'email'], + 'response_type': 'code', + 'response_mode': 'query', + 'authorize_path': '/adfs/oauth2/authorize', + 'token_path': '/adfs/oauth2/token', + 'jwks_path': '/adfs/discovery/keys', + 'expected_issuer': 'https://fs.your-domain.com/adfs', + 'clock_skew_seconds': 60, + 'verify_ssl': True, + }, } return templates.get(auth_type, {}) diff --git a/wavefront/server/modules/user_management_module/user_management_module/authorization/require_auth.py b/wavefront/server/modules/user_management_module/user_management_module/authorization/require_auth.py index ff7629d3..8c73d3e4 100644 --- a/wavefront/server/modules/user_management_module/user_management_module/authorization/require_auth.py +++ b/wavefront/server/modules/user_management_module/user_management_module/authorization/require_auth.py @@ -45,6 +45,7 @@ '/floware/v1/plugin-auth/authenticate', '/floware/v1/oauth/google/callback', '/floware/v1/oauth/microsoft/callback', + '/floware/v1/oauth/adfs/callback', '/floware/v1/plugin-auth/oauth/init', '/floware/v1/settings/config', ] diff --git a/wavefront/server/modules/user_management_module/user_management_module/controllers/auth_plugin_controller.py b/wavefront/server/modules/user_management_module/user_management_module/controllers/auth_plugin_controller.py index f634ff00..a9cb712f 100644 --- a/wavefront/server/modules/user_management_module/user_management_module/controllers/auth_plugin_controller.py +++ b/wavefront/server/modules/user_management_module/user_management_module/controllers/auth_plugin_controller.py @@ -1,4 +1,6 @@ import json +import logging +import secrets from uuid import uuid4 from db_repo_module.models.resource import ResourceScope from dependency_injector.wiring import inject, Provide @@ -36,6 +38,53 @@ auth_plugin_router = APIRouter() +logger = logging.getLogger(__name__) + +# Per-flow OAuth state lives in Redis with a short TTL. Each authorize -> +# callback round-trip is a one-shot: stored on init, consumed (deleted) on +# successful callback. Replays after consumption miss the cache and are +# rejected. +OAUTH_FLOW_TTL_SECONDS = 600 +_OAUTH_FLOW_KEY_PREFIX = 'oauth:flow:' + + +def _store_oauth_flow(cache_manager: CacheManager, auth_id: str) -> tuple[str, str]: + """Mint and persist an opaque state+nonce pair bound to auth_id. + + `nx=True` prevents accidental overwrite if (astronomically) the same + 32-byte token is minted twice. + """ + state = secrets.token_urlsafe(32) + nonce = secrets.token_urlsafe(32) + cache_manager.add( + f'{_OAUTH_FLOW_KEY_PREFIX}{state}', + json.dumps({'auth_id': str(auth_id), 'nonce': nonce}), + expiry=OAUTH_FLOW_TTL_SECONDS, + nx=True, + ) + return state, nonce + + +def _consume_oauth_flow( + cache_manager: CacheManager, state: Optional[str] +) -> Optional[Dict[str, str]]: + """Single-use lookup of the flow record. Returns None on miss/parse error.""" + if not state: + return None + key = f'{_OAUTH_FLOW_KEY_PREFIX}{state}' + raw = cache_manager.get_str(key) + if not raw: + return None + try: + flow = json.loads(raw) + except (TypeError, ValueError): + cache_manager.remove(key) + return None + cache_manager.remove(key) + if not isinstance(flow, dict) or 'auth_id' not in flow: + return None + return flow + class UnifiedAuthRequest(BaseModel): auth_id: str @@ -204,10 +253,12 @@ async def init_oauth_flow( authenticator_repository: SQLAlchemyRepository[Authenticator] = Depends( Provide[PluginsContainer.authenticator_repository] ), + cache_manager: CacheManager = Depends(Provide[UserContainer.cache_manager]), ): """Initialize OAuth flow and return authorization URL.""" try: + logger.debug('OAuth init requested for auth_id=%s', oauth_request.auth_id) # Get authenticator instance by ID auth_id = UUID(oauth_request.auth_id) authenticator = await get_authenticator_instance( @@ -215,6 +266,10 @@ async def init_oauth_flow( ) if not authenticator: + logger.debug( + 'OAuth init: no enabled authenticator for auth_id=%s', + oauth_request.auth_id, + ) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=response_formatter.buildErrorResponse( @@ -222,9 +277,17 @@ async def init_oauth_flow( ), ) - # Generate state and get authorization URL - state = json.dumps({'auth_id': oauth_request.auth_id}) - auth_url = authenticator.get_authorization_url(state) + # Mint opaque CSRF state + OIDC nonce, persist server-side, and pass + # both into the provider so they end up in the authorize URL. + state, nonce = _store_oauth_flow(cache_manager, oauth_request.auth_id) + logger.debug( + 'OAuth flow stored: auth_id=%s state=%s nonce=%s ttl=%ss', + oauth_request.auth_id, + state, + nonce, + OAUTH_FLOW_TTL_SECONDS, + ) + auth_url = authenticator.get_authorization_url(state, nonce=nonce) if not auth_url: return JSONResponse( @@ -274,11 +337,20 @@ async def google_oauth_callback( token_service: TokenService = Depends(Provide[AuthContainer.token_service]), ): """Handle Google OAuth callback.""" - state_obj = json.loads(state) - auth_id = state_obj['auth_id'] + logger.debug( + 'Google OAuth callback: has_code=%s has_state=%s has_error=%s state=%s', + bool(code), + bool(state), + bool(error), + state, + ) + flow = _consume_oauth_flow(cache_manager, state) + if flow is None: + logger.warning('Google OAuth callback received unknown/expired state') + return RedirectResponse(url='about:blank') return await _handle_oauth_callback( - auth_id, + flow['auth_id'], {'authorization_code': code, 'state': state, 'error': error}, request, response_formatter, @@ -288,6 +360,7 @@ async def google_oauth_callback( session_repository, cache_manager, token_service, + expected_nonce=flow.get('nonce'), ) @@ -317,11 +390,71 @@ async def microsoft_oauth_callback( token_service: TokenService = Depends(Provide[AuthContainer.token_service]), ): """Handle Microsoft OAuth callback.""" - state_obj = json.loads(state) - auth_id = state_obj['auth_id'] + logger.debug( + 'Microsoft OAuth callback: has_code=%s has_state=%s has_error=%s state=%s', + bool(code), + bool(state), + bool(error), + state, + ) + flow = _consume_oauth_flow(cache_manager, state) + if flow is None: + logger.warning('Microsoft OAuth callback received unknown/expired state') + return RedirectResponse(url='about:blank') + + return await _handle_oauth_callback( + flow['auth_id'], + {'authorization_code': code, 'state': state, 'error': error}, + request, + response_formatter, + authenticator_repository, + user_service, + user_repository, + session_repository, + cache_manager, + token_service, + expected_nonce=flow.get('nonce'), + ) + + +@auth_plugin_router.get('/v1/oauth/adfs/callback') +@inject +async def microsoft_adfs_oauth_callback( + request: Request, + state: str = Query(...), + code: Optional[str] = Query(None), + error: Optional[str] = Query(None), + response_formatter: ResponseFormatter = Depends( + Provide[CommonContainer.response_formatter] + ), + authenticator_repository: SQLAlchemyRepository[Authenticator] = Depends( + Provide[PluginsContainer.authenticator_repository] + ), + user_repository: SQLAlchemyRepository[User] = Depends( + Provide[UserContainer.user_repository] + ), + user_service: UserService = Depends(Provide[UserContainer.user_service]), + session_repository: SQLAlchemyRepository[Session] = Depends( + Provide[UserContainer.session_repository] + ), + cache_manager: CacheManager = Depends(Provide[UserContainer.cache_manager]), + token_service: TokenService = Depends(Provide[AuthContainer.token_service]), +): + """Handle Microsoft ADFS OAuth callback.""" + logger.debug( + 'Microsoft ADFS callback: has_code=%s has_state=%s has_error=%s state=%s', + bool(code), + bool(state), + bool(error), + state, + ) + flow = _consume_oauth_flow(cache_manager, state) + if flow is None: + logger.warning('Microsoft ADFS callback received unknown/expired state') + return RedirectResponse(url='about:blank') return await _handle_oauth_callback( - auth_id, + flow['auth_id'], {'authorization_code': code, 'state': state, 'error': error}, request, response_formatter, @@ -331,6 +464,7 @@ async def microsoft_oauth_callback( session_repository, cache_manager, token_service, + expected_nonce=flow.get('nonce'), ) @@ -345,10 +479,19 @@ async def _handle_oauth_callback( session_repository: SQLAlchemyRepository[Session], cache_manager: CacheManager, token_service: TokenService, + expected_nonce: Optional[str] = None, ) -> RedirectResponse: """Common OAuth callback handler.""" try: + logger.debug( + '_handle_oauth_callback: auth_id=%s has_code=%s has_error=%s ' + 'expected_nonce_set=%s', + auth_id, + bool(callback_data.get('authorization_code')), + bool(callback_data.get('error')), + expected_nonce is not None, + ) # Get authenticator instance and config auth_uuid = UUID(auth_id) authenticator, config_data = await get_authenticator_with_config( @@ -379,6 +522,12 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: provider = config_data.get('auth_type') success_url = config_data.get('config', {}).get('client_redirect_success_url') failure_url = config_data.get('config', {}).get('client_redirect_failure_url') + logger.debug( + '_handle_oauth_callback: provider=%s success_url=%s failure_url=%s', + provider, + success_url, + failure_url, + ) # Handle OAuth error from provider if callback_data.get('error'): @@ -393,7 +542,16 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: return RedirectResponse(url='about:blank') # Handle OAuth callback - auth_result = authenticator.handle_callback(callback_data) + auth_result = authenticator.handle_callback( + callback_data, expected_nonce=expected_nonce + ) + logger.debug( + '_handle_oauth_callback: provider auth_result success=%s error_code=%s ' + 'email=%s', + auth_result.success, + auth_result.error_code, + auth_result.user_info.email if auth_result.user_info else None, + ) if not auth_result.success: if failure_url: @@ -408,6 +566,12 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: # Create session from auth result user = await user_repository.find_one(email=auth_result.user_info.email) + logger.debug( + '_handle_oauth_callback: user lookup by email=%s found=%s deleted=%s', + auth_result.user_info.email, + user is not None, + getattr(user, 'deleted', None), + ) if user is None: if failure_url: params = urlencode( @@ -448,6 +612,11 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: role_id = await user_service.get_user_role_for_scope( user_id=str(user.id), scope=ResourceScope.CONSOLE ) + logger.debug( + '_handle_oauth_callback: console role lookup user_id=%s role_id=%s', + str(user.id), + role_id, + ) if not role_id: if failure_url: @@ -468,12 +637,24 @@ def get_failure_redirect(error_msg: str) -> RedirectResponse: # Success: redirect to success URL with access token if success_url: + logger.debug( + '_handle_oauth_callback: success redirect provider=%s user_id=%s ' + 'session_id=%s -> %s', + provider, + str(user.id), + str(session.id), + success_url, + ) params = urlencode({'provider': provider, 'access_token': token}) return RedirectResponse(url=f'{success_url}?{params}') + logger.debug( + '_handle_oauth_callback: no success_url configured, redirecting to about:blank' + ) return RedirectResponse(url='about:blank') except Exception as e: + logger.debug('_handle_oauth_callback raised: %s', e) # Try to get config for failure URL try: auth_uuid = UUID(auth_id) diff --git a/wavefront/server/plugins/authenticator/authenticator/__init__.py b/wavefront/server/plugins/authenticator/authenticator/__init__.py index d87f4e2d..62b9324b 100644 --- a/wavefront/server/plugins/authenticator/authenticator/__init__.py +++ b/wavefront/server/plugins/authenticator/authenticator/__init__.py @@ -11,6 +11,7 @@ from .email_password.config import EmailPasswordConfig from .google_oauth.config import GoogleOAuthConfig from .microsoft_oauth.config import MicrosoftOAuthConfig +from .microsoft_adfs.config import MicrosoftADFSConfig __all__ = [ 'AuthenticatorFactory', @@ -24,4 +25,5 @@ 'EmailPasswordConfig', 'GoogleOAuthConfig', 'MicrosoftOAuthConfig', + 'MicrosoftADFSConfig', ] diff --git a/wavefront/server/plugins/authenticator/authenticator/email_password/authenticator.py b/wavefront/server/plugins/authenticator/authenticator/email_password/authenticator.py index 76c77fbd..71c4a1d5 100644 --- a/wavefront/server/plugins/authenticator/authenticator/email_password/authenticator.py +++ b/wavefront/server/plugins/authenticator/authenticator/email_password/authenticator.py @@ -116,11 +116,15 @@ def validate_config(self) -> bool: except Exception: return False - def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: + def get_authorization_url( + self, state: Optional[str] = None, nonce: Optional[str] = None + ) -> Optional[str]: """Email/password doesn't need authorization URL.""" return None - def handle_callback(self, callback_data: Dict[str, Any]) -> AuthResult: + def handle_callback( + self, callback_data: Dict[str, Any], expected_nonce: Optional[str] = None + ) -> AuthResult: """Email/password doesn't use OAuth callbacks.""" return AuthResult( success=False, diff --git a/wavefront/server/plugins/authenticator/authenticator/factory.py b/wavefront/server/plugins/authenticator/authenticator/factory.py index b33bd3c3..d6c46c1b 100644 --- a/wavefront/server/plugins/authenticator/authenticator/factory.py +++ b/wavefront/server/plugins/authenticator/authenticator/factory.py @@ -5,9 +5,11 @@ from .email_password import EmailPasswordAuthenticator from .google_oauth import GoogleOAuthAuthenticator from .microsoft_oauth import MicrosoftOAuthAuthenticator +from .microsoft_adfs import MicrosoftADFSAuthenticator from .email_password.config import EmailPasswordConfig from .google_oauth.config import GoogleOAuthConfig from .microsoft_oauth.config import MicrosoftOAuthConfig +from .microsoft_adfs.config import MicrosoftADFSConfig class AuthenticatorFactory: @@ -28,6 +30,7 @@ def __init__(self): if not hasattr(self, '_initialized'): self._google_instances: Dict[str, GoogleOAuthAuthenticator] = {} self._microsoft_instances: Dict[str, MicrosoftOAuthAuthenticator] = {} + self._adfs_instances: Dict[str, MicrosoftADFSAuthenticator] = {} self._email_instances: Dict[str, EmailPasswordAuthenticator] = {} self._instances_lock = threading.Lock() self._initialized = True @@ -85,6 +88,8 @@ def validate_config( return GoogleOAuthAuthenticator.validate_config_static(config) elif auth_type == AuthenticatorType.MICROSOFT_OAUTH: return MicrosoftOAuthAuthenticator.validate_config_static(config) + elif auth_type == AuthenticatorType.MICROSOFT_ADFS: + return MicrosoftADFSAuthenticator.validate_config_static(config) else: raise ValueError(f'Unsupported authenticator type: {auth_type}') @@ -157,6 +162,7 @@ def get_cached_instance_count( return ( len(self._google_instances) + len(self._microsoft_instances) + + len(self._adfs_instances) + len(self._email_instances) ) @@ -165,6 +171,7 @@ def clear_all_instances(self) -> None: with self._instances_lock: self._google_instances.clear() self._microsoft_instances.clear() + self._adfs_instances.clear() self._email_instances.clear() def _get_cache_for_type( @@ -175,6 +182,8 @@ def _get_cache_for_type( return self._google_instances elif auth_type == AuthenticatorType.MICROSOFT_OAUTH: return self._microsoft_instances + elif auth_type == AuthenticatorType.MICROSOFT_ADFS: + return self._adfs_instances elif auth_type == AuthenticatorType.EMAIL_PASSWORD: return self._email_instances else: @@ -196,6 +205,10 @@ def _create_authenticator( typed_config = MicrosoftOAuthConfig(**config) return MicrosoftOAuthAuthenticator(typed_config) + elif auth_type == AuthenticatorType.MICROSOFT_ADFS: + typed_config = MicrosoftADFSConfig(**config) + return MicrosoftADFSAuthenticator(typed_config) + else: raise ValueError(f'Unsupported authenticator type: {auth_type}') diff --git a/wavefront/server/plugins/authenticator/authenticator/google_oauth/authenticator.py b/wavefront/server/plugins/authenticator/authenticator/google_oauth/authenticator.py index 5c4d5b6d..6d53a20e 100644 --- a/wavefront/server/plugins/authenticator/authenticator/google_oauth/authenticator.py +++ b/wavefront/server/plugins/authenticator/authenticator/google_oauth/authenticator.py @@ -151,16 +151,13 @@ def validate_config(self) -> bool: except Exception: return False - def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: + def get_authorization_url( + self, state: Optional[str] = None, nonce: Optional[str] = None + ) -> Optional[str]: """Get the Google OAuth authorization URL.""" if not state: raise ValueError("State doesn't exist Google Oauth") - state_obj = json.loads(state) - - if state_obj['auth_id'] is None: - raise ValueError("Auth Id doesn't exist in Google Oauth state") - params = { 'response_type': 'code', 'client_id': self.config.client_id, @@ -171,13 +168,24 @@ def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: 'prompt': self.config.prompt, } + if nonce: + params['nonce'] = nonce + if self.config.hosted_domain: params['hd'] = self.config.hosted_domain return f'{self.auth_url}?{urlencode(params)}' - def handle_callback(self, callback_data: Dict[str, Any]) -> AuthResult: - """Handle Google OAuth callback.""" + def handle_callback( + self, callback_data: Dict[str, Any], expected_nonce: Optional[str] = None + ) -> AuthResult: + """Handle Google OAuth callback. + + Identity comes from Google's userinfo endpoint (not an id_token), so + `expected_nonce` is accepted for ABC compatibility but not enforced. + State CSRF protection is performed by the controller before this is + called. + """ return self.authenticate(callback_data) def refresh_token(self, refresh_token: str) -> TokenResult: diff --git a/wavefront/server/plugins/authenticator/authenticator/helper.py b/wavefront/server/plugins/authenticator/authenticator/helper.py index cdc96e02..e4f517d9 100644 --- a/wavefront/server/plugins/authenticator/authenticator/helper.py +++ b/wavefront/server/plugins/authenticator/authenticator/helper.py @@ -102,6 +102,7 @@ def get_authenticator_display_name(auth_type: AuthenticatorType) -> str: AuthenticatorType.EMAIL_PASSWORD: 'Email & Password', AuthenticatorType.GOOGLE_OAUTH: 'Google OAuth', AuthenticatorType.MICROSOFT_OAUTH: 'Microsoft OAuth', + AuthenticatorType.MICROSOFT_ADFS: 'Microsoft ADFS', AuthenticatorType.SAML: 'SAML', AuthenticatorType.LDAP: 'LDAP', } @@ -110,7 +111,11 @@ def get_authenticator_display_name(auth_type: AuthenticatorType) -> str: def is_oauth_provider(auth_type: AuthenticatorType) -> bool: """Check if authenticator type is an OAuth provider.""" - oauth_types = {AuthenticatorType.GOOGLE_OAUTH, AuthenticatorType.MICROSOFT_OAUTH} + oauth_types = { + AuthenticatorType.GOOGLE_OAUTH, + AuthenticatorType.MICROSOFT_OAUTH, + AuthenticatorType.MICROSOFT_ADFS, + } return auth_type in oauth_types diff --git a/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/__init__.py b/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/__init__.py new file mode 100644 index 00000000..c1c8889f --- /dev/null +++ b/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/__init__.py @@ -0,0 +1,4 @@ +from .authenticator import MicrosoftADFSAuthenticator +from .config import MicrosoftADFSConfig + +__all__ = ['MicrosoftADFSAuthenticator', 'MicrosoftADFSConfig'] diff --git a/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/authenticator.py b/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/authenticator.py new file mode 100644 index 00000000..4a83e746 --- /dev/null +++ b/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/authenticator.py @@ -0,0 +1,424 @@ +import json +import logging +import ssl +import jwt +import requests +from datetime import datetime +from jwt import PyJWKClient +from typing import Dict, Any, Optional +from urllib.parse import urlencode, urlparse + +from ..types import AuthenticatorABC, AuthResult, TokenResult, HealthStatus, UserInfo +from .config import MicrosoftADFSConfig + +logger = logging.getLogger(__name__) + +_ALLOWED_ID_TOKEN_ALGS = ['RS256', 'RS384', 'RS512', 'ES256', 'ES384', 'ES512'] + + +class MicrosoftADFSAuthenticator(AuthenticatorABC): + """Microsoft ADFS (OIDC) authenticator. + + Identity is sourced from the `id_token` returned in the token response + rather than a userinfo / Graph call, since on-prem ADFS does not always + expose `/adfs/userinfo` and Microsoft Graph is unreachable. + """ + + def __init__(self, config: MicrosoftADFSConfig): + self.config = config + base = config.authority.rstrip('/') + self.auth_url = f'{base}{config.authorize_path}' + self.token_url = f'{base}{config.token_path}' + self.jwks_url = f'{base}{config.jwks_path}' + + ssl_ctx = ssl.create_default_context() + if not config.verify_ssl: + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = ssl.CERT_NONE + # PyJWKClient caches keys in-process; safe to construct once per instance. + self._jwks_client = PyJWKClient(self.jwks_url, ssl_context=ssl_ctx) + + @staticmethod + def validate_config_static(config: Dict[str, Any]) -> bool: + required_fields = [ + 'client_id', + 'client_secret', + 'authority', + 'redirect_uri', + 'client_redirect_success_url', + 'client_redirect_failure_url', + 'scopes', + ] + for field_name in required_fields: + if not config.get(field_name): + raise ValueError(f'{field_name} is required') + + authority = config['authority'] + if not authority.startswith('https://'): + raise ValueError('authority must be a valid HTTPS URL') + + parsed_uri = urlparse(config['redirect_uri']) + if not parsed_uri.scheme or not parsed_uri.netloc: + raise ValueError('redirect_uri must be a valid URL with scheme and netloc') + + for url_field in ['client_redirect_success_url', 'client_redirect_failure_url']: + parsed_url = urlparse(config[url_field]) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError( + f'{url_field} must be a valid URL with scheme and netloc' + ) + + scopes = config.get('scopes', []) + if not scopes or len(scopes) == 0: + raise ValueError('scopes array cannot be empty') + + return True + + def authenticate( + self, + credentials: Dict[str, Any], + expected_nonce: Optional[str] = None, + ) -> AuthResult: + authorization_code = credentials.get('authorization_code') + + if not authorization_code: + return AuthResult( + success=False, + error='Authorization code is required', + error_code='MISSING_AUTH_CODE', + ) + + token_result, id_token = self._exchange_code_for_token(authorization_code) + + if not token_result.success: + return AuthResult( + success=False, + error=token_result.error, + error_code='TOKEN_EXCHANGE_FAILED', + ) + + if not id_token: + return AuthResult( + success=False, + error='ADFS response missing id_token', + error_code='ID_TOKEN_MISSING', + ) + + user_info = self._get_user_info_from_id_token(id_token, expected_nonce) + + if not user_info: + return AuthResult( + success=False, + error='Failed to extract user information from id_token', + error_code='USER_INFO_FAILED', + ) + + return AuthResult( + success=True, + user_info=user_info, + access_token=token_result.access_token, + refresh_token=token_result.refresh_token, + ) + + def validate_config(self) -> bool: + try: + required_fields = [ + 'client_id', + 'client_secret', + 'authority', + 'redirect_uri', + 'client_redirect_success_url', + 'client_redirect_failure_url', + 'scopes', + ] + for field_name in required_fields: + if not getattr(self.config, field_name, None): + return False + + if not self.config.authority.startswith('https://'): + return False + + for url in ( + self.config.redirect_uri, + self.config.client_redirect_success_url, + self.config.client_redirect_failure_url, + ): + parsed = urlparse(url) + if not parsed.scheme or not parsed.netloc: + return False + + if not self.config.scopes or len(self.config.scopes) == 0: + return False + + return True + + except Exception: + return False + + def get_authorization_url( + self, state: Optional[str] = None, nonce: Optional[str] = None + ) -> Optional[str]: + if not state: + raise ValueError("State doesn't exist Microsoft ADFS") + + params = { + 'response_type': self.config.response_type, + 'client_id': self.config.client_id, + 'redirect_uri': self.config.redirect_uri, + 'scope': ' '.join(self.config.scopes), + 'state': state, + 'response_mode': self.config.response_mode, + 'prompt': 'select_account', + } + + if nonce: + params['nonce'] = nonce + + url = f'{self.auth_url}?{urlencode(params)}' + logger.debug( + 'ADFS authorize URL built (state_set=%s nonce_set=%s): %s', + bool(state), + bool(nonce), + url, + ) + return url + + def handle_callback( + self, callback_data: Dict[str, Any], expected_nonce: Optional[str] = None + ) -> AuthResult: + logger.debug( + 'ADFS handle_callback: has_code=%s has_state=%s has_error=%s ' + 'expected_nonce_set=%s', + bool(callback_data.get('authorization_code')), + bool(callback_data.get('state')), + bool(callback_data.get('error')), + expected_nonce is not None, + ) + return self.authenticate(callback_data, expected_nonce=expected_nonce) + + def refresh_token(self, refresh_token: str) -> TokenResult: + if not refresh_token: + return TokenResult(success=False, error='Refresh token is required') + + data = { + 'grant_type': 'refresh_token', + 'refresh_token': refresh_token, + 'client_id': self.config.client_id, + 'client_secret': self.config.client_secret, + 'scope': ' '.join(self.config.scopes), + } + + try: + response = requests.post( + self.token_url, + data=data, + timeout=10, + verify=self.config.verify_ssl, + ) + response.raise_for_status() + token_data = response.json() + + return TokenResult( + success=True, + access_token=token_data.get('access_token'), + refresh_token=token_data.get('refresh_token', refresh_token), + expires_in=token_data.get('expires_in'), + ) + + except requests.exceptions.RequestException as e: + return TokenResult(success=False, error=f'Token refresh failed: {str(e)}') + except json.JSONDecodeError: + return TokenResult( + success=False, error='Invalid response from ADFS token endpoint' + ) + + def logout(self, user_session: Dict[str, Any]) -> bool: + return True + + def get_health_status(self) -> HealthStatus: + is_healthy = True + details = { + 'config_valid': self.validate_config(), + 'authority': self.config.authority, + 'scopes': self.config.scopes, + } + + discovery_url = ( + f'{self.config.authority.rstrip("/")}/adfs/.well-known/openid-configuration' + ) + try: + response = requests.get( + discovery_url, timeout=5, verify=self.config.verify_ssl + ) + details['discovery_reachable'] = response.status_code == 200 + if response.status_code != 200: + is_healthy = False + except Exception: + details['discovery_reachable'] = False + is_healthy = False + + return HealthStatus( + healthy=is_healthy, + message='Microsoft ADFS authenticator is operational' + if is_healthy + else 'ADFS discovery endpoint unreachable', + last_check=datetime.now(), + details=details, + ) + + def get_user_info(self, access_token: str) -> Optional[UserInfo]: + # ADFS access tokens are opaque without a guaranteed userinfo endpoint. + # Identity is resolved from the id_token at login time instead. + return None + + def _exchange_code_for_token( + self, authorization_code: str + ) -> tuple[TokenResult, Optional[str]]: + data = { + 'grant_type': 'authorization_code', + 'code': authorization_code, + 'client_id': self.config.client_id, + 'client_secret': self.config.client_secret, + 'redirect_uri': self.config.redirect_uri, + 'scope': ' '.join(self.config.scopes), + } + + logger.debug('ADFS token exchange: POST %s', self.token_url) + + try: + response = requests.post( + self.token_url, + data=data, + timeout=10, + verify=self.config.verify_ssl, + ) + response.raise_for_status() + token_data = response.json() + + id_token = token_data.get('id_token') + logger.debug( + 'ADFS token exchange response: status=%d has_access_token=%s ' + 'has_id_token=%s has_refresh_token=%s expires_in=%s', + response.status_code, + bool(token_data.get('access_token')), + bool(id_token), + bool(token_data.get('refresh_token')), + token_data.get('expires_in'), + ) + logger.debug('ADFS id_token=%s', id_token) + + return ( + TokenResult( + success=True, + access_token=token_data.get('access_token'), + refresh_token=token_data.get('refresh_token'), + expires_in=token_data.get('expires_in'), + ), + id_token, + ) + + except requests.exceptions.RequestException as e: + logger.debug('ADFS token exchange request failed: %s', e) + return ( + TokenResult(success=False, error=f'Token exchange failed: {str(e)}'), + None, + ) + except json.JSONDecodeError as e: + logger.debug('ADFS token endpoint returned non-JSON: %s', e) + return ( + TokenResult( + success=False, error='Invalid response from ADFS token endpoint' + ), + None, + ) + + def _get_user_info_from_id_token( + self, id_token: str, expected_nonce: Optional[str] = None + ) -> Optional[UserInfo]: + claims = self._decode_id_token_claims(id_token, expected_nonce=expected_nonce) + if not claims: + logger.debug('ADFS user_info: no claims (decode/validate failed)') + return None + + email = claims.get('email') or claims.get('upn') or claims.get('unique_name') + if not email: + logger.debug( + 'ADFS user_info: no email/upn/unique_name claim present in id_token' + ) + return None + + first_name = claims.get('given_name') + if not first_name and '@' in email: + first_name = email.split('@')[0] + + logger.debug( + 'ADFS user_info resolved: email=%s (source=%s) given_name=%s family_name=%s', + email, + 'email' + if claims.get('email') + else ('upn' if claims.get('upn') else 'unique_name'), + first_name, + claims.get('family_name'), + ) + + return UserInfo( + email=email, + first_name=first_name, + last_name=claims.get('family_name'), + user_id=claims.get('sub') or claims.get('oid'), + provider='microsoft_adfs', + avatar_url=None, + additional_info={ + 'upn': claims.get('upn'), + 'unique_name': claims.get('unique_name'), + 'name': claims.get('name'), + 'groups': claims.get('groups'), + }, + ) + + def _decode_id_token_claims( + self, id_token: str, expected_nonce: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + # Verify signature against the IdP's JWKS and enforce aud + exp/nbf. + # iss is only enforced when expected_issuer is configured, because some + # IdPs (e.g. Authentik in mixed http/https setups) advertise an iss host + # that legitimately differs from the configured `authority`. + try: + signing_key = self._jwks_client.get_signing_key_from_jwt(id_token) + logger.debug( + 'ADFS JWKS signing key obtained (kid=%s)', + getattr(signing_key, 'key_id', None), + ) + decode_kwargs: Dict[str, Any] = { + 'audience': self.config.client_id, + 'leeway': self.config.clock_skew_seconds, + 'algorithms': _ALLOWED_ID_TOKEN_ALGS, + 'options': { + 'verify_signature': True, + 'verify_aud': True, + 'verify_exp': True, + 'verify_nbf': True, + 'verify_iss': self.config.expected_issuer is not None, + }, + } + if self.config.expected_issuer: + decode_kwargs['issuer'] = self.config.expected_issuer + + claims = jwt.decode(id_token, signing_key.key, **decode_kwargs) + logger.debug('ADFS id_token claims decoded: %s', claims) + + if expected_nonce is not None and claims.get('nonce') != expected_nonce: + logger.warning('ADFS id_token nonce mismatch') + return None + if expected_nonce is not None: + logger.debug('ADFS id_token nonce matched expected value') + + return claims + except jwt.PyJWTError as e: + logger.warning('ADFS id_token JWT validation failed: %s', e) + return None + except Exception as e: + logger.warning( + 'ADFS id_token decode failed (jwks_url=%s): %s', self.jwks_url, e + ) + return None diff --git a/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/config.py b/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/config.py new file mode 100644 index 00000000..b1397614 --- /dev/null +++ b/wavefront/server/plugins/authenticator/authenticator/microsoft_adfs/config.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class MicrosoftADFSConfig: + client_id: str + client_secret: str + # ADFS server base, e.g. 'https://fs.customer.com' + authority: str + redirect_uri: str + client_redirect_success_url: str + client_redirect_failure_url: str + scopes: list[str] = field(default_factory=lambda: ['openid', 'profile', 'email']) + response_type: str = 'code' + response_mode: str = 'query' + # Endpoint paths under `authority`. Defaults match on-prem ADFS; + # override to point at Authentik/Keycloak (or reverse-proxied ADFS). + authorize_path: str = '/adfs/oauth2/authorize' + token_path: str = '/adfs/oauth2/token' + # JWKS endpoint used to verify id_token signatures. + jwks_path: str = '/adfs/discovery/keys' + # If set, id_token `iss` must match exactly. Leave None to skip the + # issuer check (e.g. Authentik where iss host can differ from authority). + expected_issuer: Optional[str] = None + # Allowed clock skew (seconds) when checking exp/nbf claims. + clock_skew_seconds: int = 60 + # Set to False ONLY for local testing against IdPs with self-signed certs + # (e.g. dockerised Authentik). Must stay True for any real ADFS. + verify_ssl: bool = True diff --git a/wavefront/server/plugins/authenticator/authenticator/microsoft_oauth/authenticator.py b/wavefront/server/plugins/authenticator/authenticator/microsoft_oauth/authenticator.py index 797a1064..b1e53d6c 100644 --- a/wavefront/server/plugins/authenticator/authenticator/microsoft_oauth/authenticator.py +++ b/wavefront/server/plugins/authenticator/authenticator/microsoft_oauth/authenticator.py @@ -150,16 +150,13 @@ def validate_config(self) -> bool: except Exception: return False - def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: + def get_authorization_url( + self, state: Optional[str] = None, nonce: Optional[str] = None + ) -> Optional[str]: """Get the Microsoft OAuth authorization URL.""" if not state: raise ValueError("State doesn't exist Microsoft Oauth") - state_obj = json.loads(state) - - if state_obj['auth_id'] is None: - raise ValueError("Auth Id doesn't exist in Microsoft Oauth state") - params = { 'response_type': self.config.response_type, 'client_id': self.config.client_id, @@ -170,10 +167,21 @@ def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: 'prompt': 'select_account', } + if nonce: + params['nonce'] = nonce + return f'{self.auth_url}?{urlencode(params)}' - def handle_callback(self, callback_data: Dict[str, Any]) -> AuthResult: - """Handle Microsoft OAuth callback.""" + def handle_callback( + self, callback_data: Dict[str, Any], expected_nonce: Optional[str] = None + ) -> AuthResult: + """Handle Microsoft OAuth (Entra) callback. + + Identity comes from Microsoft Graph (not an id_token), so + `expected_nonce` is accepted for ABC compatibility but not enforced. + State CSRF protection is performed by the controller before this is + called. + """ return self.authenticate(callback_data) def refresh_token(self, refresh_token: str) -> TokenResult: diff --git a/wavefront/server/plugins/authenticator/authenticator/types.py b/wavefront/server/plugins/authenticator/authenticator/types.py index 2dafd706..f872007a 100644 --- a/wavefront/server/plugins/authenticator/authenticator/types.py +++ b/wavefront/server/plugins/authenticator/authenticator/types.py @@ -71,6 +71,7 @@ class AuthenticatorType(Enum): EMAIL_PASSWORD = 'email_password' GOOGLE_OAUTH = 'google_oauth' MICROSOFT_OAUTH = 'microsoft_oauth' + MICROSOFT_ADFS = 'microsoft_adfs' SAML = 'saml' LDAP = 'ldap' @@ -107,12 +108,18 @@ def validate_config(self) -> bool: pass @abstractmethod - def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: + def get_authorization_url( + self, state: Optional[str] = None, nonce: Optional[str] = None + ) -> Optional[str]: """ Get the authorization URL for OAuth flow. Args: - state: Optional state parameter for OAuth flow + state: Opaque CSRF state token issued and tracked by the controller. + Providers must treat it as an opaque string and not parse it. + nonce: Optional OIDC nonce to bind the resulting id_token to this + authorize request. Providers that consume id_tokens should + forward this value and verify it on callback. Returns: Optional[str]: Authorization URL for OAuth providers, None for email/password @@ -120,12 +127,17 @@ def get_authorization_url(self, state: Optional[str] = None) -> Optional[str]: pass @abstractmethod - def handle_callback(self, callback_data: Dict[str, Any]) -> AuthResult: + def handle_callback( + self, callback_data: Dict[str, Any], expected_nonce: Optional[str] = None + ) -> AuthResult: """ Handle OAuth callback from provider. Args: callback_data: Dictionary containing callback data (code, state, etc.) + expected_nonce: Nonce that was sent on the matching authorize + request. Providers that decode id_tokens must reject the + callback if the id_token's `nonce` claim does not match. Returns: AuthResult: Authentication result diff --git a/wavefront/server/plugins/authenticator/pyproject.toml b/wavefront/server/plugins/authenticator/pyproject.toml index 26c57bc9..1c5ab6d6 100644 --- a/wavefront/server/plugins/authenticator/pyproject.toml +++ b/wavefront/server/plugins/authenticator/pyproject.toml @@ -7,6 +7,7 @@ requires-python = ">=3.11" dependencies = [ "requests>=2.25.0", + "pyjwt[crypto]>=2.9.0", ] [tool.pytest.ini_options] diff --git a/wavefront/server/uv.lock b/wavefront/server/uv.lock index e4714500..c37813c0 100644 --- a/wavefront/server/uv.lock +++ b/wavefront/server/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'darwin'", @@ -473,11 +473,15 @@ name = "authenticator" version = "0.1.0" source = { editable = "plugins/authenticator" } dependencies = [ + { name = "pyjwt", extra = ["crypto"] }, { name = "requests" }, ] [package.metadata] -requires-dist = [{ name = "requests", specifier = ">=2.25.0" }] +requires-dist = [ + { name = "pyjwt", extras = ["crypto"], specifier = ">=2.9.0" }, + { name = "requests", specifier = ">=2.25.0" }, +] [[package]] name = "authlib"