|
| 1 | +# Copyright 2020 Google Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Firebase auth providers management sub module.""" |
| 16 | + |
| 17 | +from urllib import parse |
| 18 | + |
| 19 | +import requests |
| 20 | + |
| 21 | +from firebase_admin import _auth_utils |
| 22 | +from firebase_admin import _user_mgt |
| 23 | + |
| 24 | + |
| 25 | +MAX_LIST_CONFIGS_RESULTS = 100 |
| 26 | + |
| 27 | + |
| 28 | +class ProviderConfig: |
| 29 | + """Parent type for all authentication provider config types.""" |
| 30 | + |
| 31 | + def __init__(self, data): |
| 32 | + self._data = data |
| 33 | + |
| 34 | + @property |
| 35 | + def provider_id(self): |
| 36 | + name = self._data['name'] |
| 37 | + return name.split('/')[-1] |
| 38 | + |
| 39 | + @property |
| 40 | + def display_name(self): |
| 41 | + return self._data.get('displayName') |
| 42 | + |
| 43 | + @property |
| 44 | + def enabled(self): |
| 45 | + return self._data.get('enabled', False) |
| 46 | + |
| 47 | + |
| 48 | +class OIDCProviderConfig(ProviderConfig): |
| 49 | + """Represents the OIDC auth provider configuration. |
| 50 | +
|
| 51 | + See https://openid.net/specs/openid-connect-core-1_0-final.html. |
| 52 | + """ |
| 53 | + |
| 54 | + @property |
| 55 | + def issuer(self): |
| 56 | + return self._data['issuer'] |
| 57 | + |
| 58 | + @property |
| 59 | + def client_id(self): |
| 60 | + return self._data['clientId'] |
| 61 | + |
| 62 | + |
| 63 | +class SAMLProviderConfig(ProviderConfig): |
| 64 | + """Represents he SAML auth provider configuration. |
| 65 | +
|
| 66 | + See http://docs.oasis-open.org/security/saml/Post2.0/sstc-saml-tech-overview-2.0.html. |
| 67 | + """ |
| 68 | + |
| 69 | + @property |
| 70 | + def idp_entity_id(self): |
| 71 | + return self._data.get('idpConfig', {})['idpEntityId'] |
| 72 | + |
| 73 | + @property |
| 74 | + def sso_url(self): |
| 75 | + return self._data.get('idpConfig', {})['ssoUrl'] |
| 76 | + |
| 77 | + @property |
| 78 | + def x509_certificates(self): |
| 79 | + certs = self._data.get('idpConfig', {})['idpCertificates'] |
| 80 | + return [c['x509Certificate'] for c in certs] |
| 81 | + |
| 82 | + @property |
| 83 | + def callback_url(self): |
| 84 | + return self._data.get('spConfig', {})['callbackUri'] |
| 85 | + |
| 86 | + @property |
| 87 | + def rp_entity_id(self): |
| 88 | + return self._data.get('spConfig', {})['spEntityId'] |
| 89 | + |
| 90 | + |
| 91 | +class ListProviderConfigsPage: |
| 92 | + """Represents a page of AuthProviderConfig instances retrieved from a Firebase project. |
| 93 | +
|
| 94 | + Provides methods for traversing the provider configs included in this page, as well as |
| 95 | + retrieving subsequent pages. The iterator returned by ``iterate_all()`` can be used to iterate |
| 96 | + through all provider configs in the Firebase project starting from this page. |
| 97 | + """ |
| 98 | + |
| 99 | + def __init__(self, download, page_token, max_results): |
| 100 | + self._download = download |
| 101 | + self._max_results = max_results |
| 102 | + self._current = download(page_token, max_results) |
| 103 | + |
| 104 | + @property |
| 105 | + def provider_configs(self): |
| 106 | + """A list of ``AuthProviderConfig`` instances available in this page.""" |
| 107 | + raise NotImplementedError |
| 108 | + |
| 109 | + @property |
| 110 | + def next_page_token(self): |
| 111 | + """Page token string for the next page (empty string indicates no more pages).""" |
| 112 | + return self._current.get('nextPageToken', '') |
| 113 | + |
| 114 | + @property |
| 115 | + def has_next_page(self): |
| 116 | + """A boolean indicating whether more pages are available.""" |
| 117 | + return bool(self.next_page_token) |
| 118 | + |
| 119 | + def get_next_page(self): |
| 120 | + """Retrieves the next page of provider configs, if available. |
| 121 | +
|
| 122 | + Returns: |
| 123 | + ListProviderConfigsPage: Next page of provider configs, or None if this is the last |
| 124 | + page. |
| 125 | + """ |
| 126 | + if self.has_next_page: |
| 127 | + return self.__class__(self._download, self.next_page_token, self._max_results) |
| 128 | + return None |
| 129 | + |
| 130 | + def iterate_all(self): |
| 131 | + """Retrieves an iterator for provider configs. |
| 132 | +
|
| 133 | + Returned iterator will iterate through all the provider configs in the Firebase project |
| 134 | + starting from this page. The iterator will never buffer more than one page of configs |
| 135 | + in memory at a time. |
| 136 | +
|
| 137 | + Returns: |
| 138 | + iterator: An iterator of AuthProviderConfig instances. |
| 139 | + """ |
| 140 | + return _ProviderConfigIterator(self) |
| 141 | + |
| 142 | + |
| 143 | +class _ListOIDCProviderConfigsPage(ListProviderConfigsPage): |
| 144 | + |
| 145 | + @property |
| 146 | + def provider_configs(self): |
| 147 | + return [OIDCProviderConfig(data) for data in self._current.get('oauthIdpConfigs', [])] |
| 148 | + |
| 149 | + |
| 150 | +class _ListSAMLProviderConfigsPage(ListProviderConfigsPage): |
| 151 | + |
| 152 | + @property |
| 153 | + def provider_configs(self): |
| 154 | + return [SAMLProviderConfig(data) for data in self._current.get('inboundSamlConfigs', [])] |
| 155 | + |
| 156 | + |
| 157 | +class _ProviderConfigIterator(_auth_utils.PageIterator): |
| 158 | + |
| 159 | + @property |
| 160 | + def items(self): |
| 161 | + return self._current_page.provider_configs |
| 162 | + |
| 163 | + |
| 164 | +class ProviderConfigClient: |
| 165 | + """Client for managing Auth provider configurations.""" |
| 166 | + |
| 167 | + PROVIDER_CONFIG_URL = 'https://identitytoolkit.googleapis.com/v2beta1' |
| 168 | + |
| 169 | + def __init__(self, http_client, project_id, tenant_id=None): |
| 170 | + self.http_client = http_client |
| 171 | + self.base_url = '{0}/projects/{1}'.format(self.PROVIDER_CONFIG_URL, project_id) |
| 172 | + if tenant_id: |
| 173 | + self.base_url += '/tenants/{0}'.format(tenant_id) |
| 174 | + |
| 175 | + def get_oidc_provider_config(self, provider_id): |
| 176 | + _validate_oidc_provider_id(provider_id) |
| 177 | + body = self._make_request('get', '/oauthIdpConfigs/{0}'.format(provider_id)) |
| 178 | + return OIDCProviderConfig(body) |
| 179 | + |
| 180 | + def create_oidc_provider_config( |
| 181 | + self, provider_id, client_id, issuer, display_name=None, enabled=None): |
| 182 | + """Creates a new OIDC provider config from the given parameters.""" |
| 183 | + _validate_oidc_provider_id(provider_id) |
| 184 | + req = { |
| 185 | + 'clientId': _validate_non_empty_string(client_id, 'client_id'), |
| 186 | + 'issuer': _validate_url(issuer, 'issuer'), |
| 187 | + } |
| 188 | + if display_name is not None: |
| 189 | + req['displayName'] = _auth_utils.validate_string(display_name, 'display_name') |
| 190 | + if enabled is not None: |
| 191 | + req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') |
| 192 | + |
| 193 | + params = 'oauthIdpConfigId={0}'.format(provider_id) |
| 194 | + body = self._make_request('post', '/oauthIdpConfigs', json=req, params=params) |
| 195 | + return OIDCProviderConfig(body) |
| 196 | + |
| 197 | + def update_oidc_provider_config( |
| 198 | + self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None): |
| 199 | + """Updates an existing OIDC provider config with the given parameters.""" |
| 200 | + _validate_oidc_provider_id(provider_id) |
| 201 | + req = {} |
| 202 | + if display_name is not None: |
| 203 | + if display_name == _user_mgt.DELETE_ATTRIBUTE: |
| 204 | + req['displayName'] = None |
| 205 | + else: |
| 206 | + req['displayName'] = _auth_utils.validate_string(display_name, 'display_name') |
| 207 | + if enabled is not None: |
| 208 | + req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') |
| 209 | + if client_id: |
| 210 | + req['clientId'] = _validate_non_empty_string(client_id, 'client_id') |
| 211 | + if issuer: |
| 212 | + req['issuer'] = _validate_url(issuer, 'issuer') |
| 213 | + |
| 214 | + if not req: |
| 215 | + raise ValueError('At least one parameter must be specified for update.') |
| 216 | + |
| 217 | + update_mask = _auth_utils.build_update_mask(req) |
| 218 | + params = 'updateMask={0}'.format(','.join(update_mask)) |
| 219 | + url = '/oauthIdpConfigs/{0}'.format(provider_id) |
| 220 | + body = self._make_request('patch', url, json=req, params=params) |
| 221 | + return OIDCProviderConfig(body) |
| 222 | + |
| 223 | + def delete_oidc_provider_config(self, provider_id): |
| 224 | + _validate_oidc_provider_id(provider_id) |
| 225 | + self._make_request('delete', '/oauthIdpConfigs/{0}'.format(provider_id)) |
| 226 | + |
| 227 | + def list_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): |
| 228 | + return _ListOIDCProviderConfigsPage( |
| 229 | + self._fetch_oidc_provider_configs, page_token, max_results) |
| 230 | + |
| 231 | + def _fetch_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): |
| 232 | + return self._fetch_provider_configs('/oauthIdpConfigs', page_token, max_results) |
| 233 | + |
| 234 | + def get_saml_provider_config(self, provider_id): |
| 235 | + _validate_saml_provider_id(provider_id) |
| 236 | + body = self._make_request('get', '/inboundSamlConfigs/{0}'.format(provider_id)) |
| 237 | + return SAMLProviderConfig(body) |
| 238 | + |
| 239 | + def create_saml_provider_config( |
| 240 | + self, provider_id, idp_entity_id, sso_url, x509_certificates, |
| 241 | + rp_entity_id, callback_url, display_name=None, enabled=None): |
| 242 | + """Creates a new SAML provider config from the given parameters.""" |
| 243 | + _validate_saml_provider_id(provider_id) |
| 244 | + req = { |
| 245 | + 'idpConfig': { |
| 246 | + 'idpEntityId': _validate_non_empty_string(idp_entity_id, 'idp_entity_id'), |
| 247 | + 'ssoUrl': _validate_url(sso_url, 'sso_url'), |
| 248 | + 'idpCertificates': _validate_x509_certificates(x509_certificates), |
| 249 | + }, |
| 250 | + 'spConfig': { |
| 251 | + 'spEntityId': _validate_non_empty_string(rp_entity_id, 'rp_entity_id'), |
| 252 | + 'callbackUri': _validate_url(callback_url, 'callback_url'), |
| 253 | + }, |
| 254 | + } |
| 255 | + if display_name is not None: |
| 256 | + req['displayName'] = _auth_utils.validate_string(display_name, 'display_name') |
| 257 | + if enabled is not None: |
| 258 | + req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') |
| 259 | + |
| 260 | + params = 'inboundSamlConfigId={0}'.format(provider_id) |
| 261 | + body = self._make_request('post', '/inboundSamlConfigs', json=req, params=params) |
| 262 | + return SAMLProviderConfig(body) |
| 263 | + |
| 264 | + def update_saml_provider_config( |
| 265 | + self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, |
| 266 | + rp_entity_id=None, callback_url=None, display_name=None, enabled=None): |
| 267 | + """Updates an existing SAML provider config with the given parameters.""" |
| 268 | + _validate_saml_provider_id(provider_id) |
| 269 | + idp_config = {} |
| 270 | + if idp_entity_id is not None: |
| 271 | + idp_config['idpEntityId'] = _validate_non_empty_string(idp_entity_id, 'idp_entity_id') |
| 272 | + if sso_url is not None: |
| 273 | + idp_config['ssoUrl'] = _validate_url(sso_url, 'sso_url') |
| 274 | + if x509_certificates is not None: |
| 275 | + idp_config['idpCertificates'] = _validate_x509_certificates(x509_certificates) |
| 276 | + |
| 277 | + sp_config = {} |
| 278 | + if rp_entity_id is not None: |
| 279 | + sp_config['spEntityId'] = _validate_non_empty_string(rp_entity_id, 'rp_entity_id') |
| 280 | + if callback_url is not None: |
| 281 | + sp_config['callbackUri'] = _validate_url(callback_url, 'callback_url') |
| 282 | + |
| 283 | + req = {} |
| 284 | + if display_name is not None: |
| 285 | + if display_name == _user_mgt.DELETE_ATTRIBUTE: |
| 286 | + req['displayName'] = None |
| 287 | + else: |
| 288 | + req['displayName'] = _auth_utils.validate_string(display_name, 'display_name') |
| 289 | + if enabled is not None: |
| 290 | + req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') |
| 291 | + if idp_config: |
| 292 | + req['idpConfig'] = idp_config |
| 293 | + if sp_config: |
| 294 | + req['spConfig'] = sp_config |
| 295 | + |
| 296 | + if not req: |
| 297 | + raise ValueError('At least one parameter must be specified for update.') |
| 298 | + |
| 299 | + update_mask = _auth_utils.build_update_mask(req) |
| 300 | + params = 'updateMask={0}'.format(','.join(update_mask)) |
| 301 | + url = '/inboundSamlConfigs/{0}'.format(provider_id) |
| 302 | + body = self._make_request('patch', url, json=req, params=params) |
| 303 | + return SAMLProviderConfig(body) |
| 304 | + |
| 305 | + def delete_saml_provider_config(self, provider_id): |
| 306 | + _validate_saml_provider_id(provider_id) |
| 307 | + self._make_request('delete', '/inboundSamlConfigs/{0}'.format(provider_id)) |
| 308 | + |
| 309 | + def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): |
| 310 | + return _ListSAMLProviderConfigsPage( |
| 311 | + self._fetch_saml_provider_configs, page_token, max_results) |
| 312 | + |
| 313 | + def _fetch_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): |
| 314 | + return self._fetch_provider_configs('/inboundSamlConfigs', page_token, max_results) |
| 315 | + |
| 316 | + def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): |
| 317 | + """Fetches a page of auth provider configs""" |
| 318 | + if page_token is not None: |
| 319 | + if not isinstance(page_token, str) or not page_token: |
| 320 | + raise ValueError('Page token must be a non-empty string.') |
| 321 | + if not isinstance(max_results, int): |
| 322 | + raise ValueError('Max results must be an integer.') |
| 323 | + if max_results < 1 or max_results > MAX_LIST_CONFIGS_RESULTS: |
| 324 | + raise ValueError( |
| 325 | + 'Max results must be a positive integer less than or equal to ' |
| 326 | + '{0}.'.format(MAX_LIST_CONFIGS_RESULTS)) |
| 327 | + |
| 328 | + params = 'pageSize={0}'.format(max_results) |
| 329 | + if page_token: |
| 330 | + params += '&pageToken={0}'.format(page_token) |
| 331 | + return self._make_request('get', path, params=params) |
| 332 | + |
| 333 | + def _make_request(self, method, path, **kwargs): |
| 334 | + url = '{0}{1}'.format(self.base_url, path) |
| 335 | + try: |
| 336 | + return self.http_client.body(method, url, **kwargs) |
| 337 | + except requests.exceptions.RequestException as error: |
| 338 | + raise _auth_utils.handle_auth_backend_error(error) |
| 339 | + |
| 340 | + |
| 341 | +def _validate_oidc_provider_id(provider_id): |
| 342 | + if not isinstance(provider_id, str): |
| 343 | + raise ValueError( |
| 344 | + 'Invalid OIDC provider ID: {0}. Provider ID must be a non-empty string.'.format( |
| 345 | + provider_id)) |
| 346 | + if not provider_id.startswith('oidc.'): |
| 347 | + raise ValueError('Invalid OIDC provider ID: {0}.'.format(provider_id)) |
| 348 | + return provider_id |
| 349 | + |
| 350 | + |
| 351 | +def _validate_saml_provider_id(provider_id): |
| 352 | + if not isinstance(provider_id, str): |
| 353 | + raise ValueError( |
| 354 | + 'Invalid SAML provider ID: {0}. Provider ID must be a non-empty string.'.format( |
| 355 | + provider_id)) |
| 356 | + if not provider_id.startswith('saml.'): |
| 357 | + raise ValueError('Invalid SAML provider ID: {0}.'.format(provider_id)) |
| 358 | + return provider_id |
| 359 | + |
| 360 | + |
| 361 | +def _validate_non_empty_string(value, label): |
| 362 | + """Validates that the given value is a non-empty string.""" |
| 363 | + if not isinstance(value, str): |
| 364 | + raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) |
| 365 | + if not value: |
| 366 | + raise ValueError('{0} must not be empty.'.format(label)) |
| 367 | + return value |
| 368 | + |
| 369 | + |
| 370 | +def _validate_url(url, label): |
| 371 | + """Validates that the given value is a well-formed URL string.""" |
| 372 | + if not isinstance(url, str) or not url: |
| 373 | + raise ValueError( |
| 374 | + 'Invalid photo URL: "{0}". {1} must be a non-empty ' |
| 375 | + 'string.'.format(url, label)) |
| 376 | + try: |
| 377 | + parsed = parse.urlparse(url) |
| 378 | + if not parsed.netloc: |
| 379 | + raise ValueError('Malformed {0}: "{1}".'.format(label, url)) |
| 380 | + return url |
| 381 | + except Exception: |
| 382 | + raise ValueError('Malformed {0}: "{1}".'.format(label, url)) |
| 383 | + |
| 384 | + |
| 385 | +def _validate_x509_certificates(x509_certificates): |
| 386 | + if not isinstance(x509_certificates, list) or not x509_certificates: |
| 387 | + raise ValueError('x509_certificates must be a non-empty list.') |
| 388 | + if not all([isinstance(cert, str) and cert for cert in x509_certificates]): |
| 389 | + raise ValueError('x509_certificates must only contain non-empty strings.') |
| 390 | + return [{'x509Certificate': cert} for cert in x509_certificates] |
0 commit comments