From b57af183fa89013417653790b7693c2628a03723 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 12 Jun 2025 13:16:19 -0400 Subject: [PATCH 01/13] change(fcm): Remove deprecated FCM APIs (#890) --- firebase_admin/_gapic_utils.py | 122 ------- firebase_admin/messaging.py | 118 ------- integration/test_messaging.py | 65 ---- requirements.txt | 1 - setup.py | 1 - snippets/messaging/cloud_messaging.py | 24 +- tests/test_exceptions.py | 161 --------- tests/test_messaging.py | 486 +------------------------- 8 files changed, 16 insertions(+), 962 deletions(-) delete mode 100644 firebase_admin/_gapic_utils.py diff --git a/firebase_admin/_gapic_utils.py b/firebase_admin/_gapic_utils.py deleted file mode 100644 index 3c975808c..000000000 --- a/firebase_admin/_gapic_utils.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2021 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Internal utilities for interacting with Google API client.""" - -import io -import socket - -import googleapiclient -import httplib2 -import requests - -from firebase_admin import exceptions -from firebase_admin import _utils - - -def handle_platform_error_from_googleapiclient(error, handle_func=None): - """Constructs a ``FirebaseError`` from the given googleapiclient error. - - This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. - - Args: - error: An error raised by the googleapiclient while making an HTTP call to a GCP API. - handle_func: A function that can be used to handle platform errors in a custom way. When - specified, this function will be called with three arguments. It has the same - signature as ```_handle_func_googleapiclient``, but may return ``None``. - - Returns: - FirebaseError: A ``FirebaseError`` that can be raised to the user code. - """ - if not isinstance(error, googleapiclient.errors.HttpError): - return handle_googleapiclient_error(error) - - content = error.content.decode() - status_code = error.resp.status - error_dict, message = _utils._parse_platform_error(content, status_code) # pylint: disable=protected-access - http_response = _http_response_from_googleapiclient_error(error) - exc = None - if handle_func: - exc = handle_func(error, message, error_dict, http_response) - - return exc if exc else _handle_func_googleapiclient(error, message, error_dict, http_response) - - -def _handle_func_googleapiclient(error, message, error_dict, http_response): - """Constructs a ``FirebaseError`` from the given GCP error. - - Args: - error: An error raised by the googleapiclient module while making an HTTP call. - message: A message to be included in the resulting ``FirebaseError``. - error_dict: Parsed GCP error response. - http_response: A requests HTTP response object to associate with the exception. - - Returns: - FirebaseError: A ``FirebaseError`` that can be raised to the user code or None. - """ - code = error_dict.get('status') - return handle_googleapiclient_error(error, message, code, http_response) - - -def handle_googleapiclient_error(error, message=None, code=None, http_response=None): - """Constructs a ``FirebaseError`` from the given googleapiclient error. - - This method is agnostic of the remote service that produced the error, whether it is a GCP - service or otherwise. Therefore, this method does not attempt to parse the error response in - any way. - - Args: - error: An error raised by the googleapiclient module while making an HTTP call. - message: A message to be included in the resulting ``FirebaseError`` (optional). If not - specified the string representation of the ``error`` argument is used as the message. - code: A GCP error code that will be used to determine the resulting error type (optional). - If not specified the HTTP status code on the error response is used to determine a - suitable error code. - http_response: A requests HTTP response object to associate with the exception (optional). - If not specified, one will be created from the ``error``. - - Returns: - FirebaseError: A ``FirebaseError`` that can be raised to the user code. - """ - if isinstance(error, socket.timeout) or ( - isinstance(error, socket.error) and 'timed out' in str(error)): - return exceptions.DeadlineExceededError( - message='Timed out while making an API call: {0}'.format(error), - cause=error) - if isinstance(error, httplib2.ServerNotFoundError): - return exceptions.UnavailableError( - message='Failed to establish a connection: {0}'.format(error), - cause=error) - if not isinstance(error, googleapiclient.errors.HttpError): - return exceptions.UnknownError( - message='Unknown error while making a remote service call: {0}'.format(error), - cause=error) - - if not code: - code = _utils._http_status_to_error_code(error.resp.status) # pylint: disable=protected-access - if not message: - message = str(error) - if not http_response: - http_response = _http_response_from_googleapiclient_error(error) - - err_type = _utils._error_code_to_exception_type(code) # pylint: disable=protected-access - return err_type(message=message, cause=error, http_response=http_response) - - -def _http_response_from_googleapiclient_error(error): - """Creates a requests HTTP Response object from the given googleapiclient error.""" - resp = requests.models.Response() - resp.raw = io.BytesIO(error.content) - resp.status_code = error.resp.status - return resp diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 99dc93a67..0e3a55f49 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -18,21 +18,16 @@ from typing import Any, Callable, Dict, List, Optional, cast import concurrent.futures import json -import warnings import asyncio import logging import requests import httpx -from googleapiclient import http -from googleapiclient import _auth - import firebase_admin from firebase_admin import ( _http_client, _messaging_encoder, _messaging_utils, - _gapic_utils, _utils, exceptions, App @@ -72,8 +67,6 @@ 'WebpushNotificationAction', 'send', - 'send_all', - 'send_multicast', 'send_each', 'send_each_async', 'send_each_for_multicast', @@ -246,64 +239,6 @@ def send_each_for_multicast(multicast_message, dry_run=False, app=None): ) for token in multicast_message.tokens] return _get_messaging_service(app).send_each(messages, dry_run) -def send_all(messages, dry_run=False, app=None): - """Sends the given list of messages via Firebase Cloud Messaging as a single batch. - - If the ``dry_run`` mode is enabled, the message will not be actually delivered to the - recipients. Instead, FCM performs all the usual validations and emulates the send operation. - - Args: - messages: A list of ``messaging.Message`` instances. - dry_run: A boolean indicating whether to run the operation in dry run mode (optional). - app: An App instance (optional). - - Returns: - BatchResponse: A ``messaging.BatchResponse`` instance. - - Raises: - FirebaseError: If an error occurs while sending the message to the FCM service. - ValueError: If the input arguments are invalid. - - send_all() is deprecated. Use send_each() instead. - """ - warnings.warn('send_all() is deprecated. Use send_each() instead.', DeprecationWarning) - return _get_messaging_service(app).send_all(messages, dry_run) - -def send_multicast(multicast_message, dry_run=False, app=None): - """Sends the given mutlicast message to all tokens via Firebase Cloud Messaging (FCM). - - If the ``dry_run`` mode is enabled, the message will not be actually delivered to the - recipients. Instead, FCM performs all the usual validations and emulates the send operation. - - Args: - multicast_message: An instance of ``messaging.MulticastMessage``. - dry_run: A boolean indicating whether to run the operation in dry run mode (optional). - app: An App instance (optional). - - Returns: - BatchResponse: A ``messaging.BatchResponse`` instance. - - Raises: - FirebaseError: If an error occurs while sending the message to the FCM service. - ValueError: If the input arguments are invalid. - - send_multicast() is deprecated. Use send_each_for_multicast() instead. - """ - warnings.warn('send_multicast() is deprecated. Use send_each_for_multicast() instead.', - DeprecationWarning) - if not isinstance(multicast_message, MulticastMessage): - raise ValueError('Message must be an instance of messaging.MulticastMessage class.') - messages = [Message( - data=multicast_message.data, - notification=multicast_message.notification, - android=multicast_message.android, - webpush=multicast_message.webpush, - apns=multicast_message.apns, - fcm_options=multicast_message.fcm_options, - token=token - ) for token in multicast_message.tokens] - return _get_messaging_service(app).send_all(messages, dry_run) - def subscribe_to_topic(tokens, topic, app=None): """Subscribes a list of registration tokens to an FCM topic. @@ -472,7 +407,6 @@ def __init__(self, app: App) -> None: self._client = _http_client.JsonHttpClient(credential=self._credential, timeout=timeout) self._async_client = _http_client.HttpxAsyncClient( credential=self._credential, timeout=timeout) - self._build_transport = _auth.authorized_http @classmethod def encode_message(cls, message): @@ -555,45 +489,6 @@ async def send_data(data): message='Unknown error while making remote service calls: {0}'.format(error), cause=error) - - def send_all(self, messages, dry_run=False): - """Sends the given messages to FCM via the batch API.""" - if not isinstance(messages, list): - raise ValueError('messages must be a list of messaging.Message instances.') - if len(messages) > 500: - raise ValueError('messages must not contain more than 500 elements.') - - responses = [] - - def batch_callback(_, response, error): - exception = None - if error: - exception = self._handle_batch_error(error) - send_response = SendResponse(response, exception) - responses.append(send_response) - - batch = http.BatchHttpRequest( - callback=batch_callback, batch_uri=_MessagingService.FCM_BATCH_URL) - transport = self._build_transport(self._credential) - for message in messages: - body = json.dumps(self._message_data(message, dry_run)) - req = http.HttpRequest( - http=transport, - postproc=self._postproc, - uri=self._fcm_url, - method='POST', - body=body, - headers=self._fcm_headers - ) - batch.add(req) - - try: - batch.execute() - except Exception as error: - raise self._handle_batch_error(error) - else: - return BatchResponse(responses) - def make_topic_management_request(self, tokens, topic, operation): """Invokes the IID service for topic management functionality.""" if isinstance(tokens, str): @@ -670,11 +565,6 @@ def _handle_iid_error(self, error): return _utils.handle_requests_error(error, msg) - def _handle_batch_error(self, error): - """Handles errors received from the googleapiclient while making batch requests.""" - return _gapic_utils.handle_platform_error_from_googleapiclient( - error, _MessagingService._build_fcm_error_googleapiclient) - def close(self) -> None: asyncio.run(self._async_client.aclose()) @@ -700,14 +590,6 @@ def _build_fcm_error_httpx( message, cause=error, http_response=error.response) if exc_type else None return exc_type(message, cause=error) if exc_type else None - - @classmethod - def _build_fcm_error_googleapiclient(cls, error, message, error_dict, http_response): - """Parses an error response from the FCM API and creates a FCM-specific exception if - appropriate.""" - exc_type = cls._build_fcm_error(error_dict) - return exc_type(message, cause=error, http_response=http_response) if exc_type else None - @classmethod def _build_fcm_error( cls, diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 296a4d338..804691962 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -149,71 +149,6 @@ def test_send_each_for_multicast(): assert response.exception is not None assert response.message_id is None -@pytest.mark.skip(reason="Replaced with test_send_each") -def test_send_all(): - messages = [ - messaging.Message( - topic='foo-bar', notification=messaging.Notification('Title', 'Body')), - messaging.Message( - topic='foo-bar', notification=messaging.Notification('Title', 'Body')), - messaging.Message( - token='not-a-token', notification=messaging.Notification('Title', 'Body')), - ] - - batch_response = messaging.send_all(messages, dry_run=True) - - assert batch_response.success_count == 2 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 3 - - response = batch_response.responses[0] - assert response.success is True - assert response.exception is None - assert re.match('^projects/.*/messages/.*$', response.message_id) - - response = batch_response.responses[1] - assert response.success is True - assert response.exception is None - assert re.match('^projects/.*/messages/.*$', response.message_id) - - response = batch_response.responses[2] - assert response.success is False - assert isinstance(response.exception, exceptions.InvalidArgumentError) - assert response.message_id is None - -@pytest.mark.skip(reason="Replaced with test_send_each_500") -def test_send_all_500(): - messages = [] - for msg_number in range(500): - topic = 'foo-bar-{0}'.format(msg_number % 10) - messages.append(messaging.Message(topic=topic)) - - batch_response = messaging.send_all(messages, dry_run=True) - - assert batch_response.success_count == 500 - assert batch_response.failure_count == 0 - assert len(batch_response.responses) == 500 - for response in batch_response.responses: - assert response.success is True - assert response.exception is None - assert re.match('^projects/.*/messages/.*$', response.message_id) - -@pytest.mark.skip(reason="Replaced with test_send_each_for_multicast") -def test_send_multicast(): - multicast = messaging.MulticastMessage( - notification=messaging.Notification('Title', 'Body'), - tokens=['not-a-token', 'also-not-a-token']) - - batch_response = messaging.send_multicast(multicast) - - assert batch_response.success_count == 0 - assert batch_response.failure_count == 2 - assert len(batch_response.responses) == 2 - for response in batch_response.responses: - assert response.success is False - assert response.exception is not None - assert response.message_id is None - def test_subscribe(): resp = messaging.subscribe_to_topic(_REGISTRATION_TOKEN, 'mock-topic') assert resp.success_count + resp.failure_count == 1 diff --git a/requirements.txt b/requirements.txt index ba6f2f947..b5642b549 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,6 @@ respx == 0.22.0 cachecontrol >= 0.12.14 google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' -google-api-python-client >= 1.7.8 google-cloud-firestore >= 2.19.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.37.1 pyjwt[crypto] >= 2.5.0 diff --git a/setup.py b/setup.py index e92d207aa..b9eb11806 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,6 @@ install_requires = [ 'cachecontrol>=0.12.14', 'google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != "PyPy"', - 'google-api-python-client >= 1.7.8', 'google-cloud-firestore>=2.19.0; platform.python_implementation != "PyPy"', 'google-cloud-storage>=1.37.1', 'pyjwt[crypto] >= 2.5.0', diff --git a/snippets/messaging/cloud_messaging.py b/snippets/messaging/cloud_messaging.py index bb63db065..18a992dcc 100644 --- a/snippets/messaging/cloud_messaging.py +++ b/snippets/messaging/cloud_messaging.py @@ -222,9 +222,9 @@ def unsubscribe_from_topic(): # [END unsubscribe] -def send_all(): +def send_each(): registration_token = 'YOUR_REGISTRATION_TOKEN' - # [START send_all] + # [START send_each] # Create a list containing up to 500 messages. messages = [ messaging.Message( @@ -238,15 +238,15 @@ def send_all(): ), ] - response = messaging.send_all(messages) + response = messaging.send_each(messages) # See the BatchResponse reference documentation # for the contents of response. print('{0} messages were sent successfully'.format(response.success_count)) - # [END send_all] + # [END send_each] -def send_multicast(): - # [START send_multicast] +def send_each_for_multicast(): + # [START send_each_for_multicast] # Create a list containing up to 500 registration tokens. # These registration tokens come from the client FCM SDKs. registration_tokens = [ @@ -259,15 +259,15 @@ def send_multicast(): data={'score': '850', 'time': '2:45'}, tokens=registration_tokens, ) - response = messaging.send_multicast(message) + response = messaging.send_each_for_multicast(message) # See the BatchResponse reference documentation # for the contents of response. print('{0} messages were sent successfully'.format(response.success_count)) - # [END send_multicast] + # [END send_each_for_multicast] -def send_multicast_and_handle_errors(): - # [START send_multicast_error] +def send_each_for_multicast_and_handle_errors(): + # [START send_each_for_multicast_error] # These registration tokens come from the client FCM SDKs. registration_tokens = [ 'YOUR_REGISTRATION_TOKEN_1', @@ -279,7 +279,7 @@ def send_multicast_and_handle_errors(): data={'score': '850', 'time': '2:45'}, tokens=registration_tokens, ) - response = messaging.send_multicast(message) + response = messaging.send_each_for_multicast(message) if response.failure_count > 0: responses = response.responses failed_tokens = [] @@ -288,4 +288,4 @@ def send_multicast_and_handle_errors(): # The order of responses corresponds to the order of the registration tokens. failed_tokens.append(registration_tokens[idx]) print('List of tokens that caused failures: {0}'.format(failed_tokens)) - # [END send_multicast_error] + # [END send_each_for_multicast_error] diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 4347c838a..fa1276feb 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -14,17 +14,12 @@ import io import json -import socket -import httplib2 -import pytest import requests from requests import models -from googleapiclient import errors from firebase_admin import exceptions from firebase_admin import _utils -from firebase_admin import _gapic_utils _NOT_FOUND_ERROR_DICT = { @@ -178,159 +173,3 @@ def _create_response(self, status=500, payload=None): resp.raw = io.BytesIO(payload.encode()) exc = requests.exceptions.RequestException('Test error', response=resp) return resp, exc - - -class TestGoogleApiClient: - - @pytest.mark.parametrize('error', [ - socket.timeout('Test error'), - socket.error('Read timed out') - ]) - def test_googleapicleint_timeout_error(self, error): - firebase_error = _gapic_utils.handle_googleapiclient_error(error) - assert isinstance(firebase_error, exceptions.DeadlineExceededError) - assert str(firebase_error) == 'Timed out while making an API call: {0}'.format(error) - assert firebase_error.cause is error - assert firebase_error.http_response is None - - def test_googleapiclient_connection_error(self): - error = httplib2.ServerNotFoundError('Test error') - firebase_error = _gapic_utils.handle_googleapiclient_error(error) - assert isinstance(firebase_error, exceptions.UnavailableError) - assert str(firebase_error) == 'Failed to establish a connection: Test error' - assert firebase_error.cause is error - assert firebase_error.http_response is None - - def test_unknown_transport_error(self): - error = socket.error('Test error') - firebase_error = _gapic_utils.handle_googleapiclient_error(error) - assert isinstance(firebase_error, exceptions.UnknownError) - assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' - assert firebase_error.cause is error - assert firebase_error.http_response is None - - def test_http_response(self): - error = self._create_http_error() - firebase_error = _gapic_utils.handle_googleapiclient_error(error) - assert isinstance(firebase_error, exceptions.InternalError) - assert str(firebase_error) == str(error) - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == 'Body' - - def test_http_response_with_unknown_status(self): - error = self._create_http_error(status=501) - firebase_error = _gapic_utils.handle_googleapiclient_error(error) - assert isinstance(firebase_error, exceptions.UnknownError) - assert str(firebase_error) == str(error) - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 501 - assert firebase_error.http_response.content.decode() == 'Body' - - def test_http_response_with_message(self): - error = self._create_http_error() - firebase_error = _gapic_utils.handle_googleapiclient_error( - error, message='Explicit error message') - assert isinstance(firebase_error, exceptions.InternalError) - assert str(firebase_error) == 'Explicit error message' - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == 'Body' - - def test_http_response_with_code(self): - error = self._create_http_error() - firebase_error = _gapic_utils.handle_googleapiclient_error( - error, code=exceptions.UNAVAILABLE) - assert isinstance(firebase_error, exceptions.UnavailableError) - assert str(firebase_error) == str(error) - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == 'Body' - - def test_http_response_with_message_and_code(self): - error = self._create_http_error() - firebase_error = _gapic_utils.handle_googleapiclient_error( - error, message='Explicit error message', code=exceptions.UNAVAILABLE) - assert isinstance(firebase_error, exceptions.UnavailableError) - assert str(firebase_error) == 'Explicit error message' - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == 'Body' - - def test_handle_platform_error(self): - error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) - firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient(error) - assert isinstance(firebase_error, exceptions.NotFoundError) - assert str(firebase_error) == 'test error' - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == _NOT_FOUND_PAYLOAD - - def test_handle_platform_error_with_no_response(self): - error = socket.error('Test error') - firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient(error) - assert isinstance(firebase_error, exceptions.UnknownError) - assert str(firebase_error) == 'Unknown error while making a remote service call: Test error' - assert firebase_error.cause is error - assert firebase_error.http_response is None - - def test_handle_platform_error_with_no_error_code(self): - error = self._create_http_error(payload='no error code') - firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient(error) - assert isinstance(firebase_error, exceptions.InternalError) - message = 'Unexpected HTTP response with status: 500; body: no error code' - assert str(firebase_error) == message - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == 'no error code' - - def test_handle_platform_error_with_custom_handler(self): - error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) - invocations = [] - - def _custom_handler(cause, message, error_dict, http_response): - invocations.append((cause, message, error_dict, http_response)) - return exceptions.InvalidArgumentError('Custom message', cause, http_response) - - firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient( - error, _custom_handler) - - assert isinstance(firebase_error, exceptions.InvalidArgumentError) - assert str(firebase_error) == 'Custom message' - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == _NOT_FOUND_PAYLOAD - assert len(invocations) == 1 - args = invocations[0] - assert len(args) == 4 - assert args[0] is error - assert args[1] == 'test error' - assert args[2] == _NOT_FOUND_ERROR_DICT - assert args[3] is not None - - def test_handle_platform_error_with_custom_handler_ignore(self): - error = self._create_http_error(payload=_NOT_FOUND_PAYLOAD) - invocations = [] - - def _custom_handler(cause, message, error_dict, http_response): - invocations.append((cause, message, error_dict, http_response)) - - firebase_error = _gapic_utils.handle_platform_error_from_googleapiclient( - error, _custom_handler) - - assert isinstance(firebase_error, exceptions.NotFoundError) - assert str(firebase_error) == 'test error' - assert firebase_error.cause is error - assert firebase_error.http_response.status_code == 500 - assert firebase_error.http_response.content.decode() == _NOT_FOUND_PAYLOAD - assert len(invocations) == 1 - args = invocations[0] - assert len(args) == 4 - assert args[0] is error - assert args[1] == 'test error' - assert args[2] == _NOT_FOUND_ERROR_DICT - assert args[3] is not None - - def _create_http_error(self, status=500, payload='Body'): - resp = httplib2.Response({'status': status}) - return errors.HttpError(resp, payload.encode()) diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 76cee2a33..341fd9e07 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -20,8 +20,6 @@ import httpx import respx -from googleapiclient import http -from googleapiclient import _helpers import pytest import firebase_admin @@ -1826,17 +1824,7 @@ def test_send_unknown_fcm_error_code(self, status): self._assert_request(recorder[0], 'POST', self._get_url('explicit-project-id'), body) -class _HttpMockException: - - def __init__(self, exc): - self._exc = exc - - def request(self, url, **kwargs): - raise self._exc - - -class TestBatch: - +class TestSendEach(): @classmethod def setup_class(cls): cred = testutils.MockCredential() @@ -1856,40 +1844,6 @@ def _instrument_messaging_service(self, response_dict, app=None): testutils.MockRequestBasedMultiRequestAdapter(response_dict, recorder)) return fcm_service, recorder - def _instrument_batch_messaging_service(self, app=None, status=200, payload='', exc=None): - def build_mock_transport(_): - if exc: - return _HttpMockException(exc) - - if status == 200: - content_type = 'multipart/mixed; boundary=boundary' - else: - content_type = 'application/json' - return http.HttpMockSequence([ - ({'status': str(status), 'content-type': content_type}, payload), - ]) - - if not app: - app = firebase_admin.get_app() - - fcm_service = messaging._get_messaging_service(app) - fcm_service._build_transport = build_mock_transport - return fcm_service - - def _batch_payload(self, payloads): - # payloads should be a list of (status_code, content) tuples - payload = '' - _playload_format = """--boundary\r\nContent-Type: application/http\r\n\ -Content-ID: \r\n\r\nHTTP/1.1 {} Success\r\n\ -Content-Type: application/json; charset=UTF-8\r\n\r\n{}\r\n\r\n""" - for (index, (status_code, content)) in enumerate(payloads): - payload += _playload_format.format(str(index + 1), str(status_code), content) - payload += '--boundary--' - return payload - - -class TestSendEach(TestBatch): - def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') @@ -1948,12 +1902,6 @@ async def test_send_each_async(self): batch_response = await messaging.send_each_async([msg1, msg2, msg3], dry_run=True) - # try: - # batch_response = await messaging.send_each_async([msg1, msg2], dry_run=True) - # except Exception as error: - # if isinstance(error.cause.__cause__, StopIteration): - # raise Exception('Received more requests than mocks') - assert batch_response.success_count == 3 assert batch_response.failure_count == 0 assert len(batch_response.responses) == 3 @@ -2217,19 +2165,19 @@ def test_send_each_fcm_error_code(self, status, fcm_error_code, exc_type): check_exception(exception, 'test error', status) -class TestSendEachForMulticast(TestBatch): +class TestSendEachForMulticast(TestSendEach): def test_no_project_id(self): def evaluate(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') with pytest.raises(ValueError): - messaging.send_all([messaging.Message(topic='foo')], app=app) + messaging.send_each([messaging.Message(topic='foo')], app=app) testutils.run_without_project_id(evaluate) @pytest.mark.parametrize('msg', NON_LIST_ARGS) def test_invalid_send_each_for_multicast(self, msg): with pytest.raises(ValueError) as excinfo: - messaging.send_multicast(msg) + messaging.send_each_for_multicast(msg) expected = 'Message must be an instance of messaging.MulticastMessage class.' assert str(excinfo.value) == expected @@ -2338,432 +2286,6 @@ def test_send_each_for_multicast_fcm_error_code(self, status): check_exception(exception, 'test error', status) -class TestSendAll(TestBatch): - - def test_no_project_id(self): - def evaluate(): - app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') - with pytest.raises(ValueError): - messaging.send_all([messaging.Message(topic='foo')], app=app) - testutils.run_without_project_id(evaluate) - - @pytest.mark.parametrize('msg', NON_LIST_ARGS) - def test_invalid_send_all(self, msg): - with pytest.raises(ValueError) as excinfo: - messaging.send_all(msg) - if isinstance(msg, list): - expected = 'Message must be an instance of messaging.Message class.' - assert str(excinfo.value) == expected - else: - expected = 'messages must be a list of messaging.Message instances.' - assert str(excinfo.value) == expected - - def test_invalid_over_500(self): - msg = messaging.Message(topic='foo') - with pytest.raises(ValueError) as excinfo: - messaging.send_all([msg for _ in range(0, 501)]) - expected = 'messages must not contain more than 500 elements.' - assert str(excinfo.value) == expected - - def test_send_all(self): - payload = json.dumps({'name': 'message-id'}) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, payload), (200, payload)])) - msg = messaging.Message(topic='foo') - batch_response = messaging.send_all([msg, msg], dry_run=True) - assert batch_response.success_count == 2 - assert batch_response.failure_count == 0 - assert len(batch_response.responses) == 2 - assert [r.message_id for r in batch_response.responses] == ['message-id', 'message-id'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) - - def test_send_all_with_positional_param_enforcement(self): - payload = json.dumps({'name': 'message-id'}) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, payload), (200, payload)])) - msg = messaging.Message(topic='foo') - - enforcement = _helpers.positional_parameters_enforcement - _helpers.positional_parameters_enforcement = _helpers.POSITIONAL_EXCEPTION - try: - batch_response = messaging.send_all([msg, msg], dry_run=True) - assert batch_response.success_count == 2 - finally: - _helpers.positional_parameters_enforcement = enforcement - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_all_detailed_error(self, status): - success_payload = json.dumps({'name': 'message-id'}) - error_payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, success_payload), (status, error_payload)])) - msg = messaging.Message(topic='foo') - batch_response = messaging.send_all([msg, msg]) - assert batch_response.success_count == 1 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 2 - success_response = batch_response.responses[0] - assert success_response.message_id == 'message-id' - assert success_response.success is True - assert success_response.exception is None - error_response = batch_response.responses[1] - assert error_response.message_id is None - assert error_response.success is False - exception = error_response.exception - assert isinstance(exception, exceptions.InvalidArgumentError) - check_exception(exception, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_all_canonical_error_code(self, status): - success_payload = json.dumps({'name': 'message-id'}) - error_payload = json.dumps({ - 'error': { - 'status': 'NOT_FOUND', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, success_payload), (status, error_payload)])) - msg = messaging.Message(topic='foo') - batch_response = messaging.send_all([msg, msg]) - assert batch_response.success_count == 1 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 2 - success_response = batch_response.responses[0] - assert success_response.message_id == 'message-id' - assert success_response.success is True - assert success_response.exception is None - error_response = batch_response.responses[1] - assert error_response.message_id is None - assert error_response.success is False - exception = error_response.exception - assert isinstance(exception, exceptions.NotFoundError) - check_exception(exception, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - @pytest.mark.parametrize('fcm_error_code, exc_type', FCM_ERROR_CODES.items()) - def test_send_all_fcm_error_code(self, status, fcm_error_code, exc_type): - success_payload = json.dumps({'name': 'message-id'}) - error_payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error', - 'details': [ - { - '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', - 'errorCode': fcm_error_code, - }, - ], - } - }) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, success_payload), (status, error_payload)])) - msg = messaging.Message(topic='foo') - batch_response = messaging.send_all([msg, msg]) - assert batch_response.success_count == 1 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 2 - success_response = batch_response.responses[0] - assert success_response.message_id == 'message-id' - assert success_response.success is True - assert success_response.exception is None - error_response = batch_response.responses[1] - assert error_response.message_id is None - assert error_response.success is False - exception = error_response.exception - assert isinstance(exception, exc_type) - check_exception(exception, 'test error', status) - - @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) - def test_send_all_batch_error(self, status, exc_type): - _ = self._instrument_batch_messaging_service(status=status, payload='{}') - msg = messaging.Message(topic='foo') - with pytest.raises(exc_type) as excinfo: - messaging.send_all([msg]) - expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) - check_exception(excinfo.value, expected, status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_all_batch_detailed_error(self, status): - payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service(status=status, payload=payload) - msg = messaging.Message(topic='foo') - with pytest.raises(exceptions.InvalidArgumentError) as excinfo: - messaging.send_all([msg]) - check_exception(excinfo.value, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_all_batch_canonical_error_code(self, status): - payload = json.dumps({ - 'error': { - 'status': 'NOT_FOUND', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service(status=status, payload=payload) - msg = messaging.Message(topic='foo') - with pytest.raises(exceptions.NotFoundError) as excinfo: - messaging.send_all([msg]) - check_exception(excinfo.value, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_all_batch_fcm_error_code(self, status): - payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error', - 'details': [ - { - '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', - 'errorCode': 'UNREGISTERED', - }, - ], - } - }) - _ = self._instrument_batch_messaging_service(status=status, payload=payload) - msg = messaging.Message(topic='foo') - with pytest.raises(messaging.UnregisteredError) as excinfo: - messaging.send_all([msg]) - check_exception(excinfo.value, 'test error', status) - - def test_send_all_runtime_exception(self): - exc = BrokenPipeError('Test error') - _ = self._instrument_batch_messaging_service(exc=exc) - msg = messaging.Message(topic='foo') - - with pytest.raises(exceptions.UnknownError) as excinfo: - messaging.send_all([msg]) - - expected = 'Unknown error while making a remote service call: Test error' - assert str(excinfo.value) == expected - assert excinfo.value.cause is exc - assert excinfo.value.http_response is None - - def test_send_transport_init(self): - def track_call_count(build_transport): - def wrapper(credential): - wrapper.calls += 1 - return build_transport(credential) - wrapper.calls = 0 - return wrapper - - payload = json.dumps({'name': 'message-id'}) - fcm_service = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, payload), (200, payload)])) - build_mock_transport = fcm_service._build_transport - fcm_service._build_transport = track_call_count(build_mock_transport) - msg = messaging.Message(topic='foo') - - batch_response = messaging.send_all([msg, msg], dry_run=True) - assert batch_response.success_count == 2 - assert fcm_service._build_transport.calls == 1 - - batch_response = messaging.send_all([msg, msg], dry_run=True) - assert batch_response.success_count == 2 - assert fcm_service._build_transport.calls == 2 - - -class TestSendMulticast(TestBatch): - - def test_no_project_id(self): - def evaluate(): - app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id') - with pytest.raises(ValueError): - messaging.send_all([messaging.Message(topic='foo')], app=app) - testutils.run_without_project_id(evaluate) - - @pytest.mark.parametrize('msg', NON_LIST_ARGS) - def test_invalid_send_multicast(self, msg): - with pytest.raises(ValueError) as excinfo: - messaging.send_multicast(msg) - expected = 'Message must be an instance of messaging.MulticastMessage class.' - assert str(excinfo.value) == expected - - def test_send_multicast(self): - payload = json.dumps({'name': 'message-id'}) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, payload), (200, payload)])) - msg = messaging.MulticastMessage(tokens=['foo', 'foo']) - batch_response = messaging.send_multicast(msg, dry_run=True) - assert batch_response.success_count == 2 - assert batch_response.failure_count == 0 - assert len(batch_response.responses) == 2 - assert [r.message_id for r in batch_response.responses] == ['message-id', 'message-id'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_multicast_detailed_error(self, status): - success_payload = json.dumps({'name': 'message-id'}) - error_payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, success_payload), (status, error_payload)])) - msg = messaging.MulticastMessage(tokens=['foo', 'foo']) - batch_response = messaging.send_multicast(msg) - assert batch_response.success_count == 1 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 2 - success_response = batch_response.responses[0] - assert success_response.message_id == 'message-id' - assert success_response.success is True - assert success_response.exception is None - error_response = batch_response.responses[1] - assert error_response.message_id is None - assert error_response.success is False - assert error_response.exception is not None - exception = error_response.exception - assert isinstance(exception, exceptions.InvalidArgumentError) - check_exception(exception, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_multicast_canonical_error_code(self, status): - success_payload = json.dumps({'name': 'message-id'}) - error_payload = json.dumps({ - 'error': { - 'status': 'NOT_FOUND', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, success_payload), (status, error_payload)])) - msg = messaging.MulticastMessage(tokens=['foo', 'foo']) - batch_response = messaging.send_multicast(msg) - assert batch_response.success_count == 1 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 2 - success_response = batch_response.responses[0] - assert success_response.message_id == 'message-id' - assert success_response.success is True - assert success_response.exception is None - error_response = batch_response.responses[1] - assert error_response.message_id is None - assert error_response.success is False - assert error_response.exception is not None - exception = error_response.exception - assert isinstance(exception, exceptions.NotFoundError) - check_exception(exception, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_multicast_fcm_error_code(self, status): - success_payload = json.dumps({'name': 'message-id'}) - error_payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error', - 'details': [ - { - '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', - 'errorCode': 'UNREGISTERED', - }, - ], - } - }) - _ = self._instrument_batch_messaging_service( - payload=self._batch_payload([(200, success_payload), (status, error_payload)])) - msg = messaging.MulticastMessage(tokens=['foo', 'foo']) - batch_response = messaging.send_multicast(msg) - assert batch_response.success_count == 1 - assert batch_response.failure_count == 1 - assert len(batch_response.responses) == 2 - success_response = batch_response.responses[0] - assert success_response.message_id == 'message-id' - assert success_response.success is True - assert success_response.exception is None - error_response = batch_response.responses[1] - assert error_response.message_id is None - assert error_response.success is False - assert error_response.exception is not None - exception = error_response.exception - assert isinstance(exception, messaging.UnregisteredError) - check_exception(exception, 'test error', status) - - @pytest.mark.parametrize('status, exc_type', HTTP_ERROR_CODES.items()) - def test_send_multicast_batch_error(self, status, exc_type): - _ = self._instrument_batch_messaging_service(status=status, payload='{}') - msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(exc_type) as excinfo: - messaging.send_multicast(msg) - expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) - check_exception(excinfo.value, expected, status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_multicast_batch_detailed_error(self, status): - payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service(status=status, payload=payload) - msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(exceptions.InvalidArgumentError) as excinfo: - messaging.send_multicast(msg) - check_exception(excinfo.value, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_multicast_batch_canonical_error_code(self, status): - payload = json.dumps({ - 'error': { - 'status': 'NOT_FOUND', - 'message': 'test error' - } - }) - _ = self._instrument_batch_messaging_service(status=status, payload=payload) - msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(exceptions.NotFoundError) as excinfo: - messaging.send_multicast(msg) - check_exception(excinfo.value, 'test error', status) - - @pytest.mark.parametrize('status', HTTP_ERROR_CODES) - def test_send_multicast_batch_fcm_error_code(self, status): - payload = json.dumps({ - 'error': { - 'status': 'INVALID_ARGUMENT', - 'message': 'test error', - 'details': [ - { - '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', - 'errorCode': 'UNREGISTERED', - }, - ], - } - }) - _ = self._instrument_batch_messaging_service(status=status, payload=payload) - msg = messaging.MulticastMessage(tokens=['foo']) - with pytest.raises(messaging.UnregisteredError) as excinfo: - messaging.send_multicast(msg) - check_exception(excinfo.value, 'test error', status) - - def test_send_multicast_runtime_exception(self): - exc = BrokenPipeError('Test error') - _ = self._instrument_batch_messaging_service(exc=exc) - msg = messaging.MulticastMessage(tokens=['foo']) - - with pytest.raises(exceptions.UnknownError) as excinfo: - messaging.send_multicast(msg) - - expected = 'Unknown error while making a remote service call: Test error' - assert str(excinfo.value) == expected - assert excinfo.value.cause is exc - assert excinfo.value.http_response is None - - class TestTopicManagement: _DEFAULT_RESPONSE = json.dumps({'results': [{}, {'error': 'error_reason'}]}) From dae267c1f93450852de904627f364706718f8356 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Tue, 17 Jun 2025 14:59:53 -0400 Subject: [PATCH 02/13] chore(deps): Bump minimum supported Python version to 3.9 and add 3.13 to CIs (#892) * chore(deps): Bump minimum supported Python version to 3.9 and add 3.13 to CIs * fix deprecation warnings * fix GHA build status svg * fix: Correctly scope async eventloop * fix: Bump pylint to v2.7.4 and astroid to v2.5.8 to fix lint issues * fix ml tests * fix lint * fix: remove commented code --- .github/workflows/ci.yml | 6 +++--- .github/workflows/nightly.yml | 5 +++-- .github/workflows/release.yml | 5 +++-- CONTRIBUTING.md | 2 +- README.md | 6 +++--- firebase_admin/__init__.py | 5 +++-- firebase_admin/_auth_providers.py | 6 +++--- firebase_admin/_auth_utils.py | 16 ++++++++-------- firebase_admin/_sseclient.py | 2 +- firebase_admin/_token_gen.py | 6 +++--- firebase_admin/_user_import.py | 6 +++--- firebase_admin/_user_mgt.py | 16 ++++++++-------- firebase_admin/app_check.py | 20 ++++++++++---------- firebase_admin/credentials.py | 10 +++++----- firebase_admin/db.py | 2 +- firebase_admin/messaging.py | 5 ++++- firebase_admin/ml.py | 5 +---- firebase_admin/project_management.py | 8 ++++---- firebase_admin/storage.py | 4 ++-- firebase_admin/tenant_mgt.py | 5 +---- integration/conftest.py | 7 ------- integration/test_firestore_async.py | 8 ++++---- integration/test_messaging.py | 6 +++--- integration/test_ml.py | 14 +++++++++----- integration/test_storage.py | 2 +- requirements.txt | 8 ++++---- setup.cfg | 2 ++ setup.py | 7 +++---- tests/test_db.py | 2 +- tests/test_messaging.py | 23 ++++++++++++----------- tests/test_ml.py | 6 +++--- tests/test_remote_config.py | 2 +- tests/test_sseclient.py | 4 ++-- tests/test_tenant_mgt.py | 6 +++--- tests/testutils.py | 2 +- 35 files changed, 119 insertions(+), 120 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4cc8ec481..bfd29e2cc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ jobs: strategy: fail-fast: false matrix: - python: ['3.8', '3.9', '3.10', '3.11', '3.12', 'pypy3.9'] + python: ['3.9', '3.10', '3.11', '3.12', '3.13', 'pypy3.9'] steps: - uses: actions/checkout@v4 @@ -35,10 +35,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 282cb1b91..3d5420537 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -36,7 +36,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | @@ -45,6 +45,7 @@ jobs: pip install setuptools wheel pip install tensorflow pip install keras + pip install build - name: Run unit tests run: pytest @@ -57,7 +58,7 @@ jobs: # Build the Python Wheel and the source distribution. - name: Package release artifacts - run: python setup.py bdist_wheel sdist + run: python -m build # Attach the packaged artifacts to the workflow output. These can be manually # downloaded for later inspection if necessary. diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7a7986a5a..6cd1d3f07 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -47,7 +47,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | @@ -56,6 +56,7 @@ jobs: pip install setuptools wheel pip install tensorflow pip install keras + pip install build - name: Run unit tests run: pytest @@ -68,7 +69,7 @@ jobs: # Build the Python Wheel and the source distribution. - name: Package release artifacts - run: python setup.py bdist_wheel sdist + run: python -m build # Attach the packaged artifacts to the workflow output. These can be manually # downloaded for later inspection if necessary. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index de5934866..72933a24f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,7 +85,7 @@ information on using pull requests. ### Initial Setup -You need Python 3.8+ to build and test the code in this repo. +You need Python 3.9+ to build and test the code in this repo. We recommend using [pip](https://pypi.python.org/pypi/pip) for installing the necessary tools and project dependencies. Most recent versions of Python ship with pip. If your development environment diff --git a/README.md b/README.md index 6e3ed6805..29303fd4f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![Build Status](https://travis-ci.org/firebase/firebase-admin-python.svg?branch=master)](https://travis-ci.org/firebase/firebase-admin-python) +[![Nightly Builds](https://github.com/firebase/firebase-admin-python/actions/workflows/nightly.yml/badge.svg)](https://github.com/firebase/firebase-admin-python/actions/workflows/nightly.yml) [![Python](https://img.shields.io/pypi/pyversions/firebase-admin.svg)](https://pypi.org/project/firebase-admin/) [![Version](https://img.shields.io/pypi/v/firebase-admin.svg)](https://pypi.org/project/firebase-admin/) @@ -43,8 +43,8 @@ requests, code review feedback, and also pull requests. ## Supported Python Versions -We currently support Python 3.7+. However, Python 3.7 and Python 3.8 support is deprecated, -and developers are strongly advised to use Python 3.9 or higher. Firebase +We currently support Python 3.9+. However, Python 3.9 support is deprecated, +and developers are strongly advised to use Python 3.10 or higher. Firebase Admin Python SDK is also tested on PyPy and [Google App Engine](https://cloud.google.com/appengine/) environments. diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index 7bb9c59c2..597aaa6b6 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -178,11 +178,12 @@ def _load_from_environment(self): with open(config_file, 'r') as json_file: json_str = json_file.read() except Exception as err: - raise ValueError('Unable to read file {}. {}'.format(config_file, err)) + raise ValueError('Unable to read file {}. {}'.format(config_file, err)) from err try: json_data = json.loads(json_str) except Exception as err: - raise ValueError('JSON string "{0}" is not valid json. {1}'.format(json_str, err)) + raise ValueError( + 'JSON string "{0}" is not valid json. {1}'.format(json_str, err)) from err return {k: v for k, v in json_data.items() if k in _CONFIG_VALID_KEYS} diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index 31894a4dc..6512a4f7b 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -422,13 +422,13 @@ def _validate_url(url, label): if not parsed.netloc: raise ValueError('Malformed {0}: "{1}".'.format(label, url)) return url - except Exception: - raise ValueError('Malformed {0}: "{1}".'.format(label, url)) + except Exception as exception: + raise ValueError('Malformed {0}: "{1}".'.format(label, url)) from exception def _validate_x509_certificates(x509_certificates): if not isinstance(x509_certificates, list) or not x509_certificates: raise ValueError('x509_certificates must be a non-empty list.') - if not all([isinstance(cert, str) and cert for cert in x509_certificates]): + if not all(isinstance(cert, str) and cert for cert in x509_certificates): raise ValueError('x509_certificates must only contain non-empty strings.') return [{'x509Certificate': cert} for cert in x509_certificates] diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index ac7b322ff..0d56ca7fa 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -175,8 +175,8 @@ def validate_photo_url(photo_url, required=False): if not parsed.netloc: raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) return photo_url - except Exception: - raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) + except Exception as err: + raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) from err def validate_timestamp(timestamp, label, required=False): """Validates the given timestamp value. Timestamps must be positive integers.""" @@ -186,8 +186,8 @@ def validate_timestamp(timestamp, label, required=False): raise ValueError('Boolean value specified as timestamp.') try: timestamp_int = int(timestamp) - except TypeError: - raise ValueError('Invalid type for timestamp value: {0}.'.format(timestamp)) + except TypeError as err: + raise ValueError('Invalid type for timestamp value: {0}.'.format(timestamp)) from err else: if timestamp_int != timestamp: raise ValueError('{0} must be a numeric value and a whole number.'.format(label)) @@ -207,8 +207,8 @@ def validate_int(value, label, low=None, high=None): raise ValueError('Invalid type for integer value: {0}.'.format(value)) try: val_int = int(value) - except TypeError: - raise ValueError('Invalid type for integer value: {0}.'.format(value)) + except TypeError as err: + raise ValueError('Invalid type for integer value: {0}.'.format(value)) from err else: if val_int != value: # This will be True for non-numeric values like '2' and non-whole numbers like 2.5. @@ -246,8 +246,8 @@ def validate_custom_claims(custom_claims, required=False): MAX_CLAIMS_PAYLOAD_SIZE)) try: parsed = json.loads(claims_str) - except Exception: - raise ValueError('Failed to parse custom claims string as JSON.') + except Exception as err: + raise ValueError('Failed to parse custom claims string as JSON.') from err if not isinstance(parsed, dict): raise ValueError('Custom claims must be parseable as a JSON object.') diff --git a/firebase_admin/_sseclient.py b/firebase_admin/_sseclient.py index 6585dfc80..ec20cb45c 100644 --- a/firebase_admin/_sseclient.py +++ b/firebase_admin/_sseclient.py @@ -34,7 +34,7 @@ class KeepAuthSession(transport.requests.AuthorizedSession): """A session that does not drop authentication on redirects between domains.""" def __init__(self, credential): - super(KeepAuthSession, self).__init__(credential) + super().__init__(credential) def rebuild_auth(self, prepared_request, response): pass diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index a2fc725e8..6d82bf7a6 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -158,7 +158,7 @@ def signing_provider(self): 'Failed to determine service account: {0}. Make sure to initialize the SDK ' 'with service account credentials or specify a service account ID with ' 'iam.serviceAccounts.signBlob permission. Please refer to {1} for more ' - 'details on creating custom tokens.'.format(error, url)) + 'details on creating custom tokens.'.format(error, url)) from error return self._signing_provider def create_custom_token(self, uid, developer_claims=None, tenant_id=None): @@ -203,7 +203,7 @@ def create_custom_token(self, uid, developer_claims=None, tenant_id=None): return jwt.encode(signing_provider.signer, payload, header=header) except google.auth.exceptions.TransportError as error: msg = 'Failed to sign custom token. {0}'.format(error) - raise TokenSignError(msg, error) + raise TokenSignError(msg, error) from error def create_session_cookie(self, id_token, expires_in): @@ -403,7 +403,7 @@ def verify(self, token, request, clock_skew_seconds=0): verified_claims['uid'] = verified_claims['sub'] return verified_claims except google.auth.exceptions.TransportError as error: - raise CertificateFetchError(str(error), cause=error) + raise CertificateFetchError(str(error), cause=error) from error except ValueError as error: if 'Token expired' in str(error): raise self._expired_token_error(str(error), cause=error) diff --git a/firebase_admin/_user_import.py b/firebase_admin/_user_import.py index 659a68701..7c7a9e70b 100644 --- a/firebase_admin/_user_import.py +++ b/firebase_admin/_user_import.py @@ -216,10 +216,10 @@ def provider_data(self): def provider_data(self, provider_data): if provider_data is not None: try: - if any([not isinstance(p, UserProvider) for p in provider_data]): + if any(not isinstance(p, UserProvider) for p in provider_data): raise ValueError('One or more provider data instances are invalid.') - except TypeError: - raise ValueError('provider_data must be iterable.') + except TypeError as err: + raise ValueError('provider_data must be iterable.') from err self._provider_data = provider_data @property diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index aa0dfb0a4..957b749a6 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -128,7 +128,7 @@ class UserRecord(UserInfo): """Contains metadata associated with a Firebase user account.""" def __init__(self, data): - super(UserRecord, self).__init__() + super().__init__() if not isinstance(data, dict): raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) if not data.get('localId'): @@ -452,7 +452,7 @@ class ProviderUserInfo(UserInfo): """Contains metadata regarding how a user is known by a particular identity provider.""" def __init__(self, data): - super(ProviderUserInfo, self).__init__() + super().__init__() if not isinstance(data, dict): raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) if not data.get('rawId'): @@ -518,8 +518,8 @@ def encode_action_code_settings(settings): if not parsed.netloc: raise ValueError('Malformed dynamic action links url: "{0}".'.format(settings.url)) parameters['continueUrl'] = settings.url - except Exception: - raise ValueError('Malformed dynamic action links url: "{0}".'.format(settings.url)) + except Exception as err: + raise ValueError('Malformed dynamic action links url: "{0}".'.format(settings.url)) from err # handle_code_in_app if settings.handle_code_in_app is not None: @@ -788,13 +788,13 @@ def import_users(self, users, hash_alg=None): raise ValueError( 'Users must be a non-empty list with no more than {0} elements.'.format( MAX_IMPORT_USERS_SIZE)) - if any([not isinstance(u, _user_import.ImportUserRecord) for u in users]): + if any(not isinstance(u, _user_import.ImportUserRecord) for u in users): raise ValueError('One or more user objects are invalid.') - except TypeError: - raise ValueError('users must be iterable') + except TypeError as err: + raise ValueError('users must be iterable') from err payload = {'users': [u.to_dict() for u in users]} - if any(['passwordHash' in u for u in payload['users']]): + if any('passwordHash' in u for u in payload['users']): if not isinstance(hash_alg, _user_import.UserImportHash): raise ValueError('A UserImportHash is required to import users with passwords.') payload.update(hash_alg.to_dict()) diff --git a/firebase_admin/app_check.py b/firebase_admin/app_check.py index 53686db3d..1224f7d80 100644 --- a/firebase_admin/app_check.py +++ b/firebase_admin/app_check.py @@ -84,7 +84,7 @@ def verify_token(self, token: str) -> Dict[str, Any]: except (InvalidTokenError, DecodeError) as exception: raise ValueError( f'Verifying App Check token failed. Error: {exception}' - ) + ) from exception verified_claims['app_id'] = verified_claims.get('sub') return verified_claims @@ -112,28 +112,28 @@ def _decode_and_verify(self, token: str, signing_key: str): algorithms=["RS256"], audience=self._scoped_project_id ) - except InvalidSignatureError: + except InvalidSignatureError as exception: raise ValueError( 'The provided App Check token has an invalid signature.' - ) - except InvalidAudienceError: + ) from exception + except InvalidAudienceError as exception: raise ValueError( 'The provided App Check token has an incorrect "aud" (audience) claim. ' f'Expected payload to include {self._scoped_project_id}.' - ) - except InvalidIssuerError: + ) from exception + except InvalidIssuerError as exception: raise ValueError( 'The provided App Check token has an incorrect "iss" (issuer) claim. ' f'Expected claim to include {self._APP_CHECK_ISSUER}' - ) - except ExpiredSignatureError: + ) from exception + except ExpiredSignatureError as exception: raise ValueError( 'The provided App Check token has expired.' - ) + ) from exception except InvalidTokenError as exception: raise ValueError( f'Decoding App Check token failed. Error: {exception}' - ) + ) from exception audience = payload.get('aud') if not isinstance(audience, list) or self._scoped_project_id not in audience: diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 750600280..8259c93b4 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -63,7 +63,7 @@ class _ExternalCredentials(Base): """A wrapper for google.auth.credentials.Credentials typed credential instances""" def __init__(self, credential: GoogleAuthCredentials): - super(_ExternalCredentials, self).__init__() + super().__init__() self._g_credential = credential def get_credential(self): @@ -92,7 +92,7 @@ def __init__(self, cert): IOError: If the specified certificate file doesn't exist or cannot be read. ValueError: If the specified certificate is invalid. """ - super(Certificate, self).__init__() + super().__init__() if _is_file_path(cert): with open(cert) as json_file: json_data = json.load(json_file) @@ -111,7 +111,7 @@ def __init__(self, cert): json_data, scopes=_scopes) except ValueError as error: raise ValueError('Failed to initialize a certificate credential. ' - 'Caused by: "{0}"'.format(error)) + 'Caused by: "{0}"'.format(error)) from error @property def project_id(self): @@ -142,7 +142,7 @@ def __init__(self): The credentials will be lazily initialized when get_credential() or project_id() is called. See those methods for possible errors raised. """ - super(ApplicationDefault, self).__init__() + super().__init__() self._g_credential = None # Will be lazily-loaded via _load_credential(). def get_credential(self): @@ -193,7 +193,7 @@ def __init__(self, refresh_token): IOError: If the specified file doesn't exist or cannot be read. ValueError: If the refresh token configuration is invalid. """ - super(RefreshToken, self).__init__() + super().__init__() if _is_file_path(refresh_token): with open(refresh_token) as json_file: json_data = json.load(json_file) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 1dec98653..fc69cbd83 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -926,7 +926,7 @@ def request(self, method, url, **kwargs): kwargs['params'] = query try: - return super(_Client, self).request(method, url, **kwargs) + return super().request(method, url, **kwargs) except requests.exceptions.RequestException as error: raise _Client.handle_rtdb_error(error) diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 0e3a55f49..5b2e48e80 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -451,7 +451,7 @@ def send_data(data): message_data = [self._message_data(message, dry_run) for message in messages] try: with concurrent.futures.ThreadPoolExecutor(max_workers=len(message_data)) as executor: - responses = [resp for resp in executor.map(send_data, message_data)] + responses = list(executor.map(send_data, message_data)) return BatchResponse(responses) except Exception as error: raise exceptions.UnknownError( @@ -573,6 +573,7 @@ def _build_fcm_error_requests(cls, error, message, error_dict): """Parses an error response from the FCM API and creates a FCM-specific exception if appropriate.""" exc_type = cls._build_fcm_error(error_dict) + # pylint: disable=not-callable return exc_type(message, cause=error, http_response=error.response) if exc_type else None @classmethod @@ -586,8 +587,10 @@ def _build_fcm_error_httpx( appropriate.""" exc_type = cls._build_fcm_error(error_dict) if isinstance(error, httpx.HTTPStatusError): + # pylint: disable=not-callable return exc_type( message, cause=error, http_response=error.response) if exc_type else None + # pylint: disable=not-callable return exc_type(message, cause=error) if exc_type else None @classmethod diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 98bdbb56a..8cedc8482 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -721,7 +721,7 @@ def __init__(self, current_page): self._current_page = current_page self._index = 0 - def next(self): + def __next__(self): if self._index == len(self._current_page.models): if self._current_page.has_next_page: self._current_page = self._current_page.get_next_page() @@ -732,9 +732,6 @@ def next(self): return result raise StopIteration - def __next__(self): - return self.next() - def __iter__(self): return self diff --git a/firebase_admin/project_management.py b/firebase_admin/project_management.py index ed292b80f..9405c8318 100644 --- a/firebase_admin/project_management.py +++ b/firebase_admin/project_management.py @@ -338,7 +338,7 @@ class AndroidAppMetadata(_AppMetadata): def __init__(self, package_name, name, app_id, display_name, project_id): """Clients should not instantiate this class directly.""" - super(AndroidAppMetadata, self).__init__(name, app_id, display_name, project_id) + super().__init__(name, app_id, display_name, project_id) self._package_name = _check_is_nonempty_string(package_name, 'package_name') @property @@ -347,7 +347,7 @@ def package_name(self): return self._package_name def __eq__(self, other): - return (super(AndroidAppMetadata, self).__eq__(other) and + return (super().__eq__(other) and self.package_name == other.package_name) def __ne__(self, other): @@ -363,7 +363,7 @@ class IOSAppMetadata(_AppMetadata): def __init__(self, bundle_id, name, app_id, display_name, project_id): """Clients should not instantiate this class directly.""" - super(IOSAppMetadata, self).__init__(name, app_id, display_name, project_id) + super().__init__(name, app_id, display_name, project_id) self._bundle_id = _check_is_nonempty_string(bundle_id, 'bundle_id') @property @@ -372,7 +372,7 @@ def bundle_id(self): return self._bundle_id def __eq__(self, other): - return super(IOSAppMetadata, self).__eq__(other) and self.bundle_id == other.bundle_id + return super().__eq__(other) and self.bundle_id == other.bundle_id def __ne__(self, other): return not self.__eq__(other) diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index b6084842a..567a6abad 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -21,9 +21,9 @@ # pylint: disable=import-error,no-name-in-module try: from google.cloud import storage -except ImportError: +except ImportError as exception: raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' - 'to install the "google-cloud-storage" module.') + 'to install the "google-cloud-storage" module.') from exception from firebase_admin import _utils diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py index 8c53e30a1..133e80b45 100644 --- a/firebase_admin/tenant_mgt.py +++ b/firebase_admin/tenant_mgt.py @@ -417,7 +417,7 @@ def __init__(self, current_page): self._current_page = current_page self._index = 0 - def next(self): + def __next__(self): if self._index == len(self._current_page.tenants): if self._current_page.has_next_page: self._current_page = self._current_page.get_next_page() @@ -428,9 +428,6 @@ def next(self): return result raise StopIteration - def __next__(self): - return self.next() - def __iter__(self): return self diff --git a/integration/conftest.py b/integration/conftest.py index efa45932d..169e02d5b 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -16,7 +16,6 @@ import json import pytest -from pytest_asyncio import is_async_test import firebase_admin from firebase_admin import credentials @@ -71,9 +70,3 @@ def api_key(request): 'command-line option.') with open(path) as keyfile: return keyfile.read().strip() - -def pytest_collection_modifyitems(items): - pytest_asyncio_tests = (item for item in items if is_async_test(item)) - session_scope_marker = pytest.mark.asyncio(loop_scope="session") - for async_test in pytest_asyncio_tests: - async_test.add_marker(session_scope_marker, append=False) diff --git a/integration/test_firestore_async.py b/integration/test_firestore_async.py index 8b73dda0f..584ef590a 100644 --- a/integration/test_firestore_async.py +++ b/integration/test_firestore_async.py @@ -34,7 +34,7 @@ } -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_firestore_async(): client = firestore_async.client() expected = _CITY @@ -48,7 +48,7 @@ async def test_firestore_async(): data = await doc.get() assert data.exists is False -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_firestore_async_explicit_database_id(): client = firestore_async.client(database_id='testing-database') expected = _CITY @@ -62,7 +62,7 @@ async def test_firestore_async_explicit_database_id(): data = await doc.get() assert data.exists is False -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_firestore_async_multi_db(): city_client = firestore_async.client() movie_client = firestore_async.client(database_id='testing-database') @@ -98,7 +98,7 @@ async def test_firestore_async_multi_db(): assert data[0].exists is False assert data[1].exists is False -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_server_timestamp(): client = firestore_async.client() expected = { diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 804691962..7ab707c82 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -157,7 +157,7 @@ def test_unsubscribe(): resp = messaging.unsubscribe_from_topic(_REGISTRATION_TOKEN, 'mock-topic') assert resp.success_count + resp.failure_count == 1 -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_send_each_async(): messages = [ messaging.Message( @@ -189,7 +189,7 @@ async def test_send_each_async(): assert isinstance(response.exception, exceptions.InvalidArgumentError) assert response.message_id is None -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_send_each_async_500(): messages = [] for msg_number in range(500): @@ -206,7 +206,7 @@ async def test_send_each_async_500(): assert response.exception is None assert re.match('^projects/.*/messages/.*$', response.message_id) -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_send_each_for_multicast_async(): multicast = messaging.MulticastMessage( notification=messaging.Notification('Title', 'Body'), diff --git a/integration/test_ml.py b/integration/test_ml.py index 52cb1bb7e..f8dd6bb47 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -317,12 +317,16 @@ def _clean_up_directory(save_dir): @pytest.fixture def keras_model(): assert _TF_ENABLED - x_array = [-1, 0, 1, 2, 3, 4] - y_array = [-3, -1, 1, 3, 5, 7] - model = tf.keras.models.Sequential( - [tf.keras.layers.Dense(units=1, input_shape=[1])]) + x_list = [-1, 0, 1, 2, 3, 4] + y_list = [-3, -1, 1, 3, 5, 7] + x_tensor = tf.convert_to_tensor(x_list, dtype=tf.float32) + y_tensor = tf.convert_to_tensor(y_list, dtype=tf.float32) + model = tf.keras.models.Sequential([ + tf.keras.Input(shape=(1,)), + tf.keras.layers.Dense(units=1) + ]) model.compile(optimizer='sgd', loss='mean_squared_error') - model.fit(x_array, y_array, epochs=3) + model.fit(x_tensor, y_tensor, epochs=3) return model diff --git a/integration/test_storage.py b/integration/test_storage.py index 729190950..4f0faf76c 100644 --- a/integration/test_storage.py +++ b/integration/test_storage.py @@ -38,7 +38,7 @@ def _verify_bucket(bucket, expected_name): blob.upload_from_string('Hello World') blob = bucket.get_blob(file_name) - assert blob.download_as_string().decode() == 'Hello World' + assert blob.download_as_bytes().decode() == 'Hello World' bucket.delete_blob(file_name) assert not bucket.get_blob(file_name) diff --git a/requirements.txt b/requirements.txt index b5642b549..76eeb7582 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ -astroid == 2.3.3 -pylint == 2.3.1 -pytest >= 6.2.0 +astroid == 2.5.8 +pylint == 2.7.4 +pytest >= 8.2.2 pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 -pytest-asyncio >= 0.16.0 +pytest-asyncio >= 0.26.0 pytest-mock >= 3.6.1 respx == 0.22.0 diff --git a/setup.cfg b/setup.cfg index 25c649748..32e00676b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,4 @@ [tool:pytest] testpaths = tests +asyncio_default_test_loop_scope = class +asyncio_default_fixture_loop_scope = None diff --git a/setup.py b/setup.py index b9eb11806..25cf12672 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ (major, minor) = (sys.version_info.major, sys.version_info.minor) if major != 3 or minor < 7: - print('firebase_admin requires python >= 3.7', file=sys.stderr) + print('firebase_admin requires python >= 3.9', file=sys.stderr) sys.exit(1) # Read in the package metadata per recommendations from: @@ -60,18 +60,17 @@ keywords='firebase cloud development', install_requires=install_requires, packages=['firebase_admin'], - python_requires='>=3.7', + python_requires='>=3.9', classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Topic :: Software Development :: Build Tools', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'License :: OSI Approved :: Apache Software License', ], ) diff --git a/tests/test_db.py b/tests/test_db.py index 00a0077cb..93f4672f1 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -45,7 +45,7 @@ def __init__(self, data, status, recorder, etag=ETAG): def send(self, request, **kwargs): if_match = request.headers.get('if-match') if_none_match = request.headers.get('if-none-match') - resp = super(MockAdapter, self).send(request, **kwargs) + resp = super().send(request, **kwargs) resp.headers = {'ETag': self._etag} if if_match and if_match != MockAdapter.ETAG: resp.status_code = 412 diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 341fd9e07..63b649485 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1881,8 +1881,8 @@ def test_send_each(self): assert batch_response.failure_count == 0 assert len(batch_response.responses) == 2 assert [r.message_id for r in batch_response.responses] == ['message-id1', 'message-id2'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) + assert all(r.success for r in batch_response.responses) + assert not any(r.exception for r in batch_response.responses) @respx.mock @pytest.mark.asyncio @@ -1907,8 +1907,8 @@ async def test_send_each_async(self): assert len(batch_response.responses) == 3 assert [r.message_id for r in batch_response.responses] \ == ['message-id1', 'message-id2', 'message-id3'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) + assert all(r.success for r in batch_response.responses) + assert not any(r.exception for r in batch_response.responses) assert route.call_count == 3 @@ -1976,8 +1976,8 @@ async def test_send_each_async_error_401_pass_on_auth_retry(self): assert batch_response.failure_count == 0 assert len(batch_response.responses) == 1 assert [r.message_id for r in batch_response.responses] == ['message-id1'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) + assert all(r.success for r in batch_response.responses) + assert not any(r.exception for r in batch_response.responses) @respx.mock @pytest.mark.asyncio @@ -2049,11 +2049,12 @@ async def test_send_each_async_error_500_pass_on_retry_config(self): assert batch_response.failure_count == 0 assert len(batch_response.responses) == 1 assert [r.message_id for r in batch_response.responses] == ['message-id1'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) + assert all(r.success for r in batch_response.responses) + assert not any(r.exception for r in batch_response.responses) + - @respx.mock @pytest.mark.asyncio + @respx.mock async def test_send_each_async_request_error(self): responses = httpx.ConnectError("Test request error", request=httpx.Request( 'POST', @@ -2192,8 +2193,8 @@ def test_send_each_for_multicast(self): assert batch_response.failure_count == 0 assert len(batch_response.responses) == 2 assert [r.message_id for r in batch_response.responses] == ['message-id1', 'message-id2'] - assert all([r.success for r in batch_response.responses]) - assert not any([r.exception for r in batch_response.responses]) + assert all(r.success for r in batch_response.responses) + assert not any(r.exception for r in batch_response.responses) @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_each_for_multicast_detailed_error(self, status): diff --git a/tests/test_ml.py b/tests/test_ml.py index 18a9e2754..4aebdcab6 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -1094,7 +1094,7 @@ def test_list_single_page(self): assert models_page.next_page_token == '' assert models_page.has_next_page is False assert models_page.get_next_page() is None - models = [model for model in models_page.iterate_all()] + models = list(models_page.iterate_all()) assert len(models) == 1 def test_list_multiple_pages(self): @@ -1140,7 +1140,7 @@ def test_list_models_stop_iteration(self): assert len(recorder) == 1 assert len(page.models) == 3 iterator = page.iterate_all() - models = [model for model in iterator] + models = list(iterator) assert len(page.models) == 3 with pytest.raises(StopIteration): next(iterator) @@ -1151,5 +1151,5 @@ def test_list_models_no_models(self): page = ml.list_models() assert len(recorder) == 1 assert len(page.models) == 0 - models = [model for model in page.iterate_all()] + models = list(page.iterate_all()) assert len(models) == 0 diff --git a/tests/test_remote_config.py b/tests/test_remote_config.py index 8c6248e18..14b54838f 100644 --- a/tests/test_remote_config.py +++ b/tests/test_remote_config.py @@ -830,7 +830,7 @@ def __init__(self, data, status, recorder, etag=ETAG): self._etag = etag def send(self, request, **kwargs): - resp = super(MockAdapter, self).send(request, **kwargs) + resp = super().send(request, **kwargs) resp.headers = {'etag': self._etag} return resp diff --git a/tests/test_sseclient.py b/tests/test_sseclient.py index 70edcf0d0..2c523e36f 100644 --- a/tests/test_sseclient.py +++ b/tests/test_sseclient.py @@ -25,10 +25,10 @@ class MockSSEClientAdapter(testutils.MockAdapter): def __init__(self, payload, recorder): - super(MockSSEClientAdapter, self).__init__(payload, 200, recorder) + super().__init__(payload, 200, recorder) def send(self, request, **kwargs): - resp = super(MockSSEClientAdapter, self).send(request, **kwargs) + resp = super().send(request, **kwargs) resp.url = request.url resp.status_code = self.status resp.raw = io.BytesIO(self.data.encode()) diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index 018892e3a..156846343 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -450,7 +450,7 @@ def test_list_single_page(self, tenant_mgt_app): assert page.next_page_token == '' assert page.has_next_page is False assert page.get_next_page() is None - tenants = [tenant for tenant in page.iterate_all()] + tenants = list(page.iterate_all()) assert len(tenants) == 2 self._assert_request(recorder) @@ -514,7 +514,7 @@ def test_list_tenants_stop_iteration(self, tenant_mgt_app): _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE) page = tenant_mgt.list_tenants(app=tenant_mgt_app) iterator = page.iterate_all() - tenants = [tenant for tenant in iterator] + tenants = list(iterator) assert len(tenants) == 2 with pytest.raises(StopIteration): @@ -526,7 +526,7 @@ def test_list_tenants_no_tenants_response(self, tenant_mgt_app): _instrument_tenant_mgt(tenant_mgt_app, 200, json.dumps(response)) page = tenant_mgt.list_tenants(app=tenant_mgt_app) assert len(page.tenants) == 0 - tenants = [tenant for tenant in page.iterate_all()] + tenants = list(page.iterate_all()) assert len(tenants) == 0 def test_list_tenants_with_max_results(self, tenant_mgt_app): diff --git a/tests/testutils.py b/tests/testutils.py index 62f7bd9b5..0505eb6c7 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -183,7 +183,7 @@ def send(self, request, **kwargs): # pylint: disable=arguments-differ class MockAdapter(MockMultiRequestAdapter): """A mock HTTP adapter for the Python requests module.""" def __init__(self, data, status, recorder): - super(MockAdapter, self).__init__([data], [status], recorder) + super().__init__([data], [status], recorder) @property def status(self): From d8e2269d1cbc4e7b80aef9cc677e69ac98bcdd15 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 19 Jun 2025 14:08:39 -0400 Subject: [PATCH 03/13] change(ml): Drop AutoML model support (#894) --- firebase_admin/ml.py | 54 ++--------------------------------- integration/test_ml.py | 64 +----------------------------------------- tests/test_ml.py | 63 ----------------------------------------- 3 files changed, 3 insertions(+), 178 deletions(-) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 8cedc8482..5fffbd836 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -24,7 +24,6 @@ import time import os from urllib import parse -import warnings import requests @@ -33,14 +32,14 @@ from firebase_admin import _utils from firebase_admin import exceptions -# pylint: disable=import-error,no-name-in-module +# pylint: disable=import-error,no-member try: from firebase_admin import storage _GCS_ENABLED = True except ImportError: _GCS_ENABLED = False -# pylint: disable=import-error,no-name-in-module +# pylint: disable=import-error,no-member try: import tensorflow as tf _TF_ENABLED = True @@ -54,9 +53,6 @@ _TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,32}$') _GCS_TFLITE_URI_PATTERN = re.compile( r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') -_AUTO_ML_MODEL_PATTERN = re.compile( - r'^projects/(?P[a-z0-9-]{6,30})/locations/(?P[^/]+)/' + - r'models/(?P[A-Za-z0-9]+)$') _RESOURCE_NAME_PATTERN = re.compile( r'^projects/(?P[a-z0-9-]{6,30})/models/(?P[A-Za-z0-9_-]{1,60})$') _OPERATION_NAME_PATTERN = re.compile( @@ -388,11 +384,6 @@ def _init_model_source(data): gcs_tflite_uri = data.pop('gcsTfliteUri', None) if gcs_tflite_uri: return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) - auto_ml_model = data.pop('automlModel', None) - if auto_ml_model: - warnings.warn('AutoML model support is deprecated and will be removed in the next ' - 'major version.', DeprecationWarning) - return TFLiteAutoMlSource(auto_ml_model=auto_ml_model) return None @property @@ -606,42 +597,6 @@ def as_dict(self, for_upload=False): return {'gcsTfliteUri': self._gcs_tflite_uri} - -class TFLiteAutoMlSource(TFLiteModelSource): - """TFLite model source representing a tflite model created with AutoML. - - AutoML model support is deprecated and will be removed in the next major version. - """ - - def __init__(self, auto_ml_model, app=None): - warnings.warn('AutoML model support is deprecated and will be removed in the next ' - 'major version.', DeprecationWarning) - self._app = app - self.auto_ml_model = auto_ml_model - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.auto_ml_model == other.auto_ml_model - return False - - def __ne__(self, other): - return not self.__eq__(other) - - @property - def auto_ml_model(self): - """Resource name of the model, created by the AutoML API or Cloud console.""" - return self._auto_ml_model - - @auto_ml_model.setter - def auto_ml_model(self, auto_ml_model): - self._auto_ml_model = _validate_auto_ml_model(auto_ml_model) - - def as_dict(self, for_upload=False): - """Returns a serializable representation of the object.""" - # Upload is irrelevant for auto_ml models - return {'automlModel': self._auto_ml_model} - - class ListModelsPage: """Represents a page of models in a Firebase project. @@ -786,11 +741,6 @@ def _validate_gcs_tflite_uri(uri): raise ValueError('GCS TFLite URI format is invalid.') return uri -def _validate_auto_ml_model(model): - if not _AUTO_ML_MODEL_PATTERN.match(model): - raise ValueError('Model resource name format is invalid.') - return model - def _validate_model_format(model_format): if not isinstance(model_format, ModelFormat): diff --git a/integration/test_ml.py b/integration/test_ml.py index f8dd6bb47..6deb22a69 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -22,25 +22,18 @@ import pytest -import firebase_admin from firebase_admin import exceptions from firebase_admin import ml from tests import testutils -# pylint: disable=import-error,no-name-in-module +# pylint: disable=import-error, no-member try: import tensorflow as tf _TF_ENABLED = True except ImportError: _TF_ENABLED = False -try: - from google.cloud import automl_v1 - _AUTOML_ENABLED = True -except ImportError: - _AUTOML_ENABLED = False - def _random_identifier(prefix): #pylint: disable=unused-variable suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)]) @@ -159,14 +152,6 @@ def check_tflite_gcs_format(model, validation_error=None): assert model.model_hash is not None -def check_tflite_automl_format(model): - assert model.validation_error is None - assert model.published is False - assert model.model_format.model_source.auto_ml_model.startswith('projects/') - # Automl models don't have validation errors since they are references - # to valid automl models. - - @pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) def test_create_simple_model(firebase_model): check_model(firebase_model, NAME_AND_TAGS_ARGS) @@ -392,50 +377,3 @@ def test_from_saved_model(saved_model_dir): assert created_model.validation_error is None finally: _clean_up_model(created_model) - - -# Test AutoML functionality if AutoML is enabled. -#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True -# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the -# successful test. (Test is skipped otherwise) - -@pytest.fixture -def automl_model(): - assert _AUTOML_ENABLED - - # It takes > 20 minutes to train a model, so we expect a predefined AutoMl - # model named 'admin_sdk_integ_test1' to exist in the project, or we skip - # the test. - automl_client = automl_v1.AutoMlClient() - project_id = firebase_admin.get_app().project_id - parent = automl_client.location_path(project_id, 'us-central1') - models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1") - # Expecting exactly one. (Ok to use last one if somehow more than 1) - automl_ref = None - for model in models: - automl_ref = model.name - - # Skip if no pre-defined model. (It takes min > 20 minutes to train a model) - if automl_ref is None: - pytest.skip("No pre-existing AutoML model found. Skipping test") - - source = ml.TFLiteAutoMlSource(automl_ref) - tflite_format = ml.TFLiteFormat(model_source=source) - ml_model = ml.Model( - display_name=_random_identifier('TestModel_automl_'), - tags=['test_automl'], - model_format=tflite_format) - model = ml.create_model(model=ml_model) - yield model - _clean_up_model(model) - -@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.') -def test_automl_model(automl_model): - # This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1' - automl_model.wait_for_unlocked() - - check_model(automl_model, { - 'display_name': automl_model.display_name, - 'tags': ['test_automl'], - }) - check_tflite_automl_format(automl_model) diff --git a/tests/test_ml.py b/tests/test_ml.py index 4aebdcab6..2af9ae42f 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -121,18 +121,6 @@ } TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) -AUTOML_MODEL_NAME = 'projects/111111111111/locations/us-central1/models/ICN7683346839371803263' -AUTOML_MODEL_SOURCE = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME) -TFLITE_FORMAT_JSON_3 = { - 'automlModel': AUTOML_MODEL_NAME, - 'sizeBytes': '3456789' -} -TFLITE_FORMAT_3 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_3) - -AUTOML_MODEL_NAME_2 = 'projects/2222222222/locations/us-central1/models/ICN2222222222222222222' -AUTOML_MODEL_NAME_JSON_2 = {'automlModel': AUTOML_MODEL_NAME_2} -AUTOML_MODEL_SOURCE_2 = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME_2) - CREATED_UPDATED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, @@ -423,14 +411,6 @@ def test_model_keyword_based_creation_and_setters(self): 'tfliteModel': TFLITE_FORMAT_JSON_2 } - model.model_format = TFLITE_FORMAT_3 - assert model.as_dict() == { - 'displayName': DISPLAY_NAME_2, - 'tags': TAGS_2, - 'tfliteModel': TFLITE_FORMAT_JSON_3 - } - - def test_gcs_tflite_model_format_source_creation(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) @@ -442,17 +422,6 @@ def test_gcs_tflite_model_format_source_creation(self): } } - def test_auto_ml_tflite_model_format_source_creation(self): - model_source = ml.TFLiteAutoMlSource(auto_ml_model=AUTOML_MODEL_NAME) - model_format = ml.TFLiteFormat(model_source=model_source) - model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) - assert model.as_dict() == { - 'displayName': DISPLAY_NAME_1, - 'tfliteModel': { - 'automlModel': AUTOML_MODEL_NAME - } - } - def test_source_creation_from_tflite_file(self): model_source = ml.TFLiteGCSModelSource.from_tflite_model_file( "my_model.tflite", "my_bucket") @@ -466,13 +435,6 @@ def test_gcs_tflite_model_source_setters(self): assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 - def test_auto_ml_tflite_model_source_setters(self): - model_source = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME) - model_source.auto_ml_model = AUTOML_MODEL_NAME_2 - assert model_source.auto_ml_model == AUTOML_MODEL_NAME_2 - assert model_source.as_dict() == AUTOML_MODEL_NAME_JSON_2 - - def test_model_format_setters(self): model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 @@ -483,14 +445,6 @@ def test_model_format_setters(self): } } - model_format.model_source = AUTOML_MODEL_SOURCE - assert model_format.model_source == AUTOML_MODEL_SOURCE - assert model_format.as_dict() == { - 'tfliteModel': { - 'automlModel': AUTOML_MODEL_NAME - } - } - def test_model_as_dict_for_upload(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) @@ -576,23 +530,6 @@ def test_gcs_tflite_source_validation_errors(self, uri, exc_type): ml.TFLiteGCSModelSource(gcs_tflite_uri=uri) check_error(excinfo, exc_type) - @pytest.mark.parametrize('auto_ml_model, exc_type', [ - (123, TypeError), - ('abc', ValueError), - ('/projects/123456/locations/us-central1/models/noLeadingSlash', ValueError), - ('projects/123546/models/ICN123456', ValueError), - ('projects//locations/us-central1/models/ICN123456', ValueError), - ('projects/123456/locations//models/ICN123456', ValueError), - ('projects/123456/locations/us-central1/models/', ValueError), - ('projects/ABC/locations/us-central1/models/ICN123456', ValueError), - ('projects/123456/locations/us-central1/models/@#$%^&', ValueError), - ('projects/123456/locations/us-cent/ral1/models/ICN123456', ValueError), - ]) - def test_auto_ml_tflite_source_validation_errors(self, auto_ml_model, exc_type): - with pytest.raises(exc_type) as excinfo: - ml.TFLiteAutoMlSource(auto_ml_model=auto_ml_model) - check_error(excinfo, exc_type) - def test_wait_for_unlocked_not_locked(self): model = ml.Model(display_name="not_locked") model.wait_for_unlocked() From c9dce80a787eb4e16f478c3f52f2c83dde090e9b Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 26 Jun 2025 09:54:29 -0400 Subject: [PATCH 04/13] chore: Bump `pylint` to v3.3.7 and `astroid` to v3.3.10 (#895) * chore: Bump pylint to v3 * chore: fix src lint * chore: fix unit test lint * chore: fix integration test lint * chore: fix snippets lint * chore: 2nd pass for errors * fix: corrected use of the `bad-functions` config * fix: add EoF newline --- .pylintrc | 105 +++++++----------- firebase_admin/__init__.py | 47 ++++---- firebase_admin/_auth_client.py | 15 +-- firebase_admin/_auth_providers.py | 53 +++++---- firebase_admin/_auth_utils.py | 98 ++++++++--------- firebase_admin/_messaging_encoder.py | 52 ++++----- firebase_admin/_rfc3339.py | 2 +- firebase_admin/_sseclient.py | 9 +- firebase_admin/_token_gen.py | 113 ++++++++++---------- firebase_admin/_user_mgt.py | 57 +++++----- firebase_admin/_utils.py | 21 ++-- firebase_admin/app_check.py | 4 +- firebase_admin/credentials.py | 26 ++--- firebase_admin/db.py | 59 +++++----- firebase_admin/functions.py | 6 +- firebase_admin/instance_id.py | 6 +- firebase_admin/messaging.py | 32 +++--- firebase_admin/ml.py | 20 ++-- firebase_admin/project_management.py | 43 ++++---- firebase_admin/remote_config.py | 10 +- firebase_admin/storage.py | 4 +- firebase_admin/tenant_mgt.py | 31 +++--- integration/conftest.py | 8 +- integration/test_auth.py | 17 +-- integration/test_db.py | 14 +-- integration/test_firestore.py | 14 +-- integration/test_firestore_async.py | 10 +- integration/test_messaging.py | 4 +- integration/test_ml.py | 13 ++- integration/test_project_management.py | 11 +- integration/test_storage.py | 6 +- integration/test_tenant_mgt.py | 9 +- requirements.txt | 4 +- snippets/auth/get_service_account_tokens.py | 2 +- snippets/auth/index.py | 29 ++--- snippets/database/index.py | 12 +-- snippets/messaging/cloud_messaging.py | 6 +- tests/test_app.py | 14 +-- tests/test_app_check.py | 4 +- tests/test_auth_providers.py | 21 ++-- tests/test_credentials.py | 4 +- tests/test_db.py | 65 +++++------ tests/test_instance_id.py | 6 +- tests/test_messaging.py | 40 +++---- tests/test_ml.py | 64 ++++++----- tests/test_project_management.py | 6 +- tests/test_remote_config.py | 2 +- tests/test_storage.py | 2 +- tests/test_tenant_mgt.py | 75 +++++++------ tests/test_token_gen.py | 20 ++-- tests/test_user_mgt.py | 39 ++++--- tests/testutils.py | 2 +- 52 files changed, 648 insertions(+), 688 deletions(-) diff --git a/.pylintrc b/.pylintrc index 2155853c7..ea54e481c 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,4 +1,4 @@ -[MASTER] +[MAIN] # Specify a configuration file. #rcfile= @@ -20,7 +20,9 @@ persistent=no # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. -load-plugins=pylint.extensions.docparams,pylint.extensions.docstyle +load-plugins=pylint.extensions.docparams, + pylint.extensions.docstyle, + pylint.extensions.bad_builtin, # Use multiple processes to speed up Pylint. jobs=1 @@ -34,15 +36,6 @@ unsafe-load-any-extension=no # run arbitrary code extension-pkg-whitelist= -# Allow optimization of some AST trees. This will activate a peephole AST -# optimizer, which will apply various small optimizations. For instance, it can -# be used to obtain the result of joining multiple strings with the addition -# operator. Joining a lot of strings can lead to a maximum recursion error in -# Pylint and this flag can prevent that. It has one side effect, the resulting -# AST will be different than the one from reality. This option is deprecated -# and it will be removed in Pylint 2.0. -optimize-ast=no - [MESSAGES CONTROL] @@ -65,21 +58,31 @@ enable=indexing-exception,old-raise-syntax # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" -disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,file-ignored,missing-type-doc +disable=design, + similarities, + no-self-use, + attribute-defined-outside-init, + locally-disabled, + star-args, + pointless-except, + bad-option-value, + lobal-statement, + fixme, + suppressed-message, + useless-suppression, + locally-enabled, + file-ignored, + missing-type-doc, + c-extension-no-member, [REPORTS] -# Set the output format. Available formats are text, parseable, colorized, msvs -# (visual studio) and html. You can also give a reporter class, eg -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". This option is deprecated -# and it will be removed in Pylint 2.0. -files-output=no +# Set the output format. Available formats are: 'text', 'parseable', +# 'colorized', 'json2' (improved json format), 'json' (old json format), msvs +# (visual studio) and 'github' (GitHub actions). You can also give a reporter +# class, e.g. mypackage.mymodule.MyReporterClass. +output-format=colorized # Tells whether to display a full report or only the messages reports=no @@ -176,9 +179,12 @@ logging-modules=logging good-names=main,_ # Bad variable names which should always be refused, separated by a comma -bad-names= - -bad-functions=input,apply,reduce +bad-names=foo, + bar, + baz, + toto, + tutu, + tata # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. @@ -194,64 +200,33 @@ property-classes=abc.abstractproperty # Regular expression matching correct function names function-rgx=[a-z_][a-z0-9_]*$ -# Naming hint for function names -function-name-hint=[a-z_][a-z0-9_]*$ - # Regular expression matching correct variable names variable-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for variable names -variable-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct constant names const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Naming hint for constant names -const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ - # Regular expression matching correct attribute names attr-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for attribute names -attr-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct argument names argument-rgx=[a-z_][a-z0-9_]{2,30}$ -# Naming hint for argument names -argument-name-hint=[a-z_][a-z0-9_]{2,30}$ - # Regular expression matching correct class attribute names class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ -# Naming hint for class attribute names -class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ - # Regular expression matching correct inline iteration names inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ -# Naming hint for inline iteration names -inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ - # Regular expression matching correct class names class-rgx=[A-Z_][a-zA-Z0-9]+$ -# Naming hint for class names -class-name-hint=[A-Z_][a-zA-Z0-9]+$ - # Regular expression matching correct module names module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ -# Naming hint for module names -module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ - # Regular expression matching correct method names method-rgx=[a-z_][a-z0-9_]*$ -# Naming hint for method names -method-name-hint=[a-z_][a-z0-9_]*$ - # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=(__.*__|main) @@ -294,12 +269,6 @@ ignore-long-lines=^\s*(# )??$ # else. single-line-if-stmt=no -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check=trailing-comma,dict-separator - # Maximum number of lines in a module max-module-lines=1000 @@ -405,6 +374,12 @@ exclude-protected=_asdict,_fields,_replace,_source,_make [EXCEPTIONS] -# Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=Exception +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + +[DEPRECATED_BUILTINS] + +# List of builtins function names that should not be used, separated by a comma +bad-functions=input, + apply, + reduce diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index 597aaa6b6..8c9f628e5 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -79,11 +79,11 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): 'apps, pass a second argument to initialize_app() to give each app ' 'a unique name.')) - raise ValueError(( - 'Firebase app named "{0}" already exists. This means you called ' + raise ValueError( + f'Firebase app named "{name}" already exists. This means you called ' 'initialize_app() more than once with the same app name as the ' 'second argument. Make sure you provide a unique name every time ' - 'you call initialize_app().').format(name)) + 'you call initialize_app().') def delete_app(app): @@ -96,8 +96,7 @@ def delete_app(app): ValueError: If the app is not initialized. """ if not isinstance(app, App): - raise ValueError('Illegal app argument type: "{}". Argument must be of ' - 'type App.'.format(type(app))) + raise ValueError(f'Illegal app argument type: "{type(app)}". Argument must be of type App.') with _apps_lock: if _apps.get(app.name) is app: del _apps[app.name] @@ -109,9 +108,9 @@ def delete_app(app): 'the default app by calling initialize_app().') raise ValueError( - ('Firebase app named "{0}" is not initialized. Make sure to initialize ' - 'the app by calling initialize_app() with your app name as the ' - 'second argument.').format(app.name)) + f'Firebase app named "{app.name}" is not initialized. Make sure to initialize ' + 'the app by calling initialize_app() with your app name as the ' + 'second argument.') def get_app(name=_DEFAULT_APP_NAME): @@ -128,8 +127,8 @@ def get_app(name=_DEFAULT_APP_NAME): app does not exist. """ if not isinstance(name, str): - raise ValueError('Illegal app name argument type: "{}". App name ' - 'must be a string.'.format(type(name))) + raise ValueError( + f'Illegal app name argument type: "{type(name)}". App name must be a string.') with _apps_lock: if name in _apps: return _apps[name] @@ -140,9 +139,9 @@ def get_app(name=_DEFAULT_APP_NAME): 'the SDK by calling initialize_app().') raise ValueError( - ('Firebase app named "{0}" does not exist. Make sure to initialize ' - 'the SDK by calling initialize_app() with your app name as the ' - 'second argument.').format(name)) + f'Firebase app named "{name}" does not exist. Make sure to initialize ' + 'the SDK by calling initialize_app() with your app name as the ' + 'second argument.') class _AppOptions: @@ -153,8 +152,9 @@ def __init__(self, options): options = self._load_from_environment() if not isinstance(options, dict): - raise ValueError('Illegal Firebase app options type: {0}. Options ' - 'must be a dictionary.'.format(type(options))) + raise ValueError( + f'Illegal Firebase app options type: {type(options)}. ' + 'Options must be a dictionary.') self._options = options def get(self, key, default=None): @@ -175,15 +175,15 @@ def _load_from_environment(self): json_str = config_file else: try: - with open(config_file, 'r') as json_file: + with open(config_file, 'r', encoding='utf-8') as json_file: json_str = json_file.read() except Exception as err: - raise ValueError('Unable to read file {}. {}'.format(config_file, err)) from err + raise ValueError(f'Unable to read file {config_file}. {err}') from err try: json_data = json.loads(json_str) except Exception as err: raise ValueError( - 'JSON string "{0}" is not valid json. {1}'.format(json_str, err)) from err + f'JSON string "{json_str}" is not valid json. {err}') from err return {k: v for k, v in json_data.items() if k in _CONFIG_VALID_KEYS} @@ -206,8 +206,9 @@ def __init__(self, name, credential, options): ValueError: If an argument is None or invalid. """ if not name or not isinstance(name, str): - raise ValueError('Illegal Firebase app name "{0}" provided. App name must be a ' - 'non-empty string.'.format(name)) + raise ValueError( + f'Illegal Firebase app name "{name}" provided. App name must be a ' + 'non-empty string.') self._name = name if isinstance(credential, GoogleAuthCredentials): @@ -228,7 +229,7 @@ def __init__(self, name, credential, options): def _validate_project_id(cls, project_id): if project_id is not None and not isinstance(project_id, str): raise ValueError( - 'Invalid project ID: "{0}". project ID must be a string.'.format(project_id)) + f'Invalid project ID: "{project_id}". project ID must be a string.') @property def name(self): @@ -293,11 +294,11 @@ def _get_service(self, name, initializer): """ if not name or not isinstance(name, str): raise ValueError( - 'Illegal name argument: "{0}". Name must be a non-empty string.'.format(name)) + f'Illegal name argument: "{name}". Name must be a non-empty string.') with self._lock: if self._services is None: raise ValueError( - 'Service requested from deleted Firebase App: "{0}".'.format(self._name)) + f'Service requested from deleted Firebase App: "{self._name}".') if name not in self._services: self._services[name] = initializer(self) return self._services[name] diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 38b42993a..74261fa37 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -38,7 +38,7 @@ def __init__(self, app, tenant_id=None): 3. set the project ID via the GOOGLE_CLOUD_PROJECT environment variable.""") credential = None - version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) + version_header = f'Python/Admin/{firebase_admin.__version__}' timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) # Non-default endpoint URLs for emulator support are set in this dict later. endpoint_urls = {} @@ -48,7 +48,7 @@ def __init__(self, app, tenant_id=None): # endpoint URLs to use the emulator. Additionally, use a fake credential. emulator_host = _auth_utils.get_emulator_host() if emulator_host: - base_url = 'http://{0}/identitytoolkit.googleapis.com'.format(emulator_host) + base_url = f'http://{emulator_host}/identitytoolkit.googleapis.com' endpoint_urls['v1'] = base_url + '/v1' endpoint_urls['v2'] = base_url + '/v2' credential = _utils.EmulatorAdminCredentials() @@ -123,15 +123,16 @@ def verify_id_token(self, id_token, check_revoked=False, clock_skew_seconds=0): """ if not isinstance(check_revoked, bool): # guard against accidental wrong assignment. - raise ValueError('Illegal check_revoked argument. Argument must be of type ' - ' bool, but given "{0}".'.format(type(check_revoked))) + raise ValueError( + 'Illegal check_revoked argument. Argument must be of type bool, but given ' + f'"{type(check_revoked)}".') verified_claims = self._token_verifier.verify_id_token(id_token, clock_skew_seconds) if self.tenant_id: token_tenant_id = verified_claims.get('firebase', {}).get('tenant') if self.tenant_id != token_tenant_id: raise _auth_utils.TenantIdMismatchError( - 'Invalid tenant ID: {0}'.format(token_tenant_id)) + f'Invalid tenant ID: {token_tenant_id}') if check_revoked: self._check_jwt_revoked_or_disabled( @@ -249,7 +250,7 @@ def _matches(identifier, user_record): if identifier.provider_id == user_info.provider_id and identifier.provider_uid == user_info.uid ), False) - raise TypeError("Unexpected type: {}".format(type(identifier))) + raise TypeError(f"Unexpected type: {type(identifier)}") def _is_user_found(identifier, user_records): return any(_matches(identifier, user_record) for user_record in user_records) @@ -757,4 +758,4 @@ def _check_jwt_revoked_or_disabled(self, verified_claims, exc_type, label): if user.disabled: raise _auth_utils.UserDisabledError('The user record is disabled.') if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: - raise exc_type('The Firebase {0} has been revoked.'.format(label)) + raise exc_type(f'The Firebase {label} has been revoked.') diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index 6512a4f7b..cc7949526 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -181,13 +181,13 @@ class ProviderConfigClient: def __init__(self, http_client, project_id, tenant_id=None, url_override=None): self.http_client = http_client url_prefix = url_override or self.PROVIDER_CONFIG_URL - self.base_url = '{0}/projects/{1}'.format(url_prefix, project_id) + self.base_url = f'{url_prefix}/projects/{project_id}' if tenant_id: - self.base_url += '/tenants/{0}'.format(tenant_id) + self.base_url += f'/tenants/{tenant_id}' def get_oidc_provider_config(self, provider_id): _validate_oidc_provider_id(provider_id) - body = self._make_request('get', '/oauthIdpConfigs/{0}'.format(provider_id)) + body = self._make_request('get', f'/oauthIdpConfigs/{provider_id}') return OIDCProviderConfig(body) def create_oidc_provider_config( @@ -218,7 +218,7 @@ def create_oidc_provider_config( if response_type: req['responseType'] = response_type - params = 'oauthIdpConfigId={0}'.format(provider_id) + params = f'oauthIdpConfigId={provider_id}' body = self._make_request('post', '/oauthIdpConfigs', json=req, params=params) return OIDCProviderConfig(body) @@ -259,14 +259,14 @@ def update_oidc_provider_config( raise ValueError('At least one parameter must be specified for update.') update_mask = _auth_utils.build_update_mask(req) - params = 'updateMask={0}'.format(','.join(update_mask)) - url = '/oauthIdpConfigs/{0}'.format(provider_id) + params = f'updateMask={",".join(update_mask)}' + url = f'/oauthIdpConfigs/{provider_id}' body = self._make_request('patch', url, json=req, params=params) return OIDCProviderConfig(body) def delete_oidc_provider_config(self, provider_id): _validate_oidc_provider_id(provider_id) - self._make_request('delete', '/oauthIdpConfigs/{0}'.format(provider_id)) + self._make_request('delete', f'/oauthIdpConfigs/{provider_id}') def list_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): return _ListOIDCProviderConfigsPage( @@ -277,7 +277,7 @@ def _fetch_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CON def get_saml_provider_config(self, provider_id): _validate_saml_provider_id(provider_id) - body = self._make_request('get', '/inboundSamlConfigs/{0}'.format(provider_id)) + body = self._make_request('get', f'/inboundSamlConfigs/{provider_id}') return SAMLProviderConfig(body) def create_saml_provider_config( @@ -301,7 +301,7 @@ def create_saml_provider_config( if enabled is not None: req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') - params = 'inboundSamlConfigId={0}'.format(provider_id) + params = f'inboundSamlConfigId={provider_id}' body = self._make_request('post', '/inboundSamlConfigs', json=req, params=params) return SAMLProviderConfig(body) @@ -341,14 +341,14 @@ def update_saml_provider_config( raise ValueError('At least one parameter must be specified for update.') update_mask = _auth_utils.build_update_mask(req) - params = 'updateMask={0}'.format(','.join(update_mask)) - url = '/inboundSamlConfigs/{0}'.format(provider_id) + params = f'updateMask={",".join(update_mask)}' + url = f'/inboundSamlConfigs/{provider_id}' body = self._make_request('patch', url, json=req, params=params) return SAMLProviderConfig(body) def delete_saml_provider_config(self, provider_id): _validate_saml_provider_id(provider_id) - self._make_request('delete', '/inboundSamlConfigs/{0}'.format(provider_id)) + self._make_request('delete', f'/inboundSamlConfigs/{provider_id}') def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): return _ListSAMLProviderConfigsPage( @@ -367,15 +367,15 @@ def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CO if max_results < 1 or max_results > MAX_LIST_CONFIGS_RESULTS: raise ValueError( 'Max results must be a positive integer less than or equal to ' - '{0}.'.format(MAX_LIST_CONFIGS_RESULTS)) + f'{MAX_LIST_CONFIGS_RESULTS}.') - params = 'pageSize={0}'.format(max_results) + params = f'pageSize={max_results}' if page_token: - params += '&pageToken={0}'.format(page_token) + params += f'&pageToken={page_token}' return self._make_request('get', path, params=params) def _make_request(self, method, path, **kwargs): - url = '{0}{1}'.format(self.base_url, path) + url = f'{self.base_url}{path}' try: return self.http_client.body(method, url, **kwargs) except requests.exceptions.RequestException as error: @@ -385,29 +385,27 @@ def _make_request(self, method, path, **kwargs): def _validate_oidc_provider_id(provider_id): if not isinstance(provider_id, str): raise ValueError( - 'Invalid OIDC provider ID: {0}. Provider ID must be a non-empty string.'.format( - provider_id)) + f'Invalid OIDC provider ID: {provider_id}. Provider ID must be a non-empty string.') if not provider_id.startswith('oidc.'): - raise ValueError('Invalid OIDC provider ID: {0}.'.format(provider_id)) + raise ValueError(f'Invalid OIDC provider ID: {provider_id}.') return provider_id def _validate_saml_provider_id(provider_id): if not isinstance(provider_id, str): raise ValueError( - 'Invalid SAML provider ID: {0}. Provider ID must be a non-empty string.'.format( - provider_id)) + f'Invalid SAML provider ID: {provider_id}. Provider ID must be a non-empty string.') if not provider_id.startswith('saml.'): - raise ValueError('Invalid SAML provider ID: {0}.'.format(provider_id)) + raise ValueError(f'Invalid SAML provider ID: {provider_id}.') return provider_id def _validate_non_empty_string(value, label): """Validates that the given value is a non-empty string.""" if not isinstance(value, str): - raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) + raise ValueError(f'Invalid type for {label}: {value}.') if not value: - raise ValueError('{0} must not be empty.'.format(label)) + raise ValueError(f'{label} must not be empty.') return value @@ -415,15 +413,14 @@ def _validate_url(url, label): """Validates that the given value is a well-formed URL string.""" if not isinstance(url, str) or not url: raise ValueError( - 'Invalid photo URL: "{0}". {1} must be a non-empty ' - 'string.'.format(url, label)) + f'Invalid photo URL: "{url}". {label} must be a non-empty string.') try: parsed = parse.urlparse(url) if not parsed.netloc: - raise ValueError('Malformed {0}: "{1}".'.format(label, url)) + raise ValueError(f'Malformed {label}: "{url}".') return url except Exception as exception: - raise ValueError('Malformed {0}: "{1}".'.format(label, url)) from exception + raise ValueError(f'Malformed {label}: "{url}".') from exception def _validate_x509_certificates(x509_certificates): diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 0d56ca7fa..60d411822 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -74,8 +74,8 @@ def get_emulator_host(): emulator_host = os.getenv(EMULATOR_HOST_ENV_VAR, '') if emulator_host and '//' in emulator_host: raise ValueError( - 'Invalid {0}: "{1}". It must follow format "host:port".'.format( - EMULATOR_HOST_ENV_VAR, emulator_host)) + f'Invalid {EMULATOR_HOST_ENV_VAR}: "{emulator_host}". ' + 'It must follow format "host:port".') return emulator_host @@ -88,8 +88,8 @@ def validate_uid(uid, required=False): return None if not isinstance(uid, str) or not uid or len(uid) > 128: raise ValueError( - 'Invalid uid: "{0}". The uid must be a non-empty string with no more than 128 ' - 'characters.'.format(uid)) + f'Invalid uid: "{uid}". The uid must be a non-empty string with no more than 128 ' + 'characters.') return uid def validate_email(email, required=False): @@ -97,10 +97,10 @@ def validate_email(email, required=False): return None if not isinstance(email, str) or not email: raise ValueError( - 'Invalid email: "{0}". Email must be a non-empty string.'.format(email)) + f'Invalid email: "{email}". Email must be a non-empty string.') parts = email.split('@') if len(parts) != 2 or not parts[0] or not parts[1]: - raise ValueError('Malformed email address string: "{0}".'.format(email)) + raise ValueError(f'Malformed email address string: "{email}".') return email def validate_phone(phone, required=False): @@ -113,11 +113,12 @@ def validate_phone(phone, required=False): if phone is None and not required: return None if not isinstance(phone, str) or not phone: - raise ValueError('Invalid phone number: "{0}". Phone number must be a non-empty ' - 'string.'.format(phone)) + raise ValueError( + f'Invalid phone number: "{phone}". Phone number must be a non-empty string.') if not phone.startswith('+') or not re.search('[a-zA-Z0-9]', phone): - raise ValueError('Invalid phone number: "{0}". Phone number must be a valid, E.164 ' - 'compliant identifier.'.format(phone)) + raise ValueError( + f'Invalid phone number: "{phone}". Phone number must be a valid, E.164 ' + 'compliant identifier.') return phone def validate_password(password, required=False): @@ -132,7 +133,7 @@ def validate_bytes(value, label, required=False): if value is None and not required: return None if not isinstance(value, bytes) or not value: - raise ValueError('{0} must be a non-empty byte sequence.'.format(label)) + raise ValueError(f'{label} must be a non-empty byte sequence.') return value def validate_display_name(display_name, required=False): @@ -140,8 +141,8 @@ def validate_display_name(display_name, required=False): return None if not isinstance(display_name, str) or not display_name: raise ValueError( - 'Invalid display name: "{0}". Display name must be a non-empty ' - 'string.'.format(display_name)) + f'Invalid display name: "{display_name}". Display name must be a non-empty ' + 'string.') return display_name def validate_provider_id(provider_id, required=True): @@ -149,8 +150,7 @@ def validate_provider_id(provider_id, required=True): return None if not isinstance(provider_id, str) or not provider_id: raise ValueError( - 'Invalid provider ID: "{0}". Provider ID must be a non-empty ' - 'string.'.format(provider_id)) + f'Invalid provider ID: "{provider_id}". Provider ID must be a non-empty string.') return provider_id def validate_provider_uid(provider_uid, required=True): @@ -158,8 +158,7 @@ def validate_provider_uid(provider_uid, required=True): return None if not isinstance(provider_uid, str) or not provider_uid: raise ValueError( - 'Invalid provider UID: "{0}". Provider UID must be a non-empty ' - 'string.'.format(provider_uid)) + f'Invalid provider UID: "{provider_uid}". Provider UID must be a non-empty string.') return provider_uid def validate_photo_url(photo_url, required=False): @@ -168,15 +167,14 @@ def validate_photo_url(photo_url, required=False): return None if not isinstance(photo_url, str) or not photo_url: raise ValueError( - 'Invalid photo URL: "{0}". Photo URL must be a non-empty ' - 'string.'.format(photo_url)) + f'Invalid photo URL: "{photo_url}". Photo URL must be a non-empty string.') try: parsed = parse.urlparse(photo_url) if not parsed.netloc: - raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) + raise ValueError(f'Malformed photo URL: "{photo_url}".') return photo_url except Exception as err: - raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) from err + raise ValueError(f'Malformed photo URL: "{photo_url}".') from err def validate_timestamp(timestamp, label, required=False): """Validates the given timestamp value. Timestamps must be positive integers.""" @@ -187,13 +185,12 @@ def validate_timestamp(timestamp, label, required=False): try: timestamp_int = int(timestamp) except TypeError as err: - raise ValueError('Invalid type for timestamp value: {0}.'.format(timestamp)) from err - else: - if timestamp_int != timestamp: - raise ValueError('{0} must be a numeric value and a whole number.'.format(label)) - if timestamp_int <= 0: - raise ValueError('{0} timestamp must be a positive interger.'.format(label)) - return timestamp_int + raise ValueError(f'Invalid type for timestamp value: {timestamp}.') from err + if timestamp_int != timestamp: + raise ValueError(f'{label} must be a numeric value and a whole number.') + if timestamp_int <= 0: + raise ValueError(f'{label} timestamp must be a positive interger.') + return timestamp_int def validate_int(value, label, low=None, high=None): """Validates that the given value represents an integer. @@ -204,31 +201,30 @@ def validate_int(value, label, low=None, high=None): a developer error. """ if value is None or isinstance(value, bool): - raise ValueError('Invalid type for integer value: {0}.'.format(value)) + raise ValueError(f'Invalid type for integer value: {value}.') try: val_int = int(value) except TypeError as err: - raise ValueError('Invalid type for integer value: {0}.'.format(value)) from err - else: - if val_int != value: - # This will be True for non-numeric values like '2' and non-whole numbers like 2.5. - raise ValueError('{0} must be a numeric value and a whole number.'.format(label)) - if low is not None and val_int < low: - raise ValueError('{0} must not be smaller than {1}.'.format(label, low)) - if high is not None and val_int > high: - raise ValueError('{0} must not be larger than {1}.'.format(label, high)) - return val_int + raise ValueError(f'Invalid type for integer value: {value}.') from err + if val_int != value: + # This will be True for non-numeric values like '2' and non-whole numbers like 2.5. + raise ValueError(f'{label} must be a numeric value and a whole number.') + if low is not None and val_int < low: + raise ValueError(f'{label} must not be smaller than {low}.') + if high is not None and val_int > high: + raise ValueError(f'{label} must not be larger than {high}.') + return val_int def validate_string(value, label): """Validates that the given value is a string.""" if not isinstance(value, str): - raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) + raise ValueError(f'Invalid type for {label}: {value}.') return value def validate_boolean(value, label): """Validates that the given value is a boolean.""" if not isinstance(value, bool): - raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) + raise ValueError(f'Invalid type for {label}: {value}.') return value def validate_custom_claims(custom_claims, required=False): @@ -242,8 +238,7 @@ def validate_custom_claims(custom_claims, required=False): claims_str = str(custom_claims) if len(claims_str) > MAX_CLAIMS_PAYLOAD_SIZE: raise ValueError( - 'Custom claims payload must not exceed {0} characters.'.format( - MAX_CLAIMS_PAYLOAD_SIZE)) + f'Custom claims payload must not exceed {MAX_CLAIMS_PAYLOAD_SIZE} characters.') try: parsed = json.loads(claims_str) except Exception as err: @@ -254,16 +249,17 @@ def validate_custom_claims(custom_claims, required=False): invalid_claims = RESERVED_CLAIMS.intersection(set(parsed.keys())) if len(invalid_claims) > 1: joined = ', '.join(sorted(invalid_claims)) - raise ValueError('Claims "{0}" are reserved, and must not be set.'.format(joined)) + raise ValueError(f'Claims "{joined}" are reserved, and must not be set.') if len(invalid_claims) == 1: raise ValueError( - 'Claim "{0}" is reserved, and must not be set.'.format(invalid_claims.pop())) + f'Claim "{invalid_claims.pop()}" is reserved, and must not be set.') return claims_str def validate_action_type(action_type): if action_type not in VALID_EMAIL_ACTION_TYPES: - raise ValueError('Invalid action type provided action_type: {0}. \ - Valid values are {1}'.format(action_type, ', '.join(VALID_EMAIL_ACTION_TYPES))) + raise ValueError( + f'Invalid action type provided action_type: {action_type}. Valid values are ' + f'{", ".join(VALID_EMAIL_ACTION_TYPES)}') return action_type def validate_provider_ids(provider_ids, required=False): @@ -282,7 +278,7 @@ def build_update_mask(params): if isinstance(value, dict): child_mask = build_update_mask(value) for child in child_mask: - mask.append('{0}.{1}'.format(key, child)) + mask.append(f'{key}.{child}') else: mask.append(key) @@ -443,7 +439,7 @@ def handle_auth_backend_error(error): code, custom_message = _parse_error_body(error.response) if not code: - msg = 'Unexpected error response: {0}'.format(error.response.content.decode()) + msg = f'Unexpected error response: {error.response.content.decode()}' return _utils.handle_requests_error(error, message=msg) exc_type = _CODE_TO_EXC_TYPE.get(code) @@ -479,5 +475,5 @@ def _parse_error_body(response): def _build_error_message(code, exc_type, custom_message): default_message = exc_type.default_message if ( exc_type and hasattr(exc_type, 'default_message')) else 'Error while calling Auth service' - ext = ' {0}'.format(custom_message) if custom_message else '' - return '{0} ({1}).{2}'.format(default_message, code, ext) + ext = f' {custom_message}' if custom_message else '' + return f'{default_message} ({code}).{ext}' diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index 32f97875e..960a6d742 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -20,7 +20,7 @@ import numbers import re -import firebase_admin._messaging_utils as _messaging_utils +from firebase_admin import _messaging_utils class Message: @@ -99,10 +99,10 @@ def check_string(cls, label, value, non_empty=False): return None if not isinstance(value, str): if non_empty: - raise ValueError('{0} must be a non-empty string.'.format(label)) - raise ValueError('{0} must be a string.'.format(label)) + raise ValueError(f'{label} must be a non-empty string.') + raise ValueError(f'{label} must be a string.') if non_empty and not value: - raise ValueError('{0} must be a non-empty string.'.format(label)) + raise ValueError(f'{label} must be a non-empty string.') return value @classmethod @@ -110,7 +110,7 @@ def check_number(cls, label, value): if value is None: return None if not isinstance(value, numbers.Number): - raise ValueError('{0} must be a number.'.format(label)) + raise ValueError(f'{label} must be a number.') return value @classmethod @@ -119,13 +119,13 @@ def check_string_dict(cls, label, value): if value is None or value == {}: return None if not isinstance(value, dict): - raise ValueError('{0} must be a dictionary.'.format(label)) + raise ValueError(f'{label} must be a dictionary.') non_str = [k for k in value if not isinstance(k, str)] if non_str: - raise ValueError('{0} must not contain non-string keys.'.format(label)) + raise ValueError(f'{label} must not contain non-string keys.') non_str = [v for v in value.values() if not isinstance(v, str)] if non_str: - raise ValueError('{0} must not contain non-string values.'.format(label)) + raise ValueError(f'{label} must not contain non-string values.') return value @classmethod @@ -134,10 +134,10 @@ def check_string_list(cls, label, value): if value is None or value == []: return None if not isinstance(value, list): - raise ValueError('{0} must be a list of strings.'.format(label)) + raise ValueError(f'{label} must be a list of strings.') non_str = [k for k in value if not isinstance(k, str)] if non_str: - raise ValueError('{0} must not contain non-string values.'.format(label)) + raise ValueError(f'{label} must not contain non-string values.') return value @classmethod @@ -146,10 +146,10 @@ def check_number_list(cls, label, value): if value is None or value == []: return None if not isinstance(value, list): - raise ValueError('{0} must be a list of numbers.'.format(label)) + raise ValueError(f'{label} must be a list of numbers.') non_number = [k for k in value if not isinstance(k, numbers.Number)] if non_number: - raise ValueError('{0} must not contain non-number values.'.format(label)) + raise ValueError(f'{label} must not contain non-number values.') return value @classmethod @@ -157,7 +157,7 @@ def check_analytics_label(cls, label, value): """Checks if the given value is a valid analytics label.""" value = _Validators.check_string(label, value) if value is not None and not re.match(r'^[a-zA-Z0-9-_.~%]{1,50}$', value): - raise ValueError('Malformed {}.'.format(label)) + raise ValueError(f'Malformed {label}.') return value @classmethod @@ -166,7 +166,7 @@ def check_boolean(cls, label, value): if value is None: return None if not isinstance(value, bool): - raise ValueError('{0} must be a boolean.'.format(label)) + raise ValueError(f'{label} must be a boolean.') return value @classmethod @@ -175,7 +175,7 @@ def check_datetime(cls, label, value): if value is None: return None if not isinstance(value, datetime.datetime): - raise ValueError('{0} must be a datetime.'.format(label)) + raise ValueError(f'{label} must be a datetime.') return value @@ -245,8 +245,8 @@ def encode_ttl(cls, ttl): seconds = int(math.floor(total_seconds)) nanos = int((total_seconds - seconds) * 1e9) if nanos: - return '{0}.{1}s'.format(seconds, str(nanos).zfill(9)) - return '{0}s'.format(seconds) + return f'{seconds}.{str(nanos).zfill(9)}s' + return f'{seconds}s' @classmethod def encode_milliseconds(cls, label, msec): @@ -256,16 +256,16 @@ def encode_milliseconds(cls, label, msec): if isinstance(msec, numbers.Number): msec = datetime.timedelta(milliseconds=msec) if not isinstance(msec, datetime.timedelta): - raise ValueError('{0} must be a duration in milliseconds or an instance of ' - 'datetime.timedelta.'.format(label)) + raise ValueError( + f'{label} must be a duration in milliseconds or an instance of datetime.timedelta.') total_seconds = msec.total_seconds() if total_seconds < 0: - raise ValueError('{0} must not be negative.'.format(label)) + raise ValueError(f'{label} must not be negative.') seconds = int(math.floor(total_seconds)) nanos = int((total_seconds - seconds) * 1e9) if nanos: - return '{0}.{1}s'.format(seconds, str(nanos).zfill(9)) - return '{0}s'.format(seconds) + return f'{seconds}.{str(nanos).zfill(9)}s' + return f'{seconds}s' @classmethod def encode_android_notification(cls, notification): @@ -409,7 +409,7 @@ def encode_light_settings(cls, light_settings): raise ValueError( 'LightSettings.color must be in the form #RRGGBB or #RRGGBBAA.') if len(color) == 7: - color = (color+'FF') + color = color+'FF' rgba = [int(color[i:i + 2], 16) / 255.0 for i in (1, 3, 5, 7)] result['color'] = {'red': rgba[0], 'green': rgba[1], 'blue': rgba[2], 'alpha': rgba[3]} @@ -475,7 +475,7 @@ def encode_webpush_notification(cls, notification): for key, value in notification.custom_data.items(): if key in result: raise ValueError( - 'Multiple specifications for {0} in WebpushNotification.'.format(key)) + f'Multiple specifications for {key} in WebpushNotification.') result[key] = value return cls.remove_null_values(result) @@ -585,7 +585,7 @@ def encode_aps(cls, aps): for key, val in aps.custom_data.items(): _Validators.check_string('Aps.custom_data key', key) if key in result: - raise ValueError('Multiple specifications for {0} in Aps.'.format(key)) + raise ValueError(f'Multiple specifications for {key} in Aps.') result[key] = val return cls.remove_null_values(result) @@ -698,7 +698,7 @@ def default(self, o): # pylint: disable=method-hidden } result['topic'] = MessageEncoder.sanitize_topic_name(result.get('topic')) result = MessageEncoder.remove_null_values(result) - target_count = sum([t in result for t in ['token', 'topic', 'condition']]) + target_count = sum(t in result for t in ['token', 'topic', 'condition']) if target_count != 1: raise ValueError('Exactly one of token, topic or condition must be specified.') return result diff --git a/firebase_admin/_rfc3339.py b/firebase_admin/_rfc3339.py index 2c720bdd1..8489bdcb9 100644 --- a/firebase_admin/_rfc3339.py +++ b/firebase_admin/_rfc3339.py @@ -84,4 +84,4 @@ def _parse_to_datetime(datestr): except ValueError: pass - raise ValueError('time data {0} does not match RFC3339 format'.format(datestr)) + raise ValueError(f'time data {datestr} does not match RFC3339 format') diff --git a/firebase_admin/_sseclient.py b/firebase_admin/_sseclient.py index ec20cb45c..3372fe5f2 100644 --- a/firebase_admin/_sseclient.py +++ b/firebase_admin/_sseclient.py @@ -86,7 +86,7 @@ def __init__(self, url, session, retry=3000, **kwargs): self.requests_kwargs = kwargs self.should_connect = True self.last_id = None - self.buf = u'' # Keep data here as it streams in + self.buf = '' # Keep data here as it streams in headers = self.requests_kwargs.get('headers', {}) # The SSE spec requires making requests with Cache-Control: no-cache @@ -153,9 +153,6 @@ def __next__(self): self.last_id = event.event_id return event - def next(self): - return self.__next__() - class Event: """Event represents the events fired by SSE.""" @@ -184,7 +181,7 @@ def parse(cls, raw): match = cls.sse_line_pattern.match(line) if match is None: # Malformed line. Discard but warn. - warnings.warn('Invalid SSE line: "%s"' % line, SyntaxWarning) + warnings.warn(f'Invalid SSE line: "{line}"', SyntaxWarning) continue name = match.groupdict()['name'] @@ -196,7 +193,7 @@ def parse(cls, raw): # If we already have some data, then join to it with a newline. # Else this is it. if event.data: - event.data = '%s\n%s' % (event.data, value) + event.data = f'{event.data}\n{value}' else: event.data = value elif name == 'event': diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 6d82bf7a6..1607ef0ba 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -114,7 +114,7 @@ def __init__(self, app, http_client, url_override=None): self.http_client = http_client self.request = transport.requests.Request() url_prefix = url_override or self.ID_TOOLKIT_URL - self.base_url = '{0}/projects/{1}'.format(url_prefix, app.project_id) + self.base_url = f'{url_prefix}/projects/{app.project_id}' self._signing_provider = None def _init_signing_provider(self): @@ -142,7 +142,7 @@ def _init_signing_provider(self): resp = self.request(url=METADATA_SERVICE_URL, headers={'Metadata-Flavor': 'Google'}) if resp.status != 200: raise ValueError( - 'Failed to contact the local metadata service: {0}.'.format(resp.data.decode())) + f'Failed to contact the local metadata service: {resp.data.decode()}.') service_account = resp.data.decode() return _SigningProvider.from_iam(self.request, google_cred, service_account) @@ -155,10 +155,10 @@ def signing_provider(self): except Exception as error: url = 'https://firebase.google.com/docs/auth/admin/create-custom-tokens' raise ValueError( - 'Failed to determine service account: {0}. Make sure to initialize the SDK ' - 'with service account credentials or specify a service account ID with ' - 'iam.serviceAccounts.signBlob permission. Please refer to {1} for more ' - 'details on creating custom tokens.'.format(error, url)) from error + f'Failed to determine service account: {error}. Make sure to initialize the ' + 'SDK with service account credentials or specify a service account ID with ' + f'iam.serviceAccounts.signBlob permission. Please refer to {url} for more ' + 'details on creating custom tokens.') from error return self._signing_provider def create_custom_token(self, uid, developer_claims=None, tenant_id=None): @@ -170,13 +170,13 @@ def create_custom_token(self, uid, developer_claims=None, tenant_id=None): disallowed_keys = set(developer_claims.keys()) & RESERVED_CLAIMS if disallowed_keys: if len(disallowed_keys) > 1: - error_message = ('Developer claims {0} are reserved and ' - 'cannot be specified.'.format( - ', '.join(disallowed_keys))) + error_message = ( + f'Developer claims {", ".join(disallowed_keys)} are reserved and cannot be ' + 'specified.') else: - error_message = ('Developer claim {0} is reserved and ' - 'cannot be specified.'.format( - ', '.join(disallowed_keys))) + error_message = ( + f'Developer claim {", ".join(disallowed_keys)} is reserved and cannot be ' + 'specified.') raise ValueError(error_message) if not uid or not isinstance(uid, str) or len(uid) > 128: @@ -202,7 +202,7 @@ def create_custom_token(self, uid, developer_claims=None, tenant_id=None): try: return jwt.encode(signing_provider.signer, payload, header=header) except google.auth.exceptions.TransportError as error: - msg = 'Failed to sign custom token. {0}'.format(error) + msg = f'Failed to sign custom token. {error}' raise TokenSignError(msg, error) from error @@ -211,21 +211,22 @@ def create_session_cookie(self, id_token, expires_in): id_token = id_token.decode('utf-8') if isinstance(id_token, bytes) else id_token if not isinstance(id_token, str) or not id_token: raise ValueError( - 'Illegal ID token provided: {0}. ID token must be a non-empty ' - 'string.'.format(id_token)) + f'Illegal ID token provided: {id_token}. ID token must be a non-empty string.') if isinstance(expires_in, datetime.timedelta): expires_in = int(expires_in.total_seconds()) if isinstance(expires_in, bool) or not isinstance(expires_in, int): - raise ValueError('Illegal expiry duration: {0}.'.format(expires_in)) + raise ValueError(f'Illegal expiry duration: {expires_in}.') if expires_in < MIN_SESSION_COOKIE_DURATION_SECONDS: - raise ValueError('Illegal expiry duration: {0}. Duration must be at least {1} ' - 'seconds.'.format(expires_in, MIN_SESSION_COOKIE_DURATION_SECONDS)) + raise ValueError( + f'Illegal expiry duration: {expires_in}. Duration must be at least ' + f'{MIN_SESSION_COOKIE_DURATION_SECONDS} seconds.') if expires_in > MAX_SESSION_COOKIE_DURATION_SECONDS: - raise ValueError('Illegal expiry duration: {0}. Duration must be at most {1} ' - 'seconds.'.format(expires_in, MAX_SESSION_COOKIE_DURATION_SECONDS)) + raise ValueError( + f'Illegal expiry duration: {expires_in}. Duration must be at most ' + f'{MAX_SESSION_COOKIE_DURATION_SECONDS} seconds.') - url = '{0}:createSessionCookie'.format(self.base_url) + url = f'{self.base_url}:createSessionCookie' payload = { 'idToken': id_token, 'validDuration': expires_in, @@ -234,11 +235,10 @@ def create_session_cookie(self, id_token, expires_in): body, http_resp = self.http_client.body_and_response('post', url, json=payload) except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('sessionCookie'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to create session cookie.', http_response=http_resp) - return body.get('sessionCookie') + if not body or not body.get('sessionCookie'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to create session cookie.', http_response=http_resp) + return body.get('sessionCookie') class CertificateFetchRequest(transport.Request): @@ -307,9 +307,9 @@ def __init__(self, **kwargs): self.cert_url = kwargs.pop('cert_url') self.issuer = kwargs.pop('issuer') if self.short_name[0].lower() in 'aeiou': - self.articled_short_name = 'an {0}'.format(self.short_name) + self.articled_short_name = f'an {self.short_name}' else: - self.articled_short_name = 'a {0}'.format(self.short_name) + self.articled_short_name = f'a {self.short_name}' self._invalid_token_error = kwargs.pop('invalid_token_error') self._expired_token_error = kwargs.pop('expired_token_error') @@ -318,20 +318,20 @@ def verify(self, token, request, clock_skew_seconds=0): token = token.encode('utf-8') if isinstance(token, str) else token if not isinstance(token, bytes) or not token: raise ValueError( - 'Illegal {0} provided: {1}. {0} must be a non-empty ' - 'string.'.format(self.short_name, token)) + f'Illegal {self.short_name} provided: {token}. {self.short_name} must be a ' + 'non-empty string.') if not self.project_id: raise ValueError( 'Failed to ascertain project ID from the credential or the environment. Project ' - 'ID is required to call {0}. Initialize the app with a credentials.Certificate ' - 'or set your Firebase project ID as an app option. Alternatively set the ' - 'GOOGLE_CLOUD_PROJECT environment variable.'.format(self.operation)) + f'ID is required to call {self.operation}. Initialize the app with a ' + 'credentials.Certificate or set your Firebase project ID as an app option. ' + 'Alternatively set the GOOGLE_CLOUD_PROJECT environment variable.') if clock_skew_seconds < 0 or clock_skew_seconds > 60: raise ValueError( - 'Illegal clock_skew_seconds value: {0}. Must be between 0 and 60, inclusive.' - .format(clock_skew_seconds)) + f'Illegal clock_skew_seconds value: {clock_skew_seconds}. Must be between 0 and 60' + ', inclusive.') header, payload = self._decode_unverified(token) issuer = payload.get('iss') @@ -340,52 +340,51 @@ def verify(self, token, request, clock_skew_seconds=0): expected_issuer = self.issuer + self.project_id project_id_match_msg = ( - 'Make sure the {0} comes from the same Firebase project as the service account used ' - 'to authenticate this SDK.'.format(self.short_name)) + f'Make sure the {self.short_name} comes from the same Firebase project as the service ' + 'account used to authenticate this SDK.') verify_id_token_msg = ( - 'See {0} for details on how to retrieve {1}.'.format(self.url, self.short_name)) + f'See {self.url} for details on how to retrieve {self.short_name}.') emulated = _auth_utils.is_emulated() error_message = None if audience == FIREBASE_AUDIENCE: error_message = ( - '{0} expects {1}, but was given a custom ' - 'token.'.format(self.operation, self.articled_short_name)) + f'{self.operation} expects {self.articled_short_name}, but was given a custom ' + 'token.') elif not emulated and not header.get('kid'): if header.get('alg') == 'HS256' and payload.get( 'v') == 0 and 'uid' in payload.get('d', {}): error_message = ( - '{0} expects {1}, but was given a legacy custom ' - 'token.'.format(self.operation, self.articled_short_name)) + f'{self.operation} expects {self.articled_short_name}, but was given a legacy ' + 'custom token.') else: - error_message = 'Firebase {0} has no "kid" claim.'.format(self.short_name) + error_message = f'Firebase {self.short_name} has no "kid" claim.' elif not emulated and header.get('alg') != 'RS256': error_message = ( - 'Firebase {0} has incorrect algorithm. Expected "RS256" but got ' - '"{1}". {2}'.format(self.short_name, header.get('alg'), verify_id_token_msg)) + f'Firebase {self.short_name} has incorrect algorithm. Expected "RS256" but got ' + f'"{header.get("alg")}". {verify_id_token_msg}') elif audience != self.project_id: error_message = ( - 'Firebase {0} has incorrect "aud" (audience) claim. Expected "{1}" but ' - 'got "{2}". {3} {4}'.format(self.short_name, self.project_id, audience, - project_id_match_msg, verify_id_token_msg)) + f'Firebase {self.short_name} has incorrect "aud" (audience) claim. Expected ' + f'"{self.project_id}" but got "{audience}". {project_id_match_msg} ' + f'{verify_id_token_msg}') elif issuer != expected_issuer: error_message = ( - 'Firebase {0} has incorrect "iss" (issuer) claim. Expected "{1}" but ' - 'got "{2}". {3} {4}'.format(self.short_name, expected_issuer, issuer, - project_id_match_msg, verify_id_token_msg)) + f'Firebase {self.short_name} has incorrect "iss" (issuer) claim. Expected ' + f'"{expected_issuer}" but got "{issuer}". {project_id_match_msg} ' + f'{verify_id_token_msg}') elif subject is None or not isinstance(subject, str): error_message = ( - 'Firebase {0} has no "sub" (subject) claim. ' - '{1}'.format(self.short_name, verify_id_token_msg)) + f'Firebase {self.short_name} has no "sub" (subject) claim. {verify_id_token_msg}') elif not subject: error_message = ( - 'Firebase {0} has an empty string "sub" (subject) claim. ' - '{1}'.format(self.short_name, verify_id_token_msg)) + f'Firebase {self.short_name} has an empty string "sub" (subject) claim. ' + f'{verify_id_token_msg}') elif len(subject) > 128: error_message = ( - 'Firebase {0} has a "sub" (subject) claim longer than 128 characters. ' - '{1}'.format(self.short_name, verify_id_token_msg)) + f'Firebase {self.short_name} has a "sub" (subject) claim longer than 128 ' + f'characters. {verify_id_token_msg}') if error_message: raise self._invalid_token_error(error_message) diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 957b749a6..9a75b7a2e 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -130,7 +130,7 @@ class UserRecord(UserInfo): def __init__(self, data): super().__init__() if not isinstance(data, dict): - raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) + raise ValueError(f'Invalid data argument: {data}. Must be a dictionary.') if not data.get('localId'): raise ValueError('User ID must not be None or empty.') self._data = data @@ -454,7 +454,7 @@ class ProviderUserInfo(UserInfo): def __init__(self, data): super().__init__() if not isinstance(data, dict): - raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) + raise ValueError(f'Invalid data argument: {data}. Must be a dictionary.') if not data.get('rawId'): raise ValueError('User ID must not be None or empty.') self._data = data @@ -516,30 +516,30 @@ def encode_action_code_settings(settings): try: parsed = parse.urlparse(settings.url) if not parsed.netloc: - raise ValueError('Malformed dynamic action links url: "{0}".'.format(settings.url)) + raise ValueError(f'Malformed dynamic action links url: "{settings.url}".') parameters['continueUrl'] = settings.url except Exception as err: - raise ValueError('Malformed dynamic action links url: "{0}".'.format(settings.url)) from err + raise ValueError(f'Malformed dynamic action links url: "{settings.url}".') from err # handle_code_in_app if settings.handle_code_in_app is not None: if not isinstance(settings.handle_code_in_app, bool): - raise ValueError('Invalid value provided for handle_code_in_app: {0}' - .format(settings.handle_code_in_app)) + raise ValueError( + f'Invalid value provided for handle_code_in_app: {settings.handle_code_in_app}') parameters['canHandleCodeInApp'] = settings.handle_code_in_app # dynamic_link_domain if settings.dynamic_link_domain is not None: if not isinstance(settings.dynamic_link_domain, str): - raise ValueError('Invalid value provided for dynamic_link_domain: {0}' - .format(settings.dynamic_link_domain)) + raise ValueError( + f'Invalid value provided for dynamic_link_domain: {settings.dynamic_link_domain}') parameters['dynamicLinkDomain'] = settings.dynamic_link_domain # ios_bundle_id if settings.ios_bundle_id is not None: if not isinstance(settings.ios_bundle_id, str): - raise ValueError('Invalid value provided for ios_bundle_id: {0}' - .format(settings.ios_bundle_id)) + raise ValueError( + f'Invalid value provided for ios_bundle_id: {settings.ios_bundle_id}') parameters['iOSBundleId'] = settings.ios_bundle_id # android_* attributes @@ -549,20 +549,21 @@ def encode_action_code_settings(settings): if settings.android_package_name is not None: if not isinstance(settings.android_package_name, str): - raise ValueError('Invalid value provided for android_package_name: {0}' - .format(settings.android_package_name)) + raise ValueError( + f'Invalid value provided for android_package_name: {settings.android_package_name}') parameters['androidPackageName'] = settings.android_package_name if settings.android_minimum_version is not None: if not isinstance(settings.android_minimum_version, str): - raise ValueError('Invalid value provided for android_minimum_version: {0}' - .format(settings.android_minimum_version)) + raise ValueError( + 'Invalid value provided for android_minimum_version: ' + f'{settings.android_minimum_version}') parameters['androidMinimumVersion'] = settings.android_minimum_version if settings.android_install_app is not None: if not isinstance(settings.android_install_app, bool): - raise ValueError('Invalid value provided for android_install_app: {0}' - .format(settings.android_install_app)) + raise ValueError( + f'Invalid value provided for android_install_app: {settings.android_install_app}') parameters['androidInstallApp'] = settings.android_install_app return parameters @@ -576,9 +577,9 @@ class UserManager: def __init__(self, http_client, project_id, tenant_id=None, url_override=None): self.http_client = http_client url_prefix = url_override or self.ID_TOOLKIT_URL - self.base_url = '{0}/projects/{1}'.format(url_prefix, project_id) + self.base_url = f'{url_prefix}/projects/{project_id}' if tenant_id: - self.base_url += '/tenants/{0}'.format(tenant_id) + self.base_url += f'/tenants/{tenant_id}' def get_user(self, **kwargs): """Gets the user data corresponding to the provided key.""" @@ -592,12 +593,12 @@ def get_user(self, **kwargs): key, key_type = kwargs.pop('phone_number'), 'phone number' payload = {'phoneNumber' : [_auth_utils.validate_phone(key, required=True)]} else: - raise TypeError('Unsupported keyword arguments: {0}.'.format(kwargs)) + raise TypeError(f'Unsupported keyword arguments: {kwargs}.') body, http_resp = self._make_request('post', '/accounts:lookup', json=payload) if not body or not body.get('users'): raise _auth_utils.UserNotFoundError( - 'No user record found for the provided {0}: {1}.'.format(key_type, key), + f'No user record found for the provided {key_type}: {key}.', http_response=http_resp) return body['users'][0] @@ -638,8 +639,7 @@ def get_users(self, identifiers): }) else: raise ValueError( - 'Invalid entry in "identifiers" list. Unsupported type: {}' - .format(type(identifier))) + f'Invalid entry in "identifiers" list. Unsupported type: {type(identifier)}') body, http_resp = self._make_request( 'post', '/accounts:lookup', json=payload) @@ -657,8 +657,7 @@ def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): raise ValueError('Max results must be an integer.') if max_results < 1 or max_results > MAX_LIST_USERS_RESULTS: raise ValueError( - 'Max results must be a positive integer less than ' - '{0}.'.format(MAX_LIST_USERS_RESULTS)) + f'Max results must be a positive integer less than {MAX_LIST_USERS_RESULTS}.') payload = {'maxResults': max_results} if page_token: @@ -734,7 +733,7 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, body, http_resp = self._make_request('post', '/accounts:update', json=payload) if not body or not body.get('localId'): raise _auth_utils.UnexpectedResponseError( - 'Failed to update user: {0}.'.format(uid), http_response=http_resp) + f'Failed to update user: {uid}.', http_response=http_resp) return body.get('localId') def delete_user(self, uid): @@ -743,7 +742,7 @@ def delete_user(self, uid): body, http_resp = self._make_request('post', '/accounts:delete', json={'localId' : uid}) if not body or not body.get('kind'): raise _auth_utils.UnexpectedResponseError( - 'Failed to delete user: {0}.'.format(uid), http_response=http_resp) + f'Failed to delete user: {uid}.', http_response=http_resp) def delete_users(self, uids, force_delete=False): """Deletes the users identified by the specified user ids. @@ -786,8 +785,8 @@ def import_users(self, users, hash_alg=None): try: if not users or len(users) > MAX_IMPORT_USERS_SIZE: raise ValueError( - 'Users must be a non-empty list with no more than {0} elements.'.format( - MAX_IMPORT_USERS_SIZE)) + 'Users must be a non-empty list with no more than ' + f'{MAX_IMPORT_USERS_SIZE} elements.') if any(not isinstance(u, _user_import.ImportUserRecord) for u in users): raise ValueError('One or more user objects are invalid.') except TypeError as err: @@ -837,7 +836,7 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No return body.get('oobLink') def _make_request(self, method, path, **kwargs): - url = '{0}{1}'.format(self.base_url, path) + url = f'{self.base_url}{path}' try: return self.http_client.body_and_response(method, url, **kwargs) except requests.exceptions.RequestException as error: diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 765d11587..d0aca884b 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -93,8 +93,9 @@ def _get_initialized_app(app): 'initialized via the firebase module.') return app - raise ValueError('Illegal app argument. Argument must be of type ' - ' firebase_admin.App, but given "{0}".'.format(type(app))) + raise ValueError( + 'Illegal app argument. Argument must be of type firebase_admin.App, but given ' + f'"{type(app)}".') @@ -172,7 +173,7 @@ def handle_operation_error(error): """ if not isinstance(error, dict): return exceptions.UnknownError( - message='Unknown error while making a remote service call: {0}'.format(error), + message=f'Unknown error while making a remote service call: {error}', cause=error) rpc_code = error.get('code') @@ -217,15 +218,15 @@ def handle_requests_error(error, message=None, code=None): """ if isinstance(error, requests.exceptions.Timeout): return exceptions.DeadlineExceededError( - message='Timed out while making an API call: {0}'.format(error), + message=f'Timed out while making an API call: {error}', cause=error) if isinstance(error, requests.exceptions.ConnectionError): return exceptions.UnavailableError( - message='Failed to establish a connection: {0}'.format(error), + message=f'Failed to establish a connection: {error}', cause=error) if error.response is None: return exceptions.UnknownError( - message='Unknown error while making a remote service call: {0}'.format(error), + message=f'Unknown error while making a remote service call: {error}', cause=error) if not code: @@ -271,11 +272,11 @@ def handle_httpx_error(error: httpx.HTTPError, message=None, code=None) -> excep """ if isinstance(error, httpx.TimeoutException): return exceptions.DeadlineExceededError( - message='Timed out while making an API call: {0}'.format(error), + message=f'Timed out while making an API call: {error}', cause=error) if isinstance(error, httpx.ConnectError): return exceptions.UnavailableError( - message='Failed to establish a connection: {0}'.format(error), + message=f'Failed to establish a connection: {error}', cause=error) if isinstance(error, httpx.HTTPStatusError): print("printing status error", error) @@ -288,7 +289,7 @@ def handle_httpx_error(error: httpx.HTTPError, message=None, code=None) -> excep return err_type(message=message, cause=error, http_response=error.response) return exceptions.UnknownError( - message='Unknown error while making a remote service call: {0}'.format(error), + message=f'Unknown error while making a remote service call: {error}', cause=error) def _http_status_to_error_code(status): @@ -326,7 +327,7 @@ def _parse_platform_error(content, status_code): error_dict = data.get('error', {}) msg = error_dict.get('message') if not msg: - msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format(status_code, content) + msg = f'Unexpected HTTP response with status: {status_code}; body: {content}' return error_dict, msg diff --git a/firebase_admin/app_check.py b/firebase_admin/app_check.py index 1224f7d80..40d857f4e 100644 --- a/firebase_admin/app_check.py +++ b/firebase_admin/app_check.py @@ -156,6 +156,6 @@ class _Validators: def check_string(cls, label: str, value: Any): """Checks if the given value is a string.""" if value is None: - raise ValueError('{0} "{1}" must be a non-empty string.'.format(label, value)) + raise ValueError(f'{label} "{value}" must be a non-empty string.') if not isinstance(value, str): - raise ValueError('{0} "{1}" must be a string.'.format(label, value)) + raise ValueError(f'{label} "{value}" must be a string.') diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 8259c93b4..7117b71a9 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -94,24 +94,25 @@ def __init__(self, cert): """ super().__init__() if _is_file_path(cert): - with open(cert) as json_file: + with open(cert, encoding='utf-8') as json_file: json_data = json.load(json_file) elif isinstance(cert, dict): json_data = cert else: raise ValueError( - 'Invalid certificate argument: "{0}". Certificate argument must be a file path, ' - 'or a dict containing the parsed file contents.'.format(cert)) + f'Invalid certificate argument: "{cert}". Certificate argument must be a file ' + 'path, or a dict containing the parsed file contents.') if json_data.get('type') != self._CREDENTIAL_TYPE: - raise ValueError('Invalid service account certificate. Certificate must contain a ' - '"type" field set to "{0}".'.format(self._CREDENTIAL_TYPE)) + raise ValueError( + 'Invalid service account certificate. Certificate must contain a ' + f'"type" field set to "{self._CREDENTIAL_TYPE}".') try: self._g_credential = service_account.Credentials.from_service_account_info( json_data, scopes=_scopes) except ValueError as error: - raise ValueError('Failed to initialize a certificate credential. ' - 'Caused by: "{0}"'.format(error)) from error + raise ValueError( + f'Failed to initialize a certificate credential. Caused by: "{error}"') from error @property def project_id(self): @@ -195,18 +196,19 @@ def __init__(self, refresh_token): """ super().__init__() if _is_file_path(refresh_token): - with open(refresh_token) as json_file: + with open(refresh_token, encoding='utf-8') as json_file: json_data = json.load(json_file) elif isinstance(refresh_token, dict): json_data = refresh_token else: raise ValueError( - 'Invalid refresh token argument: "{0}". Refresh token argument must be a file ' - 'path, or a dict containing the parsed file contents.'.format(refresh_token)) + f'Invalid refresh token argument: "{refresh_token}". Refresh token argument must ' + 'be a file path, or a dict containing the parsed file contents.') if json_data.get('type') != self._CREDENTIAL_TYPE: - raise ValueError('Invalid refresh token configuration. JSON must contain a ' - '"type" field set to "{0}".'.format(self._CREDENTIAL_TYPE)) + raise ValueError( + 'Invalid refresh token configuration. JSON must contain a ' + f'"type" field set to "{self._CREDENTIAL_TYPE}".') self._g_credential = credentials.Credentials.from_authorized_user_info(json_data, _scopes) @property diff --git a/firebase_admin/db.py b/firebase_admin/db.py index fc69cbd83..800cbf8e3 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -39,8 +39,10 @@ _DB_ATTRIBUTE = '_database' _INVALID_PATH_CHARACTERS = '[].?#$' _RESERVED_FILTERS = ('$key', '$value', '$priority') -_USER_AGENT = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( - firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) +_USER_AGENT = ( + f'Firebase/HTTP/{firebase_admin.__version__}/{sys.version_info.major}' + f'.{sys.version_info.minor}/AdminPython' +) _TRANSACTION_MAX_RETRIES = 25 _EMULATOR_HOST_ENV_VAR = 'FIREBASE_DATABASE_EMULATOR_HOST' @@ -72,10 +74,9 @@ def reference(path='/', app=None, url=None): def _parse_path(path): """Parses a path string into a set of segments.""" if not isinstance(path, str): - raise ValueError('Invalid path: "{0}". Path must be a string.'.format(path)) + raise ValueError(f'Invalid path: "{path}". Path must be a string.') if any(ch in path for ch in _INVALID_PATH_CHARACTERS): - raise ValueError( - 'Invalid path: "{0}". Path contains illegal characters.'.format(path)) + raise ValueError(f'Invalid path: "{path}". Path contains illegal characters.') return [seg for seg in path.split('/') if seg] @@ -184,11 +185,9 @@ def child(self, path): ValueError: If the child path is not a string, not well-formed or begins with '/'. """ if not path or not isinstance(path, str): - raise ValueError( - 'Invalid path argument: "{0}". Path must be a non-empty string.'.format(path)) + raise ValueError(f'Invalid path argument: "{path}". Path must be a non-empty string.') if path.startswith('/'): - raise ValueError( - 'Invalid path argument: "{0}". Child path must not start with "/"'.format(path)) + raise ValueError(f'Invalid path argument: "{path}". Child path must not start with "/"') full_path = self._pathurl + '/' + path return Reference(client=self._client, path=full_path) @@ -433,7 +432,7 @@ def order_by_child(self, path): ValueError: If the child path is not a string, not well-formed or None. """ if path in _RESERVED_FILTERS: - raise ValueError('Illegal child path: {0}'.format(path)) + raise ValueError(f'Illegal child path: {path}') return Query(order_by=path, client=self._client, pathurl=self._add_suffix()) def order_by_key(self): @@ -492,8 +491,8 @@ def __init__(self, **kwargs): raise ValueError('order_by field must be a non-empty string') if order_by not in _RESERVED_FILTERS: if order_by.startswith('/'): - raise ValueError('Invalid path argument: "{0}". Child path must not start ' - 'with "/"'.format(order_by)) + raise ValueError( + f'Invalid path argument: "{order_by}". Child path must not start with "/"') segments = _parse_path(order_by) order_by = '/'.join(segments) self._client = kwargs.pop('client') @@ -501,7 +500,7 @@ def __init__(self, **kwargs): self._order_by = order_by self._params = {'orderBy' : json.dumps(order_by)} if kwargs: - raise ValueError('Unexpected keyword arguments: {0}'.format(kwargs)) + raise ValueError(f'Unexpected keyword arguments: {kwargs}') def limit_to_first(self, limit): """Creates a query with limit, and anchors it to the start of the window. @@ -604,7 +603,7 @@ def equal_to(self, value): def _querystr(self): params = [] for key in sorted(self._params): - params.append('{0}={1}'.format(key, self._params[key])) + params.append(f'{key}={self._params[key]}') return '&'.join(params) def get(self): @@ -642,7 +641,7 @@ def __init__(self, results, order_by): self.dict_input = False entries = [_SortEntry(k, v, order_by) for k, v in enumerate(results)] else: - raise ValueError('Sorting not supported for "{0}" object.'.format(type(results))) + raise ValueError(f'Sorting not supported for "{type(results)}" object.') self.sort_entries = sorted(entries) def get(self): @@ -783,8 +782,8 @@ def __init__(self, app): if emulator_host: if '//' in emulator_host: raise ValueError( - 'Invalid {0}: "{1}". It must follow format "host:port".'.format( - _EMULATOR_HOST_ENV_VAR, emulator_host)) + f'Invalid {_EMULATOR_HOST_ENV_VAR}: "{emulator_host}". It must follow format ' + '"host:port".') self._emulator_host = emulator_host else: self._emulator_host = None @@ -796,14 +795,12 @@ def get_client(self, db_url=None): if not db_url or not isinstance(db_url, str): raise ValueError( - 'Invalid database URL: "{0}". Database URL must be a non-empty ' - 'URL string.'.format(db_url)) + f'Invalid database URL: "{db_url}". Database URL must be a non-empty URL string.') parsed_url = parse.urlparse(db_url) if not parsed_url.netloc: raise ValueError( - 'Invalid database URL: "{0}". Database URL must be a wellformed ' - 'URL string.'.format(db_url)) + f'Invalid database URL: "{db_url}". Database URL must be a wellformed URL string.') emulator_config = self._get_emulator_config(parsed_url) if emulator_config: @@ -813,7 +810,7 @@ def get_client(self, db_url=None): else: # Defer credential lookup until we are certain it's going to be prod connection. credential = self._credential.get_credential() - base_url = 'https://{0}'.format(parsed_url.netloc) + base_url = f'https://{parsed_url.netloc}' params = {} @@ -835,7 +832,7 @@ def _get_emulator_config(self, parsed_url): return EmulatorConfig(base_url, namespace) if self._emulator_host: # Emulator mode enabled via environment variable - base_url = 'http://{0}'.format(self._emulator_host) + base_url = f'http://{self._emulator_host}' namespace = parsed_url.netloc.split('.')[0] return EmulatorConfig(base_url, namespace) @@ -847,21 +844,23 @@ def _parse_emulator_url(cls, parsed_url): query_ns = parse.parse_qs(parsed_url.query).get('ns') if parsed_url.scheme != 'http' or (not query_ns or len(query_ns) != 1 or not query_ns[0]): raise ValueError( - 'Invalid database URL: "{0}". Database URL must be a valid URL to a ' - 'Firebase Realtime Database instance.'.format(parsed_url.geturl())) + f'Invalid database URL: "{parsed_url.geturl()}". Database URL must be a valid URL ' + 'to a Firebase Realtime Database instance.') namespace = query_ns[0] - base_url = '{0}://{1}'.format(parsed_url.scheme, parsed_url.netloc) + base_url = f'{parsed_url.scheme}://{parsed_url.netloc}' return base_url, namespace @classmethod def _get_auth_override(cls, app): + """Gets and validates the database auth override to be used.""" auth_override = app.options.get('databaseAuthVariableOverride', cls._DEFAULT_AUTH_OVERRIDE) if auth_override == cls._DEFAULT_AUTH_OVERRIDE or auth_override is None: return auth_override if not isinstance(auth_override, dict): - raise ValueError('Invalid databaseAuthVariableOverride option: "{0}". Override ' - 'value must be a dict or None.'.format(auth_override)) + raise ValueError( + f'Invalid databaseAuthVariableOverride option: "{auth_override}". Override ' + 'value must be a dict or None.') return auth_override @@ -916,7 +915,7 @@ def request(self, method, url, **kwargs): Raises: FirebaseError: If an error occurs while making the HTTP call. """ - query = '&'.join('{0}={1}'.format(key, self.params[key]) for key in self.params) + query = '&'.join(f'{key}={value}' for key, value in self.params.items()) extra_params = kwargs.get('params') if extra_params: if query: @@ -961,6 +960,6 @@ def _extract_error_message(cls, response): pass if not message: - message = 'Unexpected response from database: {0}'.format(response.content.decode()) + message = f'Unexpected response from database: {response.content.decode()}' return message diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py index fa17dfc0c..86eea557a 100644 --- a/firebase_admin/functions.py +++ b/firebase_admin/functions.py @@ -48,7 +48,7 @@ _FUNCTIONS_HEADERS = { 'X-GOOG-API-FORMAT-VERSION': '2', - 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + 'X-FIREBASE-CLIENT': f'fire-admin-python/{firebase_admin.__version__}', } # Default canonical location ID of the task queue. @@ -306,9 +306,9 @@ class _Validators: def check_non_empty_string(cls, label: str, value: Any): """Checks if given value is a non-empty string and throws error if not.""" if not isinstance(value, str): - raise ValueError('{0} "{1}" must be a string.'.format(label, value)) + raise ValueError(f'{label} "{value}" must be a string.') if value == '': - raise ValueError('{0} "{1}" must be a non-empty string.'.format(label, value)) + raise ValueError(f'{label} "{value}" must be a non-empty string.') @classmethod def is_non_empty_string(cls, value: Any): diff --git a/firebase_admin/instance_id.py b/firebase_admin/instance_id.py index 604158d9c..812daf40b 100644 --- a/firebase_admin/instance_id.py +++ b/firebase_admin/instance_id.py @@ -81,7 +81,7 @@ def __init__(self, app): def delete_instance_id(self, instance_id): if not isinstance(instance_id, str) or not instance_id: raise ValueError('Instance ID must be a non-empty string.') - path = 'project/{0}/instanceId/{1}'.format(self._project_id, instance_id) + path = f'project/{self._project_id}/instanceId/{instance_id}' try: self._client.request('delete', path) except requests.exceptions.RequestException as error: @@ -94,6 +94,6 @@ def _extract_message(self, instance_id, error): status = error.response.status_code msg = self.error_codes.get(status) if msg: - return 'Instance ID "{0}": {1}'.format(instance_id, msg) + return f'Instance ID "{instance_id}": {msg}' - return 'Instance ID "{0}": {1}'.format(instance_id, error) + return f'Instance ID "{instance_id}": {error}' diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 5b2e48e80..749044436 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -301,7 +301,7 @@ class TopicManagementResponse: def __init__(self, resp): if not isinstance(resp, dict) or 'results' not in resp: - raise ValueError('Unexpected topic management response: {0}.'.format(resp)) + raise ValueError(f'Unexpected topic management response: {resp}.') self._success_count = 0 self._failure_count = 0 self._errors = [] @@ -400,7 +400,7 @@ def __init__(self, app: App) -> None: self._fcm_url = _MessagingService.FCM_URL.format(project_id) self._fcm_headers = { 'X-GOOG-API-FORMAT-VERSION': '2', - 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + 'X-FIREBASE-CLIENT': f'fire-admin-python/{firebase_admin.__version__}', } timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._credential = app.credential.get_credential() @@ -426,8 +426,7 @@ def send(self, message: Message, dry_run: bool = False) -> str: ) except requests.exceptions.RequestException as error: raise self._handle_fcm_error(error) - else: - return cast(str, resp['name']) + return cast(str, resp['name']) def send_each(self, messages: List[Message], dry_run: bool = False) -> BatchResponse: """Sends the given messages to FCM via the FCM v1 API.""" @@ -445,8 +444,7 @@ def send_data(data): json=data) except requests.exceptions.RequestException as exception: return SendResponse(resp=None, exception=self._handle_fcm_error(exception)) - else: - return SendResponse(resp, exception=None) + return SendResponse(resp, exception=None) message_data = [self._message_data(message, dry_run) for message in messages] try: @@ -455,7 +453,7 @@ def send_data(data): return BatchResponse(responses) except Exception as error: raise exceptions.UnknownError( - message='Unknown error while making remote service calls: {0}'.format(error), + message=f'Unknown error while making remote service calls: {error}', cause=error) async def send_each_async(self, messages: List[Message], dry_run: bool = True) -> BatchResponse: @@ -477,8 +475,7 @@ async def send_data(data): # Catch errors caused by the requests library during authorization except requests.exceptions.RequestException as exception: return SendResponse(resp=None, exception=self._handle_fcm_error(exception)) - else: - return SendResponse(resp.json(), exception=None) + return SendResponse(resp.json(), exception=None) message_data = [self._message_data(message, dry_run) for message in messages] try: @@ -486,7 +483,7 @@ async def send_data(data): return BatchResponse(responses) except Exception as error: raise exceptions.UnknownError( - message='Unknown error while making remote service calls: {0}'.format(error), + message=f'Unknown error while making remote service calls: {error}', cause=error) def make_topic_management_request(self, tokens, topic, operation): @@ -502,12 +499,12 @@ def make_topic_management_request(self, tokens, topic, operation): if not isinstance(topic, str) or not topic: raise ValueError('Topic must be a non-empty string.') if not topic.startswith('/topics/'): - topic = '/topics/{0}'.format(topic) + topic = f'/topics/{topic}' data = { 'to': topic, 'registration_tokens': tokens, } - url = '{0}/{1}'.format(_MessagingService.IID_URL, operation) + url = f'{_MessagingService.IID_URL}/{operation}' try: resp = self._client.body( 'post', @@ -517,8 +514,7 @@ def make_topic_management_request(self, tokens, topic, operation): ) except requests.exceptions.RequestException as error: raise self._handle_iid_error(error) - else: - return TopicManagementResponse(resp) + return TopicManagementResponse(resp) def _message_data(self, message, dry_run): data = {'message': _MessagingService.encode_message(message)} @@ -558,10 +554,12 @@ def _handle_iid_error(self, error): code = data.get('error') msg = None if code: - msg = 'Error while calling the IID service: {0}'.format(code) + msg = f'Error while calling the IID service: {code}' else: - msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format( - error.response.status_code, error.response.content.decode()) + msg = ( + f'Unexpected HTTP response with status: {error.response.status_code}; body: ' + f'{error.response.content.decode()}' + ) return _utils.handle_requests_error(error, msg) diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 5fffbd836..3a77dd05f 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -507,8 +507,8 @@ def _assert_tf_enabled(): raise ImportError('Failed to import the tensorflow library for Python. Make sure ' 'to install the tensorflow module.') if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'): - raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}' - .format(tf.version.VERSION)) + raise ImportError( + f'Expected tensorflow version 1.x or 2.x, but found {tf.version.VERSION}') @staticmethod def _tf_convert_from_saved_model(saved_model_dir): @@ -760,8 +760,8 @@ def _validate_page_size(page_size): # Specifically type() to disallow boolean which is a subtype of int raise TypeError('Page size must be a number or None.') if page_size < 1 or page_size > _MAX_PAGE_SIZE: - raise ValueError('Page size must be a positive integer between ' - '1 and {0}'.format(_MAX_PAGE_SIZE)) + raise ValueError( + f'Page size must be a positive integer between 1 and {_MAX_PAGE_SIZE}') def _validate_page_token(page_token): @@ -786,7 +786,7 @@ def __init__(self, app): 'projectId option, or use service account credentials.') self._project_url = _MLService.PROJECT_URL.format(self._project_id) ml_headers = { - 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), + 'X-FIREBASE-CLIENT': f'fire-admin-python/{firebase_admin.__version__}', } self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), @@ -883,9 +883,9 @@ def create_model(self, model): def update_model(self, model, update_mask=None): _validate_model(model, update_mask) - path = 'models/{0}'.format(model.model_id) + path = f'models/{model.model_id}' if update_mask is not None: - path = path + '?updateMask={0}'.format(update_mask) + path = path + f'?updateMask={update_mask}' try: return self.handle_operation( self._client.body('patch', url=path, json=model.as_dict(for_upload=True))) @@ -894,7 +894,7 @@ def update_model(self, model, update_mask=None): def set_published(self, model_id, publish): _validate_model_id(model_id) - model_name = 'projects/{0}/models/{1}'.format(self._project_id, model_id) + model_name = f'projects/{self._project_id}/models/{model_id}' model = Model.from_dict({ 'name': model_name, 'state': { @@ -906,7 +906,7 @@ def set_published(self, model_id, publish): def get_model(self, model_id): _validate_model_id(model_id) try: - return self._client.body('get', url='models/{0}'.format(model_id)) + return self._client.body('get', url=f'models/{model_id}') except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) @@ -934,6 +934,6 @@ def list_models(self, list_filter, page_size, page_token): def delete_model(self, model_id): _validate_model_id(model_id) try: - self._client.body('delete', url='models/{0}'.format(model_id)) + self._client.body('delete', url=f'models/{model_id}') except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) diff --git a/firebase_admin/project_management.py b/firebase_admin/project_management.py index 9405c8318..73c100d3a 100644 --- a/firebase_admin/project_management.py +++ b/firebase_admin/project_management.py @@ -118,13 +118,13 @@ def create_ios_app(bundle_id, display_name=None, app=None): def _check_is_string_or_none(obj, field_name): if obj is None or isinstance(obj, str): return obj - raise ValueError('{0} must be a string.'.format(field_name)) + raise ValueError(f'{field_name} must be a string.') def _check_is_nonempty_string(obj, field_name): if isinstance(obj, str) and obj: return obj - raise ValueError('{0} must be a non-empty string.'.format(field_name)) + raise ValueError(f'{field_name} must be a non-empty string.') def _check_is_nonempty_string_or_none(obj, field_name): @@ -135,7 +135,7 @@ def _check_is_nonempty_string_or_none(obj, field_name): def _check_not_none(obj, field_name): if obj is None: - raise ValueError('{0} cannot be None.'.format(field_name)) + raise ValueError(f'{field_name} cannot be None.') return obj @@ -477,7 +477,7 @@ def __init__(self, app): 'set the projectId option, or use service account credentials. Alternatively, set ' 'the GOOGLE_CLOUD_PROJECT environment variable.') self._project_id = project_id - version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) + version_header = f'Python/Admin/{firebase_admin.__version__}' timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), @@ -502,7 +502,7 @@ def get_ios_app_metadata(self, app_id): def _get_app_metadata(self, platform_resource_name, identifier_name, metadata_class, app_id): """Retrieves detailed information about an Android or iOS app.""" _check_is_nonempty_string(app_id, 'app_id') - path = '/v1beta1/projects/-/{0}/{1}'.format(platform_resource_name, app_id) + path = f'/v1beta1/projects/-/{platform_resource_name}/{app_id}' response = self._make_request('get', path) return metadata_class( response[identifier_name], @@ -525,8 +525,7 @@ def set_ios_app_display_name(self, app_id, new_display_name): def _set_display_name(self, app_id, new_display_name, platform_resource_name): """Sets the display name of an Android or iOS app.""" - path = '/v1beta1/projects/-/{0}/{1}?updateMask=displayName'.format( - platform_resource_name, app_id) + path = f'/v1beta1/projects/-/{platform_resource_name}/{app_id}?updateMask=displayName' request_body = {'displayName': new_display_name} self._make_request('patch', path, json=request_body) @@ -542,10 +541,10 @@ def list_ios_apps(self): def _list_apps(self, platform_resource_name, app_class): """Lists all the Android or iOS apps within the Firebase project.""" - path = '/v1beta1/projects/{0}/{1}?pageSize={2}'.format( - self._project_id, - platform_resource_name, - _ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE) + path = ( + f'/v1beta1/projects/{self._project_id}/{platform_resource_name}?pageSize=' + f'{_ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE}' + ) response = self._make_request('get', path) apps_list = [] while True: @@ -557,11 +556,11 @@ def _list_apps(self, platform_resource_name, app_class): if not next_page_token: break # Retrieve the next page of apps. - path = '/v1beta1/projects/{0}/{1}?pageToken={2}&pageSize={3}'.format( - self._project_id, - platform_resource_name, - next_page_token, - _ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE) + path = ( + f'/v1beta1/projects/{self._project_id}/{platform_resource_name}' + f'?pageToken={next_page_token}' + f'&pageSize={_ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE}' + ) response = self._make_request('get', path) return apps_list @@ -590,7 +589,7 @@ def _create_app( app_class): """Creates an Android or iOS app.""" _check_is_string_or_none(display_name, 'display_name') - path = '/v1beta1/projects/{0}/{1}'.format(self._project_id, platform_resource_name) + path = f'/v1beta1/projects/{self._project_id}/{platform_resource_name}' request_body = {identifier_name: identifier} if display_name: request_body['displayName'] = display_name @@ -606,7 +605,7 @@ def _poll_app_creation(self, operation_name): _ProjectManagementService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) wait_time_seconds = delay_factor * _ProjectManagementService.POLL_BASE_WAIT_TIME_SECONDS time.sleep(wait_time_seconds) - path = '/v1/{0}'.format(operation_name) + path = f'/v1/{operation_name}' poll_response, http_response = self._body_and_response('get', path) done = poll_response.get('done') if done: @@ -629,20 +628,20 @@ def get_ios_app_config(self, app_id): platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, app_id=app_id) def _get_app_config(self, platform_resource_name, app_id): - path = '/v1beta1/projects/-/{0}/{1}/config'.format(platform_resource_name, app_id) + path = f'/v1beta1/projects/-/{platform_resource_name}/{app_id}/config' response = self._make_request('get', path) # In Python 2.7, the base64 module works with strings, while in Python 3, it works with # bytes objects. This line works in both versions. return base64.standard_b64decode(response['configFileContents']).decode(encoding='utf-8') def get_sha_certificates(self, app_id): - path = '/v1beta1/projects/-/androidApps/{0}/sha'.format(app_id) + path = f'/v1beta1/projects/-/androidApps/{app_id}/sha' response = self._make_request('get', path) cert_list = response.get('certificates') or [] return [SHACertificate(sha_hash=cert['shaHash'], name=cert['name']) for cert in cert_list] def add_sha_certificate(self, app_id, certificate_to_add): - path = '/v1beta1/projects/-/androidApps/{0}/sha'.format(app_id) + path = f'/v1beta1/projects/-/androidApps/{app_id}/sha' sha_hash = _check_not_none(certificate_to_add, 'certificate_to_add').sha_hash cert_type = certificate_to_add.cert_type request_body = {'shaHash': sha_hash, 'certType': cert_type} @@ -650,7 +649,7 @@ def add_sha_certificate(self, app_id, certificate_to_add): def delete_sha_certificate(self, certificate_to_delete): name = _check_not_none(certificate_to_delete, 'certificate_to_delete').name - path = '/v1beta1/{0}'.format(name) + path = f'/v1beta1/{name}' self._make_request('delete', path) def _make_request(self, method, url, json=None): diff --git a/firebase_admin/remote_config.py b/firebase_admin/remote_config.py index 943141ccf..880804d3d 100644 --- a/firebase_admin/remote_config.py +++ b/firebase_admin/remote_config.py @@ -251,7 +251,7 @@ def __init__(self, app): self._project_id = app.project_id app_credential = app.credential.get_credential() rc_headers = { - 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), } + 'X-FIREBASE-CLIENT': f'fire-admin-python/{firebase_admin.__version__}', } timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._client = _http_client.JsonHttpClient(credential=app_credential, @@ -268,14 +268,12 @@ async def get_server_template(self): 'get', self._get_url()) except requests.exceptions.RequestException as error: raise self._handle_remote_config_error(error) - else: - template_data['etag'] = headers.get('etag') - return _ServerTemplateData(template_data) + template_data['etag'] = headers.get('etag') + return _ServerTemplateData(template_data) def _get_url(self): """Returns project prefix for url, in the format of /v1/projects/${projectId}""" - return "/v1/projects/{0}/namespaces/firebase-server/serverRemoteConfig".format( - self._project_id) + return f"/v1/projects/{self._project_id}/namespaces/firebase-server/serverRemoteConfig" @classmethod def _handle_remote_config_error(cls, error: Any): diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index 567a6abad..d2f004be6 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -82,6 +82,6 @@ def bucket(self, name=None): 'name explicitly when calling the storage.bucket() function.') if not bucket_name or not isinstance(bucket_name, str): raise ValueError( - 'Invalid storage bucket name: "{0}". Bucket name must be a non-empty ' - 'string.'.format(bucket_name)) + f'Invalid storage bucket name: "{bucket_name}". Bucket name must be a non-empty ' + 'string.') return self._client.bucket(bucket_name) diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py index 133e80b45..9e713d988 100644 --- a/firebase_admin/tenant_mgt.py +++ b/firebase_admin/tenant_mgt.py @@ -205,7 +205,7 @@ class Tenant: def __init__(self, data): if not isinstance(data, dict): - raise ValueError('Invalid data argument in Tenant constructor: {0}'.format(data)) + raise ValueError(f'Invalid data argument in Tenant constructor: {data}') if not 'name' in data: raise ValueError('Tenant response missing required keys.') @@ -236,8 +236,8 @@ class _TenantManagementService: def __init__(self, app): credential = app.credential.get_credential() - version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) - base_url = '{0}/projects/{1}'.format(self.TENANT_MGT_URL, app.project_id) + version_header = f'Python/Admin/{firebase_admin.__version__}' + base_url = f'{self.TENANT_MGT_URL}/projects/{app.project_id}' self.app = app self.client = _http_client.JsonHttpClient( credential=credential, base_url=base_url, headers={'X-Client-Version': version_header}) @@ -248,7 +248,7 @@ def auth_for_tenant(self, tenant_id): """Gets an Auth Client instance scoped to the given tenant ID.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( - 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + f'Invalid tenant ID: {tenant_id}. Tenant ID must be a non-empty string.') with self.lock: if tenant_id in self.tenant_clients: @@ -262,14 +262,13 @@ def get_tenant(self, tenant_id): """Gets the tenant corresponding to the given ``tenant_id``.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( - 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + f'Invalid tenant ID: {tenant_id}. Tenant ID must be a non-empty string.') try: - body = self.client.body('get', '/tenants/{0}'.format(tenant_id)) + body = self.client.body('get', f'/tenants/{tenant_id}') except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - else: - return Tenant(body) + return Tenant(body) def create_tenant( self, display_name, allow_password_sign_up=None, enable_email_link_sign_in=None): @@ -287,8 +286,7 @@ def create_tenant( body = self.client.body('post', '/tenants', json=payload) except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - else: - return Tenant(body) + return Tenant(body) def update_tenant( self, tenant_id, display_name=None, allow_password_sign_up=None, @@ -310,24 +308,23 @@ def update_tenant( if not payload: raise ValueError('At least one parameter must be specified for update.') - url = '/tenants/{0}'.format(tenant_id) + url = f'/tenants/{tenant_id}' update_mask = ','.join(_auth_utils.build_update_mask(payload)) - params = 'updateMask={0}'.format(update_mask) + params = f'updateMask={update_mask}' try: body = self.client.body('patch', url, json=payload, params=params) except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - else: - return Tenant(body) + return Tenant(body) def delete_tenant(self, tenant_id): """Deletes the tenant corresponding to the given ``tenant_id``.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( - 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + f'Invalid tenant ID: {tenant_id}. Tenant ID must be a non-empty string.') try: - self.client.request('delete', '/tenants/{0}'.format(tenant_id)) + self.client.request('delete', f'/tenants/{tenant_id}') except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) @@ -341,7 +338,7 @@ def list_tenants(self, page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS): if max_results < 1 or max_results > _MAX_LIST_TENANTS_RESULTS: raise ValueError( 'Max results must be a positive integer less than or equal to ' - '{0}.'.format(_MAX_LIST_TENANTS_RESULTS)) + f'{_MAX_LIST_TENANTS_RESULTS}.') payload = {'pageSize': max_results} if page_token: diff --git a/integration/conftest.py b/integration/conftest.py index 169e02d5b..ebaf9297a 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -36,7 +36,7 @@ def _get_cert_path(request): def integration_conf(request): cert_path = _get_cert_path(request) - with open(cert_path) as cert: + with open(cert_path, encoding='utf-8') as cert: project_id = json.load(cert).get('project_id') if not project_id: raise ValueError('Failed to determine project ID from service account certificate.') @@ -57,8 +57,8 @@ def default_app(request): """ cred, project_id = integration_conf(request) ops = { - 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), - 'storageBucket' : '{0}.appspot.com'.format(project_id) + 'databaseURL' : f'https://{project_id}.firebaseio.com', + 'storageBucket' : f'{project_id}.appspot.com' } return firebase_admin.initialize_app(cred, ops) @@ -68,5 +68,5 @@ def api_key(request): if not path: raise ValueError('API key file not specified. Make sure to specify the "--apikey" ' 'command-line option.') - with open(path) as keyfile: + with open(path, encoding='utf-8') as keyfile: return keyfile.read().strip() diff --git a/integration/test_auth.py b/integration/test_auth.py index e1d01a254..7f4725dfe 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -30,6 +30,7 @@ import firebase_admin from firebase_admin import auth from firebase_admin import credentials +from firebase_admin._http_client import DEFAULT_TIMEOUT_SECONDS as timeout _verify_token_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' @@ -67,14 +68,14 @@ def _sign_in(custom_token, api_key): body = {'token' : custom_token.decode(), 'returnSecureToken' : True} params = {'key' : api_key} - resp = requests.request('post', _verify_token_url, params=params, json=body) + resp = requests.request('post', _verify_token_url, params=params, json=body, timeout=timeout) resp.raise_for_status() return resp.json().get('idToken') def _sign_in_with_password(email, password, api_key): body = {'email': email, 'password': password, 'returnSecureToken': True} params = {'key' : api_key} - resp = requests.request('post', _verify_password_url, params=params, json=body) + resp = requests.request('post', _verify_password_url, params=params, json=body, timeout=timeout) resp.raise_for_status() return resp.json().get('idToken') @@ -84,7 +85,7 @@ def _random_string(length=10): def _random_id(): random_id = str(uuid.uuid4()).lower().replace('-', '') - email = 'test{0}@example.{1}.com'.format(random_id[:12], random_id[12:]) + email = f'test{random_id[:12]}@example.{random_id[12:]}.com' return random_id, email def _random_phone(): @@ -93,21 +94,21 @@ def _random_phone(): def _reset_password(oob_code, new_password, api_key): body = {'oobCode': oob_code, 'newPassword': new_password} params = {'key' : api_key} - resp = requests.request('post', _password_reset_url, params=params, json=body) + resp = requests.request('post', _password_reset_url, params=params, json=body, timeout=timeout) resp.raise_for_status() return resp.json().get('email') def _verify_email(oob_code, api_key): body = {'oobCode': oob_code} params = {'key' : api_key} - resp = requests.request('post', _verify_email_url, params=params, json=body) + resp = requests.request('post', _verify_email_url, params=params, json=body, timeout=timeout) resp.raise_for_status() return resp.json().get('email') def _sign_in_with_email_link(email, oob_code, api_key): body = {'oobCode': oob_code, 'email': email} params = {'key' : api_key} - resp = requests.request('post', _email_sign_in_url, params=params, json=body) + resp = requests.request('post', _email_sign_in_url, params=params, json=body, timeout=timeout) resp.raise_for_status() return resp.json().get('idToken') @@ -870,7 +871,7 @@ def test_delete_saml_provider_config(): def _create_oidc_provider_config(): - provider_id = 'oidc.{0}'.format(_random_string()) + provider_id = f'oidc.{_random_string()}' return auth.create_oidc_provider_config( provider_id=provider_id, client_id='OIDC_CLIENT_ID', @@ -882,7 +883,7 @@ def _create_oidc_provider_config(): def _create_saml_provider_config(): - provider_id = 'saml.{0}'.format(_random_string()) + provider_id = f'saml.{_random_string()}' return auth.create_saml_provider_config( provider_id=provider_id, idp_entity_id='IDP_ENTITY_ID', diff --git a/integration/test_db.py b/integration/test_db.py index 0170743dd..1ceb0b992 100644 --- a/integration/test_db.py +++ b/integration/test_db.py @@ -39,7 +39,7 @@ def integration_conf(request): def app(request): cred, project_id = integration_conf(request) ops = { - 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), + 'databaseURL' : f'https://{project_id}.firebaseio.com', } return firebase_admin.initialize_app(cred, ops, name='integration-db') @@ -53,7 +53,7 @@ def default_app(): @pytest.fixture(scope='module') def update_rules(app): - with open(testutils.resource_filename('dinosaurs_index.json')) as rules_file: + with open(testutils.resource_filename('dinosaurs_index.json'), encoding='utf-8') as rules_file: new_rules = json.load(rules_file) client = db.reference('', app)._client rules = client.body('get', '/.settings/rules.json', params='format=strict') @@ -64,7 +64,7 @@ def update_rules(app): @pytest.fixture(scope='module') def testdata(): - with open(testutils.resource_filename('dinosaurs.json')) as dino_file: + with open(testutils.resource_filename('dinosaurs.json'), encoding='utf-8') as dino_file: return json.load(dino_file) @pytest.fixture(scope='module') @@ -195,8 +195,8 @@ def test_update_nested_children(self, testref): edward = python.child('users').push({'name' : 'Edward Cope', 'since' : 1800}) jack = python.child('users').push({'name' : 'Jack Horner', 'since' : 1940}) delta = { - '{0}/since'.format(edward.key) : 1840, - '{0}/since'.format(jack.key) : 1946 + f'{edward.key}/since' : 1840, + f'{jack.key}/since' : 1946 } python.child('users').update(delta) assert edward.get() == {'name' : 'Edward Cope', 'since' : 1840} @@ -363,7 +363,7 @@ def override_app(request, update_rules): del update_rules cred, project_id = integration_conf(request) ops = { - 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), + 'databaseURL' : f'https://{project_id}.firebaseio.com', 'databaseAuthVariableOverride' : {'uid' : 'user1'} } app = firebase_admin.initialize_app(cred, ops, 'db-override') @@ -375,7 +375,7 @@ def none_override_app(request, update_rules): del update_rules cred, project_id = integration_conf(request) ops = { - 'databaseURL' : 'https://{0}.firebaseio.com'.format(project_id), + 'databaseURL' : f'https://{project_id}.firebaseio.com', 'databaseAuthVariableOverride' : None } app = firebase_admin.initialize_app(cred, ops, 'db-none-override') diff --git a/integration/test_firestore.py b/integration/test_firestore.py index fd39d9b8a..96cdd3fb1 100644 --- a/integration/test_firestore.py +++ b/integration/test_firestore.py @@ -18,16 +18,16 @@ from firebase_admin import firestore _CITY = { - 'name': u'Mountain View', - 'country': u'USA', + 'name': 'Mountain View', + 'country': 'USA', 'population': 77846, 'capital': False } _MOVIE = { - 'Name': u'Interstellar', + 'Name': 'Interstellar', 'Year': 2014, - 'Runtime': u'2h 49m', + 'Runtime': '2h 49m', 'Academy Award Winner': True } @@ -35,8 +35,8 @@ def test_firestore(): client = firestore.client() expected = { - 'name': u'Mountain View', - 'country': u'USA', + 'name': 'Mountain View', + 'country': 'USA', 'population': 77846, 'capital': False } @@ -93,7 +93,7 @@ def test_firestore_multi_db(): def test_server_timestamp(): client = firestore.client() expected = { - 'name': u'Mountain View', + 'name': 'Mountain View', 'timestamp': firestore.SERVER_TIMESTAMP # pylint: disable=no-member } doc = client.collection('cities').document() diff --git a/integration/test_firestore_async.py b/integration/test_firestore_async.py index 584ef590a..e899f25b2 100644 --- a/integration/test_firestore_async.py +++ b/integration/test_firestore_async.py @@ -20,16 +20,16 @@ from firebase_admin import firestore_async _CITY = { - 'name': u'Mountain View', - 'country': u'USA', + 'name': 'Mountain View', + 'country': 'USA', 'population': 77846, 'capital': False } _MOVIE = { - 'Name': u'Interstellar', + 'Name': 'Interstellar', 'Year': 2014, - 'Runtime': u'2h 49m', + 'Runtime': '2h 49m', 'Academy Award Winner': True } @@ -102,7 +102,7 @@ async def test_firestore_async_multi_db(): async def test_server_timestamp(): client = firestore_async.client() expected = { - 'name': u'Mountain View', + 'name': 'Mountain View', 'timestamp': firestore_async.SERVER_TIMESTAMP # pylint: disable=no-member } doc = client.collection('cities').document() diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 7ab707c82..e72086741 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -121,7 +121,7 @@ def test_send_each(): def test_send_each_500(): messages = [] for msg_number in range(500): - topic = 'foo-bar-{0}'.format(msg_number % 10) + topic = f'foo-bar-{msg_number % 10}' messages.append(messaging.Message(topic=topic)) batch_response = messaging.send_each(messages, dry_run=True) @@ -193,7 +193,7 @@ async def test_send_each_async(): async def test_send_each_async_500(): messages = [] for msg_number in range(500): - topic = 'foo-bar-{0}'.format(msg_number % 10) + topic = f'foo-bar-{msg_number % 10}' messages.append(messaging.Message(topic=topic)) batch_response = await messaging.send_each_async(messages, dry_run=True) diff --git a/integration/test_ml.py b/integration/test_ml.py index 6deb22a69..ea5b10be9 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -37,7 +37,7 @@ def _random_identifier(prefix): #pylint: disable=unused-variable suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)]) - return '{0}_{1}'.format(prefix, suffix) + return f'{prefix}_{suffix}' NAME_ONLY_ARGS = { @@ -170,7 +170,7 @@ def test_create_already_existing_fails(firebase_model): ml.create_model(model=firebase_model) check_operation_error( excinfo, - 'Model \'{0}\' already exists'.format(firebase_model.display_name)) + f'Model \'{firebase_model.display_name}\' already exists') @pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True) @@ -219,7 +219,7 @@ def test_update_non_existing_model(firebase_model): ml.update_model(firebase_model) check_operation_error( excinfo, - 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + f'Model \'{firebase_model.as_dict().get("name")}\' was not found') @pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) @@ -252,18 +252,17 @@ def test_publish_unpublish_non_existing_model(firebase_model): ml.publish_model(firebase_model.model_id) check_operation_error( excinfo, - 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + f'Model \'{firebase_model.as_dict().get("name")}\' was not found') with pytest.raises(exceptions.NotFoundError) as excinfo: ml.unpublish_model(firebase_model.model_id) check_operation_error( excinfo, - 'Model \'{0}\' was not found'.format(firebase_model.as_dict().get('name'))) + f'Model \'{firebase_model.as_dict().get("name")}\' was not found') def test_list_models(model_list): - filter_str = 'displayName={0} OR tags:{1}'.format( - model_list[0].display_name, model_list[1].tags[0]) + filter_str = f'displayName={model_list[0].display_name} OR tags:{model_list[1].tags[0]}' all_models = ml.list_models(list_filter=filter_str) all_model_ids = [mdl.model_id for mdl in all_models.iterate_all()] diff --git a/integration/test_project_management.py b/integration/test_project_management.py index b0b7fa52a..ba2c5ec16 100644 --- a/integration/test_project_management.py +++ b/integration/test_project_management.py @@ -74,14 +74,13 @@ def test_create_android_app_already_exists(android_app): def test_android_set_display_name_and_get_metadata(android_app, project_id): app_id = android_app.app_id android_app = project_management.android_app(app_id) - new_display_name = '{0} helloworld {1}'.format( - TEST_APP_DISPLAY_NAME_PREFIX, random.randint(0, 10000)) + new_display_name = f'{TEST_APP_DISPLAY_NAME_PREFIX} helloworld {random.randint(0, 10000)}' android_app.set_display_name(new_display_name) metadata = project_management.android_app(app_id).get_metadata() android_app.set_display_name(TEST_APP_DISPLAY_NAME_PREFIX) # Revert the display name. - assert metadata._name == 'projects/{0}/androidApps/{1}'.format(project_id, app_id) + assert metadata._name == f'projects/{project_id}/androidApps/{app_id}' assert metadata.app_id == app_id assert metadata.project_id == project_id assert metadata.display_name == new_display_name @@ -149,15 +148,13 @@ def test_create_ios_app_already_exists(ios_app): def test_ios_set_display_name_and_get_metadata(ios_app, project_id): app_id = ios_app.app_id ios_app = project_management.ios_app(app_id) - new_display_name = '{0} helloworld {1}'.format( - TEST_APP_DISPLAY_NAME_PREFIX, random.randint(0, 10000)) + new_display_name = f'{TEST_APP_DISPLAY_NAME_PREFIX} helloworld {random.randint(0, 10000)}' ios_app.set_display_name(new_display_name) metadata = project_management.ios_app(app_id).get_metadata() ios_app.set_display_name(TEST_APP_DISPLAY_NAME_PREFIX) # Revert the display name. - assert metadata._name == 'projects/{0}/iosApps/{1}'.format(project_id, app_id) - assert metadata.app_id == app_id + assert metadata._name == f'projects/{project_id}/iosApps/{app_id}' assert metadata.project_id == project_id assert metadata.display_name == new_display_name assert metadata.bundle_id == TEST_APP_BUNDLE_ID diff --git a/integration/test_storage.py b/integration/test_storage.py index 4f0faf76c..32e4d86a3 100644 --- a/integration/test_storage.py +++ b/integration/test_storage.py @@ -20,10 +20,10 @@ def test_default_bucket(project_id): bucket = storage.bucket() - _verify_bucket(bucket, '{0}.appspot.com'.format(project_id)) + _verify_bucket(bucket, f'{project_id}.appspot.com') def test_custom_bucket(project_id): - bucket_name = '{0}.appspot.com'.format(project_id) + bucket_name = f'{project_id}.appspot.com' bucket = storage.bucket(bucket_name) _verify_bucket(bucket, bucket_name) @@ -33,7 +33,7 @@ def test_non_existing_bucket(): def _verify_bucket(bucket, expected_name): assert bucket.name == expected_name - file_name = 'data_{0}.txt'.format(int(time.time())) + file_name = f'data_{int(time.time())}.txt' blob = bucket.blob(file_name) blob.upload_from_string('Hello World') diff --git a/integration/test_tenant_mgt.py b/integration/test_tenant_mgt.py index c9eefd96e..f0bad58b2 100644 --- a/integration/test_tenant_mgt.py +++ b/integration/test_tenant_mgt.py @@ -25,6 +25,7 @@ from firebase_admin import auth from firebase_admin import tenant_mgt +from firebase_admin._http_client import DEFAULT_TIMEOUT_SECONDS as timeout from integration import test_auth @@ -359,7 +360,7 @@ def test_delete_saml_provider_config(sample_tenant): def _create_oidc_provider_config(client): - provider_id = 'oidc.{0}'.format(_random_string()) + provider_id = f'oidc.{_random_string()}' return client.create_oidc_provider_config( provider_id=provider_id, client_id='OIDC_CLIENT_ID', @@ -369,7 +370,7 @@ def _create_oidc_provider_config(client): def _create_saml_provider_config(client): - provider_id = 'saml.{0}'.format(_random_string()) + provider_id = f'saml.{_random_string()}' return client.create_saml_provider_config( provider_id=provider_id, idp_entity_id='IDP_ENTITY_ID', @@ -387,7 +388,7 @@ def _random_uid(): def _random_email(): random_id = str(uuid.uuid4()).lower().replace('-', '') - return 'test{0}@example.{1}.com'.format(random_id[:12], random_id[12:]) + return f'test{random_id[:12]}@example.{random_id[12:]}.com' def _random_phone(): @@ -412,6 +413,6 @@ def _sign_in(custom_token, tenant_id, api_key): 'tenantId': tenant_id, } params = {'key' : api_key} - resp = requests.request('post', VERIFY_TOKEN_URL, params=params, json=body) + resp = requests.request('post', VERIFY_TOKEN_URL, params=params, json=body, timeout=timeout) resp.raise_for_status() return resp.json().get('idToken') diff --git a/requirements.txt b/requirements.txt index 76eeb7582..3e67d1dd5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -astroid == 2.5.8 -pylint == 2.7.4 +astroid == 3.3.10 +pylint == 3.3.7 pytest >= 8.2.2 pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 diff --git a/snippets/auth/get_service_account_tokens.py b/snippets/auth/get_service_account_tokens.py index 9f60590fe..7ad67a093 100644 --- a/snippets/auth/get_service_account_tokens.py +++ b/snippets/auth/get_service_account_tokens.py @@ -26,4 +26,4 @@ # After expiration_time, you must generate a new access token # [END get_service_account_tokens] -print('The access token {} expires at {}'.format(access_token, expiration_time)) +print(f'The access token {access_token} expires at {expiration_time}') diff --git a/snippets/auth/index.py b/snippets/auth/index.py index ed324e486..6a509b8f5 100644 --- a/snippets/auth/index.py +++ b/snippets/auth/index.py @@ -169,7 +169,7 @@ def revoke_refresh_token_uid(): user = auth.get_user(uid) # Convert to seconds as the auth_time in the token claims is in seconds. revocation_second = user.tokens_valid_after_timestamp / 1000 - print('Tokens revoked at: {0}'.format(revocation_second)) + print(f'Tokens revoked at: {revocation_second}') # [END revoke_tokens] # [START save_revocation_in_db] metadata_ref = firebase_admin.db.reference("metadata/" + uid) @@ -183,7 +183,7 @@ def get_user(uid): from firebase_admin import auth user = auth.get_user(uid) - print('Successfully fetched user data: {0}'.format(user.uid)) + print(f'Successfully fetched user data: {user.uid}') # [END get_user] def get_user_by_email(): @@ -192,7 +192,7 @@ def get_user_by_email(): from firebase_admin import auth user = auth.get_user_by_email(email) - print('Successfully fetched user data: {0}'.format(user.uid)) + print(f'Successfully fetched user data: {user.uid}') # [END get_user_by_email] def bulk_get_users(): @@ -221,7 +221,7 @@ def get_user_by_phone_number(): from firebase_admin import auth user = auth.get_user_by_phone_number(phone) - print('Successfully fetched user data: {0}'.format(user.uid)) + print(f'Successfully fetched user data: {user.uid}') # [END get_user_by_phone] def create_user(): @@ -234,7 +234,7 @@ def create_user(): display_name='John Doe', photo_url='http://www.example.com/12345678/photo.png', disabled=False) - print('Sucessfully created new user: {0}'.format(user.uid)) + print(f'Sucessfully created new user: {user.uid}') # [END create_user] return user.uid @@ -242,7 +242,7 @@ def create_user_with_id(): # [START create_user_with_id] user = auth.create_user( uid='some-uid', email='user@example.com', phone_number='+15555550100') - print('Sucessfully created new user: {0}'.format(user.uid)) + print(f'Sucessfully created new user: {user.uid}') # [END create_user_with_id] def update_user(uid): @@ -256,7 +256,7 @@ def update_user(uid): display_name='John Doe', photo_url='http://www.example.com/12345678/photo.png', disabled=True) - print('Sucessfully updated user: {0}'.format(user.uid)) + print(f'Sucessfully updated user: {user.uid}') # [END update_user] def delete_user(uid): @@ -271,10 +271,10 @@ def bulk_delete_users(): result = auth.delete_users(["uid1", "uid2", "uid3"]) - print('Successfully deleted {0} users'.format(result.success_count)) - print('Failed to delete {0} users'.format(result.failure_count)) + print(f'Successfully deleted {result.success_count} users') + print(f'Failed to delete {result.failure_count} users') for err in result.errors: - print('error #{0}, reason: {1}'.format(result.index, result.reason)) + print(f'error #{result.index}, reason: {result.reason}') # [END bulk_delete_users] def set_custom_user_claims(uid): @@ -475,10 +475,11 @@ def import_users(): hash_alg = auth.UserImportHash.hmac_sha256(key=b'secret_key') try: result = auth.import_users(users, hash_alg=hash_alg) - print('Successfully imported {0} users. Failed to import {1} users.'.format( - result.success_count, result.failure_count)) + print( + f'Successfully imported {result.success_count} users. Failed to import ' + f'{result.failure_count} users.') for err in result.errors: - print('Failed to import {0} due to {1}'.format(users[err.index].uid, err.reason)) + print(f'Failed to import {users[err.index].uid} due to {err.reason}') except exceptions.FirebaseError: # Some unrecoverable error occurred that prevented the operation from running. pass @@ -1012,7 +1013,7 @@ def revoke_refresh_tokens_tenant(tenant_client, uid): user = tenant_client.get_user(uid) # Convert to seconds as the auth_time in the token claims is in seconds. revocation_second = user.tokens_valid_after_timestamp / 1000 - print('Tokens revoked at: {0}'.format(revocation_second)) + print(f'Tokens revoked at: {revocation_second}') # [END revoke_tokens_tenant] def verify_id_token_and_check_revoked_tenant(tenant_client, id_token): diff --git a/snippets/database/index.py b/snippets/database/index.py index adfa13476..99bb4981e 100644 --- a/snippets/database/index.py +++ b/snippets/database/index.py @@ -235,7 +235,7 @@ def order_by_child(): ref = db.reference('dinosaurs') snapshot = ref.order_by_child('height').get() for key, val in snapshot.items(): - print('{0} was {1} meters tall'.format(key, val)) + print(f'{key} was {val} meters tall') # [END order_by_child] def order_by_nested_child(): @@ -243,7 +243,7 @@ def order_by_nested_child(): ref = db.reference('dinosaurs') snapshot = ref.order_by_child('dimensions/height').get() for key, val in snapshot.items(): - print('{0} was {1} meters tall'.format(key, val)) + print(f'{key} was {val} meters tall') # [END order_by_nested_child] def order_by_key(): @@ -258,7 +258,7 @@ def order_by_value(): ref = db.reference('scores') snapshot = ref.order_by_value().get() for key, val in snapshot.items(): - print('The {0} dinosaur\'s score is {1}'.format(key, val)) + print(f'The {key} dinosaur\'s score is {val}') # [END order_by_value] def limit_query(): @@ -280,7 +280,7 @@ def limit_query(): scores_ref = db.reference('scores') snapshot = scores_ref.order_by_value().limit_to_last(3).get() for key, val in snapshot.items(): - print('The {0} dinosaur\'s score is {1}'.format(key, val)) + print(f'The {key} dinosaur\'s score is {val}') # [END limit_query_3] def range_query(): @@ -300,7 +300,7 @@ def range_query(): # [START range_query_3] ref = db.reference('dinosaurs') - snapshot = ref.order_by_key().start_at('b').end_at(u'b\uf8ff').get() + snapshot = ref.order_by_key().start_at('b').end_at('b\uf8ff').get() for key in snapshot: print(key) # [END range_query_3] @@ -322,7 +322,7 @@ def complex_query(): # Data is ordered by increasing height, so we want the first entry. # Second entry is stegosarus. for key in snapshot: - print('The dinosaur just shorter than the stegosaurus is {0}'.format(key)) + print(f'The dinosaur just shorter than the stegosaurus is {key}') return else: print('The stegosaurus is the shortest dino') diff --git a/snippets/messaging/cloud_messaging.py b/snippets/messaging/cloud_messaging.py index 18a992dcc..3efd223ea 100644 --- a/snippets/messaging/cloud_messaging.py +++ b/snippets/messaging/cloud_messaging.py @@ -241,7 +241,7 @@ def send_each(): response = messaging.send_each(messages) # See the BatchResponse reference documentation # for the contents of response. - print('{0} messages were sent successfully'.format(response.success_count)) + print(f'{response.success_count} messages were sent successfully') # [END send_each] @@ -262,7 +262,7 @@ def send_each_for_multicast(): response = messaging.send_each_for_multicast(message) # See the BatchResponse reference documentation # for the contents of response. - print('{0} messages were sent successfully'.format(response.success_count)) + print(f'{response.success_count} messages were sent successfully') # [END send_each_for_multicast] @@ -287,5 +287,5 @@ def send_each_for_multicast_and_handle_errors(): if not resp.success: # The order of responses corresponds to the order of the registration tokens. failed_tokens.append(registration_tokens[idx]) - print('List of tokens that caused failures: {0}'.format(failed_tokens)) + print(f'List of tokens that caused failures: {failed_tokens}') # [END send_each_for_multicast_error] diff --git a/tests/test_app.py b/tests/test_app.py index 5b203661f..0ff0854b4 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -215,11 +215,11 @@ def revert_config_env(config_old): class TestFirebaseApp: """Test cases for App initialization and life cycle.""" - invalid_credentials = ['', 'foo', 0, 1, dict(), list(), tuple(), True, False] - invalid_options = ['', 0, 1, list(), tuple(), True, False] - invalid_names = [None, '', 0, 1, dict(), list(), tuple(), True, False] + invalid_credentials = ['', 'foo', 0, 1, {}, [], tuple(), True, False] + invalid_options = ['', 0, 1, [], tuple(), True, False] + invalid_names = [None, '', 0, 1, {}, [], tuple(), True, False] invalid_apps = [ - None, '', 0, 1, dict(), list(), tuple(), True, False, + None, '', 0, 1, {}, [], tuple(), True, False, firebase_admin.App('uninitialized', CREDENTIAL, {}) ] @@ -308,11 +308,11 @@ def test_project_id_from_environment(self): variables = ['GOOGLE_CLOUD_PROJECT', 'GCLOUD_PROJECT'] for idx, var in enumerate(variables): old_project_id = os.environ.get(var) - new_project_id = 'env-project-{0}'.format(idx) + new_project_id = f'env-project-{idx}' os.environ[var] = new_project_id try: app = firebase_admin.initialize_app( - testutils.MockCredential(), name='myApp{0}'.format(var)) + testutils.MockCredential(), name=f'myApp{var}') assert app.project_id == new_project_id finally: if old_project_id: @@ -388,7 +388,7 @@ def test_app_services(self, init_app): with pytest.raises(ValueError): _utils.get_app_service(init_app, 'test.service', AppService) - @pytest.mark.parametrize('arg', [0, 1, True, False, 'str', list(), dict(), tuple()]) + @pytest.mark.parametrize('arg', [0, 1, True, False, 'str', [], {}, tuple()]) def test_app_services_invalid_arg(self, arg): with pytest.raises(ValueError): _utils.get_app_service(arg, 'test.service', AppService) diff --git a/tests/test_app_check.py b/tests/test_app_check.py index 168d0a972..e55ae39de 100644 --- a/tests/test_app_check.py +++ b/tests/test_app_check.py @@ -22,7 +22,7 @@ from firebase_admin import app_check from tests import testutils -NON_STRING_ARGS = [list(), tuple(), dict(), True, False, 1, 0] +NON_STRING_ARGS = [[], tuple(), {}, True, False, 1, 0] APP_ID = "1234567890" PROJECT_ID = "1334" @@ -71,7 +71,7 @@ def evaluate(): def test_verify_token_with_non_string_raises_error(self, token): with pytest.raises(ValueError) as excinfo: app_check.verify_token(token) - expected = 'app check token "{0}" must be a string.'.format(token) + expected = f'app check token "{token}" must be a string.' assert str(excinfo.value) == expected def test_has_valid_token_headers(self): diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index 304e0fd78..106e1cae3 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -27,8 +27,7 @@ ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v2' EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' AUTH_EMULATOR_HOST = 'localhost:9099' -EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v2'.format( - AUTH_EMULATOR_HOST) +EMULATED_ID_TOOLKIT_URL = f'http://{AUTH_EMULATOR_HOST}/identitytoolkit.googleapis.com/v2' URL_PROJECT_SUFFIX = '/projects/mock-project-id' USER_MGT_URLS = { 'ID_TOOLKIT': ID_TOOLKIT_URL, @@ -45,7 +44,7 @@ } }""" -INVALID_PROVIDER_IDS = [None, True, False, 1, 0, list(), tuple(), dict(), ''] +INVALID_PROVIDER_IDS = [None, True, False, 1, 0, [], tuple(), {}, ''] @pytest.fixture(scope='module', params=[{'emulated': False}, {'emulated': True}]) @@ -282,12 +281,12 @@ def test_delete(self, user_mgt_app): _assert_request(recorder[0], 'DELETE', f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs/oidc.provider') - @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + @pytest.mark.parametrize('arg', [None, 'foo', [], {}, 0, -1, 101, False]) def test_invalid_max_results(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_oidc_provider_configs(max_results=arg, app=user_mgt_app) - @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 101, False]) + @pytest.mark.parametrize('arg', ['', [], {}, 0, -1, 101, False]) def test_invalid_page_token(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_oidc_provider_configs(page_token=arg, app=user_mgt_app) @@ -346,7 +345,7 @@ def test_paged_iteration(self, user_mgt_app): for index in range(2): provider_config = next(iterator) - assert provider_config.provider_id == 'oidc.provider{0}'.format(index) + assert provider_config.provider_id == f'oidc.provider{index}' assert len(recorder) == 1 _assert_request(recorder[0], 'GET', f'{USER_MGT_URLS["PREFIX"]}/oauthIdpConfigs?pageSize=100') @@ -403,7 +402,7 @@ def _assert_page(self, page, count=2, start=0, next_page_token=''): index = start assert len(page.provider_configs) == count for provider_config in page.provider_configs: - self._assert_provider_config(provider_config, want_id='oidc.provider{0}'.format(index)) + self._assert_provider_config(provider_config, want_id=f'oidc.provider{index}') index += 1 if next_page_token: @@ -621,12 +620,12 @@ def test_config_not_found(self, user_mgt_app): assert excinfo.value.http_response is not None assert excinfo.value.cause is not None - @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + @pytest.mark.parametrize('arg', [None, 'foo', [], {}, 0, -1, 101, False]) def test_invalid_max_results(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_saml_provider_configs(max_results=arg, app=user_mgt_app) - @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 101, False]) + @pytest.mark.parametrize('arg', ['', [], {}, 0, -1, 101, False]) def test_invalid_page_token(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_saml_provider_configs(page_token=arg, app=user_mgt_app) @@ -686,7 +685,7 @@ def test_paged_iteration(self, user_mgt_app): for index in range(2): provider_config = next(iterator) - assert provider_config.provider_id == 'saml.provider{0}'.format(index) + assert provider_config.provider_id == f'saml.provider{index}' assert len(recorder) == 1 _assert_request( recorder[0], 'GET', f'{USER_MGT_URLS["PREFIX"]}/inboundSamlConfigs?pageSize=100') @@ -735,7 +734,7 @@ def _assert_page(self, page, count=2, start=0, next_page_token=''): index = start assert len(page.provider_configs) == count for provider_config in page.provider_configs: - self._assert_provider_config(provider_config, want_id='saml.provider{0}'.format(index)) + self._assert_provider_config(provider_config, want_id=f'saml.provider{index}') index += 1 if next_page_token: diff --git a/tests/test_credentials.py b/tests/test_credentials.py index cceb6b6f9..1e1db6460 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -64,7 +64,7 @@ def test_init_from_invalid_certificate(self, file_name, error): with pytest.raises(error): credentials.Certificate(testutils.resource_filename(file_name)) - @pytest.mark.parametrize('arg', [None, 0, 1, True, False, list(), tuple(), dict()]) + @pytest.mark.parametrize('arg', [None, 0, 1, True, False, [], tuple(), {}]) def test_invalid_args(self, arg): with pytest.raises(ValueError): credentials.Certificate(arg) @@ -156,7 +156,7 @@ def test_init_from_invalid_file(self): credentials.RefreshToken( testutils.resource_filename('service_account.json')) - @pytest.mark.parametrize('arg', [None, 0, 1, True, False, list(), tuple(), dict()]) + @pytest.mark.parametrize('arg', [None, 0, 1, True, False, [], tuple(), {}]) def test_invalid_args(self, arg): with pytest.raises(ValueError): credentials.RefreshToken(arg) diff --git a/tests/test_db.py b/tests/test_db.py index 93f4672f1..abba3baa8 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -87,7 +87,7 @@ class TestReferencePath: } invalid_paths = [ - None, True, False, 0, 1, dict(), list(), tuple(), _Object(), + None, True, False, 0, 1, {}, [], tuple(), _Object(), 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', ] @@ -98,7 +98,7 @@ class TestReferencePath: } invalid_children = [ - None, '', '/foo', '/foo/bar', True, False, 0, 1, dict(), list(), tuple(), + None, '', '/foo', '/foo/bar', True, False, 0, 1, {}, [], tuple(), 'foo#', 'foo.', 'foo$', 'foo[', 'foo]', _Object() ] @@ -248,7 +248,7 @@ def test_get_if_changed(self, data): self._assert_request(recorder[1], 'GET', 'https://test.firebaseio.com/test.json') assert recorder[1].headers['if-none-match'] == MockAdapter.ETAG - @pytest.mark.parametrize('etag', [0, 1, True, False, dict(), list(), tuple()]) + @pytest.mark.parametrize('etag', [0, 1, True, False, {}, [], tuple()]) def test_get_if_changed_invalid_etag(self, etag): ref = db.reference('/test') with pytest.raises(ValueError): @@ -347,7 +347,7 @@ def test_set_if_unchanged_failure(self, data): assert json.loads(recorder[0].body.decode()) == data assert recorder[0].headers['if-match'] == 'invalid-etag' - @pytest.mark.parametrize('etag', [0, 1, True, False, dict(), list(), tuple()]) + @pytest.mark.parametrize('etag', [0, 1, True, False, {}, [], tuple()]) def test_set_if_unchanged_invalid_etag(self, etag): ref = db.reference('/test') with pytest.raises(ValueError): @@ -369,7 +369,7 @@ def test_set_if_unchanged_non_json_value(self, value): ref.set_if_unchanged(MockAdapter.ETAG, value) @pytest.mark.parametrize('update', [ - None, {}, {None:'foo'}, '', 'foo', 0, 1, list(), tuple(), _Object() + None, {}, {None:'foo'}, '', 'foo', 0, 1, [], tuple(), _Object() ]) def test_set_invalid_update(self, update): ref = db.reference('/test') @@ -466,7 +466,7 @@ def test_transaction_abort(self): assert excinfo.value.http_response is None assert len(recorder) == 1 + 25 - @pytest.mark.parametrize('func', [None, 0, 1, True, False, 'foo', dict(), list(), tuple()]) + @pytest.mark.parametrize('func', [None, 0, 1, True, False, 'foo', {}, [], tuple()]) def test_transaction_invalid_function(self, func): ref = db.reference('/test') with pytest.raises(ValueError): @@ -672,7 +672,7 @@ def _assert_request(self, request, expected_method, expected_url): def test_get_value(self): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps('data')) - query_str = 'auth_variable_override={0}'.format(self.encoded_override) + query_str = f'auth_variable_override={self.encoded_override}' assert ref.get() == 'data' assert len(recorder) == 1 self._assert_request( @@ -683,7 +683,7 @@ def test_set_value(self): recorder = self.instrument(ref, '') data = {'foo' : 'bar'} ref.set(data) - query_str = 'print=silent&auth_variable_override={0}'.format(self.encoded_override) + query_str = f'print=silent&auth_variable_override={self.encoded_override}' assert len(recorder) == 1 self._assert_request( recorder[0], 'PUT', 'https://test.firebaseio.com/test.json?' + query_str) @@ -693,7 +693,7 @@ def test_order_by_query(self): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps('data')) query = ref.order_by_child('foo') - query_str = 'orderBy=%22foo%22&auth_variable_override={0}'.format(self.encoded_override) + query_str = f'orderBy=%22foo%22&auth_variable_override={self.encoded_override}' assert query.get() == 'data' assert len(recorder) == 1 self._assert_request( @@ -703,8 +703,9 @@ def test_range_query(self): ref = db.reference('/test') recorder = self.instrument(ref, json.dumps('data')) query = ref.order_by_child('foo').start_at(1).end_at(10) - query_str = ('endAt=10&orderBy=%22foo%22&startAt=1&' - 'auth_variable_override={0}'.format(self.encoded_override)) + query_str = ( + f'endAt=10&orderBy=%22foo%22&startAt=1&auth_variable_override={self.encoded_override}' + ) assert query.get() == 'data' assert len(recorder) == 1 self._assert_request( @@ -794,7 +795,7 @@ def test_valid_db_url(self, url): @pytest.mark.parametrize('url', [ None, '', 'foo', 'http://test.firebaseio.com', 'http://test.firebasedatabase.app', - True, False, 1, 0, dict(), list(), tuple(), _Object() + True, False, 1, 0, {}, [], tuple(), _Object() ]) def test_invalid_db_url(self, url): firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) @@ -838,7 +839,7 @@ def test_valid_auth_override(self, override): assert ref._client.params['auth_variable_override'] == encoded @pytest.mark.parametrize('override', [ - '', 'foo', 0, 1, True, False, list(), tuple(), _Object()]) + '', 'foo', 0, 1, True, False, [], tuple(), _Object()]) def test_invalid_auth_override(self, override): firebase_admin.initialize_app(testutils.MockCredential(), { 'databaseURL' : 'https://test.firebaseio.com', @@ -885,8 +886,10 @@ def test_app_delete(self): assert other_ref._client.session is None def test_user_agent_format(self): - expected = 'Firebase/HTTP/{0}/{1}.{2}/AdminPython'.format( - firebase_admin.__version__, sys.version_info.major, sys.version_info.minor) + expected = ( + f'Firebase/HTTP/{firebase_admin.__version__}/{sys.version_info.major}.' + f'{sys.version_info.minor}/AdminPython' + ) assert db._USER_AGENT == expected def _check_timeout(self, ref, timeout): @@ -925,7 +928,7 @@ class TestQuery: ref = db.Reference(path='foo') @pytest.mark.parametrize('path', [ - '', None, '/', '/foo', 0, 1, True, False, dict(), list(), tuple(), _Object(), + '', None, '/', '/foo', 0, 1, True, False, {}, [], tuple(), _Object(), '$foo', '.foo', '#foo', '[foo', 'foo]', '$key', '$value', '$priority' ]) def test_invalid_path(self, path): @@ -935,13 +938,13 @@ def test_invalid_path(self, path): @pytest.mark.parametrize('path, expected', valid_paths.items()) def test_order_by_valid_path(self, path, expected): query = self.ref.order_by_child(path) - assert query._querystr == 'orderBy="{0}"'.format(expected) + assert query._querystr == f'orderBy="{expected}"' @pytest.mark.parametrize('path, expected', valid_paths.items()) def test_filter_by_valid_path(self, path, expected): query = self.ref.order_by_child(path) query.equal_to(10) - assert query._querystr == 'equalTo=10&orderBy="{0}"'.format(expected) + assert query._querystr == f'equalTo=10&orderBy="{expected}"' def test_order_by_key(self): query = self.ref.order_by_key() @@ -972,7 +975,7 @@ def test_multiple_limits(self): with pytest.raises(ValueError): query.limit_to_first(1) - @pytest.mark.parametrize('limit', [None, -1, 'foo', 1.2, list(), dict(), tuple(), _Object()]) + @pytest.mark.parametrize('limit', [None, -1, 'foo', 1.2, [], {}, tuple(), _Object()]) def test_invalid_limit(self, limit): query = self.ref.order_by_child('foo') with pytest.raises(ValueError): @@ -985,47 +988,47 @@ def test_start_at_none(self): with pytest.raises(ValueError): query.start_at(None) - @pytest.mark.parametrize('arg', ['', 'foo', True, False, 0, 1, dict()]) + @pytest.mark.parametrize('arg', ['', 'foo', True, False, 0, 1, {}]) def test_valid_start_at(self, arg): query = self.ref.order_by_child('foo').start_at(arg) - assert query._querystr == 'orderBy="foo"&startAt={0}'.format(json.dumps(arg)) + assert query._querystr == f'orderBy="foo"&startAt={json.dumps(arg)}' def test_end_at_none(self): query = self.ref.order_by_child('foo') with pytest.raises(ValueError): query.end_at(None) - @pytest.mark.parametrize('arg', ['', 'foo', True, False, 0, 1, dict()]) + @pytest.mark.parametrize('arg', ['', 'foo', True, False, 0, 1, {}]) def test_valid_end_at(self, arg): query = self.ref.order_by_child('foo').end_at(arg) - assert query._querystr == 'endAt={0}&orderBy="foo"'.format(json.dumps(arg)) + assert query._querystr == f'endAt={json.dumps(arg)}&orderBy="foo"' def test_equal_to_none(self): query = self.ref.order_by_child('foo') with pytest.raises(ValueError): query.equal_to(None) - @pytest.mark.parametrize('arg', ['', 'foo', True, False, 0, 1, dict()]) + @pytest.mark.parametrize('arg', ['', 'foo', True, False, 0, 1, {}]) def test_valid_equal_to(self, arg): query = self.ref.order_by_child('foo').equal_to(arg) - assert query._querystr == 'equalTo={0}&orderBy="foo"'.format(json.dumps(arg)) + assert query._querystr == f'equalTo={json.dumps(arg)}&orderBy="foo"' def test_range_query(self, initquery): query, order_by = initquery query.start_at(1) query.equal_to(2) query.end_at(3) - assert query._querystr == 'endAt=3&equalTo=2&orderBy="{0}"&startAt=1'.format(order_by) + assert query._querystr == f'endAt=3&equalTo=2&orderBy="{order_by}"&startAt=1' def test_limit_first_query(self, initquery): query, order_by = initquery query.limit_to_first(1) - assert query._querystr == 'limitToFirst=1&orderBy="{0}"'.format(order_by) + assert query._querystr == f'limitToFirst=1&orderBy="{order_by}"' def test_limit_last_query(self, initquery): query, order_by = initquery query.limit_to_last(1) - assert query._querystr == 'limitToLast=1&orderBy="{0}"'.format(order_by) + assert query._querystr == f'limitToLast=1&orderBy="{order_by}"' def test_all_in(self, initquery): query, order_by = initquery @@ -1033,7 +1036,7 @@ def test_all_in(self, initquery): query.equal_to(2) query.end_at(3) query.limit_to_first(10) - expected = 'endAt=3&equalTo=2&limitToFirst=10&orderBy="{0}"&startAt=1'.format(order_by) + expected = f'endAt=3&equalTo=2&limitToFirst=10&orderBy="{order_by}"&startAt=1' assert query._querystr == expected def test_invalid_query_args(self): @@ -1059,9 +1062,9 @@ class TestSorter: ({'k1' : False, 'k2' : 'bar', 'k3' : None}, ['k3', 'k1', 'k2']), ({'k1' : False, 'k2' : 1, 'k3' : None}, ['k3', 'k1', 'k2']), ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo'}, ['k3', 'k1', 'k2', 'k4']), - ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, + ({'k1' : True, 'k2' : 0, 'k3' : None, 'k4' : 'foo', 'k5' : False, 'k6' : {}}, ['k3', 'k5', 'k1', 'k2', 'k4', 'k6']), - ({'k1' : True, 'k2' : 0, 'k3' : 'foo', 'k4' : 'foo', 'k5' : False, 'k6' : dict()}, + ({'k1' : True, 'k2' : 0, 'k3' : 'foo', 'k4' : 'foo', 'k5' : False, 'k6' : {}}, ['k5', 'k1', 'k2', 'k3', 'k4', 'k6']), ] diff --git a/tests/test_instance_id.py b/tests/test_instance_id.py index 387e067c9..2b0e21079 100644 --- a/tests/test_instance_id.py +++ b/tests/test_instance_id.py @@ -72,7 +72,7 @@ def _assert_request(self, request, expected_method, expected_url): assert request.headers['x-goog-api-client'] == expected_metrics_header def _get_url(self, project_id, iid): - return instance_id._IID_SERVICE_URL + 'project/{0}/instanceId/{1}'.format(project_id, iid) + return instance_id._IID_SERVICE_URL + f'project/{project_id}/instanceId/{iid}' def test_no_project_id(self): def evaluate(): @@ -131,14 +131,14 @@ def test_delete_instance_id_unexpected_error(self): with pytest.raises(exceptions.UnknownError) as excinfo: instance_id.delete_instance_id('test_iid') url = self._get_url('explicit-project-id', 'test_iid') - message = 'Instance ID "test_iid": 501 Server Error: None for url: {0}'.format(url) + message = f'Instance ID "test_iid": 501 Server Error: None for url: {url}' assert str(excinfo.value) == message assert excinfo.value.cause is not None assert excinfo.value.http_response is not None assert len(recorder) == 1 self._assert_request(recorder[0], 'DELETE', url) - @pytest.mark.parametrize('iid', [None, '', 0, 1, True, False, list(), dict(), tuple()]) + @pytest.mark.parametrize('iid', [None, '', 0, 1, True, False, [], {}, tuple()]) def test_invalid_instance_id(self, iid): cred = testutils.MockCredential() app = firebase_admin.initialize_app(cred, {'projectId': 'explicit-project-id'}) diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 63b649485..9fa30fef9 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -30,12 +30,12 @@ from tests import testutils -NON_STRING_ARGS = [list(), tuple(), dict(), True, False, 1, 0] -NON_DICT_ARGS = ['', list(), tuple(), True, False, 1, 0, {1: 'foo'}, {'foo': 1}] -NON_OBJECT_ARGS = [list(), tuple(), dict(), 'foo', 0, 1, True, False] -NON_LIST_ARGS = ['', tuple(), dict(), True, False, 1, 0, [1], ['foo', 1]] -NON_UINT_ARGS = ['1.23s', list(), tuple(), dict(), -1.23] -NON_BOOL_ARGS = ['', list(), tuple(), dict(), 1, 0, [1], ['foo', 1], {1: 'foo'}, {'foo': 1}] +NON_STRING_ARGS = [[], tuple(), {}, True, False, 1, 0] +NON_DICT_ARGS = ['', [], tuple(), True, False, 1, 0, {1: 'foo'}, {'foo': 1}] +NON_OBJECT_ARGS = [[], tuple(), {}, 'foo', 0, 1, True, False] +NON_LIST_ARGS = ['', tuple(), {}, True, False, 1, 0, [1], ['foo', 1]] +NON_UINT_ARGS = ['1.23s', [], tuple(), {}, -1.23] +NON_BOOL_ARGS = ['', [], tuple(), {}, 1, 0, [1], ['foo', 1], {1: 'foo'}, {'foo': 1}] HTTP_ERROR_CODES = { 400: exceptions.InvalidArgumentError, 403: exceptions.PermissionDeniedError, @@ -501,7 +501,7 @@ def test_invalid_channel_id(self, data): excinfo = self._check_notification(notification) assert str(excinfo.value) == 'AndroidNotification.channel_id must be a string.' - @pytest.mark.parametrize('timestamp', [100, '', 'foo', {}, [], list(), dict()]) + @pytest.mark.parametrize('timestamp', [100, '', 'foo', {}, []]) def test_invalid_event_timestamp(self, timestamp): notification = messaging.AndroidNotification(event_timestamp=timestamp) excinfo = self._check_notification(notification) @@ -568,7 +568,7 @@ def test_negative_vibrate_timings_millis(self): expected = 'AndroidNotification.vibrate_timings_millis must not be negative.' assert str(excinfo.value) == expected - @pytest.mark.parametrize('notification_count', ['', 'foo', list(), tuple(), dict()]) + @pytest.mark.parametrize('notification_count', ['', 'foo', [], tuple(), {}]) def test_invalid_notification_count(self, notification_count): notification = messaging.AndroidNotification(notification_count=notification_count) excinfo = self._check_notification(notification) @@ -939,19 +939,19 @@ def test_invalid_tag(self, data): excinfo = self._check_notification(notification) assert str(excinfo.value) == 'WebpushNotification.tag must be a string.' - @pytest.mark.parametrize('data', ['', 'foo', list(), tuple(), dict()]) + @pytest.mark.parametrize('data', ['', 'foo', [], tuple(), {}]) def test_invalid_timestamp(self, data): notification = messaging.WebpushNotification(timestamp_millis=data) excinfo = self._check_notification(notification) assert str(excinfo.value) == 'WebpushNotification.timestamp_millis must be a number.' - @pytest.mark.parametrize('data', ['', list(), tuple(), True, False, 1, 0]) + @pytest.mark.parametrize('data', ['', [], tuple(), True, False, 1, 0]) def test_invalid_custom_data(self, data): notification = messaging.WebpushNotification(custom_data=data) excinfo = self._check_notification(notification) assert str(excinfo.value) == 'WebpushNotification.custom_data must be a dict.' - @pytest.mark.parametrize('data', ['', dict(), tuple(), True, False, 1, 0, [1, 2]]) + @pytest.mark.parametrize('data', ['', {}, tuple(), True, False, 1, 0, [1, 2]]) def test_invalid_actions(self, data): notification = messaging.WebpushNotification(actions=data) excinfo = self._check_notification(notification) @@ -1172,7 +1172,7 @@ def test_invalid_alert(self, data): expected = 'Aps.alert must be a string or an instance of ApsAlert class.' assert str(excinfo.value) == expected - @pytest.mark.parametrize('data', [list(), tuple(), dict(), 'foo']) + @pytest.mark.parametrize('data', [[], tuple(), {}, 'foo']) def test_invalid_badge(self, data): aps = messaging.Aps(badge=data) with pytest.raises(ValueError) as excinfo: @@ -1204,7 +1204,7 @@ def test_invalid_thread_id(self, data): expected = 'Aps.thread_id must be a string.' assert str(excinfo.value) == expected - @pytest.mark.parametrize('data', ['', list(), tuple(), True, False, 1, 0, ]) + @pytest.mark.parametrize('data', ['', [], tuple(), True, False, 1, 0, ]) def test_invalid_custom_data_dict(self, data): if isinstance(data, dict): return @@ -1309,7 +1309,7 @@ def test_invalid_name(self, data): expected = 'CriticalSound.name must be a non-empty string.' assert str(excinfo.value) == expected - @pytest.mark.parametrize('data', [list(), tuple(), dict(), 'foo']) + @pytest.mark.parametrize('data', [[], tuple(), {}, 'foo']) def test_invalid_volume(self, data): sound = messaging.CriticalSound(name='default', volume=data) excinfo = self._check_sound(sound) @@ -1659,7 +1659,7 @@ def test_topic_management_custom_timeout(self, options, timeout): class TestSend: _DEFAULT_RESPONSE = json.dumps({'name': 'message-id'}) - _CLIENT_VERSION = 'fire-admin-python/{0}'.format(firebase_admin.__version__) + _CLIENT_VERSION = f'fire-admin-python/{firebase_admin.__version__}' @classmethod def setup_class(cls): @@ -1736,7 +1736,7 @@ def test_send_error(self, status, exc_type): msg = messaging.Message(topic='foo') with pytest.raises(exc_type) as excinfo: messaging.send(msg) - expected = 'Unexpected HTTP response with status: {0}; body: {{}}'.format(status) + expected = f'Unexpected HTTP response with status: {status}; body: {{}}' check_exception(excinfo.value, expected, status) assert len(recorder) == 1 body = {'message': messaging._MessagingService.JSON_ENCODER.default(msg)} @@ -2332,9 +2332,9 @@ def _assert_request(self, request, expected_method, expected_url): assert request.headers['x-goog-api-client'] == expected_metrics_header def _get_url(self, path): - return '{0}/{1}'.format(messaging._MessagingService.IID_URL, path) + return f'{messaging._MessagingService.IID_URL}/{path}' - @pytest.mark.parametrize('tokens', [None, '', list(), dict(), tuple()]) + @pytest.mark.parametrize('tokens', [None, '', [], {}, tuple()]) def test_invalid_tokens(self, tokens): expected = 'Tokens must be a string or a non-empty list of strings.' if isinstance(tokens, str): @@ -2383,7 +2383,7 @@ def test_subscribe_to_topic_non_json_error(self, status, exc_type): _, recorder = self._instrument_iid_service(status=status, payload='not json') with pytest.raises(exc_type) as excinfo: messaging.subscribe_to_topic('foo', 'test-topic') - reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) + reason = f'Unexpected HTTP response with status: {status}; body: not json' assert str(excinfo.value) == reason assert len(recorder) == 1 self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchAdd')) @@ -2412,7 +2412,7 @@ def test_unsubscribe_from_topic_non_json_error(self, status, exc_type): _, recorder = self._instrument_iid_service(status=status, payload='not json') with pytest.raises(exc_type) as excinfo: messaging.unsubscribe_from_topic('foo', 'test-topic') - reason = 'Unexpected HTTP response with status: {0}; body: not json'.format(status) + reason = f'Unexpected HTTP response with status: {status}; body: not json' assert str(excinfo.value) == reason assert len(recorder) == 1 self._assert_request(recorder[0], 'POST', self._get_url('iid/v1:batchRemove')) diff --git a/tests/test_ml.py b/tests/test_ml.py index 2af9ae42f..bcc93fd05 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -49,7 +49,7 @@ TAGS_2 = [TAG_1, TAG_3] MODEL_ID_1 = 'modelId1' -MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1) +MODEL_NAME_1 = f'projects/{PROJECT_ID}/models/{MODEL_ID_1}' DISPLAY_NAME_1 = 'displayName1' MODEL_JSON_1 = { 'name': MODEL_NAME_1, @@ -58,7 +58,7 @@ MODEL_1 = ml.Model.from_dict(MODEL_JSON_1) MODEL_ID_2 = 'modelId2' -MODEL_NAME_2 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_2) +MODEL_NAME_2 = f'projects/{PROJECT_ID}/models/{MODEL_ID_2}' DISPLAY_NAME_2 = 'displayName2' MODEL_JSON_2 = { 'name': MODEL_NAME_2, @@ -67,7 +67,7 @@ MODEL_2 = ml.Model.from_dict(MODEL_JSON_2) MODEL_ID_3 = 'modelId3' -MODEL_NAME_3 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_3) +MODEL_NAME_3 = f'projects/{PROJECT_ID}/models/{MODEL_ID_3}' DISPLAY_NAME_3 = 'displayName3' MODEL_JSON_3 = { 'name': MODEL_NAME_3, @@ -79,7 +79,7 @@ 'published': True } VALIDATION_ERROR_CODE = 400 -VALIDATION_ERROR_MSG = 'No model format found for {0}.'.format(MODEL_ID_1) +VALIDATION_ERROR_MSG = f'No model format found for {MODEL_ID_1}.' MODEL_STATE_ERROR_JSON = { 'validationError': { 'code': VALIDATION_ERROR_CODE, @@ -87,19 +87,19 @@ } } -OPERATION_NAME_1 = 'projects/{0}/operations/123'.format(PROJECT_ID) +OPERATION_NAME_1 = f'projects/{PROJECT_ID}/operations/123' OPERATION_NOT_DONE_JSON_1 = { 'name': OPERATION_NAME_1, 'metadata': { '@type': 'type.googleapis.com/google.firebase.ml.v1beta2.ModelOperationMetadata', - 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), + 'name': f'projects/{PROJECT_ID}/models/{MODEL_ID_1}', 'basic_operation_status': 'BASIC_OPERATION_STATUS_UPLOADING' } } GCS_BUCKET_NAME = 'my_bucket' GCS_BLOB_NAME = 'mymodel.tflite' -GCS_TFLITE_URI = 'gs://{0}/{1}'.format(GCS_BUCKET_NAME, GCS_BLOB_NAME) +GCS_TFLITE_URI = f'gs://{GCS_BUCKET_NAME}/{GCS_BLOB_NAME}' GCS_TFLITE_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI} GCS_TFLITE_MODEL_SOURCE = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) TFLITE_FORMAT_JSON = { @@ -257,8 +257,8 @@ INVALID_MODEL_ARGS = [ 'abc', 4.2, - list(), - dict(), + [], + {}, True, -1, 0, @@ -272,9 +272,10 @@ 'projects/$#@/operations/123', 'projects/1234/operations/123/extrathing', ] -PAGE_SIZE_VALUE_ERROR_MSG = 'Page size must be a positive integer between ' \ - '1 and {0}'.format(ml._MAX_PAGE_SIZE) -INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, list(), dict()] +PAGE_SIZE_VALUE_ERROR_MSG = ( + f'Page size must be a positive integer between 1 and {ml._MAX_PAGE_SIZE}' +) +INVALID_STRING_OR_NONE_ARGS = [0, -1, 4.2, 0x10, False, [], {}] # For validation type errors @@ -358,8 +359,7 @@ def teardown_class(cls): @staticmethod def _op_url(project_id): - return BASE_URL + \ - 'projects/{0}/operations/123'.format(project_id) + return BASE_URL + f'projects/{project_id}/operations/123' def test_model_success_err_state_lro(self): model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) @@ -569,16 +569,15 @@ def teardown_class(cls): @staticmethod def _url(project_id): - return BASE_URL + 'projects/{0}/models'.format(project_id) + return BASE_URL + f'projects/{project_id}/models' @staticmethod def _op_url(project_id): - return BASE_URL + \ - 'projects/{0}/operations/123'.format(project_id) + return BASE_URL + f'projects/{project_id}/operations/123' @staticmethod def _get_url(project_id, model_id): - return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + return BASE_URL + f'projects/{project_id}/models/{model_id}' def test_immediate_done(self): instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) @@ -663,12 +662,11 @@ def teardown_class(cls): @staticmethod def _url(project_id, model_id): - return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + return BASE_URL + f'projects/{project_id}/models/{model_id}' @staticmethod def _op_url(project_id): - return BASE_URL + \ - 'projects/{0}/operations/123'.format(project_id) + return BASE_URL + f'projects/{project_id}/operations/123' def test_immediate_done(self): instrument_ml_service(status=200, payload=OPERATION_DONE_RESPONSE) @@ -760,18 +758,16 @@ def teardown_class(cls): @staticmethod def _update_url(project_id, model_id): - update_url = 'projects/{0}/models/{1}?updateMask=state.published'.format( - project_id, model_id) + update_url = f'projects/{project_id}/models/{model_id}?updateMask=state.published' return BASE_URL + update_url @staticmethod def _get_url(project_id, model_id): - return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + return BASE_URL + f'projects/{project_id}/models/{model_id}' @staticmethod def _op_url(project_id): - return BASE_URL + \ - 'projects/{0}/operations/123'.format(project_id) + return BASE_URL + f'projects/{project_id}/operations/123' @pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS) def test_immediate_done(self, publish_function, published): @@ -842,7 +838,7 @@ def teardown_class(cls): @staticmethod def _url(project_id, model_id): - return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + return BASE_URL + f'projects/{project_id}/models/{model_id}' def test_get_model(self): recorder = instrument_ml_service(status=200, payload=DEFAULT_GET_RESPONSE) @@ -893,7 +889,7 @@ def teardown_class(cls): @staticmethod def _url(project_id, model_id): - return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + return BASE_URL + f'projects/{project_id}/models/{model_id}' def test_delete_model(self): recorder = instrument_ml_service(status=200, payload=EMPTY_RESPONSE) @@ -941,7 +937,7 @@ def teardown_class(cls): @staticmethod def _url(project_id): - return BASE_URL + 'projects/{0}/models'.format(project_id) + return BASE_URL + f'projects/{project_id}/models' @staticmethod def _check_page(page, model_count): @@ -970,8 +966,8 @@ def test_list_models_with_all_args(self): assert len(recorder) == 1 _assert_request(recorder[0], 'GET', ( TestListModels._url(PROJECT_ID) + - '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' - .format(PAGE_TOKEN))) + f'?filter=display_name%3DdisplayName3&page_size=10&page_token={PAGE_TOKEN}' + )) assert isinstance(models_page, ml.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 @@ -986,8 +982,8 @@ def test_list_models_list_filter_validation(self, list_filter): @pytest.mark.parametrize('page_size, exc_type, error_message', [ ('abc', TypeError, 'Page size must be a number or None.'), (4.2, TypeError, 'Page size must be a number or None.'), - (list(), TypeError, 'Page size must be a number or None.'), - (dict(), TypeError, 'Page size must be a number or None.'), + ([], TypeError, 'Page size must be a number or None.'), + ({}, TypeError, 'Page size must be a number or None.'), (True, TypeError, 'Page size must be a number or None.'), (-1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), (0, ValueError, PAGE_SIZE_VALUE_ERROR_MSG), @@ -1061,7 +1057,7 @@ def test_list_models_paged_iteration(self): iterator = page.iterate_all() for index in range(2): model = next(iterator) - assert model.display_name == 'displayName{0}'.format(index+1) + assert model.display_name == f'displayName{index+1}' assert len(recorder) == 1 # Page 2 diff --git a/tests/test_project_management.py b/tests/test_project_management.py index a242f523f..89e48c2e5 100644 --- a/tests/test_project_management.py +++ b/tests/test_project_management.py @@ -545,7 +545,7 @@ def test_custom_timeout(self, timeout): 'projectId': 'test-project-id' } app = firebase_admin.initialize_app( - testutils.MockCredential(), options, 'timeout-{0}'.format(timeout)) + testutils.MockCredential(), options, f'timeout-{timeout}') project_management_service = project_management._get_project_management_service(app) assert project_management_service._client.timeout == timeout @@ -820,7 +820,7 @@ def test_list_android_apps_rpc_error(self): assert len(recorder) == 1 def test_list_android_apps_empty_list(self): - recorder = self._instrument_service(statuses=[200], responses=[json.dumps(dict())]) + recorder = self._instrument_service(statuses=[200], responses=[json.dumps({})]) android_apps = project_management.list_android_apps() @@ -883,7 +883,7 @@ def test_list_ios_apps_rpc_error(self): assert len(recorder) == 1 def test_list_ios_apps_empty_list(self): - recorder = self._instrument_service(statuses=[200], responses=[json.dumps(dict())]) + recorder = self._instrument_service(statuses=[200], responses=[json.dumps({})]) ios_apps = project_management.list_ios_apps() diff --git a/tests/test_remote_config.py b/tests/test_remote_config.py index 14b54838f..7bbf9721d 100644 --- a/tests/test_remote_config.py +++ b/tests/test_remote_config.py @@ -865,7 +865,7 @@ async def test_rc_instance_get_server_template(self): template = await rc_instance.get_server_template() - assert template.parameters == dict(test_key="test_value") + assert template.parameters == {"test_key": 'test_value'} assert str(template.version) == 'test' assert str(template.etag) == 'etag' diff --git a/tests/test_storage.py b/tests/test_storage.py index e15c4e2ab..c874ef640 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -33,7 +33,7 @@ def test_invalid_config(): with pytest.raises(ValueError): storage.bucket() -@pytest.mark.parametrize('name', [None, '', 0, 1, True, False, list(), tuple(), dict()]) +@pytest.mark.parametrize('name', [None, '', 0, 1, True, False, [], tuple(), {}]) def test_invalid_name(name): with pytest.raises(ValueError): storage.bucket(name) diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index 156846343..900faa376 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -107,8 +107,8 @@ LIST_OIDC_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_oidc_provider_configs.json') LIST_SAML_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_saml_provider_configs.json') -INVALID_TENANT_IDS = [None, '', 0, 1, True, False, list(), tuple(), dict()] -INVALID_BOOLEANS = ['', 1, 0, list(), tuple(), dict()] +INVALID_TENANT_IDS = [None, '', 0, 1, True, False, [], tuple(), {}] +INVALID_BOOLEANS = ['', 1, 0, [], tuple(), {}] USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' PROVIDER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2/projects/mock-project-id' @@ -152,7 +152,7 @@ def _instrument_provider_mgt(client, status, payload): class TestTenant: - @pytest.mark.parametrize('data', [None, 'foo', 0, 1, True, False, list(), tuple(), dict()]) + @pytest.mark.parametrize('data', [None, 'foo', 0, 1, True, False, [], tuple(), {}]) def test_invalid_data(self, data): with pytest.raises(ValueError): tenant_mgt.Tenant(data) @@ -197,7 +197,7 @@ def test_get_tenant(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) + assert req.url == f'{TENANT_MGT_URL_PREFIX}/tenants/tenant-id' assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' assert req.headers['x-goog-api-client'] == expected_metrics_header @@ -216,7 +216,7 @@ def test_tenant_not_found(self, tenant_mgt_app): class TestCreateTenant: - @pytest.mark.parametrize('display_name', [True, False, 1, 0, list(), tuple(), dict()]) + @pytest.mark.parametrize('display_name', [True, False, 1, 0, [], tuple(), {}]) def test_invalid_display_name_type(self, display_name, tenant_mgt_app): with pytest.raises(ValueError) as excinfo: tenant_mgt.create_tenant(display_name=display_name, app=tenant_mgt_app) @@ -290,7 +290,7 @@ def _assert_request(self, recorder, body): assert len(recorder) == 1 req = recorder[0] assert req.method == 'POST' - assert req.url == '{0}/tenants'.format(TENANT_MGT_URL_PREFIX) + assert req.url == f'{TENANT_MGT_URL_PREFIX}/tenants' assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' assert req.headers['x-goog-api-client'] == expected_metrics_header @@ -306,7 +306,7 @@ def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): tenant_mgt.update_tenant(tenant_id, display_name='My Tenant', app=tenant_mgt_app) assert str(excinfo.value).startswith('Tenant ID must be a non-empty string') - @pytest.mark.parametrize('display_name', [True, False, 1, 0, list(), tuple(), dict()]) + @pytest.mark.parametrize('display_name', [True, False, 1, 0, [], tuple(), {}]) def test_invalid_display_name_type(self, display_name, tenant_mgt_app): with pytest.raises(ValueError) as excinfo: tenant_mgt.update_tenant('tenant-id', display_name=display_name, app=tenant_mgt_app) @@ -390,8 +390,7 @@ def _assert_request(self, recorder, body, mask): assert len(recorder) == 1 req = recorder[0] assert req.method == 'PATCH' - assert req.url == '{0}/tenants/tenant-id?updateMask={1}'.format( - TENANT_MGT_URL_PREFIX, ','.join(mask)) + assert req.url == f'{TENANT_MGT_URL_PREFIX}/tenants/tenant-id?updateMask={",".join(mask)}' assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' assert req.headers['x-goog-api-client'] == expected_metrics_header @@ -414,7 +413,7 @@ def test_delete_tenant(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'DELETE' - assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) + assert req.url == f'{TENANT_MGT_URL_PREFIX}/tenants/tenant-id' assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' assert req.headers['x-goog-api-client'] == expected_metrics_header @@ -433,12 +432,12 @@ def test_tenant_not_found(self, tenant_mgt_app): class TestListTenants: - @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + @pytest.mark.parametrize('arg', [None, 'foo', [], {}, 0, -1, 101, False]) def test_invalid_max_results(self, tenant_mgt_app, arg): with pytest.raises(ValueError): tenant_mgt.list_tenants(max_results=arg, app=tenant_mgt_app) - @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, True, False]) + @pytest.mark.parametrize('arg', ['', [], {}, 0, -1, True, False]) def test_invalid_page_token(self, tenant_mgt_app, arg): with pytest.raises(ValueError): tenant_mgt.list_tenants(page_token=arg, app=tenant_mgt_app) @@ -480,7 +479,7 @@ def test_list_tenants_paged_iteration(self, tenant_mgt_app): iterator = page.iterate_all() for index in range(3): tenant = next(iterator) - assert tenant.tenant_id == 'tenant{0}'.format(index) + assert tenant.tenant_id == f'tenant{index}' self._assert_request(recorder) # Page 2 (also the last page) @@ -551,7 +550,7 @@ def _assert_tenants_page(self, page): assert isinstance(page, tenant_mgt.ListTenantsPage) assert len(page.tenants) == 2 for idx, tenant in enumerate(page.tenants): - _assert_tenant(tenant, 'tenant{0}'.format(idx)) + _assert_tenant(tenant, f'tenant{idx}') def _assert_request(self, recorder, expected=None): if expected is None: @@ -671,8 +670,7 @@ def test_revoke_refresh_tokens(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'POST' - assert req.url == '{0}/tenants/tenant-id/accounts:update'.format( - USER_MGT_URL_PREFIX) + assert req.url == f'{USER_MGT_URL_PREFIX}/tenants/tenant-id/accounts:update' body = json.loads(req.body.decode()) assert body['localId'] == 'testuser' assert 'validSince' in body @@ -693,8 +691,9 @@ def test_list_users(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/tenants/tenant-id/accounts:batchGet?maxResults=1000'.format( - USER_MGT_URL_PREFIX) + assert req.url == ( + f'{USER_MGT_URL_PREFIX}/tenants/tenant-id/accounts:batchGet?maxResults=1000' + ) def test_import_users(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -765,8 +764,9 @@ def test_get_oidc_provider_config(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/tenants/tenant-id/oauthIdpConfigs/oidc.provider'.format( - PROVIDER_MGT_URL_PREFIX) + assert req.url == ( + f'{PROVIDER_MGT_URL_PREFIX}/tenants/tenant-id/oauthIdpConfigs/oidc.provider' + ) def test_create_oidc_provider_config(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -791,7 +791,7 @@ def test_update_oidc_provider_config(self, tenant_mgt_app): self._assert_oidc_provider_config(provider_config) mask = ['clientId', 'displayName', 'enabled', 'issuer'] - url = '/oauthIdpConfigs/oidc.provider?updateMask={0}'.format(','.join(mask)) + url = f'/oauthIdpConfigs/oidc.provider?updateMask={",".join(mask)}' self._assert_request( recorder, url, OIDC_PROVIDER_CONFIG_REQUEST, method='PATCH', prefix=PROVIDER_MGT_URL_PREFIX) @@ -805,8 +805,9 @@ def test_delete_oidc_provider_config(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'DELETE' - assert req.url == '{0}/tenants/tenant-id/oauthIdpConfigs/oidc.provider'.format( - PROVIDER_MGT_URL_PREFIX) + assert req.url == ( + f'{PROVIDER_MGT_URL_PREFIX}/tenants/tenant-id/oauthIdpConfigs/oidc.provider' + ) def test_list_oidc_provider_configs(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -819,7 +820,7 @@ def test_list_oidc_provider_configs(self, tenant_mgt_app): assert len(page.provider_configs) == 2 for provider_config in page.provider_configs: self._assert_oidc_provider_config( - provider_config, want_id='oidc.provider{0}'.format(index)) + provider_config, want_id=f'oidc.provider{index}') index += 1 assert page.next_page_token == '' @@ -831,8 +832,9 @@ def test_list_oidc_provider_configs(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}{1}'.format( - PROVIDER_MGT_URL_PREFIX, '/tenants/tenant-id/oauthIdpConfigs?pageSize=100') + assert req.url == ( + f'{PROVIDER_MGT_URL_PREFIX}/tenants/tenant-id/oauthIdpConfigs?pageSize=100' + ) def test_get_saml_provider_config(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -844,8 +846,9 @@ def test_get_saml_provider_config(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}/tenants/tenant-id/inboundSamlConfigs/saml.provider'.format( - PROVIDER_MGT_URL_PREFIX) + assert req.url == ( + f'{PROVIDER_MGT_URL_PREFIX}/tenants/tenant-id/inboundSamlConfigs/saml.provider' + ) def test_create_saml_provider_config(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -877,7 +880,7 @@ def test_update_saml_provider_config(self, tenant_mgt_app): 'displayName', 'enabled', 'idpConfig.idpCertificates', 'idpConfig.idpEntityId', 'idpConfig.ssoUrl', 'spConfig.callbackUri', 'spConfig.spEntityId', ] - url = '/inboundSamlConfigs/saml.provider?updateMask={0}'.format(','.join(mask)) + url = f'/inboundSamlConfigs/saml.provider?updateMask={",".join(mask)}' self._assert_request( recorder, url, SAML_PROVIDER_CONFIG_REQUEST, method='PATCH', prefix=PROVIDER_MGT_URL_PREFIX) @@ -891,8 +894,9 @@ def test_delete_saml_provider_config(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'DELETE' - assert req.url == '{0}/tenants/tenant-id/inboundSamlConfigs/saml.provider'.format( - PROVIDER_MGT_URL_PREFIX) + assert req.url == ( + f'{PROVIDER_MGT_URL_PREFIX}/tenants/tenant-id/inboundSamlConfigs/saml.provider' + ) def test_list_saml_provider_configs(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -905,7 +909,7 @@ def test_list_saml_provider_configs(self, tenant_mgt_app): assert len(page.provider_configs) == 2 for provider_config in page.provider_configs: self._assert_saml_provider_config( - provider_config, want_id='saml.provider{0}'.format(index)) + provider_config, want_id=f'saml.provider{index}') index += 1 assert page.next_page_token == '' @@ -917,8 +921,9 @@ def test_list_saml_provider_configs(self, tenant_mgt_app): assert len(recorder) == 1 req = recorder[0] assert req.method == 'GET' - assert req.url == '{0}{1}'.format( - PROVIDER_MGT_URL_PREFIX, '/tenants/tenant-id/inboundSamlConfigs?pageSize=100') + assert req.url == ( + f'{PROVIDER_MGT_URL_PREFIX}/tenants/tenant-id/inboundSamlConfigs?pageSize=100' + ) def test_tenant_not_found(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) @@ -937,7 +942,7 @@ def _assert_request( assert len(recorder) == 1 req = recorder[0] assert req.method == method - assert req.url == '{0}/tenants/tenant-id{1}'.format(prefix, want_url) + assert req.url == f'{prefix}/tenants/tenant-id{want_url}' assert req.headers['X-Client-Version'] == f'Python/Admin/{firebase_admin.__version__}' expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag' assert req.headers['x-goog-api-client'] == expected_metrics_header diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index fe0b28dbe..384bc22c3 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -48,8 +48,8 @@ MOCK_SERVICE_ACCOUNT_EMAIL = MOCK_CREDENTIAL.service_account_email MOCK_REQUEST = testutils.MockRequest(200, MOCK_PUBLIC_CERTS) -INVALID_STRINGS = [None, '', 0, 1, True, False, list(), tuple(), dict()] -INVALID_BOOLS = [None, '', 'foo', 0, 1, list(), tuple(), dict()] +INVALID_STRINGS = [None, '', 0, 1, True, False, [], tuple(), {}] +INVALID_BOOLS = [None, '', 'foo', 0, 1, [], tuple(), {}] INVALID_JWT_ARGS = { 'NoneToken': None, 'EmptyToken': '', @@ -63,7 +63,7 @@ ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' AUTH_EMULATOR_HOST = 'localhost:9099' -EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v1'.format(AUTH_EMULATOR_HOST) +EMULATED_ID_TOOLKIT_URL = f'http://{AUTH_EMULATOR_HOST}/identitytoolkit.googleapis.com/v1' TOKEN_MGT_URLS = { 'ID_TOOLKIT': ID_TOOLKIT_URL, } @@ -136,8 +136,9 @@ def _get_session_cookie( payload_overrides=None, header_overrides=None, current_time=MOCK_CURRENT_TIME): payload_overrides = payload_overrides or {} if 'iss' not in payload_overrides: - payload_overrides['iss'] = 'https://session.firebase.google.com/{0}'.format( - MOCK_CREDENTIAL.project_id) + payload_overrides['iss'] = ( + f'https://session.firebase.google.com/{MOCK_CREDENTIAL.project_id}' + ) return _get_id_token(payload_overrides, header_overrides, current_time=current_time) def _instrument_user_manager(app, status, payload): @@ -282,7 +283,7 @@ def test_sign_with_iam(self): testutils.MockCredential(), name='iam-signer-app', options=options) try: signature = base64.b64encode(b'test').decode() - iam_resp = '{{"signedBlob": "{0}"}}'.format(signature) + iam_resp = json.dumps({'signedBlob': signature}) _overwrite_iam_request(app, testutils.MockRequest(200, iam_resp)) custom_token = auth.create_custom_token(MOCK_UID, app=app).decode() assert custom_token.endswith('.' + signature.rstrip('=')) @@ -319,8 +320,7 @@ def test_sign_with_discovered_service_account(self): # Now invoke the IAM signer. signature = base64.b64encode(b'test').decode() - request.response = testutils.MockResponse( - 200, '{{"signedBlob": "{0}"}}'.format(signature)) + request.response = testutils.MockResponse(200, json.dumps({'signedBlob': signature})) custom_token = auth.create_custom_token(MOCK_UID, app=app).decode() assert custom_token.endswith('.' + signature.rstrip('=')) self._verify_signer(custom_token, 'discovered-service-account') @@ -354,13 +354,13 @@ def _verify_signer(self, token, signer): class TestCreateSessionCookie: - @pytest.mark.parametrize('id_token', [None, '', 0, 1, True, False, list(), dict(), tuple()]) + @pytest.mark.parametrize('id_token', [None, '', 0, 1, True, False, [], {}, tuple()]) def test_invalid_id_token(self, user_mgt_app, id_token): with pytest.raises(ValueError): auth.create_session_cookie(id_token, expires_in=3600, app=user_mgt_app) @pytest.mark.parametrize('expires_in', [ - None, '', True, False, list(), dict(), tuple(), + None, '', True, False, [], {}, tuple(), _token_gen.MIN_SESSION_COOKIE_DURATION_SECONDS - 1, _token_gen.MAX_SESSION_COOKIE_DURATION_SECONDS + 1, ]) diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 34b698be4..2c747ee5e 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -32,10 +32,10 @@ from tests import testutils -INVALID_STRINGS = [None, '', 0, 1, True, False, list(), tuple(), dict()] -INVALID_DICTS = [None, 'foo', 0, 1, True, False, list(), tuple()] -INVALID_INTS = [None, 'foo', '1', -1, 1.1, True, False, list(), tuple(), dict()] -INVALID_TIMESTAMPS = ['foo', '1', 0, -1, 1.1, True, False, list(), tuple(), dict()] +INVALID_STRINGS = [None, '', 0, 1, True, False, [], tuple(), {}] +INVALID_DICTS = [None, 'foo', 0, 1, True, False, [], tuple()] +INVALID_INTS = [None, 'foo', '1', -1, 1.1, True, False, [], tuple(), {}] +INVALID_TIMESTAMPS = ['foo', '1', 0, -1, 1.1, True, False, [], tuple(), {}] MOCK_GET_USER_RESPONSE = testutils.resource('get_user.json') MOCK_LIST_USERS_RESPONSE = testutils.resource('list_users.json') @@ -56,7 +56,7 @@ ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' AUTH_EMULATOR_HOST = 'localhost:9099' -EMULATED_ID_TOOLKIT_URL = 'http://{}/identitytoolkit.googleapis.com/v1'.format(AUTH_EMULATOR_HOST) +EMULATED_ID_TOOLKIT_URL = f'http://{AUTH_EMULATOR_HOST}/identitytoolkit.googleapis.com/v1' URL_PROJECT_SUFFIX = '/projects/mock-project-id' USER_MGT_URLS = { 'ID_TOOLKIT': ID_TOOLKIT_URL, @@ -135,7 +135,7 @@ def _check_request(recorder, want_url, want_body=None, want_timeout=None): assert len(recorder) == 1 req = recorder[0] assert req.method == 'POST' - assert req.url == '{0}{1}'.format(USER_MGT_URLS['PREFIX'], want_url) + assert req.url == f'{USER_MGT_URLS["PREFIX"]}{want_url}' expected_metrics_header = [ _utils.get_metrics_header(), _utils.get_metrics_header() + ' mock-cred-metric-tag' @@ -538,7 +538,7 @@ def test_user_already_exists(self, user_mgt_app, error_code): with pytest.raises(exc_type) as excinfo: auth.create_user(app=user_mgt_app) assert isinstance(excinfo.value, exceptions.AlreadyExistsError) - assert str(excinfo.value) == '{0} ({1}).'.format(exc_type.default_message, error_code) + assert str(excinfo.value) == f'{exc_type.default_message} ({error_code}).' assert excinfo.value.http_response is not None assert excinfo.value.cause is not None @@ -704,15 +704,14 @@ def test_single_reserved_claim(self, user_mgt_app, key): claims = {key : 'value'} with pytest.raises(ValueError) as excinfo: auth.set_custom_user_claims('user', claims, app=user_mgt_app) - assert str(excinfo.value) == 'Claim "{0}" is reserved, and must not be set.'.format(key) + assert str(excinfo.value) == f'Claim "{key}" is reserved, and must not be set.' def test_multiple_reserved_claims(self, user_mgt_app): claims = {key : 'value' for key in _auth_utils.RESERVED_CLAIMS} with pytest.raises(ValueError) as excinfo: auth.set_custom_user_claims('user', claims, app=user_mgt_app) joined = ', '.join(sorted(claims.keys())) - assert str(excinfo.value) == ('Claims "{0}" are reserved, and must not be ' - 'set.'.format(joined)) + assert str(excinfo.value) == f'Claims "{joined}" are reserved, and must not be set.' def test_large_claims_payload(self, user_mgt_app): claims = {'key' : 'A'*1000} @@ -830,12 +829,12 @@ def test_success(self, user_mgt_app): class TestListUsers: - @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 1001, False]) + @pytest.mark.parametrize('arg', [None, 'foo', [], {}, 0, -1, 1001, False]) def test_invalid_max_results(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_users(max_results=arg, app=user_mgt_app) - @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 1001, False]) + @pytest.mark.parametrize('arg', ['', [], {}, 0, -1, 1001, False]) def test_invalid_page_token(self, user_mgt_app, arg): with pytest.raises(ValueError): auth.list_users(page_token=arg, app=user_mgt_app) @@ -887,7 +886,7 @@ def test_list_users_paged_iteration(self, user_mgt_app): iterator = page.iterate_all() for index in range(3): user = next(iterator) - assert user.uid == 'user{0}'.format(index+1) + assert user.uid == f'user{index+1}' assert len(recorder) == 1 self._check_rpc_calls(recorder) @@ -912,7 +911,7 @@ def test_list_users_iterator_state(self, user_mgt_app): iterator = page.iterate_all() for user in iterator: index += 1 - assert user.uid == 'user{0}'.format(index) + assert user.uid == f'user{index}' if index == 2: break @@ -986,7 +985,7 @@ def _check_page(self, page): assert len(page.users) == 2 for user in page.users: assert isinstance(user, auth.ExportedUserRecord) - _check_user_record(user, 'testuser{0}'.format(index)) + _check_user_record(user, f'testuser{index}') assert user.password_hash == 'passwordHash' assert user.password_salt == 'passwordSalt' index += 1 @@ -1061,8 +1060,8 @@ class TestImportUserRecord: [{'email': arg} for arg in INVALID_STRINGS[1:] + ['not-an-email']] + [{'photo_url': arg} for arg in INVALID_STRINGS[1:] + ['not-a-url']] + [{'phone_number': arg} for arg in INVALID_STRINGS[1:] + ['not-a-phone']] + - [{'password_hash': arg} for arg in INVALID_STRINGS[1:] + [u'test']] + - [{'password_salt': arg} for arg in INVALID_STRINGS[1:] + [u'test']] + + [{'password_hash': arg} for arg in INVALID_STRINGS[1:] + ['test']] + + [{'password_salt': arg} for arg in INVALID_STRINGS[1:] + ['test']] + [{'custom_claims': arg} for arg in INVALID_DICTS[1:] + ['"json"', {'key': 'a'*1000}]] + [{'provider_data': arg} for arg in ['foo', 1, True]] ) @@ -1245,13 +1244,13 @@ def test_invalid_standard_scrypt(self, arg): class TestImportUsers: - @pytest.mark.parametrize('arg', [None, list(), tuple(), dict(), 0, 1, 'foo']) + @pytest.mark.parametrize('arg', [None, [], tuple(), {}, 0, 1, 'foo']) def test_invalid_users(self, user_mgt_app, arg): with pytest.raises(Exception): auth.import_users(arg, app=user_mgt_app) def test_too_many_users(self, user_mgt_app): - users = [auth.ImportUserRecord(uid='test{0}'.format(i)) for i in range(1001)] + users = [auth.ImportUserRecord(uid=f'test{i}') for i in range(1001)] with pytest.raises(ValueError): auth.import_users(users, app=user_mgt_app) @@ -1384,7 +1383,7 @@ def test_valid_data(self): {'android_install_app':'nonboolean'}, {'dynamic_link_domain': False}, {'ios_bundle_id':11}, - {'android_package_name':dict()}, + {'android_package_name':{}}, {'android_minimum_version':tuple()}, {'android_minimum_version':'7'}, {'android_install_app': True}]) diff --git a/tests/testutils.py b/tests/testutils.py index 0505eb6c7..598a929b4 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -33,7 +33,7 @@ def resource_filename(filename): def resource(filename): """Returns the contents of a test resource.""" - with open(resource_filename(filename), 'r') as file_obj: + with open(resource_filename(filename), 'r', encoding='utf-8') as file_obj: return file_obj.read() From d6807f58c8adcfb08bd40601436edfe13ad4a062 Mon Sep 17 00:00:00 2001 From: ViktorSky Date: Mon, 30 Jun 2025 15:09:17 -0300 Subject: [PATCH 05/13] add type annotations --- firebase_admin/__init__.py | 84 ++++-- firebase_admin/_auth_client.py | 202 ++++++++++--- firebase_admin/_auth_providers.py | 216 +++++++++----- firebase_admin/_auth_utils.py | 413 ++++++++++++++++++++++----- firebase_admin/_http_client.py | 155 ++++++---- firebase_admin/_messaging_encoder.py | 202 +++++++++---- firebase_admin/_messaging_utils.py | 226 ++++++++++++--- firebase_admin/_retry.py | 55 ++-- firebase_admin/_rfc3339.py | 21 +- firebase_admin/_sseclient.py | 54 ++-- firebase_admin/_token_gen.py | 225 +++++++++++---- firebase_admin/_user_identifier.py | 33 ++- firebase_admin/_user_import.py | 174 ++++++----- firebase_admin/_user_mgt.py | 291 ++++++++++++------- firebase_admin/_utils.py | 116 +++++--- firebase_admin/app_check.py | 55 ++-- firebase_admin/auth.py | 240 ++++++++++++---- firebase_admin/credentials.py | 82 +++--- firebase_admin/db.py | 282 ++++++++++++------ firebase_admin/exceptions.py | 200 ++++++++++--- firebase_admin/firestore.py | 29 +- firebase_admin/firestore_async.py | 34 ++- firebase_admin/functions.py | 119 ++++---- firebase_admin/instance_id.py | 15 +- firebase_admin/messaging.py | 178 +++++++----- firebase_admin/ml.py | 347 +++++++++++++--------- firebase_admin/project_management.py | 248 ++++++++++------ firebase_admin/remote_config.py | 245 ++++++++++------ firebase_admin/storage.py | 19 +- firebase_admin/tenant_mgt.py | 116 +++++--- pyrightconfig.json | 33 +++ requirements.txt | 6 +- setup.py | 7 +- 33 files changed, 3309 insertions(+), 1413 deletions(-) create mode 100644 pyrightconfig.json diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index 8c9f628e5..3d485c831 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -13,27 +13,41 @@ # limitations under the License. """Firebase Admin SDK for Python.""" -import datetime + import json import os import threading +from collections.abc import Callable +from typing import Any, Optional, TypeVar, Union, overload + +import google.auth.credentials +import google.auth.exceptions -from google.auth.credentials import Credentials as GoogleAuthCredentials -from google.auth.exceptions import DefaultCredentialsError from firebase_admin import credentials from firebase_admin.__about__ import __version__ +__all__ = ( + 'App', + 'delete_app', + 'get_app', + 'initialize_app', +) -_apps = {} +_T = TypeVar('_T') + +_apps: dict[str, 'App'] = {} _apps_lock = threading.RLock() -_clock = datetime.datetime.utcnow _DEFAULT_APP_NAME = '[DEFAULT]' _FIREBASE_CONFIG_ENV_VAR = 'FIREBASE_CONFIG' _CONFIG_VALID_KEYS = ['databaseAuthVariableOverride', 'databaseURL', 'httpTimeout', 'projectId', 'storageBucket'] -def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): +def initialize_app( + credential: Optional[Union[credentials.Base, google.auth.credentials.Credentials]] = None, + options: Optional[dict[str, Any]] = None, + name: str = _DEFAULT_APP_NAME, +) -> 'App': """Initializes and returns a new App instance. Creates a new App instance using the specified options @@ -86,7 +100,7 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): 'you call initialize_app().') -def delete_app(app): +def delete_app(app: 'App') -> None: """Gracefully deletes an App instance. Args: @@ -113,7 +127,7 @@ def delete_app(app): 'second argument.') -def get_app(name=_DEFAULT_APP_NAME): +def get_app(name: str = _DEFAULT_APP_NAME) -> 'App': """Retrieves an App instance by name. Args: @@ -147,7 +161,7 @@ def get_app(name=_DEFAULT_APP_NAME): class _AppOptions: """A collection of configuration options for an App.""" - def __init__(self, options): + def __init__(self, options: Optional[dict[str, Any]]) -> None: if options is None: options = self._load_from_environment() @@ -157,11 +171,16 @@ def __init__(self, options): 'Options must be a dictionary.') self._options = options - def get(self, key, default=None): + @overload + def get(self, key: str, default: None = None) -> Optional[Any]: ... + # possible issue: needs return Any | _T ? + @overload + def get(self, key: str, default: _T) -> _T: ... + def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: """Returns the option identified by the provided key.""" return self._options.get(key, default) - def _load_from_environment(self): + def _load_from_environment(self) -> dict[str, Any]: """Invoked when no options are passed to __init__, loads options from FIREBASE_CONFIG. If the value of the FIREBASE_CONFIG environment variable starts with "{" an attempt is made @@ -194,7 +213,12 @@ class App: common to all Firebase APIs. """ - def __init__(self, name, credential, options): + def __init__( + self, + name: str, + credential: Union[credentials.Base, google.auth.credentials.Credentials], + options: Optional[dict[str, Any]], + ) -> None: """Constructs a new App using the provided name and options. Args: @@ -211,7 +235,7 @@ def __init__(self, name, credential, options): 'non-empty string.') self._name = name - if isinstance(credential, GoogleAuthCredentials): + if isinstance(credential, google.auth.credentials.Credentials): self._credential = credentials._ExternalCredentials(credential) # pylint: disable=protected-access elif isinstance(credential, credentials.Base): self._credential = credential @@ -220,37 +244,38 @@ def __init__(self, name, credential, options): 'with a valid credential instance.') self._options = _AppOptions(options) self._lock = threading.RLock() - self._services = {} + self._services: Optional[dict[str, Any]] = {} App._validate_project_id(self._options.get('projectId')) - self._project_id_initialized = False + self._project_id_initialized: bool = False @classmethod - def _validate_project_id(cls, project_id): + def _validate_project_id(cls, project_id: Optional[Any]) -> Optional[str]: if project_id is not None and not isinstance(project_id, str): raise ValueError( f'Invalid project ID: "{project_id}". project ID must be a string.') + return project_id @property - def name(self): + def name(self) -> str: return self._name @property - def credential(self): + def credential(self) -> credentials.Base: return self._credential @property - def options(self): + def options(self) -> _AppOptions: return self._options @property - def project_id(self): + def project_id(self) -> Optional[str]: if not self._project_id_initialized: self._project_id = self._lookup_project_id() self._project_id_initialized = True return self._project_id - def _lookup_project_id(self): + def _lookup_project_id(self) -> Optional[str]: """Looks up the Firebase project ID associated with an App. If a ``projectId`` is specified in app options, it is returned. Then tries to @@ -264,8 +289,8 @@ def _lookup_project_id(self): project_id = self._options.get('projectId') if not project_id: try: - project_id = self._credential.project_id - except (AttributeError, DefaultCredentialsError): + project_id = getattr(self._credential, 'project_id') + except (AttributeError, google.auth.exceptions.DefaultCredentialsError): pass if not project_id: project_id = os.environ.get('GOOGLE_CLOUD_PROJECT', @@ -273,7 +298,7 @@ def _lookup_project_id(self): App._validate_project_id(self._options.get('projectId')) return project_id - def _get_service(self, name, initializer): + def _get_service(self, name: str, initializer: Callable[['App'], _T]) -> _T: """Returns the service instance identified by the given name. Services are functional entities exposed by the Admin SDK (e.g. auth, database). Each @@ -303,7 +328,7 @@ def _get_service(self, name, initializer): self._services[name] = initializer(self) return self._services[name] - def _cleanup(self): + def _cleanup(self) -> None: """Cleans up any services associated with this App. Checks whether each service contains a close() method, and calls it if available. @@ -311,7 +336,8 @@ def _cleanup(self): any services started by the App. """ with self._lock: - for service in self._services.values(): - if hasattr(service, 'close') and hasattr(service.close, '__call__'): - service.close() - self._services = None + if self._services is not None: + for service in self._services.values(): + if hasattr(service, 'close') and hasattr(service.close, '__call__'): + service.close() + self._services = None diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 74261fa37..170c05851 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -15,6 +15,8 @@ """Firebase auth client sub module.""" import time +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, Optional, Union import firebase_admin from firebase_admin import _auth_providers @@ -25,12 +27,18 @@ from firebase_admin import _user_import from firebase_admin import _user_mgt from firebase_admin import _utils +from firebase_admin import exceptions + +if TYPE_CHECKING: + from _typeshed import ConvertibleToInt + +__all__ = ('Client',) class Client: """Firebase Authentication client scoped to a specific tenant.""" - def __init__(self, app, tenant_id=None): + def __init__(self, app: firebase_admin.App, tenant_id: Optional[str] = None) -> None: if not app.project_id: raise ValueError("""A project ID is required to access the auth service. 1. Use a service account credential, or @@ -41,7 +49,7 @@ def __init__(self, app, tenant_id=None): version_header = f'Python/Admin/{firebase_admin.__version__}' timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) # Non-default endpoint URLs for emulator support are set in this dict later. - endpoint_urls = {} + endpoint_urls: dict[str, str] = {} self.emulated = False # If an emulator is present, check that the given value matches the expected format and set @@ -70,11 +78,15 @@ def __init__(self, app, tenant_id=None): http_client, app.project_id, tenant_id, url_override=endpoint_urls.get('v2')) @property - def tenant_id(self): + def tenant_id(self) -> Optional[str]: """Tenant ID associated with this client.""" return self._tenant_id - def create_custom_token(self, uid, developer_claims=None): + def create_custom_token( + self, + uid: str, + developer_claims: Optional[dict[str, Any]] = None, + ) -> bytes: """Builds and signs a Firebase custom auth token. Args: @@ -92,7 +104,12 @@ def create_custom_token(self, uid, developer_claims=None): return self._token_generator.create_custom_token( uid, developer_claims, tenant_id=self.tenant_id) - def verify_id_token(self, id_token, check_revoked=False, clock_skew_seconds=0): + def verify_id_token( + self, + id_token: Union[bytes, str], + check_revoked: bool = False, + clock_skew_seconds: int = 0, + ) -> dict[str, Any]: """Verifies the signature and data for the provided JWT. Accepts a signed token string, verifies that it is current, was issued @@ -139,7 +156,7 @@ def verify_id_token(self, id_token, check_revoked=False, clock_skew_seconds=0): verified_claims, _token_gen.RevokedIdTokenError, 'ID token') return verified_claims - def revoke_refresh_tokens(self, uid): + def revoke_refresh_tokens(self, uid: str) -> None: """Revokes all refresh tokens for an existing user. This method updates the user's ``tokens_valid_after_timestamp`` to the current UTC @@ -160,7 +177,7 @@ def revoke_refresh_tokens(self, uid): """ self._user_manager.update_user(uid, valid_since=int(time.time())) - def get_user(self, uid): + def get_user(self, uid: str) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified user ID. Args: @@ -177,7 +194,7 @@ def get_user(self, uid): response = self._user_manager.get_user(uid=uid) return _user_mgt.UserRecord(response) - def get_user_by_email(self, email): + def get_user_by_email(self, email: str) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified user email. Args: @@ -194,7 +211,7 @@ def get_user_by_email(self, email): response = self._user_manager.get_user(email=email) return _user_mgt.UserRecord(response) - def get_user_by_phone_number(self, phone_number): + def get_user_by_phone_number(self, phone_number: str) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified phone number. Args: @@ -211,7 +228,7 @@ def get_user_by_phone_number(self, phone_number): response = self._user_manager.get_user(phone_number=phone_number) return _user_mgt.UserRecord(response) - def get_users(self, identifiers): + def get_users(self, identifiers: 'Sequence[_user_identifier.UserIdentifier]') -> _user_mgt.GetUsersResult: """Gets the user data corresponding to the specified identifiers. There are no ordering guarantees; in particular, the nth entry in the @@ -236,7 +253,7 @@ def get_users(self, identifiers): """ response = self._user_manager.get_users(identifiers=identifiers) - def _matches(identifier, user_record): + def _matches(identifier: _user_identifier.UserIdentifier, user_record: _user_mgt.UserRecord) -> bool: if isinstance(identifier, _user_identifier.UidIdentifier): return identifier.uid == user_record.uid if isinstance(identifier, _user_identifier.EmailIdentifier): @@ -252,7 +269,10 @@ def _matches(identifier, user_record): ), False) raise TypeError(f"Unexpected type: {type(identifier)}") - def _is_user_found(identifier, user_records): + def _is_user_found( + identifier: _user_identifier.UserIdentifier, + user_records: list[_user_mgt.UserRecord], + ) -> bool: return any(_matches(identifier, user_record) for user_record in user_records) users = [_user_mgt.UserRecord(user) for user in response] @@ -261,7 +281,11 @@ def _is_user_found(identifier, user_records): return _user_mgt.GetUsersResult(users=users, not_found=not_found) - def list_users(self, page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS): + def list_users( + self, + page_token: Optional[str] = None, + max_results: int = _user_mgt.MAX_LIST_USERS_RESULTS, + ) -> _user_mgt.ListUsersPage: """Retrieves a page of user accounts from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -283,11 +307,23 @@ def list_users(self, page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESUL ValueError: If max_results or page_token are invalid. FirebaseError: If an error occurs while retrieving the user accounts. """ - def download(page_token, max_results): + def download(page_token: Optional[str], max_results: int) -> dict[str, Any]: return self._user_manager.list_users(page_token, max_results) return _user_mgt.ListUsersPage(download, page_token, max_results) - def create_user(self, **kwargs): # pylint: disable=differing-param-doc + def create_user( + self, + *, + uid: Optional[str] = None, + display_name: Optional[str] = None, + email: Optional[str] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + password: Optional[str] = None, + disabled: Optional[bool] = None, + email_verified: Optional[bool] = None, + **kwargs: Any, + ) -> _user_mgt.UserRecord: """Creates a new user account with the specified properties. Args: @@ -311,10 +347,27 @@ def create_user(self, **kwargs): # pylint: disable=differing-param-doc ValueError: If the specified user properties are invalid. FirebaseError: If an error occurs while creating the user account. """ - uid = self._user_manager.create_user(**kwargs) + uid = self._user_manager.create_user(uid=uid, display_name=display_name, email=email, + phone_number=phone_number, photo_url=photo_url, password=password, disabled=disabled, + email_verified=email_verified, **kwargs) return self.get_user(uid=uid) - def update_user(self, uid, **kwargs): # pylint: disable=differing-param-doc + def update_user( + self, + uid: str, + *, + display_name: Optional[str] = None, + email: Optional[str] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + password: Optional[str] = None, + disabled: Optional[bool] = None, + email_verified: Optional[bool] = None, + valid_since: Optional['ConvertibleToInt'] = None, + custom_claims: Optional[Union[dict[str, Any], str]] = None, + providers_to_delete: Optional[list[str]] = None, + **kwargs: Any, + ) -> _user_mgt.UserRecord: """Updates an existing user account with the specified properties. Args: @@ -349,10 +402,16 @@ def update_user(self, uid, **kwargs): # pylint: disable=differing-param-doc ValueError: If the specified user ID or properties are invalid. FirebaseError: If an error occurs while updating the user account. """ - self._user_manager.update_user(uid, **kwargs) + self._user_manager.update_user(uid, display_name=display_name, email=email, phone_number=phone_number, + photo_url=photo_url, password=password, disabled=disabled, email_verified=email_verified, + valid_since=valid_since, custom_claims=custom_claims, providers_to_delete=providers_to_delete, **kwargs) return self.get_user(uid=uid) - def set_custom_user_claims(self, uid, custom_claims): + def set_custom_user_claims( + self, + uid: str, + custom_claims: Optional[Union[dict[str, Any], str]], + ) -> None: """Sets additional claims on an existing user account. Custom claims set via this function can be used to define user roles and privilege levels. @@ -375,7 +434,7 @@ def set_custom_user_claims(self, uid, custom_claims): custom_claims = _user_mgt.DELETE_ATTRIBUTE self._user_manager.update_user(uid, custom_claims=custom_claims) - def delete_user(self, uid): + def delete_user(self, uid: str) -> None: """Deletes the user identified by the specified user ID. Args: @@ -387,7 +446,7 @@ def delete_user(self, uid): """ self._user_manager.delete_user(uid) - def delete_users(self, uids): + def delete_users(self, uids: Sequence[str]) -> _user_mgt.DeleteUsersResult: """Deletes the users specified by the given identifiers. Deleting a non-existing user does not generate an error (the method is @@ -414,7 +473,11 @@ def delete_users(self, uids): result = self._user_manager.delete_users(uids, force_delete=True) return _user_mgt.DeleteUsersResult(result, len(uids)) - def import_users(self, users, hash_alg=None): + def import_users( + self, + users: Sequence[_user_import.ImportUserRecord], + hash_alg: Optional[_user_import.UserImportHash] = None, + ) -> _user_import.UserImportResult: """Imports the specified list of users into Firebase Auth. At most 1000 users can be imported at a time. This operation is optimized for bulk imports @@ -438,7 +501,11 @@ def import_users(self, users, hash_alg=None): result = self._user_manager.import_users(users, hash_alg) return _user_import.UserImportResult(result, len(users)) - def generate_password_reset_link(self, email, action_code_settings=None): + def generate_password_reset_link( + self, + email: Optional[str], + action_code_settings: Optional[_user_mgt.ActionCodeSettings] = None, + ) -> str: """Generates the out-of-band email action link for password reset flows for the specified email address. @@ -459,7 +526,11 @@ def generate_password_reset_link(self, email, action_code_settings=None): return self._user_manager.generate_email_action_link( 'PASSWORD_RESET', email, action_code_settings=action_code_settings) - def generate_email_verification_link(self, email, action_code_settings=None): + def generate_email_verification_link( + self, + email: Optional[str], + action_code_settings: Optional[_user_mgt.ActionCodeSettings] = None, + ) -> str: """Generates the out-of-band email action link for email verification flows for the specified email address. @@ -480,7 +551,11 @@ def generate_email_verification_link(self, email, action_code_settings=None): return self._user_manager.generate_email_action_link( 'VERIFY_EMAIL', email, action_code_settings=action_code_settings) - def generate_sign_in_with_email_link(self, email, action_code_settings): + def generate_sign_in_with_email_link( + self, + email: Optional[str], + action_code_settings: Optional[_user_mgt.ActionCodeSettings], + ) -> str: """Generates the out-of-band email action link for email link sign-in flows, using the action code settings provided. @@ -500,7 +575,7 @@ def generate_sign_in_with_email_link(self, email, action_code_settings): return self._user_manager.generate_email_action_link( 'EMAIL_SIGNIN', email, action_code_settings=action_code_settings) - def get_oidc_provider_config(self, provider_id): + def get_oidc_provider_config(self, provider_id: str) -> _auth_providers.OIDCProviderConfig: """Returns the ``OIDCProviderConfig`` with the given ID. Args: @@ -517,8 +592,16 @@ def get_oidc_provider_config(self, provider_id): return self._provider_manager.get_oidc_provider_config(provider_id) def create_oidc_provider_config( - self, provider_id, client_id, issuer, display_name=None, enabled=None, - client_secret=None, id_token_response_type=None, code_response_type=None): + self, + provider_id: str, + client_id: str, + issuer: str, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + client_secret: Optional[str] = None, + id_token_response_type: Optional[bool] = None, + code_response_type: Optional[bool] = None, + ) -> _auth_providers.OIDCProviderConfig: """Creates a new OIDC provider config from the given parameters. OIDC provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -556,8 +639,16 @@ def create_oidc_provider_config( id_token_response_type=id_token_response_type, code_response_type=code_response_type) def update_oidc_provider_config( - self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None, - client_secret=None, id_token_response_type=None, code_response_type=None): + self, + provider_id: str, + client_id: Optional[str] = None, + issuer: Optional[str] = None, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + client_secret: Optional[str] = None, + id_token_response_type: Optional[bool] = None, + code_response_type: Optional[bool] = None, + ) -> _auth_providers.OIDCProviderConfig: """Updates an existing OIDC provider config with the given parameters. Args: @@ -591,7 +682,7 @@ def update_oidc_provider_config( enabled=enabled, client_secret=client_secret, id_token_response_type=id_token_response_type, code_response_type=code_response_type) - def delete_oidc_provider_config(self, provider_id): + def delete_oidc_provider_config(self, provider_id: str) -> None: """Deletes the ``OIDCProviderConfig`` with the given ID. Args: @@ -605,7 +696,10 @@ def delete_oidc_provider_config(self, provider_id): self._provider_manager.delete_oidc_provider_config(provider_id) def list_oidc_provider_configs( - self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + self, + page_token: Optional[str] = None, + max_results: int = _auth_providers.MAX_LIST_CONFIGS_RESULTS, + ) -> _auth_providers._ListOIDCProviderConfigsPage: """Retrieves a page of OIDC provider configs from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -629,7 +723,7 @@ def list_oidc_provider_configs( """ return self._provider_manager.list_oidc_provider_configs(page_token, max_results) - def get_saml_provider_config(self, provider_id): + def get_saml_provider_config(self, provider_id: str) -> _auth_providers.SAMLProviderConfig: """Returns the ``SAMLProviderConfig`` with the given ID. Args: @@ -646,8 +740,16 @@ def get_saml_provider_config(self, provider_id): return self._provider_manager.get_saml_provider_config(provider_id) def create_saml_provider_config( - self, provider_id, idp_entity_id, sso_url, x509_certificates, rp_entity_id, - callback_url, display_name=None, enabled=None): + self, + provider_id: str, + idp_entity_id: str, + sso_url: str, + x509_certificates: list[str], + rp_entity_id: str, + callback_url: str, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + ) -> _auth_providers.SAMLProviderConfig: """Creates a new SAML provider config from the given parameters. SAML provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -686,8 +788,16 @@ def create_saml_provider_config( callback_url=callback_url, display_name=display_name, enabled=enabled) def update_saml_provider_config( - self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, - rp_entity_id=None, callback_url=None, display_name=None, enabled=None): + self, + provider_id: str, + idp_entity_id: Optional[str] = None, + sso_url: Optional[str] = None, + x509_certificates: Optional[list[str]] = None, + rp_entity_id: Optional[str] = None, + callback_url: Optional[str] = None, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + ) -> _auth_providers.SAMLProviderConfig: """Updates an existing SAML provider config with the given parameters. Args: @@ -715,7 +825,7 @@ def update_saml_provider_config( x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, callback_url=callback_url, display_name=display_name, enabled=enabled) - def delete_saml_provider_config(self, provider_id): + def delete_saml_provider_config(self, provider_id: str) -> None: """Deletes the ``SAMLProviderConfig`` with the given ID. Args: @@ -729,7 +839,10 @@ def delete_saml_provider_config(self, provider_id): self._provider_manager.delete_saml_provider_config(provider_id) def list_saml_provider_configs( - self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + self, + page_token: Optional[str] = None, + max_results: int = _auth_providers.MAX_LIST_CONFIGS_RESULTS, + ) -> _auth_providers._ListSAMLProviderConfigsPage: """Retrieves a page of SAML provider configs from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -753,9 +866,14 @@ def list_saml_provider_configs( """ return self._provider_manager.list_saml_provider_configs(page_token, max_results) - def _check_jwt_revoked_or_disabled(self, verified_claims, exc_type, label): - user = self.get_user(verified_claims.get('uid')) + def _check_jwt_revoked_or_disabled( + self, + verified_claims: dict[str, Any], + exc_type: Callable[[str], exceptions.FirebaseError], + label: str, + ) -> None: + user = self.get_user(verified_claims['uid']) if user.disabled: raise _auth_utils.UserDisabledError('The user record is disabled.') - if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: + if verified_claims['iat'] * 1000 < user.tokens_valid_after_timestamp: raise exc_type(f'The Firebase {label} has been revoked.') diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index cc7949526..9c1653e53 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -14,13 +14,28 @@ """Firebase auth providers management sub module.""" +from collections.abc import Callable +from typing import Any, Generic, Optional, cast +from typing_extensions import Self, TypeVar from urllib import parse import requests from firebase_admin import _auth_utils +from firebase_admin import _http_client from firebase_admin import _user_mgt +__all__ = ( + 'MAX_LIST_CONFIGS_RESULTS', + 'ListProviderConfigsPage', + 'OIDCProviderConfig', + 'ProviderConfig', + 'ProviderConfigClient', + 'SAMLProviderConfig', +) + +_ProviderConfigT = TypeVar('_ProviderConfigT', bound='ProviderConfig', default='ProviderConfig') + MAX_LIST_CONFIGS_RESULTS = 100 @@ -28,20 +43,20 @@ class ProviderConfig: """Parent type for all authentication provider config types.""" - def __init__(self, data): + def __init__(self, data: dict[str, Any]) -> None: self._data = data @property - def provider_id(self): - name = self._data['name'] + def provider_id(self) -> str: + name = cast(str, self._data['name']) return name.split('/')[-1] @property - def display_name(self): + def display_name(self) -> Optional[str]: return self._data.get('displayName') @property - def enabled(self): + def enabled(self) -> bool: return self._data.get('enabled', False) @@ -80,55 +95,60 @@ class SAMLProviderConfig(ProviderConfig): @property def idp_entity_id(self): - return self._data.get('idpConfig', {})['idpEntityId'] + return self._data['idpConfig']['idpEntityId'] @property def sso_url(self): - return self._data.get('idpConfig', {})['ssoUrl'] + return self._data['idpConfig']['ssoUrl'] @property def x509_certificates(self): - certs = self._data.get('idpConfig', {})['idpCertificates'] + certs = self._data['idpConfig']['idpCertificates'] return [c['x509Certificate'] for c in certs] @property def callback_url(self): - return self._data.get('spConfig', {})['callbackUri'] + return self._data['spConfig']['callbackUri'] @property def rp_entity_id(self): - return self._data.get('spConfig', {})['spEntityId'] + return self._data['spConfig']['spEntityId'] -class ListProviderConfigsPage: - """Represents a page of AuthProviderConfig instances retrieved from a Firebase project. +class ListProviderConfigsPage(Generic[_ProviderConfigT]): + """Represents a page of ProviderConfig instances retrieved from a Firebase project. Provides methods for traversing the provider configs included in this page, as well as retrieving subsequent pages. The iterator returned by ``iterate_all()`` can be used to iterate through all provider configs in the Firebase project starting from this page. """ - def __init__(self, download, page_token, max_results): + def __init__( + self, + download: Callable[[Optional[str], int], dict[str, Any]], + page_token: Optional[str], + max_results: int, + ) -> None: self._download = download self._max_results = max_results self._current = download(page_token, max_results) @property - def provider_configs(self): - """A list of ``AuthProviderConfig`` instances available in this page.""" + def provider_configs(self) -> list[_ProviderConfigT]: + """A list of ``ProviderConfig`` instances available in this page.""" raise NotImplementedError @property - def next_page_token(self): + def next_page_token(self) -> str: """Page token string for the next page (empty string indicates no more pages).""" return self._current.get('nextPageToken', '') @property - def has_next_page(self): + def has_next_page(self) -> bool: """A boolean indicating whether more pages are available.""" return bool(self.next_page_token) - def get_next_page(self): + def get_next_page(self) -> Optional[Self]: """Retrieves the next page of provider configs, if available. Returns: @@ -139,7 +159,7 @@ def get_next_page(self): return self.__class__(self._download, self.next_page_token, self._max_results) return None - def iterate_all(self): + def iterate_all(self) -> '_ProviderConfigIterator[_ProviderConfigT]': """Retrieves an iterator for provider configs. Returned iterator will iterate through all the provider configs in the Firebase project @@ -147,30 +167,39 @@ def iterate_all(self): in memory at a time. Returns: - iterator: An iterator of AuthProviderConfig instances. + iterator: An iterator of ProviderConfig instances. """ return _ProviderConfigIterator(self) -class _ListOIDCProviderConfigsPage(ListProviderConfigsPage): - +class _ListOIDCProviderConfigsPage(ListProviderConfigsPage[OIDCProviderConfig]): @property - def provider_configs(self): - return [OIDCProviderConfig(data) for data in self._current.get('oauthIdpConfigs', [])] + def provider_configs(self) -> list[OIDCProviderConfig]: + return [ + OIDCProviderConfig(data) + for data in cast( + list[dict[str, Any]], + self._current.get('oauthIdpConfigs', []), + ) + ] -class _ListSAMLProviderConfigsPage(ListProviderConfigsPage): - +class _ListSAMLProviderConfigsPage(ListProviderConfigsPage[SAMLProviderConfig]): @property - def provider_configs(self): - return [SAMLProviderConfig(data) for data in self._current.get('inboundSamlConfigs', [])] - + def provider_configs(self) -> list[SAMLProviderConfig]: + return [ + SAMLProviderConfig(data) + for data in cast( + list[dict[str, Any]], + self._current.get('inboundSamlConfigs', []), + ) + ] -class _ProviderConfigIterator(_auth_utils.PageIterator): +class _ProviderConfigIterator(_auth_utils.PageIterator[ListProviderConfigsPage[_ProviderConfigT]]): @property - def items(self): - return self._current_page.provider_configs + def items(self) -> list[_ProviderConfigT]: + return self._current_page.provider_configs if self._current_page else [] class ProviderConfigClient: @@ -178,24 +207,38 @@ class ProviderConfigClient: PROVIDER_CONFIG_URL = 'https://identitytoolkit.googleapis.com/v2' - def __init__(self, http_client, project_id, tenant_id=None, url_override=None): + def __init__( + self, + http_client: _http_client.HttpClient[dict[str, Any]], + project_id: str, + tenant_id: Optional[str] = None, + url_override: Optional[str] = None, + ) -> None: self.http_client = http_client url_prefix = url_override or self.PROVIDER_CONFIG_URL self.base_url = f'{url_prefix}/projects/{project_id}' if tenant_id: self.base_url += f'/tenants/{tenant_id}' - def get_oidc_provider_config(self, provider_id): + def get_oidc_provider_config(self, provider_id: str) -> OIDCProviderConfig: _validate_oidc_provider_id(provider_id) body = self._make_request('get', f'/oauthIdpConfigs/{provider_id}') return OIDCProviderConfig(body) def create_oidc_provider_config( - self, provider_id, client_id, issuer, display_name=None, enabled=None, - client_secret=None, id_token_response_type=None, code_response_type=None): + self, + provider_id: str, + client_id: str, + issuer: str, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + client_secret: Optional[str] = None, + id_token_response_type: Optional[bool] = None, + code_response_type: Optional[bool] = None, + ) -> OIDCProviderConfig: """Creates a new OIDC provider config from the given parameters.""" _validate_oidc_provider_id(provider_id) - req = { + req: dict[str, Any] = { 'clientId': _validate_non_empty_string(client_id, 'client_id'), 'issuer': _validate_url(issuer, 'issuer'), } @@ -204,7 +247,7 @@ def create_oidc_provider_config( if enabled is not None: req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') - response_type = {} + response_type: dict[str, Any] = {} if id_token_response_type is False and code_response_type is False: raise ValueError('At least one response type must be returned.') if id_token_response_type is not None: @@ -223,12 +266,19 @@ def create_oidc_provider_config( return OIDCProviderConfig(body) def update_oidc_provider_config( - self, provider_id, client_id=None, issuer=None, display_name=None, - enabled=None, client_secret=None, id_token_response_type=None, - code_response_type=None): + self, + provider_id: str, + client_id: Optional[str] = None, + issuer: Optional[str] = None, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + client_secret: Optional[str] = None, + id_token_response_type: Optional[bool] = None, + code_response_type: Optional[bool] = None, + ) -> OIDCProviderConfig: """Updates an existing OIDC provider config with the given parameters.""" _validate_oidc_provider_id(provider_id) - req = {} + req: dict[str, Any] = {} if display_name is not None: if display_name == _user_mgt.DELETE_ATTRIBUTE: req['displayName'] = None @@ -264,28 +314,44 @@ def update_oidc_provider_config( body = self._make_request('patch', url, json=req, params=params) return OIDCProviderConfig(body) - def delete_oidc_provider_config(self, provider_id): + def delete_oidc_provider_config(self, provider_id: str) -> None: _validate_oidc_provider_id(provider_id) self._make_request('delete', f'/oauthIdpConfigs/{provider_id}') - def list_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def list_oidc_provider_configs( + self, + page_token: Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> _ListOIDCProviderConfigsPage: return _ListOIDCProviderConfigsPage( self._fetch_oidc_provider_configs, page_token, max_results) - def _fetch_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def _fetch_oidc_provider_configs( + self, + page_token: Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> dict[str, Any]: return self._fetch_provider_configs('/oauthIdpConfigs', page_token, max_results) - def get_saml_provider_config(self, provider_id): + def get_saml_provider_config(self, provider_id: str) -> SAMLProviderConfig: _validate_saml_provider_id(provider_id) body = self._make_request('get', f'/inboundSamlConfigs/{provider_id}') return SAMLProviderConfig(body) def create_saml_provider_config( - self, provider_id, idp_entity_id, sso_url, x509_certificates, - rp_entity_id, callback_url, display_name=None, enabled=None): + self, + provider_id: str, + idp_entity_id: str, + sso_url: str, + x509_certificates: list[str], + rp_entity_id: str, + callback_url: str, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + ) -> SAMLProviderConfig: """Creates a new SAML provider config from the given parameters.""" _validate_saml_provider_id(provider_id) - req = { + req: dict[str, Any] = { 'idpConfig': { 'idpEntityId': _validate_non_empty_string(idp_entity_id, 'idp_entity_id'), 'ssoUrl': _validate_url(sso_url, 'sso_url'), @@ -306,11 +372,19 @@ def create_saml_provider_config( return SAMLProviderConfig(body) def update_saml_provider_config( - self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, - rp_entity_id=None, callback_url=None, display_name=None, enabled=None): + self, + provider_id: str, + idp_entity_id: Optional[str] = None, + sso_url: Optional[str] = None, + x509_certificates: Optional[list[str]]=None, + rp_entity_id: Optional[str] = None, + callback_url: Optional[str] = None, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + ) -> SAMLProviderConfig: """Updates an existing SAML provider config with the given parameters.""" _validate_saml_provider_id(provider_id) - idp_config = {} + idp_config: dict[str, Any] = {} if idp_entity_id is not None: idp_config['idpEntityId'] = _validate_non_empty_string(idp_entity_id, 'idp_entity_id') if sso_url is not None: @@ -318,13 +392,13 @@ def update_saml_provider_config( if x509_certificates is not None: idp_config['idpCertificates'] = _validate_x509_certificates(x509_certificates) - sp_config = {} + sp_config: dict[str, Any] = {} if rp_entity_id is not None: sp_config['spEntityId'] = _validate_non_empty_string(rp_entity_id, 'rp_entity_id') if callback_url is not None: sp_config['callbackUri'] = _validate_url(callback_url, 'callback_url') - req = {} + req: dict[str, Any] = {} if display_name is not None: if display_name == _user_mgt.DELETE_ATTRIBUTE: req['displayName'] = None @@ -346,18 +420,31 @@ def update_saml_provider_config( body = self._make_request('patch', url, json=req, params=params) return SAMLProviderConfig(body) - def delete_saml_provider_config(self, provider_id): + def delete_saml_provider_config(self, provider_id: str) -> None: _validate_saml_provider_id(provider_id) self._make_request('delete', f'/inboundSamlConfigs/{provider_id}') - def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def list_saml_provider_configs( + self, + page_token: Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> _ListSAMLProviderConfigsPage: return _ListSAMLProviderConfigsPage( self._fetch_saml_provider_configs, page_token, max_results) - def _fetch_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def _fetch_saml_provider_configs( + self, + page_token: Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> dict[str, Any]: return self._fetch_provider_configs('/inboundSamlConfigs', page_token, max_results) - def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def _fetch_provider_configs( + self, + path: str, + page_token: Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> dict[str, Any]: """Fetches a page of auth provider configs""" if page_token is not None: if not isinstance(page_token, str) or not page_token: @@ -374,7 +461,7 @@ def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CO params += f'&pageToken={page_token}' return self._make_request('get', path, params=params) - def _make_request(self, method, path, **kwargs): + def _make_request(self, method: str, path: str, **kwargs: Any) -> dict[str, Any]: url = f'{self.base_url}{path}' try: return self.http_client.body(method, url, **kwargs) @@ -382,7 +469,7 @@ def _make_request(self, method, path, **kwargs): raise _auth_utils.handle_auth_backend_error(error) -def _validate_oidc_provider_id(provider_id): +def _validate_oidc_provider_id(provider_id: Any) -> str: if not isinstance(provider_id, str): raise ValueError( f'Invalid OIDC provider ID: {provider_id}. Provider ID must be a non-empty string.') @@ -391,7 +478,7 @@ def _validate_oidc_provider_id(provider_id): return provider_id -def _validate_saml_provider_id(provider_id): +def _validate_saml_provider_id(provider_id: Any) -> str: if not isinstance(provider_id, str): raise ValueError( f'Invalid SAML provider ID: {provider_id}. Provider ID must be a non-empty string.') @@ -400,7 +487,7 @@ def _validate_saml_provider_id(provider_id): return provider_id -def _validate_non_empty_string(value, label): +def _validate_non_empty_string(value: Any, label: str) -> str: """Validates that the given value is a non-empty string.""" if not isinstance(value, str): raise ValueError(f'Invalid type for {label}: {value}.') @@ -409,7 +496,7 @@ def _validate_non_empty_string(value, label): return value -def _validate_url(url, label): +def _validate_url(url: Any, label: str) -> str: """Validates that the given value is a well-formed URL string.""" if not isinstance(url, str) or not url: raise ValueError( @@ -423,9 +510,10 @@ def _validate_url(url, label): raise ValueError(f'Malformed {label}: "{url}".') from exception -def _validate_x509_certificates(x509_certificates): +def _validate_x509_certificates(x509_certificates: Any) -> list[dict[str, str]]: if not isinstance(x509_certificates, list) or not x509_certificates: raise ValueError('x509_certificates must be a non-empty list.') + x509_certificates = cast(list[Any], x509_certificates) if not all(isinstance(cert, str) and cert for cert in x509_certificates): raise ValueError('x509_certificates must only contain non-empty strings.') return [{'x509Certificate': cert} for cert in x509_certificates] diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 60d411822..e702ff8f2 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -17,22 +17,97 @@ import json import os import re +from collections.abc import Callable, Iterator, Sequence +from typing import ( + Any, + Generic, + Literal, + Optional, + Protocol, + Union, + cast, + overload, +) from urllib import parse +import httpx +import requests +from typing_extensions import Self, TypeVar + from firebase_admin import exceptions from firebase_admin import _utils +__all__ = ( + 'EMULATOR_HOST_ENV_VAR', + 'MAX_CLAIMS_PAYLOAD_SIZE', + 'RESERVED_CLAIMS', + 'VALID_EMAIL_ACTION_TYPES', + 'ConfigurationNotFoundError', + 'EmailAlreadyExistsError', + 'EmailNotFoundError', + 'InsufficientPermissionError', + 'InvalidDynamicLinkDomainError', + 'InvalidIdTokenError', + 'PhoneNumberAlreadyExistsError', + 'ResetPasswordExceedLimitError', + 'TenantNotFoundError', + 'TenantIdMismatchError', + 'TooManyAttemptsTryLaterError', + 'UidAlreadyExistsError', + 'UnexpectedResponseError', + 'UserDisabledError', + 'UserNotFoundError', + 'PageIterator', + 'build_update_mask', + 'get_emulator_host', + 'handle_auth_backend_error', + 'is_emulated', + 'validate_action_type', + 'validate_boolean', + 'validate_bytes', + 'validate_custom_claims', + 'validate_display_name', + 'validate_email', + 'validate_int', + 'validate_password', + 'validate_phone', + 'validate_photo_url', + 'validate_provider_id', + 'validate_provider_ids', + 'validate_provider_uid', + 'validate_string', + 'validate_timestamp', + 'validate_uid', +) + +_PageT = TypeVar('_PageT', bound='_Page') +_ErrorT = TypeVar( + '_ErrorT', bound=exceptions.FirebaseError, default=exceptions.FirebaseError +) + +_EmailActionType = Literal[ + 'VERIFY_EMAIL', + 'EMAIL_SIGNIN', + 'PASSWORD_RESET', +] EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' MAX_CLAIMS_PAYLOAD_SIZE = 1000 -RESERVED_CLAIMS = set([ +RESERVED_CLAIMS = { 'acr', 'amr', 'at_hash', 'aud', 'auth_time', 'azp', 'cnf', 'c_hash', 'exp', 'iat', 'iss', 'jti', 'nbf', 'nonce', 'sub', 'firebase', -]) -VALID_EMAIL_ACTION_TYPES = set(['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET']) +} +VALID_EMAIL_ACTION_TYPES = {'VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET'} + + +class _Page(Protocol): + @property + def has_next_page(self) -> bool: ... + + def get_next_page(self) -> Optional[Self]: ... -class PageIterator: +class PageIterator(Generic[_PageT]): """An iterator that allows iterating over a sequence of items, one at a time. This implementation loads a page of items into memory, and iterates on them. When the whole @@ -40,21 +115,21 @@ class PageIterator: of entries in memory. """ - def __init__(self, current_page): + def __init__(self, current_page: _PageT) -> None: if not current_page: raise ValueError('Current page must not be None.') - self._current_page = current_page - self._iter = None + self._current_page: Optional[_PageT] = current_page + self._iter: Optional[Iterator[_PageT]] = None - def __next__(self): + def __next__(self) -> _PageT: if self._iter is None: self._iter = iter(self.items) try: return next(self._iter) except StopIteration: - if self._current_page.has_next_page: + if self._current_page and self._current_page.has_next_page: self._current_page = self._current_page.get_next_page() self._iter = iter(self.items) @@ -62,15 +137,15 @@ def __next__(self): raise - def __iter__(self): + def __iter__(self) -> Iterator[_PageT]: return self @property - def items(self): + def items(self) -> Sequence[Any]: raise NotImplementedError -def get_emulator_host(): +def get_emulator_host() -> str: emulator_host = os.getenv(EMULATOR_HOST_ENV_VAR, '') if emulator_host and '//' in emulator_host: raise ValueError( @@ -79,11 +154,15 @@ def get_emulator_host(): return emulator_host -def is_emulated(): +def is_emulated() -> bool: return get_emulator_host() != '' -def validate_uid(uid, required=False): +@overload +def validate_uid(uid: Optional[Any], required: Literal[True]) -> str: ... +@overload +def validate_uid(uid: Optional[Any], required: bool = False) -> Optional[str]: ... +def validate_uid(uid: Optional[Any], required: bool = False) -> Optional[str]: if uid is None and not required: return None if not isinstance(uid, str) or not uid or len(uid) > 128: @@ -92,7 +171,12 @@ def validate_uid(uid, required=False): 'characters.') return uid -def validate_email(email, required=False): + +@overload +def validate_email(email: Optional[Any], required: Literal[True]) -> str: ... +@overload +def validate_email(email: Optional[Any], required: bool = False) -> Optional[str]: ... +def validate_email(email: Optional[Any], required: bool = False) -> Optional[str]: if email is None and not required: return None if not isinstance(email, str) or not email: @@ -103,7 +187,12 @@ def validate_email(email, required=False): raise ValueError(f'Malformed email address string: "{email}".') return email -def validate_phone(phone, required=False): + +@overload +def validate_phone(phone: Optional[Any], required: Literal[True]) -> str: ... +@overload +def validate_phone(phone: Optional[Any], required: bool = False) -> Optional[str]: ... +def validate_phone(phone: Optional[Any], required: bool = False) -> Optional[str]: """Validates the specified phone number. Phone number vlidation is very lax here. Backend will enforce E.164 spec compliance, and @@ -121,7 +210,14 @@ def validate_phone(phone, required=False): 'compliant identifier.') return phone -def validate_password(password, required=False): + +@overload +def validate_password(password: Optional[Any], required: Literal[True]) -> str: ... +@overload +def validate_password( + password: Optional[Any], required: bool = False +) -> Optional[str]: ... +def validate_password(password: Optional[Any], required: bool = False) -> Optional[str]: if password is None and not required: return None if not isinstance(password, str) or len(password) < 6: @@ -129,14 +225,36 @@ def validate_password(password, required=False): 'Invalid password string. Password must be a string at least 6 characters long.') return password -def validate_bytes(value, label, required=False): + +@overload +def validate_bytes( + value: Optional[Any], label: Any, required: Literal[True] +) -> bytes: ... +@overload +def validate_bytes( + value: Optional[Any], label: Any, required: bool = False +) -> Optional[bytes]: ... +def validate_bytes( + value: Optional[Any], label: Any, required: bool = False +) -> Optional[bytes]: if value is None and not required: return None if not isinstance(value, bytes) or not value: raise ValueError(f'{label} must be a non-empty byte sequence.') return value -def validate_display_name(display_name, required=False): + +@overload +def validate_display_name( + display_name: Optional[Any], required: Literal[True] +) -> str: ... +@overload +def validate_display_name( + display_name: Optional[Any], required: bool = False +) -> Optional[str]: ... +def validate_display_name( + display_name: Optional[Any], required: bool = False +) -> Optional[str]: if display_name is None and not required: return None if not isinstance(display_name, str) or not display_name: @@ -145,7 +263,18 @@ def validate_display_name(display_name, required=False): 'string.') return display_name -def validate_provider_id(provider_id, required=True): + +@overload +def validate_provider_id( + provider_id: Optional[Any], required: Literal[True] +) -> str: ... +@overload +def validate_provider_id( + provider_id: Optional[Any], required: bool = True +) -> Optional[str]: ... +def validate_provider_id( + provider_id: Optional[Any], required: bool = True +) -> Optional[str]: if provider_id is None and not required: return None if not isinstance(provider_id, str) or not provider_id: @@ -153,7 +282,18 @@ def validate_provider_id(provider_id, required=True): f'Invalid provider ID: "{provider_id}". Provider ID must be a non-empty string.') return provider_id -def validate_provider_uid(provider_uid, required=True): + +@overload +def validate_provider_uid( + provider_uid: Optional[Any], required: Literal[True] = True +) -> str: ... +@overload +def validate_provider_uid( + provider_uid: Optional[Any], required: bool = True +) -> Optional[str]: ... +def validate_provider_uid( + provider_uid: Optional[Any], required: bool = True +) -> Optional[str]: if provider_uid is None and not required: return None if not isinstance(provider_uid, str) or not provider_uid: @@ -161,7 +301,16 @@ def validate_provider_uid(provider_uid, required=True): f'Invalid provider UID: "{provider_uid}". Provider UID must be a non-empty string.') return provider_uid -def validate_photo_url(photo_url, required=False): + +@overload +def validate_photo_url(photo_url: Optional[Any], required: Literal[True]) -> str: ... +@overload +def validate_photo_url( + photo_url: Optional[Any], required: bool = False +) -> Optional[str]: ... +def validate_photo_url( + photo_url: Optional[Any], required: bool = False +) -> Optional[str]: """Parses and validates the given URL string.""" if photo_url is None and not required: return None @@ -176,14 +325,31 @@ def validate_photo_url(photo_url, required=False): except Exception as err: raise ValueError(f'Malformed photo URL: "{photo_url}".') from err -def validate_timestamp(timestamp, label, required=False): + +@overload +def validate_timestamp( + timestamp: Optional[Any], + label: Any, + required: Literal[True], +) -> int: ... +@overload +def validate_timestamp( + timestamp: Optional[Any], + label: Any, + required: bool = False, +) -> Optional[int]: ... +def validate_timestamp( + timestamp: Optional[Any], + label: Any, + required: bool = False, +) -> Optional[int]: """Validates the given timestamp value. Timestamps must be positive integers.""" if timestamp is None and not required: return None if isinstance(timestamp, bool): raise ValueError('Boolean value specified as timestamp.') try: - timestamp_int = int(timestamp) + timestamp_int = int(timestamp) # pyright: ignore[reportArgumentType] except TypeError as err: raise ValueError(f'Invalid type for timestamp value: {timestamp}.') from err if timestamp_int != timestamp: @@ -192,7 +358,13 @@ def validate_timestamp(timestamp, label, required=False): raise ValueError(f'{label} timestamp must be a positive interger.') return timestamp_int -def validate_int(value, label, low=None, high=None): + +def validate_int( + value: Any, + label: Any, + low: Optional[int] = None, + high: Optional[int] = None, +) -> int: """Validates that the given value represents an integer. There are several ways to represent an integer in Python (e.g. 2, 2L, 2.0). This method allows @@ -215,19 +387,28 @@ def validate_int(value, label, low=None, high=None): raise ValueError(f'{label} must not be larger than {high}.') return val_int -def validate_string(value, label): + +def validate_string(value: Any, label: Any) -> str: """Validates that the given value is a string.""" if not isinstance(value, str): raise ValueError(f'Invalid type for {label}: {value}.') return value -def validate_boolean(value, label): + +def validate_boolean(value: Any, label: Any) -> bool: """Validates that the given value is a boolean.""" if not isinstance(value, bool): raise ValueError(f'Invalid type for {label}: {value}.') return value -def validate_custom_claims(custom_claims, required=False): + +@overload +def validate_custom_claims(custom_claims: Any, required: Literal[True]) -> str: ... +@overload +def validate_custom_claims( + custom_claims: Any, required: bool = False +) -> Optional[str]: ... +def validate_custom_claims(custom_claims: Any, required: bool = False) -> Optional[str]: """Validates the specified custom claims. Custom claims must be specified as a JSON string. The string must not exceed 1000 @@ -255,14 +436,18 @@ def validate_custom_claims(custom_claims, required=False): f'Claim "{invalid_claims.pop()}" is reserved, and must not be set.') return claims_str -def validate_action_type(action_type): + +def validate_action_type( + action_type: Any, +) -> Literal['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET']: if action_type not in VALID_EMAIL_ACTION_TYPES: raise ValueError( f'Invalid action type provided action_type: {action_type}. Valid values are ' f'{", ".join(VALID_EMAIL_ACTION_TYPES)}') return action_type -def validate_provider_ids(provider_ids, required=False): + +def validate_provider_ids(provider_ids: Any, required: bool = False) -> list[str]: if not provider_ids: if required: raise ValueError('Invalid provider IDs. Provider ids should be provided') @@ -271,9 +456,10 @@ def validate_provider_ids(provider_ids, required=False): validate_provider_id(provider_id, True) return provider_ids -def build_update_mask(params): + +def build_update_mask(params: dict[str, Any]) -> list[str]: """Creates an update mask list from the given dictionary.""" - mask = [] + mask: list[str] = [] for key, value in params.items(): if isinstance(value, dict): child_mask = build_update_mask(value) @@ -290,8 +476,13 @@ class UidAlreadyExistsError(exceptions.AlreadyExistsError): default_message = 'The user with the provided uid already exists' - def __init__(self, message, cause, http_response): - exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception], + http_response: Optional[Union[httpx.Response, requests.Response]], + ) -> None: + super().__init__(message, cause, http_response) class EmailAlreadyExistsError(exceptions.AlreadyExistsError): @@ -299,8 +490,13 @@ class EmailAlreadyExistsError(exceptions.AlreadyExistsError): default_message = 'The user with the provided email already exists' - def __init__(self, message, cause, http_response): - exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception], + http_response: Optional[Union[httpx.Response, requests.Response]], + ) -> None: + super().__init__(message, cause, http_response) class InsufficientPermissionError(exceptions.PermissionDeniedError): @@ -311,8 +507,13 @@ class InsufficientPermissionError(exceptions.PermissionDeniedError): 'https://firebase.google.com/docs/admin/setup for details ' 'on how to initialize the Admin SDK with appropriate permissions') - def __init__(self, message, cause, http_response): - exceptions.PermissionDeniedError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception], + http_response: Optional[Union[httpx.Response, requests.Response]], + ) -> None: + super().__init__(message, cause, http_response) class InvalidDynamicLinkDomainError(exceptions.InvalidArgumentError): @@ -320,8 +521,13 @@ class InvalidDynamicLinkDomainError(exceptions.InvalidArgumentError): default_message = 'Dynamic link domain specified in ActionCodeSettings is not authorized' - def __init__(self, message, cause, http_response): - exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception], + http_response: Optional[Union[httpx.Response, requests.Response]], + ) -> None: + super().__init__(message, cause, http_response) class InvalidIdTokenError(exceptions.InvalidArgumentError): @@ -329,8 +535,13 @@ class InvalidIdTokenError(exceptions.InvalidArgumentError): default_message = 'The provided ID token is invalid' - def __init__(self, message, cause=None, http_response=None): - exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class PhoneNumberAlreadyExistsError(exceptions.AlreadyExistsError): @@ -338,15 +549,25 @@ class PhoneNumberAlreadyExistsError(exceptions.AlreadyExistsError): default_message = 'The user with the provided phone number already exists' - def __init__(self, message, cause, http_response): - exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception], + http_response: Optional[Union[httpx.Response, requests.Response]], + ) -> None: + super().__init__(message, cause, http_response) class UnexpectedResponseError(exceptions.UnknownError): """Backend service responded with an unexpected or malformed response.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.UnknownError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class UserNotFoundError(exceptions.NotFoundError): @@ -354,8 +575,13 @@ class UserNotFoundError(exceptions.NotFoundError): default_message = 'No user record found for the given identifier' - def __init__(self, message, cause=None, http_response=None): - exceptions.NotFoundError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class EmailNotFoundError(exceptions.NotFoundError): @@ -363,8 +589,13 @@ class EmailNotFoundError(exceptions.NotFoundError): default_message = 'No user record found for the given email' - def __init__(self, message, cause=None, http_response=None): - exceptions.NotFoundError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class TenantNotFoundError(exceptions.NotFoundError): @@ -372,15 +603,20 @@ class TenantNotFoundError(exceptions.NotFoundError): default_message = 'No tenant found for the given identifier' - def __init__(self, message, cause=None, http_response=None): - exceptions.NotFoundError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class TenantIdMismatchError(exceptions.InvalidArgumentError): """Missing or invalid tenant ID field in the given JWT.""" - def __init__(self, message): - exceptions.InvalidArgumentError.__init__(self, message) + def __init__(self, message: str) -> None: + super().__init__(message) class ConfigurationNotFoundError(exceptions.NotFoundError): @@ -388,8 +624,13 @@ class ConfigurationNotFoundError(exceptions.NotFoundError): default_message = 'No auth provider found for the given identifier' - def __init__(self, message, cause=None, http_response=None): - exceptions.NotFoundError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class UserDisabledError(exceptions.InvalidArgumentError): @@ -397,22 +638,37 @@ class UserDisabledError(exceptions.InvalidArgumentError): default_message = 'The user record is disabled' - def __init__(self, message, cause=None, http_response=None): - exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class TooManyAttemptsTryLaterError(exceptions.ResourceExhaustedError): """Rate limited because of too many attempts.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class ResetPasswordExceedLimitError(exceptions.ResourceExhaustedError): """Reset password emails exceeded their limits.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) _CODE_TO_EXC_TYPE = { @@ -432,7 +688,7 @@ def __init__(self, message, cause=None, http_response=None): } -def handle_auth_backend_error(error): +def handle_auth_backend_error(error: requests.RequestException) -> exceptions.FirebaseError: """Converts a requests error received from the Firebase Auth service into a FirebaseError.""" if error.response is None: return _utils.handle_requests_error(error) @@ -450,19 +706,26 @@ def handle_auth_backend_error(error): return exc_type(msg, cause=error, http_response=error.response) -def _parse_error_body(response): +def _parse_error_body( + response: requests.Response, +) -> tuple[Optional[str], Optional[str]]: """Parses the given error response to extract Auth error code and message.""" - error_dict = {} + parsed_body = None try: parsed_body = response.json() - if isinstance(parsed_body, dict): - error_dict = parsed_body.get('error', {}) except ValueError: pass + if not isinstance(parsed_body, dict): + return None, None + # Auth error response format: {"error": {"message": "AUTH_ERROR_CODE: Optional text"}} - code = error_dict.get('message') if isinstance(error_dict, dict) else None - custom_message = None + parsed_body = cast(dict[str, Any], parsed_body) + error_dict = parsed_body.get('error', {}) + if not isinstance(error_dict, dict): + return None, None + error_dict = cast(dict[str, str], error_dict) + code, custom_message = error_dict.get('message'), None if code: separator = code.find(':') if separator != -1: @@ -472,8 +735,14 @@ def _parse_error_body(response): return code, custom_message -def _build_error_message(code, exc_type, custom_message): - default_message = exc_type.default_message if ( - exc_type and hasattr(exc_type, 'default_message')) else 'Error while calling Auth service' +def _build_error_message( + code: str, + exc_type: Optional[Callable[ + [str, Optional[Exception], Optional[requests.Response]], + exceptions.FirebaseError + ]], + custom_message: Optional[str], +) -> str: + default_message = getattr(exc_type, 'default_message', 'Error while calling Auth service') ext = f' {custom_message}' if custom_message else '' return f'{default_message} ({code}).{ext}' diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index 6d2582291..0ecc69cb5 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -17,25 +17,48 @@ This module provides utilities for making HTTP calls using the requests library. """ -from __future__ import annotations import logging -from typing import Any, Dict, Generator, Optional, Tuple, Union +from collections.abc import Generator, Iterable +from typing import TYPE_CHECKING, Any, Generic, Optional, Union + import httpx +import google.auth.transport.requests +import google.auth.credentials import requests.adapters -from requests.packages.urllib3.util import retry # pylint: disable=import-error -from google.auth import credentials -from google.auth import transport -from google.auth.transport import requests as google_auth_requests +import requests.structures +import typing_extensions +from firebase_admin import _retry from firebase_admin import _utils -from firebase_admin._retry import HttpxRetry, HttpxRetryTransport + +if TYPE_CHECKING: + from urllib3.util import retry + from _typeshed import SupportsKeysAndGetItem +else: + from requests.packages.urllib3.util import retry # pylint: disable=import-error + +__all__ = ( + 'DEFAULT_HTTPX_RETRY_CONFIG', + 'DEFAULT_RETRY_CONFIG', + 'DEFAULT_TIMEOUT_SECONDS', + 'METRICS_HEADERS', + 'GoogleAuthCredentialFlow', + 'HttpClient', + 'HttpxAsyncClient', + 'JsonHttpClient', +) logger = logging.getLogger(__name__) +_T = typing_extensions.TypeVar('_T', default=Any) + +_ANY_METHOD: dict[str, Any] = {} + if hasattr(retry.Retry.DEFAULT, 'allowed_methods'): - _ANY_METHOD = {'allowed_methods': None} + _ANY_METHOD['allowed_methods'] = None else: - _ANY_METHOD = {'method_whitelist': None} + _ANY_METHOD['method_whitelist'] = None + # Default retry configuration: Retries once on low-level connection and socket read errors. # Retries up to 4 times on HTTP 500 and 503 errors, with exponential backoff. Returns the # last response upon exhausting all retries. @@ -43,17 +66,16 @@ connect=1, read=1, status=4, status_forcelist=[500, 503], raise_on_status=False, backoff_factor=0.5, **_ANY_METHOD) -DEFAULT_HTTPX_RETRY_CONFIG = HttpxRetry( +DEFAULT_HTTPX_RETRY_CONFIG = _retry.HttpxRetry( max_retries=4, status_forcelist=[500, 503], backoff_factor=0.5) - DEFAULT_TIMEOUT_SECONDS = 120 METRICS_HEADERS = { 'x-goog-api-client': _utils.get_metrics_header(), } -class HttpClient: +class HttpClient(Generic[_T]): """Base HTTP client used to make HTTP calls. HttpClient maintains an HTTP session, and handles request authentication and retries if @@ -61,8 +83,17 @@ class HttpClient: """ def __init__( - self, credential=None, session=None, base_url='', headers=None, - retries=DEFAULT_RETRY_CONFIG, timeout=DEFAULT_TIMEOUT_SECONDS): + self, + credential: Optional[google.auth.credentials.Credentials] = None, + session: Optional[requests.Session] = None, + base_url: str = '', + headers: Optional[Union[ + 'SupportsKeysAndGetItem[str, Union[bytes, str]]', + Iterable[tuple[str, Union[bytes, str]]], + ]] = None, + retries: retry.Retry = DEFAULT_RETRY_CONFIG, + timeout: int = DEFAULT_TIMEOUT_SECONDS, + ) -> None: """Creates a new HttpClient instance from the provided arguments. If a credential is provided, initializes a new HTTP session authorized with it. If neither @@ -79,8 +110,9 @@ def __init__( timeout: HTTP timeout in seconds. Defaults to 120 seconds when not specified. Set to None to disable timeouts (optional). """ + self._session: Optional[requests.Session] if credential: - self._session = transport.requests.AuthorizedSession(credential) + self._session = google.auth.transport.requests.AuthorizedSession(credential) elif session: self._session = session else: @@ -95,21 +127,21 @@ def __init__( self._timeout = timeout @property - def session(self): + def session(self) -> Optional[requests.Session]: return self._session @property - def base_url(self): + def base_url(self) -> str: return self._base_url @property - def timeout(self): + def timeout(self) -> int: return self._timeout - def parse_body(self, resp): + def parse_body(self, resp: requests.Response) -> _T: raise NotImplementedError - def request(self, method, url, **kwargs): + def request(self, method: str, url: str, **kwargs: Any) -> requests.Response: """Makes an HTTP call using the Python requests library. This is the sole entry point to the requests library. All other helper methods in this @@ -132,51 +164,58 @@ class call this method to send HTTP requests out. Refer to if 'timeout' not in kwargs: kwargs['timeout'] = self.timeout kwargs.setdefault('headers', {}).update(METRICS_HEADERS) + # possible issue: _session can be None resp = self._session.request(method, self.base_url + url, **kwargs) resp.raise_for_status() return resp - def headers(self, method, url, **kwargs): + def headers(self, method: str, url: str, **kwargs: Any) -> 'requests.structures.CaseInsensitiveDict[str]': resp = self.request(method, url, **kwargs) return resp.headers - def body_and_response(self, method, url, **kwargs): + def body_and_response(self, method: str, url: str, **kwargs: Any) -> tuple[_T, requests.Response]: resp = self.request(method, url, **kwargs) return self.parse_body(resp), resp - def body(self, method, url, **kwargs): + def body(self, method: str, url: str, **kwargs: Any) -> _T: resp = self.request(method, url, **kwargs) return self.parse_body(resp) - def headers_and_body(self, method, url, **kwargs): + def headers_and_body( + self, + method: str, + url: str, + **kwargs: Any, + ) -> tuple[requests.structures.CaseInsensitiveDict[str], _T]: resp = self.request(method, url, **kwargs) return resp.headers, self.parse_body(resp) - def close(self): - self._session.close() - self._session = None + def close(self) -> None: + if self._session is not None: + self._session.close() + self._session = None -class JsonHttpClient(HttpClient): - """An HTTP client that parses response messages as JSON.""" - def __init__(self, **kwargs): - HttpClient.__init__(self, **kwargs) +class JsonHttpClient(HttpClient[dict[str, Any]]): + """An HTTP client that parses response messages as JSON.""" - def parse_body(self, resp): + def parse_body(self, resp: requests.Response) -> dict[str, Any]: return resp.json() + + class GoogleAuthCredentialFlow(httpx.Auth): """Google Auth Credential Auth Flow""" - def __init__(self, credential: credentials.Credentials): + def __init__(self, credential: google.auth.credentials.Credentials) -> None: self._credential = credential self._max_refresh_attempts = 2 self._refresh_status_codes = (401,) def apply_auth_headers( - self, - request: httpx.Request, - auth_request: google_auth_requests.Request - ) -> None: + self, + request: httpx.Request, + auth_request: google.auth.transport.requests.Request, + ) -> None: """A helper function that refreshes credentials if needed and mutates the request headers to contain access token and any other Google Auth headers.""" @@ -194,7 +233,7 @@ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Re _credential_refresh_attempt = 0 # Create a Google auth request object to be used for refreshing credentials - auth_request = google_auth_requests.Request() + auth_request = google.auth.transport.requests.Request() while True: # Copy original headers for each attempt @@ -237,20 +276,22 @@ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Re break # The last yielded response is automatically returned by httpx's auth flow. -class HttpxAsyncClient(): + +class HttpxAsyncClient: """Async HTTP client used to make HTTP/2 calls using HTTPX. HttpxAsyncClient maintains an async HTTPX client, handles request authentication, and retries if necessary. """ + def __init__( - self, - credential: Optional[credentials.Credentials] = None, - base_url: str = '', - headers: Optional[Union[httpx.Headers, Dict[str, str]]] = None, - retry_config: HttpxRetry = DEFAULT_HTTPX_RETRY_CONFIG, - timeout: int = DEFAULT_TIMEOUT_SECONDS, - http2: bool = True + self, + credential: Optional[google.auth.credentials.Credentials] = None, + base_url: str = '', + headers: Optional[Union[httpx.Headers, dict[str, str]]] = None, + retry_config: _retry.HttpxRetry = DEFAULT_HTTPX_RETRY_CONFIG, + timeout: int = DEFAULT_TIMEOUT_SECONDS, + http2: bool = True, ) -> None: """Creates a new HttpxAsyncClient instance from the provided arguments. @@ -274,8 +315,8 @@ def __init__( # Only set up retries on urls starting with 'http://' and 'https://' self._mounts = { - 'http://': HttpxRetryTransport(retry=self._retry_config, http2=http2), - 'https://': HttpxRetryTransport(retry=self._retry_config, http2=http2) + 'http://': _retry.HttpxRetryTransport(retry=self._retry_config, http2=http2), + 'https://': _retry.HttpxRetryTransport(retry=self._retry_config, http2=http2) } if credential: @@ -295,15 +336,15 @@ def __init__( ) @property - def base_url(self): + def base_url(self) -> str: return self._base_url @property - def timeout(self): + def timeout(self) -> int: return self._timeout @property - def async_client(self): + def async_client(self) -> httpx.AsyncClient: return self._async_client async def request(self, method: str, url: str, **kwargs: Any) -> httpx.Response: @@ -337,7 +378,11 @@ async def headers(self, method: str, url: str, **kwargs: Any) -> httpx.Headers: return resp.headers async def body_and_response( - self, method: str, url: str, **kwargs: Any) -> Tuple[Any, httpx.Response]: + self, + method: str, + url: str, + **kwargs: Any, + ) -> tuple[Any, httpx.Response]: resp = await self.request(method, url, **kwargs) return self.parse_body(resp), resp @@ -346,7 +391,11 @@ async def body(self, method: str, url: str, **kwargs: Any) -> Any: return self.parse_body(resp) async def headers_and_body( - self, method: str, url: str, **kwargs: Any) -> Tuple[httpx.Headers, Any]: + self, + method: str, + url: str, + **kwargs: Any, + ) -> tuple[httpx.Headers, Any]: resp = await self.request(method, url, **kwargs) return resp.headers, self.parse_body(resp) diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index 960a6d742..068015b56 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -19,9 +19,19 @@ import math import numbers import re +from typing import Any, Optional, TypeVar, Union, cast from firebase_admin import _messaging_utils +_K = TypeVar('_K') +_V = TypeVar('_V') + +__all__ = ( + 'Message', + 'MessageEncoder', + 'MulticastMessage', +) + class Message: """A message that can be sent via Firebase Cloud Messaging. @@ -35,7 +45,7 @@ class Message: notification: An instance of ``messaging.Notification`` (optional). android: An instance of ``messaging.AndroidConfig`` (optional). webpush: An instance of ``messaging.WebpushConfig`` (optional). - apns: An instance of ``messaging.ApnsConfig`` (optional). + apns: An instance of ``messaging.APNSConfig`` (optional). fcm_options: An instance of ``messaging.FCMOptions`` (optional). token: The registration token of the device to which the message should be sent (optional). topic: Name of the FCM topic to which the message should be sent (optional). Topic name @@ -43,8 +53,18 @@ class Message: condition: The FCM condition to which the message should be sent (optional). """ - def __init__(self, data=None, notification=None, android=None, webpush=None, apns=None, - fcm_options=None, token=None, topic=None, condition=None): + def __init__( + self, + data: Optional[dict[str, str]] = None, + notification: Optional[_messaging_utils.Notification] = None, + android: Optional[_messaging_utils.AndroidConfig] = None, + webpush: Optional[_messaging_utils.WebpushConfig] = None, + apns: Optional[_messaging_utils.APNSConfig] = None, + fcm_options: Optional[_messaging_utils.FCMOptions] = None, + token: Optional[str] = None, + topic: Optional[str] = None, + condition: Optional[str] = None, + ) -> None: self.data = data self.notification = notification self.android = android @@ -55,7 +75,7 @@ def __init__(self, data=None, notification=None, android=None, webpush=None, apn self.topic = topic self.condition = condition - def __str__(self): + def __str__(self) -> str: return json.dumps(self, cls=MessageEncoder, sort_keys=True) @@ -69,11 +89,19 @@ class MulticastMessage: notification: An instance of ``messaging.Notification`` (optional). android: An instance of ``messaging.AndroidConfig`` (optional). webpush: An instance of ``messaging.WebpushConfig`` (optional). - apns: An instance of ``messaging.ApnsConfig`` (optional). + apns: An instance of ``messaging.APNSConfig`` (optional). fcm_options: An instance of ``messaging.FCMOptions`` (optional). """ - def __init__(self, tokens, data=None, notification=None, android=None, webpush=None, apns=None, - fcm_options=None): + def __init__( + self, + tokens: list[str], + data: Optional[dict[str, str]] = None, + notification: Optional[_messaging_utils.Notification] = None, + android: Optional[_messaging_utils.AndroidConfig] = None, + webpush: Optional[_messaging_utils.WebpushConfig] = None, + apns: Optional[_messaging_utils.APNSConfig] = None, + fcm_options: Optional[_messaging_utils.FCMOptions] = None, + ) -> None: _Validators.check_string_list('MulticastMessage.tokens', tokens) if len(tokens) > 500: raise ValueError('MulticastMessage.tokens must not contain more than 500 tokens.') @@ -93,7 +121,7 @@ class _Validators: """ @classmethod - def check_string(cls, label, value, non_empty=False): + def check_string(cls, label: str, value: Any, non_empty: bool = False) -> Optional[str]: """Checks if the given value is a string.""" if value is None: return None @@ -106,7 +134,7 @@ def check_string(cls, label, value, non_empty=False): return value @classmethod - def check_number(cls, label, value): + def check_number(cls, label: str, value: Any) -> Optional[numbers.Number]: if value is None: return None if not isinstance(value, numbers.Number): @@ -114,12 +142,17 @@ def check_number(cls, label, value): return value @classmethod - def check_string_dict(cls, label, value): + def check_string_dict( + cls, + label: str, + value: Optional[Any], + ) -> Optional[dict[str, str]]: """Checks if the given value is a dictionary comprised only of string keys and values.""" if value is None or value == {}: return None if not isinstance(value, dict): raise ValueError(f'{label} must be a dictionary.') + value = cast(dict[Any, Any], value) non_str = [k for k in value if not isinstance(k, str)] if non_str: raise ValueError(f'{label} must not contain non-string keys.') @@ -129,31 +162,41 @@ def check_string_dict(cls, label, value): return value @classmethod - def check_string_list(cls, label, value): + def check_string_list( + cls, + label: str, + value: Optional[Any], + ) -> Optional[list[str]]: """Checks if the given value is a list comprised only of strings.""" if value is None or value == []: return None if not isinstance(value, list): raise ValueError(f'{label} must be a list of strings.') + value = cast(list[Any], value) non_str = [k for k in value if not isinstance(k, str)] if non_str: raise ValueError(f'{label} must not contain non-string values.') return value @classmethod - def check_number_list(cls, label, value): + def check_number_list( + cls, + label: str, + value: Optional[Any], + ) -> Optional[list[numbers.Number]]: """Checks if the given value is a list comprised only of numbers.""" if value is None or value == []: return None if not isinstance(value, list): raise ValueError(f'{label} must be a list of numbers.') + value = cast(list[Any], value) non_number = [k for k in value if not isinstance(k, numbers.Number)] if non_number: raise ValueError(f'{label} must not contain non-number values.') return value @classmethod - def check_analytics_label(cls, label, value): + def check_analytics_label(cls, label: str, value: Optional[Any]) -> Optional[str]: """Checks if the given value is a valid analytics label.""" value = _Validators.check_string(label, value) if value is not None and not re.match(r'^[a-zA-Z0-9-_.~%]{1,50}$', value): @@ -161,7 +204,7 @@ def check_analytics_label(cls, label, value): return value @classmethod - def check_boolean(cls, label, value): + def check_boolean(cls, label: str, value: Optional[Any]) -> Optional[bool]: """Checks if the given value is boolean.""" if value is None: return None @@ -170,7 +213,7 @@ def check_boolean(cls, label, value): return value @classmethod - def check_datetime(cls, label, value): + def check_datetime(cls, label: str, value: Optional[Any]) -> Optional[datetime.datetime]: """Checks if the given value is a datetime.""" if value is None: return None @@ -183,17 +226,20 @@ class MessageEncoder(json.JSONEncoder): """A custom ``JSONEncoder`` implementation for serializing Message instances into JSON.""" @classmethod - def remove_null_values(cls, dict_value): - return {k: v for k, v in dict_value.items() if v not in [None, [], {}]} + def remove_null_values(cls, dict_value: dict[_K, Optional[_V]]) -> dict[_K, _V]: + return {k: cast(_V, v) for k, v in dict_value.items() if v not in [None, [], {}]} @classmethod - def encode_android(cls, android): + def encode_android( + cls, + android: Optional[_messaging_utils.AndroidConfig], + ) -> Optional[dict[str, Any]]: """Encodes an ``AndroidConfig`` instance into JSON.""" if android is None: return None if not isinstance(android, _messaging_utils.AndroidConfig): raise ValueError('Message.android must be an instance of AndroidConfig class.') - result = { + result: dict[str, Any] = { 'collapse_key': _Validators.check_string( 'AndroidConfig.collapse_key', android.collapse_key), 'data': _Validators.check_string_dict( @@ -215,7 +261,10 @@ def encode_android(cls, android): return result @classmethod - def encode_android_fcm_options(cls, fcm_options): + def encode_android_fcm_options( + cls, + fcm_options: Optional[_messaging_utils.AndroidFCMOptions], + ) -> Optional[dict[str, str]]: """Encodes an ``AndroidFCMOptions`` instance into JSON.""" if fcm_options is None: return None @@ -230,12 +279,12 @@ def encode_android_fcm_options(cls, fcm_options): return result @classmethod - def encode_ttl(cls, ttl): + def encode_ttl(cls, ttl: Optional[Union[numbers.Real, datetime.timedelta]]) -> Optional[str]: """Encodes an ``AndroidConfig`` ``TTL`` duration into a string.""" if ttl is None: return None - if isinstance(ttl, numbers.Number): - ttl = datetime.timedelta(seconds=ttl) + if isinstance(ttl, numbers.Real): + ttl = datetime.timedelta(seconds=float(ttl)) if not isinstance(ttl, datetime.timedelta): raise ValueError('AndroidConfig.ttl must be a duration in seconds or an instance of ' 'datetime.timedelta.') @@ -249,12 +298,16 @@ def encode_ttl(cls, ttl): return f'{seconds}s' @classmethod - def encode_milliseconds(cls, label, msec): + def encode_milliseconds( + cls, + label: str, + msec: Optional[Union[numbers.Real, datetime.timedelta]], + ) -> Optional[str]: """Encodes a duration in milliseconds into a string.""" if msec is None: return None - if isinstance(msec, numbers.Number): - msec = datetime.timedelta(milliseconds=msec) + if isinstance(msec, numbers.Real): + msec = datetime.timedelta(milliseconds=float(msec)) if not isinstance(msec, datetime.timedelta): raise ValueError( f'{label} must be a duration in milliseconds or an instance of datetime.timedelta.') @@ -268,14 +321,17 @@ def encode_milliseconds(cls, label, msec): return f'{seconds}s' @classmethod - def encode_android_notification(cls, notification): + def encode_android_notification( + cls, + notification: Optional[_messaging_utils.AndroidNotification], + ) -> Optional[dict[str, Any]]: """Encodes an ``AndroidNotification`` instance into JSON.""" if notification is None: return None if not isinstance(notification, _messaging_utils.AndroidNotification): raise ValueError('AndroidConfig.notification must be an instance of ' 'AndroidNotification class.') - result = { + result: dict[str, Any] = { 'body': _Validators.check_string( 'AndroidNotification.body', notification.body), 'body_loc_args': _Validators.check_string_list( @@ -324,7 +380,7 @@ def encode_android_notification(cls, notification): 'AndroidNotification.proxy', notification.proxy, non_empty=True) } result = cls.remove_null_values(result) - color = result.get('color') + color: Optional[str] = result.get('color') if color and not re.match(r'^#[0-9a-fA-F]{6}$', color): raise ValueError( 'AndroidNotification.color must be in the form #RRGGBB.') @@ -335,7 +391,7 @@ def encode_android_notification(cls, notification): raise ValueError( 'AndroidNotification.title_loc_key is required when specifying title_loc_args.') - event_time = result.get('event_time') + event_time: Optional[datetime.datetime] = result.get('event_time') if event_time: # if the datetime instance is not naive (tzinfo is present), convert to UTC # otherwise (tzinfo is None) assume the datetime instance is already in UTC @@ -357,9 +413,9 @@ def encode_android_notification(cls, notification): 'AndroidNotification.visibility must be "private", "public" or "secret".') result['visibility'] = visibility.upper() - vibrate_timings_millis = result.get('vibrate_timings') + vibrate_timings_millis: Optional[list[Any]] = result.get('vibrate_timings') if vibrate_timings_millis: - vibrate_timing_strings = [] + vibrate_timing_strings: list[Optional[str]] = [] for msec in vibrate_timings_millis: formated_string = cls.encode_milliseconds( 'AndroidNotification.vibrate_timings_millis', msec) @@ -375,14 +431,17 @@ def encode_android_notification(cls, notification): return result @classmethod - def encode_light_settings(cls, light_settings): + def encode_light_settings( + cls, + light_settings: Optional[_messaging_utils.LightSettings], + ) -> Optional[dict[str, Any]]: """Encodes a ``LightSettings`` instance into JSON.""" if light_settings is None: return None if not isinstance(light_settings, _messaging_utils.LightSettings): raise ValueError( 'AndroidNotification.light_settings must be an instance of LightSettings class.') - result = { + result: dict[str, Any] = { 'color': _Validators.check_string( 'LightSettings.color', light_settings.color, non_empty=True), 'light_on_duration': cls.encode_milliseconds( @@ -416,7 +475,10 @@ def encode_light_settings(cls, light_settings): return result @classmethod - def encode_webpush(cls, webpush): + def encode_webpush( + cls, + webpush: Optional[_messaging_utils.WebpushConfig], + ) -> Optional[dict[str, Any]]: """Encodes a ``WebpushConfig`` instance into JSON.""" if webpush is None: return None @@ -433,14 +495,17 @@ def encode_webpush(cls, webpush): return cls.remove_null_values(result) @classmethod - def encode_webpush_notification(cls, notification): + def encode_webpush_notification( + cls, + notification: Optional[_messaging_utils.WebpushNotification], + ) -> Optional[dict[str, Any]]: """Encodes a ``WebpushNotification`` instance into JSON.""" if notification is None: return None if not isinstance(notification, _messaging_utils.WebpushNotification): raise ValueError('WebpushConfig.notification must be an instance of ' 'WebpushNotification class.') - result = { + result: dict[str, Any] = { 'actions': cls.encode_webpush_notification_actions(notification.actions), 'badge': _Validators.check_string( 'WebpushNotification.badge', notification.badge), @@ -480,14 +545,17 @@ def encode_webpush_notification(cls, notification): return cls.remove_null_values(result) @classmethod - def encode_webpush_notification_actions(cls, actions): + def encode_webpush_notification_actions( + cls, + actions: Optional[list[_messaging_utils.WebpushNotificationAction]], + ) -> Optional[list[dict[str, str]]]: """Encodes a list of ``WebpushNotificationActions`` into JSON.""" if actions is None: return None if not isinstance(actions, list): raise ValueError('WebpushConfig.notification.actions must be a list of ' 'WebpushNotificationAction instances.') - results = [] + results: list[dict[str, str]] = [] for action in actions: if not isinstance(action, _messaging_utils.WebpushNotificationAction): raise ValueError('WebpushConfig.notification.actions must be a list of ' @@ -504,7 +572,10 @@ def encode_webpush_notification_actions(cls, actions): return results @classmethod - def encode_webpush_fcm_options(cls, options): + def encode_webpush_fcm_options( + cls, + options: Optional[_messaging_utils.WebpushFCMOptions], + ) -> Optional[dict[str, str]]: """Encodes a ``WebpushFCMOptions`` instance into JSON.""" if options is None: return None @@ -518,13 +589,16 @@ def encode_webpush_fcm_options(cls, options): return result @classmethod - def encode_apns(cls, apns): + def encode_apns( + cls, + apns: Optional[_messaging_utils.APNSConfig], + ) -> Optional[dict[str, Any]]: """Encodes an ``APNSConfig`` instance into JSON.""" if apns is None: return None if not isinstance(apns, _messaging_utils.APNSConfig): raise ValueError('Message.apns must be an instance of APNSConfig class.') - result = { + result: dict[str, Any] = { 'headers': _Validators.check_string_dict( 'APNSConfig.headers', apns.headers), 'payload': cls.encode_apns_payload(apns.payload), @@ -535,13 +609,16 @@ def encode_apns(cls, apns): return cls.remove_null_values(result) @classmethod - def encode_apns_payload(cls, payload): + def encode_apns_payload( + cls, + payload: Optional[_messaging_utils.APNSPayload], + ) -> Optional[dict[str, Any]]: """Encodes an ``APNSPayload`` instance into JSON.""" if payload is None: return None if not isinstance(payload, _messaging_utils.APNSPayload): raise ValueError('APNSConfig.payload must be an instance of APNSPayload class.') - result = { + result: dict[str, Any] = { 'aps': cls.encode_aps(payload.aps) } for key, value in payload.custom_data.items(): @@ -549,7 +626,10 @@ def encode_apns_payload(cls, payload): return cls.remove_null_values(result) @classmethod - def encode_apns_fcm_options(cls, fcm_options): + def encode_apns_fcm_options( + cls, + fcm_options: Optional[_messaging_utils.APNSFCMOptions], + ) -> Optional[dict[str, str]]: """Encodes an ``APNSFCMOptions`` instance into JSON.""" if fcm_options is None: return None @@ -564,11 +644,11 @@ def encode_apns_fcm_options(cls, fcm_options): return result @classmethod - def encode_aps(cls, aps): + def encode_aps(cls, aps: _messaging_utils.Aps) -> dict[str, Any]: """Encodes an ``Aps`` instance into JSON.""" if not isinstance(aps, _messaging_utils.Aps): raise ValueError('APNSPayload.aps must be an instance of Aps class.') - result = { + result: dict[str, Any] = { 'alert': cls.encode_aps_alert(aps.alert), 'badge': _Validators.check_number('Aps.badge', aps.badge), 'sound': cls.encode_aps_sound(aps.sound), @@ -585,12 +665,15 @@ def encode_aps(cls, aps): for key, val in aps.custom_data.items(): _Validators.check_string('Aps.custom_data key', key) if key in result: - raise ValueError(f'Multiple specifications for {key} in Aps.') + raise ValueError('Multiple specifications for {0} in Aps.'.format(key)) result[key] = val return cls.remove_null_values(result) @classmethod - def encode_aps_sound(cls, sound): + def encode_aps_sound( + cls, + sound: Optional[Union[str, _messaging_utils.CriticalSound]], + ) -> Optional[Union[str, dict[str, Any]]]: """Encodes an APNs sound configuration into JSON.""" if sound is None: return None @@ -599,7 +682,7 @@ def encode_aps_sound(cls, sound): if not isinstance(sound, _messaging_utils.CriticalSound): raise ValueError( 'Aps.sound must be a non-empty string or an instance of CriticalSound class.') - result = { + result: dict[str, Any] = { 'name': _Validators.check_string('CriticalSound.name', sound.name, non_empty=True), 'volume': _Validators.check_number('CriticalSound.volume', sound.volume), } @@ -613,7 +696,10 @@ def encode_aps_sound(cls, sound): return cls.remove_null_values(result) @classmethod - def encode_aps_alert(cls, alert): + def encode_aps_alert( + cls, + alert: Optional[Union[_messaging_utils.ApsAlert, str]], + ) -> Optional[Union[str, dict[str, Any]]]: """Encodes an ``ApsAlert`` instance into JSON.""" if alert is None: return None @@ -655,7 +741,10 @@ def encode_aps_alert(cls, alert): return cls.remove_null_values(result) @classmethod - def encode_notification(cls, notification): + def encode_notification( + cls, + notification: Optional[_messaging_utils.Notification], + ) -> Optional[dict[str, str]]: """Encodes a ``Notification`` instance into JSON.""" if notification is None: return None @@ -669,7 +758,7 @@ def encode_notification(cls, notification): return cls.remove_null_values(result) @classmethod - def sanitize_topic_name(cls, topic): + def sanitize_topic_name(cls, topic: Optional[str]) -> Optional[str]: """Removes the /topics/ prefix from the topic name, if present.""" if not topic: return None @@ -681,7 +770,7 @@ def sanitize_topic_name(cls, topic): raise ValueError('Malformed topic name.') return topic - def default(self, o): # pylint: disable=method-hidden + def default(self, o: Any) -> dict[str, Any]: # pylint: disable=method-hidden if not isinstance(o, Message): return json.JSONEncoder.default(self, o) result = { @@ -704,7 +793,10 @@ def default(self, o): # pylint: disable=method-hidden return result @classmethod - def encode_fcm_options(cls, fcm_options): + def encode_fcm_options( + cls, + fcm_options: Optional[_messaging_utils.FCMOptions], + ) -> Optional[dict[str, str]]: """Encodes an ``FCMOptions`` instance into JSON.""" if fcm_options is None: return None diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 8fd720701..3b942a31e 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -14,8 +14,43 @@ """Types and utilities used by the messaging (FCM) module.""" +import datetime +import numbers +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + +import httpx +import requests + from firebase_admin import exceptions +if TYPE_CHECKING: + from _typeshed import Incomplete +else: + Incomplete = Any + +__all__ = ( + 'APNSConfig', + 'APNSFCMOptions', + 'APNSPayload', + 'AndroidConfig', + 'AndroidFCMOptions', + 'AndroidNotification', + 'Aps', + 'ApsAlert', + 'CriticalSound', + 'FCMOptions', + 'LightSettings', + 'Notification', + 'QuotaExceededError', + 'SenderIdMismatchError', + 'ThirdPartyAuthError', + 'UnregisteredError', + 'WebpushConfig', + 'WebpushFCMOptions', + 'WebpushNotification', + 'WebpushNotificationAction', +) + class Notification: """A notification that can be included in a message. @@ -26,7 +61,12 @@ class Notification: image: Image url of the notification (optional) """ - def __init__(self, title=None, body=None, image=None): + def __init__( + self, + title: Optional[str] = None, + body: Optional[str] = None, + image: Optional[str] = None, + ) -> None: self.title = title self.body = body self.image = image @@ -53,8 +93,17 @@ class AndroidConfig: the app while the device is in direct boot mode (optional). """ - def __init__(self, collapse_key=None, priority=None, ttl=None, restricted_package_name=None, - data=None, notification=None, fcm_options=None, direct_boot_ok=None): + def __init__( + self, + collapse_key: Optional[str] = None, + priority: Optional[Literal['high', 'normal']] = None, + ttl: Optional[Union[numbers.Real, datetime.timedelta]] = None, + restricted_package_name: Optional[str] = None, + data: Optional[dict[str, str]] = None, + notification: Optional['AndroidNotification'] = None, + fcm_options: Optional['AndroidFCMOptions'] = None, + direct_boot_ok: Optional[bool] = None, + ) -> None: self.collapse_key = collapse_key self.priority = priority self.ttl = ttl @@ -153,13 +202,35 @@ class AndroidNotification: """ - def __init__(self, title=None, body=None, icon=None, color=None, sound=None, tag=None, - click_action=None, body_loc_key=None, body_loc_args=None, title_loc_key=None, - title_loc_args=None, channel_id=None, image=None, ticker=None, sticky=None, - event_timestamp=None, local_only=None, priority=None, vibrate_timings_millis=None, - default_vibrate_timings=None, default_sound=None, light_settings=None, - default_light_settings=None, visibility=None, notification_count=None, - proxy=None): + def __init__( + self, + title: Optional[str] = None, + body: Optional[str] = None, + icon: Optional[str] = None, + color: Optional[str] = None, + sound: Optional[str] = None, + tag: Optional[str] = None, + click_action: Optional[Incomplete] = None, + body_loc_key: Optional[str] = None, + body_loc_args: Optional[list[str]] = None, + title_loc_key: Optional[str] = None, + title_loc_args: Optional[list[str]] = None, + channel_id: Optional[Incomplete] = None, + image: Optional[str] = None, + ticker: Optional[Incomplete] = None, + sticky: Optional[bool] = None, + event_timestamp: Optional[datetime.datetime] = None, + local_only: Optional[Incomplete] = None, + priority: Optional[Literal['default', 'min', 'low', 'high', 'max', 'normal']] = None, + vibrate_timings_millis: Optional[float] = None, + default_vibrate_timings: Optional[bool] = None, + default_sound: Optional[bool] = None, + light_settings: Optional['LightSettings'] = None, + default_light_settings: Optional[bool] = None, + visibility: Optional[Literal['private', 'public', 'secret']] = None, + notification_count: Optional[int] = None, + proxy: Optional[Literal['allow', 'deny']] = None, + ) -> None: self.title = title self.body = body self.icon = icon @@ -199,8 +270,12 @@ class LightSettings: light_off_duration_millis: Along with ``light_on_duration``, defines the blink rate of LED flashes. """ - def __init__(self, color, light_on_duration_millis, - light_off_duration_millis): + def __init__( + self, + color: str, + light_on_duration_millis: Union[numbers.Real, datetime.timedelta], + light_off_duration_millis: Union[numbers.Real, datetime.timedelta], + ) -> None: self.color = color self.light_on_duration_millis = light_on_duration_millis self.light_off_duration_millis = light_off_duration_millis @@ -214,7 +289,7 @@ class AndroidFCMOptions: (optional). """ - def __init__(self, analytics_label=None): + def __init__(self, analytics_label: Optional[Incomplete] = None) -> None: self.analytics_label = analytics_label @@ -233,7 +308,13 @@ class WebpushConfig: .. _Webpush Specification: https://tools.ietf.org/html/rfc8030#section-5 """ - def __init__(self, headers=None, data=None, notification=None, fcm_options=None): + def __init__( + self, + headers: Optional[dict[str, str]] = None, + data: Optional[dict[str, str]] = None, + notification: Optional['WebpushNotification'] = None, + fcm_options: Optional['WebpushFCMOptions'] = None, + ) -> None: self.headers = headers self.data = data self.notification = notification @@ -249,7 +330,7 @@ class WebpushNotificationAction: icon: Icon URL for the action (optional). """ - def __init__(self, action, title, icon=None): + def __init__(self, action: str, title: str, icon: Optional[str] = None) -> None: self.action = action self.title = title self.icon = icon @@ -290,10 +371,25 @@ class WebpushNotification: /notification/Notification """ - def __init__(self, title=None, body=None, icon=None, actions=None, badge=None, data=None, - direction=None, image=None, language=None, renotify=None, - require_interaction=None, silent=None, tag=None, timestamp_millis=None, - vibrate=None, custom_data=None): + def __init__( + self, + title: Optional[str] = None, + body: Optional[str] = None, + icon: Optional[str] = None, + actions: Optional[list[WebpushNotificationAction]] = None, + badge: Optional[str] = None, + data: Optional[Any] = None, + direction: Optional[Literal['auto', 'ltr', 'rtl']] = None, + image: Optional[str] = None, + language: Optional[str] = None, + renotify: Optional[bool] = None, + require_interaction: Optional[bool] = None, + silent: Optional[bool] = None, + tag: Optional[str] = None, + timestamp_millis: Optional[int] = None, + vibrate: Optional[list[int]] = None, + custom_data: Optional[dict[str, Any]] = None, + ) -> None: self.title = title self.body = body self.icon = icon @@ -320,7 +416,7 @@ class WebpushFCMOptions: (optional). """ - def __init__(self, link=None): + def __init__(self, link: Optional[str] = None) -> None: self.link = link @@ -340,7 +436,13 @@ class APNSConfig: /NetworkingInternet/Conceptual/RemoteNotificationsPG/CommunicatingwithAPNs.html """ - def __init__(self, headers=None, payload=None, fcm_options=None, live_activity_token=None): + def __init__( + self, + headers: Optional[dict[str, str]] = None, + payload: Optional['APNSPayload'] = None, + fcm_options: Optional['APNSFCMOptions'] = None, + live_activity_token: Optional[str] = None, + ) -> None: self.headers = headers self.payload = payload self.fcm_options = fcm_options @@ -356,7 +458,7 @@ class APNSPayload: (optional). """ - def __init__(self, aps, **kwargs): + def __init__(self, aps: 'Aps', **kwargs: Any) -> None: self.aps = aps self.custom_data = kwargs @@ -379,8 +481,17 @@ class Aps: (optional). """ - def __init__(self, alert=None, badge=None, sound=None, content_available=None, category=None, - thread_id=None, mutable_content=None, custom_data=None): + def __init__( + self, + alert: Optional[Union['ApsAlert', str]] = None, + badge: Optional[float] = None, # should it be int? + sound: Optional[Union[str, 'CriticalSound']] = None, + content_available: Optional[bool] = None, + category: Optional[str] = None, + thread_id: Optional[str] = None, + mutable_content: Optional[bool] = None, + custom_data: Optional[dict[str, Any]] = None, + ) -> None: self.alert = alert self.badge = badge self.sound = sound @@ -404,7 +515,12 @@ class CriticalSound: and 1.0 (full volume) (optional). """ - def __init__(self, name, critical=None, volume=None): + def __init__( + self, + name: str, + critical: Optional[bool] = None, + volume: Optional[float] = None, + ) -> None: self.name = name self.critical = critical self.volume = volume @@ -434,9 +550,19 @@ class ApsAlert: (optional) """ - def __init__(self, title=None, subtitle=None, body=None, loc_key=None, loc_args=None, - title_loc_key=None, title_loc_args=None, action_loc_key=None, launch_image=None, - custom_data=None): + def __init__( + self, + title: Optional[str] = None, + subtitle: Optional[str] = None, + body: Optional[str] = None, + loc_key: Optional[str] = None, + loc_args: Optional[list[str]] = None, + title_loc_key: Optional[str] = None, + title_loc_args: Optional[list[str]] = None, + action_loc_key: Optional[str] = None, + launch_image: Optional[str] = None, + custom_data: Optional[dict[str, Any]] = None, + ) -> None: self.title = title self.subtitle = subtitle self.body = body @@ -459,7 +585,11 @@ class APNSFCMOptions: (optional). """ - def __init__(self, analytics_label=None, image=None): + def __init__( + self, + analytics_label: Optional[Incomplete] = None, + image: Optional[str] = None, + ) -> None: self.analytics_label = analytics_label self.image = image @@ -471,29 +601,44 @@ class FCMOptions: analytics_label: contains additional options to use across all platforms (optional). """ - def __init__(self, analytics_label=None): + def __init__(self, analytics_label: Optional[Incomplete] = None) -> None: self.analytics_label = analytics_label class ThirdPartyAuthError(exceptions.UnauthenticatedError): """APNs certificate or web push auth key was invalid or missing.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.UnauthenticatedError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class QuotaExceededError(exceptions.ResourceExhaustedError): """Sending limit exceeded for the message target.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class SenderIdMismatchError(exceptions.PermissionDeniedError): """The authenticated sender ID is different from the sender ID for the registration token.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.PermissionDeniedError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class UnregisteredError(exceptions.NotFoundError): @@ -501,5 +646,10 @@ class UnregisteredError(exceptions.NotFoundError): This usually means that the token used is no longer valid and a new one must be used.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.NotFoundError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) diff --git a/firebase_admin/_retry.py b/firebase_admin/_retry.py index efd90a743..84c27ccf1 100644 --- a/firebase_admin/_retry.py +++ b/firebase_admin/_retry.py @@ -17,17 +17,22 @@ This module provides utilities for adding retry logic to HTTPX requests """ -from __future__ import annotations import copy import email.utils import random import re import time -from typing import Any, Callable, List, Optional, Tuple, Coroutine import logging +from collections.abc import Callable, Coroutine +from typing import Any, Optional + import asyncio +from typing_extensions import Self + import httpx +__all__ = ('HttpxRetry', 'HttpxRetryTransport') + logger = logging.getLogger(__name__) @@ -40,18 +45,18 @@ class HttpxRetry: DEFAULT_BACKOFF_MAX = 120 def __init__( - self, - max_retries: int = 10, - status_forcelist: Optional[List[int]] = None, - backoff_factor: float = 0, - backoff_max: float = DEFAULT_BACKOFF_MAX, - backoff_jitter: float = 0, - history: Optional[List[Tuple[ - httpx.Request, - Optional[httpx.Response], - Optional[Exception] - ]]] = None, - respect_retry_after_header: bool = False, + self, + max_retries: int = 10, + status_forcelist: Optional[list[int]] = None, + backoff_factor: float = 0, + backoff_max: float = DEFAULT_BACKOFF_MAX, + backoff_jitter: float = 0, + history: Optional[list[tuple[ + httpx.Request, + Optional[httpx.Response], + Optional[Exception] + ]]] = None, + respect_retry_after_header: bool = False, ) -> None: self.retries_left = max_retries self.status_forcelist = status_forcelist @@ -64,7 +69,7 @@ def __init__( self.history = [] self.respect_retry_after_header = respect_retry_after_header - def copy(self) -> HttpxRetry: + def copy(self) -> Self: """Creates a deep copy of this instance.""" return copy.deepcopy(self) @@ -89,7 +94,7 @@ def is_exhausted(self) -> bool: return self.retries_left < 0 # Identical implementation of `urllib3.Retry.parse_retry_after()` - def _parse_retry_after(self, retry_after_header: str) -> float | None: + def _parse_retry_after(self, retry_after_header: str) -> Optional[float]: """Parses Retry-After string into a float with unit seconds.""" seconds: float # Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4 @@ -107,7 +112,7 @@ def _parse_retry_after(self, retry_after_header: str) -> float | None: return seconds - def get_retry_after(self, response: httpx.Response) -> float | None: + def get_retry_after(self, response: httpx.Response) -> Optional[float]: """Determine the Retry-After time needed before sending the next request.""" retry_after_header = response.headers.get('Retry-After', None) if retry_after_header: @@ -115,7 +120,7 @@ def get_retry_after(self, response: httpx.Response) -> float | None: return self._parse_retry_after(retry_after_header) return None - def get_backoff_time(self): + def get_backoff_time(self) -> float: """Determine the backoff time needed before sending the next request.""" # attempt_count is the number of previous request attempts attempt_count = len(self.history) @@ -147,10 +152,10 @@ async def sleep(self, response: httpx.Response) -> None: await self.sleep_for_backoff() def increment( - self, - request: httpx.Request, - response: Optional[httpx.Response] = None, - error: Optional[Exception] = None + self, + request: httpx.Request, + response: Optional[httpx.Response] = None, + error: Optional[Exception] = None, ) -> None: """Update the retry state based on request attempt.""" self.retries_left -= 1 @@ -177,9 +182,9 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: request, self._wrapped_transport.handle_async_request) async def _dispatch_with_retry( - self, - request: httpx.Request, - dispatch_method: Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]] + self, + request: httpx.Request, + dispatch_method: Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]], ) -> httpx.Response: """Sends a request with retry logic using a provided dispatch method.""" # This request config is used across all requests that use this transport and therefore diff --git a/firebase_admin/_rfc3339.py b/firebase_admin/_rfc3339.py index 8489bdcb9..7132911e4 100644 --- a/firebase_admin/_rfc3339.py +++ b/firebase_admin/_rfc3339.py @@ -14,10 +14,13 @@ """Parse RFC3339 date strings""" -from datetime import datetime, timezone +import datetime import re -def parse_to_epoch(datestr): +__all__ = ('parse_to_epoch',) + + +def parse_to_epoch(datestr: str) -> float: """Parse an RFC3339 date string and return the number of seconds since the epoch (as a float). @@ -37,7 +40,7 @@ def parse_to_epoch(datestr): return _parse_to_datetime(datestr).timestamp() -def _parse_to_datetime(datestr): +def _parse_to_datetime(datestr: str) -> datetime.datetime: """Parse an RFC3339 date string and return a python datetime instance. Args: @@ -55,16 +58,16 @@ def _parse_to_datetime(datestr): # This format is the one we actually expect to occur from our backend. The # others are only present because the spec says we *should* accept them. try: - return datetime.strptime( + return datetime.datetime.strptime( datestr_modified, '%Y-%m-%dT%H:%M:%S.%fZ' - ).replace(tzinfo=timezone.utc) + ).replace(tzinfo=datetime.timezone.utc) except ValueError: pass try: - return datetime.strptime( + return datetime.datetime.strptime( datestr_modified, '%Y-%m-%dT%H:%M:%SZ' - ).replace(tzinfo=timezone.utc) + ).replace(tzinfo=datetime.timezone.utc) except ValueError: pass @@ -75,12 +78,12 @@ def _parse_to_datetime(datestr): datestr_modified = re.sub(r'(\d\d):(\d\d)$', r'\1\2', datestr_modified) try: - return datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S.%f%z') + return datetime.datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S.%f%z') except ValueError: pass try: - return datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S%z') + return datetime.datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S%z') except ValueError: pass diff --git a/firebase_admin/_sseclient.py b/firebase_admin/_sseclient.py index 3372fe5f2..ea0d5ac23 100644 --- a/firebase_admin/_sseclient.py +++ b/firebase_admin/_sseclient.py @@ -20,58 +20,67 @@ import re import time import warnings +from collections.abc import Iterator +from typing import Any, Optional +from typing_extensions import Self -from google.auth import transport +import google.auth.credentials +import google.auth.transport.requests import requests +__all__ = ( + 'Event', + 'KeepAuthSession', + 'SSEClient', +) # Technically, we should support streams that mix line endings. This regex, # however, assumes that a system will provide consistent line endings. end_of_field = re.compile(r'\r\n\r\n|\r\r|\n\n') -class KeepAuthSession(transport.requests.AuthorizedSession): +class KeepAuthSession(google.auth.transport.requests.AuthorizedSession): """A session that does not drop authentication on redirects between domains.""" - def __init__(self, credential): + def __init__(self, credential: Optional[google.auth.credentials.Credentials]) -> None: super().__init__(credential) - def rebuild_auth(self, prepared_request, response): + def rebuild_auth(self, prepared_request: requests.PreparedRequest, response: requests.Response) -> None: pass class _EventBuffer: """A helper class for buffering and parsing raw SSE data.""" - def __init__(self): - self._buffer = [] + def __init__(self) -> None: + self._buffer: list[str] = [] self._tail = '' - def append(self, char): + def append(self, char: str) -> None: self._buffer.append(char) self._tail += char self._tail = self._tail[-4:] - def truncate(self): + def truncate(self) -> None: head, sep, _ = self.buffer_string.rpartition('\n') rem = head + sep self._buffer = list(rem) self._tail = rem[-4:] @property - def is_end_of_field(self): + def is_end_of_field(self) -> bool: last_two_chars = self._tail[-2:] return last_two_chars == '\n\n' or last_two_chars == '\r\r' or self._tail == '\r\n\r\n' @property - def buffer_string(self): + def buffer_string(self) -> str: return ''.join(self._buffer) class SSEClient: """SSE client implementation.""" - def __init__(self, url, session, retry=3000, **kwargs): + def __init__(self, url: str, session: requests.Session, retry: int = 3000, **kwargs: Any) -> None: """Initializes the SSEClient. Args: @@ -85,7 +94,7 @@ def __init__(self, url, session, retry=3000, **kwargs): self.retry = retry self.requests_kwargs = kwargs self.should_connect = True - self.last_id = None + self.last_id: Optional[str] = None self.buf = '' # Keep data here as it streams in headers = self.requests_kwargs.get('headers', {}) @@ -96,13 +105,13 @@ def __init__(self, url, session, retry=3000, **kwargs): self.requests_kwargs['headers'] = headers self._connect() - def close(self): + def close(self) -> None: """Closes the SSEClient instance.""" self.should_connect = False self.retry = 0 self.resp.close() - def _connect(self): + def _connect(self) -> None: """Connects to the server using requests.""" if self.should_connect: if self.last_id: @@ -113,10 +122,10 @@ def _connect(self): else: raise StopIteration() - def __iter__(self): + def __iter__(self) -> Iterator[Optional['Event']]: return self - def __next__(self): + def __next__(self) -> Optional['Event']: if not re.search(end_of_field, self.buf): temp_buffer = _EventBuffer() while not temp_buffer.is_end_of_field: @@ -153,20 +162,29 @@ def __next__(self): self.last_id = event.event_id return event + def next(self) -> Optional['Event']: + return self.__next__() + class Event: """Event represents the events fired by SSE.""" sse_line_pattern = re.compile('(?P[^:]*):?( ?(?P.*))?') - def __init__(self, data='', event_type='message', event_id=None, retry=None): + def __init__( + self, + data: str = '', + event_type: str = 'message', + event_id: Optional[str] = None, + retry: Optional[int] = None, + ) -> None: self.data = data self.event_type = event_type self.event_id = event_id self.retry = retry @classmethod - def parse(cls, raw): + def parse(cls, raw: str) -> Self: """Given a possibly-multiline string representing an SSE message, parses it and returns an Event object. diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 1607ef0ba..22d2aed7e 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -16,21 +16,55 @@ import datetime import time +from collections.abc import Callable, Mapping +from typing import TYPE_CHECKING, Any, Optional, Union, cast import cachecontrol import requests from google.auth import credentials from google.auth import iam from google.auth import jwt -from google.auth import transport +import google.auth.transport.requests +import google.auth.crypt import google.auth.exceptions import google.oauth2.id_token import google.oauth2.service_account +import firebase_admin from firebase_admin import exceptions from firebase_admin import _auth_utils from firebase_admin import _http_client +if TYPE_CHECKING: + from _typeshed import Incomplete +else: + Incomplete = Any + +__all__ = ( + 'ALGORITHM_NONE', + 'ALGORITHM_RS256', + 'AUTH_EMULATOR_EMAIL', + 'COOKIE_CERT_URI', + 'COOKIE_ISSUER_PREFIX', + 'FIREBASE_AUDIENCE', + 'ID_TOKEN_CERT_URI', + 'ID_TOKEN_ISSUER_PREFIX', + 'MAX_SESSION_COOKIE_DURATION_SECONDS', + 'MAX_TOKEN_LIFETIME_SECONDS', + 'METADATA_SERVICE_URL', + 'MIN_SESSION_COOKIE_DURATION_SECONDS', + 'RESERVED_CLAIMS', + 'CertificateFetchRequest', + 'TokenGenerator', + 'TokenSignError', + 'TokenVerifier', + 'CertificateFetchError', + 'ExpiredIdTokenError', + 'ExpiredSessionCookieError', + 'InvalidSessionCookieError', + 'RevokedIdTokenError', + 'RevokedSessionCookieError', +) # ID token constants ID_TOKEN_ISSUER_PREFIX = 'https://securetoken.google.com/' @@ -61,19 +95,26 @@ class _EmulatedSigner(google.auth.crypt.Signer): - key_id = None + @property + def key_id(self) -> Optional[str]: + return None - def __init__(self): + def __init__(self) -> None: pass - def sign(self, message): + def sign(self, message: Union[str, bytes]) -> bytes: return b'' class _SigningProvider: """Stores a reference to a google.auth.crypto.Signer.""" - def __init__(self, signer, signer_email, alg=ALGORITHM_RS256): + def __init__( + self, + signer: google.auth.crypt.Signer, + signer_email: Optional[str], + alg: str = ALGORITHM_RS256, + ) -> None: self._signer = signer self._signer_email = signer_email self._alg = alg @@ -87,20 +128,28 @@ def signer_email(self): return self._signer_email @property - def alg(self): + def alg(self) -> str: return self._alg @classmethod - def from_credential(cls, google_cred): + def from_credential( + cls, + google_cred: Union[google.oauth2.service_account.Credentials, credentials.Signing] + ) -> '_SigningProvider': return _SigningProvider(google_cred.signer, google_cred.signer_email) @classmethod - def from_iam(cls, request, google_cred, service_account): + def from_iam( + cls, + request: google.auth.transport.Request, + google_cred: credentials.Credentials, + service_account: str, + ) -> '_SigningProvider': signer = iam.Signer(request, google_cred, service_account) return _SigningProvider(signer, service_account) @classmethod - def for_emulator(cls): + def for_emulator(cls) -> '_SigningProvider': return _SigningProvider(_EmulatedSigner(), AUTH_EMULATOR_EMAIL, ALGORITHM_NONE) @@ -109,15 +158,20 @@ class TokenGenerator: ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' - def __init__(self, app, http_client, url_override=None): + def __init__( + self, + app: firebase_admin.App, + http_client: _http_client.HttpClient[dict[str, Any]], + url_override: Optional[str] = None, + ) -> None: self.app = app self.http_client = http_client - self.request = transport.requests.Request() + self.request = google.auth.transport.requests.Request() url_prefix = url_override or self.ID_TOOLKIT_URL self.base_url = f'{url_prefix}/projects/{app.project_id}' - self._signing_provider = None + self._signing_provider: Optional[_SigningProvider] = None - def _init_signing_provider(self): + def _init_signing_provider(self) -> _SigningProvider: """Initializes a signing provider by following the go/firebase-admin-sign protocol.""" if _auth_utils.is_emulated(): return _SigningProvider.for_emulator() @@ -143,11 +197,11 @@ def _init_signing_provider(self): if resp.status != 200: raise ValueError( f'Failed to contact the local metadata service: {resp.data.decode()}.') - service_account = resp.data.decode() + service_account = cast(str, resp.data.decode()) return _SigningProvider.from_iam(self.request, google_cred, service_account) @property - def signing_provider(self): + def signing_provider(self) -> _SigningProvider: """Initializes and returns the SigningProvider instance to be used.""" if not self._signing_provider: try: @@ -161,7 +215,12 @@ def signing_provider(self): 'details on creating custom tokens.') from error return self._signing_provider - def create_custom_token(self, uid, developer_claims=None, tenant_id=None): + def create_custom_token( + self, + uid: str, + developer_claims: Optional[dict[str, Any]] = None, + tenant_id: Optional[str] = None, + ) -> bytes: """Builds and signs a Firebase custom auth token.""" if developer_claims is not None: if not isinstance(developer_claims, dict): @@ -184,7 +243,7 @@ def create_custom_token(self, uid, developer_claims=None, tenant_id=None): signing_provider = self.signing_provider now = int(time.time()) - payload = { + payload: dict[str, Any] = { 'iss': signing_provider.signer_email, 'sub': signing_provider.signer_email, 'aud': FIREBASE_AUDIENCE, @@ -206,7 +265,11 @@ def create_custom_token(self, uid, developer_claims=None, tenant_id=None): raise TokenSignError(msg, error) from error - def create_session_cookie(self, id_token, expires_in): + def create_session_cookie( + self, + id_token: Union[bytes, str], + expires_in: Union[datetime.timedelta, int], + ) -> str: """Creates a session cookie from the provided ID token.""" id_token = id_token.decode('utf-8') if isinstance(id_token, bytes) else id_token if not isinstance(id_token, str) or not id_token: @@ -238,38 +301,47 @@ def create_session_cookie(self, id_token, expires_in): if not body or not body.get('sessionCookie'): raise _auth_utils.UnexpectedResponseError( 'Failed to create session cookie.', http_response=http_resp) - return body.get('sessionCookie') + return cast(str, body['sessionCookie']) -class CertificateFetchRequest(transport.Request): +class CertificateFetchRequest(google.auth.transport.Request): """A google-auth transport that supports HTTP cache-control. Also injects a timeout to each outgoing HTTP request. """ - def __init__(self, timeout_seconds=None): + def __init__(self, timeout_seconds: Optional[float] = None) -> None: self._session = cachecontrol.CacheControl(requests.Session()) - self._delegate = transport.requests.Request(self.session) + self._delegate = google.auth.transport.requests.Request(self.session) self._timeout_seconds = timeout_seconds @property - def session(self): + def session(self) -> requests.Session: return self._session @property - def timeout_seconds(self): + def timeout_seconds(self) -> Optional[float]: return self._timeout_seconds - def __call__(self, url, method='GET', body=None, headers=None, timeout=None, **kwargs): + def __call__( + self, + url: str, + method: str = 'GET', + body: Optional[Incomplete] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + **kwargs: Incomplete, + ) -> google.auth.transport.Response: timeout = timeout or self.timeout_seconds return self._delegate( - url, method=method, body=body, headers=headers, timeout=timeout, **kwargs) + url, method=method, body=body, headers=headers, + timeout=timeout, **kwargs) # pyright: ignore[reportArgumentType] class TokenVerifier: """Verifies ID tokens and session cookies.""" - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self.request = CertificateFetchRequest(timeout) self.id_token_verifier = _JWTVerifier( @@ -289,31 +361,56 @@ def __init__(self, app): invalid_token_error=InvalidSessionCookieError, expired_token_error=ExpiredSessionCookieError) - def verify_id_token(self, id_token, clock_skew_seconds=0): + def verify_id_token( + self, + id_token: Union[bytes, str], + clock_skew_seconds: int = 0, + ) -> dict[str, Any]: return self.id_token_verifier.verify(id_token, self.request, clock_skew_seconds) - def verify_session_cookie(self, cookie, clock_skew_seconds=0): + def verify_session_cookie( + self, + cookie: Union[bytes, str], + clock_skew_seconds: int = 0, + ) -> dict[str, Any]: return self.cookie_verifier.verify(cookie, self.request, clock_skew_seconds) class _JWTVerifier: """Verifies Firebase JWTs (ID tokens or session cookies).""" - def __init__(self, **kwargs): - self.project_id = kwargs.pop('project_id') - self.short_name = kwargs.pop('short_name') - self.operation = kwargs.pop('operation') - self.url = kwargs.pop('doc_url') - self.cert_url = kwargs.pop('cert_url') - self.issuer = kwargs.pop('issuer') + def __init__( + self, + *, + project_id: Optional[str], + short_name: str, + operation: str, + doc_url: str, + cert_url: str, + issuer: str, + invalid_token_error: Callable[[str, Optional[Exception]], exceptions.FirebaseError], + expired_token_error: Callable[[str, Optional[Exception]], exceptions.FirebaseError], + **kwargs: Any, + ) -> None: + self.project_id = project_id + self.short_name = short_name + self.operation = operation + self.url = doc_url + self.cert_url = cert_url + self.issuer = issuer if self.short_name[0].lower() in 'aeiou': self.articled_short_name = f'an {self.short_name}' else: self.articled_short_name = f'a {self.short_name}' - self._invalid_token_error = kwargs.pop('invalid_token_error') - self._expired_token_error = kwargs.pop('expired_token_error') - - def verify(self, token, request, clock_skew_seconds=0): + self._invalid_token_error = invalid_token_error + self._expired_token_error = expired_token_error + + def verify( + self, + token: Union[bytes, str], + request: google.auth.transport.Request, + clock_skew_seconds: int = 0, + ) -> dict[str, Any]: """Verifies the signature and data for the provided JWT.""" token = token.encode('utf-8') if isinstance(token, str) else token if not isinstance(token, bytes) or not token: @@ -387,7 +484,7 @@ def verify(self, token, request, clock_skew_seconds=0): f'characters. {verify_id_token_msg}') if error_message: - raise self._invalid_token_error(error_message) + raise self._invalid_token_error(error_message, None) try: if emulated: @@ -399,68 +496,72 @@ def verify(self, token, request, clock_skew_seconds=0): audience=self.project_id, certs_url=self.cert_url, clock_skew_in_seconds=clock_skew_seconds) + verified_claims = cast(dict[str, Any], verified_claims) verified_claims['uid'] = verified_claims['sub'] return verified_claims except google.auth.exceptions.TransportError as error: raise CertificateFetchError(str(error), cause=error) from error except ValueError as error: if 'Token expired' in str(error): - raise self._expired_token_error(str(error), cause=error) - raise self._invalid_token_error(str(error), cause=error) + raise self._expired_token_error(str(error), error) + raise self._invalid_token_error(str(error), error) - def _decode_unverified(self, token): + def _decode_unverified( + self, + token: Union[bytes, str], + ) -> tuple[dict[str, str], dict[str, Any]]: try: - header = jwt.decode_header(token) - payload = jwt.decode(token, verify=False) - return header, payload + header = cast(Mapping[str, str], jwt.decode_header(token)) + payload = cast(Mapping[str, Any], jwt.decode(token, verify=False)) + return dict(header), dict(payload) except ValueError as error: - raise self._invalid_token_error(str(error), cause=error) + raise self._invalid_token_error(str(error), error) class TokenSignError(exceptions.UnknownError): """Unexpected error while signing a Firebase custom token.""" - def __init__(self, message, cause): - exceptions.UnknownError.__init__(self, message, cause) + def __init__(self, message: str, cause: Optional[Exception]) -> None: + super().__init__(message, cause) class CertificateFetchError(exceptions.UnknownError): """Failed to fetch some public key certificates required to verify a token.""" - def __init__(self, message, cause): - exceptions.UnknownError.__init__(self, message, cause) + def __init__(self, message: str, cause: Optional[Exception]) -> None: + super().__init__(message, cause) class ExpiredIdTokenError(_auth_utils.InvalidIdTokenError): """The provided ID token is expired.""" - def __init__(self, message, cause): - _auth_utils.InvalidIdTokenError.__init__(self, message, cause) + def __init__(self, message: str, cause: Optional[Exception]) -> None: + super().__init__(message, cause) class RevokedIdTokenError(_auth_utils.InvalidIdTokenError): """The provided ID token has been revoked.""" - def __init__(self, message): - _auth_utils.InvalidIdTokenError.__init__(self, message) + def __init__(self, message: str) -> None: + super().__init__(message) class InvalidSessionCookieError(exceptions.InvalidArgumentError): """The provided string is not a valid Firebase session cookie.""" - def __init__(self, message, cause=None): - exceptions.InvalidArgumentError.__init__(self, message, cause) + def __init__(self, message: str, cause: Optional[Exception] = None) -> None: + super().__init__(message, cause) class ExpiredSessionCookieError(InvalidSessionCookieError): """The provided session cookie is expired.""" - def __init__(self, message, cause): - InvalidSessionCookieError.__init__(self, message, cause) + def __init__(self, message: str, cause: Optional[Exception]) -> None: + super().__init__(message, cause) class RevokedSessionCookieError(InvalidSessionCookieError): """The provided session cookie has been revoked.""" - def __init__(self, message): - InvalidSessionCookieError.__init__(self, message) + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/firebase_admin/_user_identifier.py b/firebase_admin/_user_identifier.py index 85a224e0b..37ac388b7 100644 --- a/firebase_admin/_user_identifier.py +++ b/firebase_admin/_user_identifier.py @@ -16,6 +16,15 @@ from firebase_admin import _auth_utils +__all__ = ( + 'EmailIdentifier', + 'PhoneIdentifier', + 'ProviderIdentifier', + 'UidIdentifier', + 'UserIdentifier', +) + + class UserIdentifier: """Identifies a user to be looked up.""" @@ -26,7 +35,7 @@ class UidIdentifier(UserIdentifier): See ``auth.get_user()``. """ - def __init__(self, uid): + def __init__(self, uid: str) -> None: """Constructs a new `UidIdentifier` object. Args: @@ -35,7 +44,7 @@ def __init__(self, uid): self._uid = _auth_utils.validate_uid(uid, required=True) @property - def uid(self): + def uid(self) -> str: return self._uid @@ -45,7 +54,7 @@ class EmailIdentifier(UserIdentifier): See ``auth.get_user()``. """ - def __init__(self, email): + def __init__(self, email: str) -> None: """Constructs a new `EmailIdentifier` object. Args: @@ -54,7 +63,7 @@ def __init__(self, email): self._email = _auth_utils.validate_email(email, required=True) @property - def email(self): + def email(self) -> str: return self._email @@ -64,7 +73,7 @@ class PhoneIdentifier(UserIdentifier): See ``auth.get_user()``. """ - def __init__(self, phone_number): + def __init__(self, phone_number: str) -> None: """Constructs a new `PhoneIdentifier` object. Args: @@ -73,7 +82,7 @@ def __init__(self, phone_number): self._phone_number = _auth_utils.validate_phone(phone_number, required=True) @property - def phone_number(self): + def phone_number(self) -> str: return self._phone_number @@ -83,21 +92,21 @@ class ProviderIdentifier(UserIdentifier): See ``auth.get_user()``. """ - def __init__(self, provider_id, provider_uid): + def __init__(self, provider_id: str, provider_uid: str) -> None: """Constructs a new `ProviderIdentifier` object. -   Args: -     provider_id: A provider ID string. -     provider_uid: A provider UID string. + Args: + provider_id: A provider ID string. + provider_uid: A provider UID string. """ self._provider_id = _auth_utils.validate_provider_id(provider_id, required=True) self._provider_uid = _auth_utils.validate_provider_uid( provider_uid, required=True) @property - def provider_id(self): + def provider_id(self) -> str: return self._provider_id @property - def provider_uid(self): + def provider_uid(self) -> str: return self._provider_uid diff --git a/firebase_admin/_user_import.py b/firebase_admin/_user_import.py index 7c7a9e70b..b2ef1421e 100644 --- a/firebase_admin/_user_import.py +++ b/firebase_admin/_user_import.py @@ -16,11 +16,22 @@ import base64 import json +from typing import Any, Optional, cast from firebase_admin import _auth_utils +from firebase_admin import _user_mgt +__all__ = ( + 'ErrorInfo', + 'ImportUserRecord', + 'UserImportHash', + 'UserImportResult', + 'UserProvider', + 'b64_encode', +) -def b64_encode(bytes_value): + +def b64_encode(bytes_value: bytes) -> str: return base64.urlsafe_b64encode(bytes_value).decode() @@ -39,7 +50,14 @@ class UserProvider: photo_url: User's photo URL (optional). """ - def __init__(self, uid, provider_id, email=None, display_name=None, photo_url=None): + def __init__( + self, + uid: str, + provider_id: str, + email: Optional[str] = None, + display_name: Optional[str] = None, + photo_url: Optional[str] = None, + ) -> None: self.uid = uid self.provider_id = provider_id self.email = email @@ -47,46 +65,46 @@ def __init__(self, uid, provider_id, email=None, display_name=None, photo_url=No self.photo_url = photo_url @property - def uid(self): + def uid(self) -> str: return self._uid @uid.setter - def uid(self, uid): + def uid(self, uid: str) -> None: self._uid = _auth_utils.validate_uid(uid, required=True) @property - def provider_id(self): + def provider_id(self) -> str: return self._provider_id @provider_id.setter - def provider_id(self, provider_id): + def provider_id(self, provider_id: str) -> None: self._provider_id = _auth_utils.validate_provider_id(provider_id, required=True) @property - def email(self): + def email(self) -> Optional[str]: return self._email @email.setter - def email(self, email): + def email(self, email: Optional[str]) -> None: self._email = _auth_utils.validate_email(email) @property - def display_name(self): + def display_name(self) -> Optional[str]: return self._display_name @display_name.setter - def display_name(self, display_name): + def display_name(self, display_name: Optional[str]) -> None: self._display_name = _auth_utils.validate_display_name(display_name) @property - def photo_url(self): + def photo_url(self) -> Optional[str]: return self._photo_url @photo_url.setter - def photo_url(self, photo_url): + def photo_url(self, photo_url: Optional[str]): self._photo_url = _auth_utils.validate_photo_url(photo_url) - def to_dict(self): + def to_dict(self) -> dict[str, str]: payload = { 'rawId': self.uid, 'providerId': self.provider_id, @@ -123,9 +141,21 @@ class ImportUserRecord: ValueError: If provided arguments are invalid. """ - def __init__(self, uid, email=None, email_verified=None, display_name=None, phone_number=None, - photo_url=None, disabled=None, user_metadata=None, provider_data=None, - custom_claims=None, password_hash=None, password_salt=None): + def __init__( + self, + uid: str, + email: Optional[str] = None, + email_verified: Optional[bool] = None, + display_name: Optional[str] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + disabled: Optional[bool] = None, + user_metadata: Optional['_user_mgt.UserMetadata'] = None, + provider_data: Optional[list[UserProvider]] = None, + custom_claims: Optional[dict[str, Any]] = None, + password_hash: Optional[bytes] = None, + password_salt: Optional[bytes] = None, + ) -> None: self.uid = uid self.email = email self.display_name = display_name @@ -140,67 +170,67 @@ def __init__(self, uid, email=None, email_verified=None, display_name=None, phon self.custom_claims = custom_claims @property - def uid(self): + def uid(self) -> str: return self._uid @uid.setter - def uid(self, uid): + def uid(self, uid: str) -> None: self._uid = _auth_utils.validate_uid(uid, required=True) @property - def email(self): + def email(self) -> Optional[str]: return self._email @email.setter - def email(self, email): + def email(self, email: Optional[str]) -> None: self._email = _auth_utils.validate_email(email) @property - def display_name(self): + def display_name(self) -> Optional[str]: return self._display_name @display_name.setter - def display_name(self, display_name): + def display_name(self, display_name: Optional[str]) -> None: self._display_name = _auth_utils.validate_display_name(display_name) @property - def phone_number(self): + def phone_number(self) -> Optional[str]: return self._phone_number @phone_number.setter - def phone_number(self, phone_number): + def phone_number(self, phone_number: Optional[str]) -> None: self._phone_number = _auth_utils.validate_phone(phone_number) @property - def photo_url(self): + def photo_url(self) -> Optional[str]: return self._photo_url @photo_url.setter - def photo_url(self, photo_url): + def photo_url(self, photo_url: Optional[str]) -> None: self._photo_url = _auth_utils.validate_photo_url(photo_url) @property - def password_hash(self): + def password_hash(self) -> Optional[bytes]: return self._password_hash @password_hash.setter - def password_hash(self, password_hash): + def password_hash(self, password_hash: Optional[bytes]) -> None: self._password_hash = _auth_utils.validate_bytes(password_hash, 'password_hash') @property - def password_salt(self): + def password_salt(self) -> Optional[bytes]: return self._password_salt @password_salt.setter - def password_salt(self, password_salt): + def password_salt(self, password_salt: Optional[bytes]) -> None: self._password_salt = _auth_utils.validate_bytes(password_salt, 'password_salt') @property - def user_metadata(self): + def user_metadata(self) -> Optional['_user_mgt.UserMetadata']: return self._user_metadata @user_metadata.setter - def user_metadata(self, user_metadata): + def user_metadata(self, user_metadata: Optional['_user_mgt.UserMetadata']) -> None: created_at = user_metadata.creation_timestamp if user_metadata is not None else None last_login_at = user_metadata.last_sign_in_timestamp if user_metadata is not None else None self._created_at = _auth_utils.validate_timestamp(created_at, 'creation_timestamp') @@ -209,11 +239,11 @@ def user_metadata(self, user_metadata): self._user_metadata = user_metadata @property - def provider_data(self): + def provider_data(self) -> Optional[list[UserProvider]]: return self._provider_data @provider_data.setter - def provider_data(self, provider_data): + def provider_data(self, provider_data: Optional[list[UserProvider]]) -> None: if provider_data is not None: try: if any(not isinstance(p, UserProvider) for p in provider_data): @@ -223,19 +253,19 @@ def provider_data(self, provider_data): self._provider_data = provider_data @property - def custom_claims(self): + def custom_claims(self) -> Optional[dict[str, Any]]: return self._custom_claims @custom_claims.setter - def custom_claims(self, custom_claims): + def custom_claims(self, custom_claims: Optional[dict[str, Any]]) -> None: json_claims = json.dumps(custom_claims) if isinstance( custom_claims, dict) else custom_claims self._custom_claims_str = _auth_utils.validate_custom_claims(json_claims) self._custom_claims = custom_claims - def to_dict(self): + def to_dict(self) -> dict[str, Any]: """Returns a dict representation of the user. For internal use only.""" - payload = { + payload: dict[str, Any] = { 'localId': self.uid, 'email': self.email, 'displayName': self.display_name, @@ -265,25 +295,25 @@ class UserImportHash: .. _documentation: https://firebase.google.com/docs/auth/admin/import-users """ - def __init__(self, name, data=None): + def __init__(self, name: str, data: Optional[dict[str, Any]] = None) -> None: self._name = name self._data = data - def to_dict(self): - payload = {'hashAlgorithm': self._name} + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = {'hashAlgorithm': self._name} if self._data: payload.update(self._data) return payload @classmethod - def _hmac(cls, name, key): + def _hmac(cls, name: str, key: bytes) -> 'UserImportHash': data = { 'signerKey': b64_encode(_auth_utils.validate_bytes(key, 'key', required=True)) } return UserImportHash(name, data) @classmethod - def hmac_sha512(cls, key): + def hmac_sha512(cls, key: bytes) -> 'UserImportHash': """Creates a new HMAC SHA512 algorithm instance. Args: @@ -295,7 +325,7 @@ def hmac_sha512(cls, key): return cls._hmac('HMAC_SHA512', key) @classmethod - def hmac_sha256(cls, key): + def hmac_sha256(cls, key: bytes) -> 'UserImportHash': """Creates a new HMAC SHA256 algorithm instance. Args: @@ -307,7 +337,7 @@ def hmac_sha256(cls, key): return cls._hmac('HMAC_SHA256', key) @classmethod - def hmac_sha1(cls, key): + def hmac_sha1(cls, key: bytes) -> 'UserImportHash': """Creates a new HMAC SHA1 algorithm instance. Args: @@ -319,7 +349,7 @@ def hmac_sha1(cls, key): return cls._hmac('HMAC_SHA1', key) @classmethod - def hmac_md5(cls, key): + def hmac_md5(cls, key: bytes) -> 'UserImportHash': """Creates a new HMAC MD5 algorithm instance. Args: @@ -331,7 +361,7 @@ def hmac_md5(cls, key): return cls._hmac('HMAC_MD5', key) @classmethod - def md5(cls, rounds): + def md5(cls, rounds: int) -> 'UserImportHash': """Creates a new MD5 algorithm instance. Args: @@ -345,7 +375,7 @@ def md5(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 8192)}) @classmethod - def sha1(cls, rounds): + def sha1(cls, rounds: int) -> 'UserImportHash': """Creates a new SHA1 algorithm instance. Args: @@ -359,7 +389,7 @@ def sha1(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8192)}) @classmethod - def sha256(cls, rounds): + def sha256(cls, rounds: int) -> 'UserImportHash': """Creates a new SHA256 algorithm instance. Args: @@ -373,7 +403,7 @@ def sha256(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8192)}) @classmethod - def sha512(cls, rounds): + def sha512(cls, rounds: int) -> 'UserImportHash': """Creates a new SHA512 algorithm instance. Args: @@ -387,7 +417,7 @@ def sha512(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8192)}) @classmethod - def pbkdf_sha1(cls, rounds): + def pbkdf_sha1(cls, rounds: int) -> 'UserImportHash': """Creates a new PBKDF SHA1 algorithm instance. Args: @@ -401,7 +431,7 @@ def pbkdf_sha1(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 120000)}) @classmethod - def pbkdf2_sha256(cls, rounds): + def pbkdf2_sha256(cls, rounds: int) -> 'UserImportHash': """Creates a new PBKDF2 SHA256 algorithm instance. Args: @@ -415,7 +445,13 @@ def pbkdf2_sha256(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 120000)}) @classmethod - def scrypt(cls, key, rounds, memory_cost, salt_separator=None): + def scrypt( + cls, + key: bytes, + rounds: int, + memory_cost: int, + salt_separator: Optional[bytes] = None, + ) -> 'UserImportHash': """Creates a new Scrypt algorithm instance. This is the modified Scrypt algorithm used by Firebase Auth. See ``standard_scrypt()`` @@ -430,18 +466,18 @@ def scrypt(cls, key, rounds, memory_cost, salt_separator=None): Returns: UserImportHash: A new ``UserImportHash``. """ - data = { + data: dict[str, Any] = { 'signerKey': b64_encode(_auth_utils.validate_bytes(key, 'key', required=True)), 'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8), 'memoryCost': _auth_utils.validate_int(memory_cost, 'memory_cost', 1, 14), } if salt_separator: data['saltSeparator'] = b64_encode(_auth_utils.validate_bytes( - salt_separator, 'salt_separator')) + salt_separator, 'salt_separator', True)) return UserImportHash('SCRYPT', data) @classmethod - def bcrypt(cls): + def bcrypt(cls) -> 'UserImportHash': """Creates a new Bcrypt algorithm instance. Returns: @@ -450,7 +486,13 @@ def bcrypt(cls): return UserImportHash('BCRYPT') @classmethod - def standard_scrypt(cls, memory_cost, parallelization, block_size, derived_key_length): + def standard_scrypt( + cls, + memory_cost: int, + parallelization: int, + block_size: int, + derived_key_length: int, + ) -> 'UserImportHash': """Creates a new standard Scrypt algorithm instance. Args: @@ -479,16 +521,16 @@ class ErrorInfo: # it's home in _user_import.py). It's now also used by bulk deletion of # users. Move this to a more common location. - def __init__(self, error): - self._index = error['index'] - self._reason = error['message'] + def __init__(self, error: dict[str, Any]) -> None: + self._index = cast(int, error['index']) + self._reason = cast(str, error['message']) @property - def index(self): + def index(self) -> int: return self._index @property - def reason(self): + def reason(self) -> str: return self._reason @@ -498,23 +540,23 @@ class UserImportResult: See ``auth.import_users()`` API for more details. """ - def __init__(self, result, total): + def __init__(self, result: dict[str, Any], total: int) -> None: errors = result.get('error', []) self._success_count = total - len(errors) self._failure_count = len(errors) self._errors = [ErrorInfo(err) for err in errors] @property - def success_count(self): + def success_count(self) -> int: """Returns the number of users successfully imported.""" return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: """Returns the number of users that failed to be imported.""" return self._failure_count @property - def errors(self): + def errors(self) -> list[ErrorInfo]: """Returns a list of ``auth.ErrorInfo`` instances describing the errors encountered.""" return self._errors diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 9a75b7a2e..9a26c9773 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -15,18 +15,42 @@ """Firebase user management sub module.""" import base64 -from collections import defaultdict +import collections import json +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from urllib import parse import requests from firebase_admin import _auth_utils +from firebase_admin import _http_client from firebase_admin import _rfc3339 from firebase_admin import _user_identifier from firebase_admin import _user_import -from firebase_admin._user_import import ErrorInfo +if TYPE_CHECKING: + from _typeshed import ConvertibleToInt + +__all__ = ( + 'B64_REDACTED', + 'DELETE_ATTRIBUTE', + 'MAX_IMPORT_USERS_SIZE', + 'MAX_LIST_USERS_RESULTS', + 'ActionCodeSettings', + 'BatchDeleteAccountsResponse', + 'DeleteUsersResult', + 'ExportedUserRecord', + 'GetUsersResult', + 'ListUsersPage', + 'ProviderUserInfo', + 'Sentinel', + 'UserInfo', + 'UserManager', + 'UserMetadata', + 'UserRecord', + 'encode_action_code_settings', +) MAX_LIST_USERS_RESULTS = 1000 MAX_IMPORT_USERS_SIZE = 1000 @@ -34,19 +58,22 @@ class Sentinel: - - def __init__(self, description): + def __init__(self, description: str) -> None: self.description = description -DELETE_ATTRIBUTE = Sentinel('Value used to delete an attribute from a user profile') +DELETE_ATTRIBUTE: Any = Sentinel('Value used to delete an attribute from a user profile') class UserMetadata: """Contains additional metadata associated with a user account.""" - def __init__(self, creation_timestamp=None, last_sign_in_timestamp=None, - last_refresh_timestamp=None): + def __init__( + self, + creation_timestamp: Optional[Any] = None, + last_sign_in_timestamp: Optional[Any] = None, + last_refresh_timestamp: Optional[Any] = None, + ) -> None: self._creation_timestamp = _auth_utils.validate_timestamp( creation_timestamp, 'creation_timestamp') self._last_sign_in_timestamp = _auth_utils.validate_timestamp( @@ -55,7 +82,7 @@ def __init__(self, creation_timestamp=None, last_sign_in_timestamp=None, last_refresh_timestamp, 'last_refresh_timestamp') @property - def creation_timestamp(self): + def creation_timestamp(self) -> Optional[int]: """ Creation timestamp in milliseconds since the epoch. Returns: @@ -64,7 +91,7 @@ def creation_timestamp(self): return self._creation_timestamp @property - def last_sign_in_timestamp(self): + def last_sign_in_timestamp(self) -> Optional[int]: """ Last sign in timestamp in milliseconds since the epoch. Returns: @@ -73,7 +100,7 @@ def last_sign_in_timestamp(self): return self._last_sign_in_timestamp @property - def last_refresh_timestamp(self): + def last_refresh_timestamp(self) -> Optional[int]: """The time at which the user was last active (ID token refreshed). Returns: @@ -90,32 +117,32 @@ class UserInfo: """ @property - def uid(self): + def uid(self) -> str: """Returns the user ID of this user.""" raise NotImplementedError @property - def display_name(self): + def display_name(self) -> Optional[str]: """Returns the display name of this user.""" raise NotImplementedError @property - def email(self): + def email(self) -> Optional[str]: """Returns the email address associated with this user.""" raise NotImplementedError @property - def phone_number(self): + def phone_number(self) -> Optional[str]: """Returns the phone number associated with this user.""" raise NotImplementedError @property - def photo_url(self): + def photo_url(self) -> Optional[str]: """Returns the photo URL of this user.""" raise NotImplementedError @property - def provider_id(self): + def provider_id(self) -> str: """Returns the ID of the identity provider. This can be a short domain name (e.g. google.com), or the identity of an OpenID @@ -127,8 +154,7 @@ def provider_id(self): class UserRecord(UserInfo): """Contains metadata associated with a Firebase user account.""" - def __init__(self, data): - super().__init__() + def __init__(self, data: dict[str, Any]) -> None: if not isinstance(data, dict): raise ValueError(f'Invalid data argument: {data}. Must be a dictionary.') if not data.get('localId'): @@ -136,16 +162,16 @@ def __init__(self, data): self._data = data @property - def uid(self): + def uid(self) -> str: """Returns the user ID of this user. Returns: string: A user ID string. This value is never None or empty. """ - return self._data.get('localId') + return self._data['localId'] @property - def display_name(self): + def display_name(self) -> Optional[str]: """Returns the display name of this user. Returns: @@ -154,7 +180,7 @@ def display_name(self): return self._data.get('displayName') @property - def email(self): + def email(self) -> Optional[str]: """Returns the email address associated with this user. Returns: @@ -163,7 +189,7 @@ def email(self): return self._data.get('email') @property - def phone_number(self): + def phone_number(self) -> Optional[str]: """Returns the phone number associated with this user. Returns: @@ -172,7 +198,7 @@ def phone_number(self): return self._data.get('phoneNumber') @property - def photo_url(self): + def photo_url(self) -> Optional[str]: """Returns the photo URL of this user. Returns: @@ -181,7 +207,7 @@ def photo_url(self): return self._data.get('photoUrl') @property - def provider_id(self): + def provider_id(self) -> str: """Returns the provider ID of this user. Returns: @@ -190,7 +216,7 @@ def provider_id(self): return 'firebase' @property - def email_verified(self): + def email_verified(self) -> bool: """Returns whether the email address of this user has been verified. Returns: @@ -199,7 +225,7 @@ def email_verified(self): return bool(self._data.get('emailVerified')) @property - def disabled(self): + def disabled(self) -> bool: """Returns whether this user account is disabled. Returns: @@ -208,7 +234,7 @@ def disabled(self): return bool(self._data.get('disabled')) @property - def tokens_valid_after_timestamp(self): + def tokens_valid_after_timestamp(self) -> int: """Returns the time, in milliseconds since the epoch, before which tokens are invalid. Note: this is truncated to 1 second accuracy. @@ -223,16 +249,17 @@ def tokens_valid_after_timestamp(self): return 0 @property - def user_metadata(self): + def user_metadata(self) -> UserMetadata: """Returns additional metadata associated with this user. Returns: UserMetadata: A UserMetadata instance. Does not return None. """ - def _int_or_none(key): + def _int_or_none(key: str) -> Optional[int]: if key in self._data: return int(self._data[key]) return None + last_refresh_at_millis = None last_refresh_at_rfc3339 = self._data.get('lastRefreshAt', None) if last_refresh_at_rfc3339: @@ -241,7 +268,7 @@ def _int_or_none(key): _int_or_none('createdAt'), _int_or_none('lastLoginAt'), last_refresh_at_millis) @property - def provider_data(self): + def provider_data(self) -> list['ProviderUserInfo']: """Returns a list of UserInfo instances. Each object represents an identity from an identity provider that is linked to this user. @@ -253,7 +280,7 @@ def provider_data(self): return [ProviderUserInfo(entry) for entry in providers] @property - def custom_claims(self): + def custom_claims(self) -> Optional[dict[str, Any]]: """Returns any custom claims set on this user account. Returns: @@ -267,7 +294,7 @@ def custom_claims(self): return None @property - def tenant_id(self): + def tenant_id(self) -> Optional[str]: """Returns the tenant ID of this user. Returns: @@ -280,7 +307,7 @@ class ExportedUserRecord(UserRecord): """Contains metadata associated with a user including password hash and salt.""" @property - def password_hash(self): + def password_hash(self) -> Optional[str]: """The user's password hash as a base64-encoded string. If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this @@ -299,7 +326,7 @@ def password_hash(self): return password_hash @property - def password_salt(self): + def password_salt(self) -> Optional[str]: """The user's password salt as a base64-encoded string. If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this @@ -314,7 +341,11 @@ def password_salt(self): class GetUsersResult: """Represents the result of the ``auth.get_users()`` API.""" - def __init__(self, users, not_found): + def __init__( + self, + users: list[UserRecord], + not_found: list[_user_identifier.UserIdentifier], + ) -> None: """Constructs a `GetUsersResult` object. Args: @@ -325,7 +356,7 @@ def __init__(self, users, not_found): self._not_found = not_found @property - def users(self): + def users(self) -> list[UserRecord]: """Set of `UserRecord` instances, corresponding to the set of users that were requested. Only users that were found are listed here. The result set is unordered. @@ -333,7 +364,7 @@ def users(self): return self._users @property - def not_found(self): + def not_found(self) -> list[_user_identifier.UserIdentifier]: """Set of `UserIdentifier` instances that were requested, but not found. """ @@ -348,27 +379,38 @@ class ListUsersPage: through all users in the Firebase project starting from this page. """ - def __init__(self, download, page_token, max_results): + def __init__( + self, + download: Callable[[Optional[str], int], dict[str, Any]], + page_token: Optional[str], + max_results: int, + ) -> None: self._download = download self._max_results = max_results self._current = download(page_token, max_results) @property - def users(self): + def users(self) -> list[ExportedUserRecord]: """A list of ``ExportedUserRecord`` instances available in this page.""" - return [ExportedUserRecord(user) for user in self._current.get('users', [])] + return [ + ExportedUserRecord(user) + for user in cast( + list[dict[str, Any]], + self._current.get('users', []), + ) + ] @property - def next_page_token(self): + def next_page_token(self) -> str: """Page token string for the next page (empty string indicates no more pages).""" - return self._current.get('nextPageToken', '') + return cast(str, self._current.get('nextPageToken', '')) @property - def has_next_page(self): + def has_next_page(self) -> bool: """A boolean indicating whether more pages are available.""" return bool(self.next_page_token) - def get_next_page(self): + def get_next_page(self) -> Optional['ListUsersPage']: """Retrieves the next page of user accounts, if available. Returns: @@ -378,7 +420,7 @@ def get_next_page(self): return ListUsersPage(self._download, self.next_page_token, self._max_results) return None - def iterate_all(self): + def iterate_all(self) -> '_UserIterator': """Retrieves an iterator for user accounts. Returned iterator will iterate through all the user accounts in the Firebase project @@ -394,7 +436,7 @@ def iterate_all(self): class DeleteUsersResult: """Represents the result of the ``auth.delete_users()`` API.""" - def __init__(self, result, total): + def __init__(self, result: 'BatchDeleteAccountsResponse', total: int) -> None: """Constructs a `DeleteUsersResult` object. Args: @@ -408,7 +450,7 @@ def __init__(self, result, total): self._errors = errors @property - def success_count(self): + def success_count(self) -> int: """Returns the number of users that were deleted successfully (possibly zero). @@ -418,14 +460,14 @@ def success_count(self): return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: """Returns the number of users that failed to be deleted (possibly zero). """ return self._failure_count @property - def errors(self): + def errors(self) -> list[_user_import.ErrorInfo]: """A list of `auth.ErrorInfo` instances describing the errors that were encountered during the deletion. Length of this list is equal to `failure_count`. @@ -436,7 +478,7 @@ def errors(self): class BatchDeleteAccountsResponse: """Represents the results of a `delete_users()` call.""" - def __init__(self, errors=None): + def __init__(self, errors: Optional[list[dict[str, Any]]] = None) -> None: """Constructs a `BatchDeleteAccountsResponse` instance, corresponding to the JSON representing the `BatchDeleteAccountsResponse` proto. @@ -445,14 +487,13 @@ def __init__(self, errors=None): `ErrorInfo` instance as returned by the server. `None` implies an empty list. """ - self.errors = [ErrorInfo(err) for err in errors] if errors else [] + self.errors = [_user_import.ErrorInfo(err) for err in errors] if errors else [] class ProviderUserInfo(UserInfo): """Contains metadata regarding how a user is known by a particular identity provider.""" - def __init__(self, data): - super().__init__() + def __init__(self, data: dict[str, Any]) -> None: if not isinstance(data, dict): raise ValueError(f'Invalid data argument: {data}. Must be a dictionary.') if not data.get('rawId'): @@ -460,28 +501,29 @@ def __init__(self, data): self._data = data @property - def uid(self): - return self._data.get('rawId') + def uid(self) -> str: + return self._data['rawId'] @property - def display_name(self): + def display_name(self) -> Optional[str]: return self._data.get('displayName') @property - def email(self): + def email(self) -> Optional[str]: return self._data.get('email') @property - def phone_number(self): + def phone_number(self) -> Optional[str]: return self._data.get('phoneNumber') @property - def photo_url(self): + def photo_url(self) -> Optional[str]: return self._data.get('photoUrl') @property - def provider_id(self): - return self._data.get('providerId') + def provider_id(self) -> str: + # possible issue: can providerId be `None`? + return self._data.get('providerId') # pyright: ignore[reportReturnType] class ActionCodeSettings: @@ -489,8 +531,16 @@ class ActionCodeSettings: Used when invoking the email action link generation APIs. """ - def __init__(self, url, handle_code_in_app=None, dynamic_link_domain=None, ios_bundle_id=None, - android_package_name=None, android_install_app=None, android_minimum_version=None): + def __init__( + self, + url: str, + handle_code_in_app: Optional[bool] = None, + dynamic_link_domain: Optional[str] = None, + ios_bundle_id: Optional[str] = None, + android_package_name: Optional[str] = None, + android_install_app: Optional[bool] = None, + android_minimum_version: Optional[str] = None, + ) -> None: self.url = url self.handle_code_in_app = handle_code_in_app self.dynamic_link_domain = dynamic_link_domain @@ -500,7 +550,7 @@ def __init__(self, url, handle_code_in_app=None, dynamic_link_domain=None, ios_b self.android_minimum_version = android_minimum_version -def encode_action_code_settings(settings): +def encode_action_code_settings(settings: ActionCodeSettings) -> dict[str, Any]: """ Validates the provided action code settings for email link generation and populates the REST api parameters. @@ -508,7 +558,7 @@ def encode_action_code_settings(settings): returns - dict of parameters to be passed for link gereration. """ - parameters = {} + parameters: dict[str, Any] = {} # url if not settings.url: raise ValueError("Dynamic action links url is mandatory") @@ -574,14 +624,20 @@ class UserManager: ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' - def __init__(self, http_client, project_id, tenant_id=None, url_override=None): + def __init__( + self, + http_client: _http_client.HttpClient[dict[str, Any]], + project_id: str, + tenant_id: Optional[str] = None, + url_override: Optional[str] = None, + ) -> None: self.http_client = http_client url_prefix = url_override or self.ID_TOOLKIT_URL self.base_url = f'{url_prefix}/projects/{project_id}' if tenant_id: self.base_url += f'/tenants/{tenant_id}' - def get_user(self, **kwargs): + def get_user(self, **kwargs: Any) -> dict[str, Any]: """Gets the user data corresponding to the provided key.""" if 'uid' in kwargs: key, key_type = kwargs.pop('uid'), 'user ID' @@ -600,9 +656,12 @@ def get_user(self, **kwargs): raise _auth_utils.UserNotFoundError( f'No user record found for the provided {key_type}: {key}.', http_response=http_resp) - return body['users'][0] + return cast(list[dict[str, Any]], body['users'])[0] - def get_users(self, identifiers): + def get_users( + self, + identifiers: Sequence[_user_identifier.UserIdentifier], + ) -> list[dict[str, Any]]: """Looks up multiple users by their identifiers (uid, email, etc.) Args: @@ -624,7 +683,7 @@ def get_users(self, identifiers): if len(identifiers) > 100: raise ValueError('`identifiers` parameter must have <= 100 entries.') - payload = defaultdict(list) + payload: dict[str, list[Any]] = collections.defaultdict(list) for identifier in identifiers: if isinstance(identifier, _user_identifier.UidIdentifier): payload['localId'].append(identifier.uid) @@ -646,9 +705,13 @@ def get_users(self, identifiers): if not http_resp.ok: raise _auth_utils.UnexpectedResponseError( 'Failed to get users.', http_response=http_resp) - return body.get('users', []) + return cast(list[dict[str, Any]], body.get('users', [])) - def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): + def list_users( + self, + page_token: Optional[str] = None, + max_results: int = MAX_LIST_USERS_RESULTS, + ) -> dict[str, Any]: """Retrieves a batch of users.""" if page_token is not None: if not isinstance(page_token, str) or not page_token: @@ -659,14 +722,23 @@ def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): raise ValueError( f'Max results must be a positive integer less than {MAX_LIST_USERS_RESULTS}.') - payload = {'maxResults': max_results} + payload: dict[str, Any] = {'maxResults': max_results} if page_token: payload['nextPageToken'] = page_token body, _ = self._make_request('get', '/accounts:batchGet', params=payload) return body - def create_user(self, uid=None, display_name=None, email=None, phone_number=None, - photo_url=None, password=None, disabled=None, email_verified=None): + def create_user( + self, + uid: Optional[str] = None, + display_name: Optional[str] = None, + email: Optional[str] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + password: Optional[str] = None, + disabled: Optional[bool] = None, + email_verified: Optional[bool] = None, + ) -> str: """Creates a new user account with the specified properties.""" payload = { 'localId': _auth_utils.validate_uid(uid), @@ -683,13 +755,24 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None if not body or not body.get('localId'): raise _auth_utils.UnexpectedResponseError( 'Failed to create new user.', http_response=http_resp) - return body.get('localId') - - def update_user(self, uid, display_name=None, email=None, phone_number=None, - photo_url=None, password=None, disabled=None, email_verified=None, - valid_since=None, custom_claims=None, providers_to_delete=None): + return cast(str, body['localId']) + + def update_user( + self, + uid: str, + display_name: Optional[str] = None, + email: Optional[str] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + password: Optional[str] = None, + disabled: Optional[bool] = None, + email_verified: Optional[bool] = None, + valid_since: Optional['ConvertibleToInt'] = None, + custom_claims: Optional[Union[dict[str, Any], str]] = None, + providers_to_delete: Optional[list[str]] = None, + ) -> str: """Updates an existing user account with the specified properties""" - payload = { + payload: dict[str, Any] = { 'localId': _auth_utils.validate_uid(uid, required=True), 'email': _auth_utils.validate_email(email), 'password': _auth_utils.validate_password(password), @@ -698,7 +781,7 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, 'disableUser': bool(disabled) if disabled is not None else None, } - remove = [] + remove: list[str] = [] remove_provider = _auth_utils.validate_provider_ids(providers_to_delete) if display_name is not None: if display_name is DELETE_ATTRIBUTE: @@ -734,9 +817,9 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, if not body or not body.get('localId'): raise _auth_utils.UnexpectedResponseError( f'Failed to update user: {uid}.', http_response=http_resp) - return body.get('localId') + return cast(str, body['localId']) - def delete_user(self, uid): + def delete_user(self, uid: str) -> None: """Deletes the user identified by the specified user ID.""" _auth_utils.validate_uid(uid, required=True) body, http_resp = self._make_request('post', '/accounts:delete', json={'localId' : uid}) @@ -744,7 +827,7 @@ def delete_user(self, uid): raise _auth_utils.UnexpectedResponseError( f'Failed to delete user: {uid}.', http_response=http_resp) - def delete_users(self, uids, force_delete=False): + def delete_users(self, uids: Sequence[str], force_delete: bool = False) -> BatchDeleteAccountsResponse: """Deletes the users identified by the specified user ids. Args: @@ -778,9 +861,14 @@ def delete_users(self, uids, force_delete=False): raise _auth_utils.UnexpectedResponseError( 'Unexpected response from server while attempting to delete users.', http_response=http_resp) - return BatchDeleteAccountsResponse(body.get('errors', [])) - - def import_users(self, users, hash_alg=None): + return BatchDeleteAccountsResponse(cast(list[dict[str, Any]], + body.get('errors', []))) + + def import_users( + self, + users: Sequence[_user_import.ImportUserRecord], + hash_alg: Optional[_user_import.UserImportHash] = None, + ) -> dict[str, Any]: """Imports the given list of users to Firebase Auth.""" try: if not users or len(users) > MAX_IMPORT_USERS_SIZE: @@ -803,7 +891,12 @@ def import_users(self, users, hash_alg=None): 'Failed to import users.', http_response=http_resp) return body - def generate_email_action_link(self, action_type, email, action_code_settings=None): + def generate_email_action_link( + self, + action_type: Literal['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET'], + email: Optional[str], + action_code_settings: Optional[ActionCodeSettings] = None, + ) -> str: """Fetches the email action links for types Args: @@ -833,9 +926,14 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No if not body or not body.get('oobLink'): raise _auth_utils.UnexpectedResponseError( 'Failed to generate email action link.', http_response=http_resp) - return body.get('oobLink') - - def _make_request(self, method, path, **kwargs): + return cast(str, body['oobLink']) + + def _make_request( + self, + method: str, + path: str, + **kwargs: Any, + ) -> tuple[dict[str, Any], requests.Response]: url = f'{self.base_url}{path}' try: return self.http_client.body_and_response(method, url, **kwargs) @@ -843,8 +941,7 @@ def _make_request(self, method, path, **kwargs): raise _auth_utils.handle_auth_backend_error(error) -class _UserIterator(_auth_utils.PageIterator): - +class _UserIterator(_auth_utils.PageIterator[ListUsersPage]): @property - def items(self): - return self._current_page.users + def items(self) -> list[ExportedUserRecord]: + return self._current_page.users if self._current_page else [] diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index d0aca884b..2d9e82aa5 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -15,16 +15,30 @@ """Internal utilities common to all modules.""" import json +from collections.abc import Callable +from typing import Any, Optional, TypeVar, Union, cast from platform import python_version -from typing import Callable, Optional -import google.auth -import requests +import google.auth.credentials +import google.auth.transport import httpx +import requests import firebase_admin from firebase_admin import exceptions +__all__ = ( + 'EmulatorAdminCredentials', + 'get_app_service', + 'get_metrics_header', + 'handle_httpx_error', + 'handle_operation_error', + 'handle_platform_error_from_httpx', + 'handle_platform_error_from_requests', + 'handle_requests_error', +) + +_T = TypeVar('_T') _ERROR_CODE_TO_EXCEPTION_TYPE = { exceptions.INVALID_ARGUMENT: exceptions.InvalidArgumentError, @@ -46,7 +60,7 @@ } -_HTTP_STATUS_TO_ERROR_CODE = { +_HTTP_STATUS_TO_ERROR_CODE: dict[int, str] = { 400: exceptions.INVALID_ARGUMENT, 401: exceptions.UNAUTHENTICATED, 403: exceptions.PERMISSION_DENIED, @@ -60,7 +74,7 @@ # See https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto -_RPC_CODE_TO_ERROR_CODE = { +_RPC_CODE_TO_ERROR_CODE: dict[int, str] = { 1: exceptions.CANCELLED, 2: exceptions.UNKNOWN, 3: exceptions.INVALID_ARGUMENT, @@ -78,10 +92,12 @@ 16: exceptions.UNAUTHENTICATED, } -def get_metrics_header(): + +def get_metrics_header() -> str: return f'gl-python/{python_version()} fire-admin/{firebase_admin.__version__}' -def _get_initialized_app(app): + +def _get_initialized_app(app: Optional[firebase_admin.App]) -> firebase_admin.App: """Returns a reference to an initialized App instance.""" if app is None: return firebase_admin.get_app() @@ -98,13 +114,22 @@ def _get_initialized_app(app): f'"{type(app)}".') - -def get_app_service(app, name, initializer): +def get_app_service( + app: Optional[firebase_admin.App], + name: str, + initializer: Callable[[firebase_admin.App], _T], +) -> _T: app = _get_initialized_app(app) return app._get_service(name, initializer) # pylint: disable=protected-access -def handle_platform_error_from_requests(error, handle_func=None): +def handle_platform_error_from_requests( + error: requests.RequestException, + handle_func: Optional[Callable[ + [requests.RequestException, str, dict[str, Any]], + Optional[exceptions.FirebaseError], + ]] = None, +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given requests error. This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. @@ -131,9 +156,10 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) + def handle_platform_error_from_httpx( - error: httpx.HTTPError, - handle_func: Optional[Callable[..., Optional[exceptions.FirebaseError]]] = None + error: httpx.HTTPError, + handle_func: Optional[Callable[..., Optional[exceptions.FirebaseError]]] = None, ) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given httpx error. @@ -162,7 +188,7 @@ def handle_platform_error_from_httpx( return handle_httpx_error(error) -def handle_operation_error(error): +def handle_operation_error(error: Union[dict[str, Any], Exception]) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given operation error. Args: @@ -173,17 +199,22 @@ def handle_operation_error(error): """ if not isinstance(error, dict): return exceptions.UnknownError( - message=f'Unknown error while making a remote service call: {error}', + message='Unknown error while making a remote service call: {0}'.format(error), cause=error) - rpc_code = error.get('code') - message = error.get('message') + rpc_code = error.get('code', 0) + # possible issue: needs be str | None ? + message = cast(str, error.get('message')) error_code = _rpc_code_to_error_code(rpc_code) err_type = _error_code_to_exception_type(error_code) - return err_type(message=message) + return err_type(message, None, None) -def _handle_func_requests(error, message, error_dict): +def _handle_func_requests( + error: requests.RequestException, + message: str, + error_dict: dict[str, Any], +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given GCP error. Args: @@ -198,7 +229,11 @@ def _handle_func_requests(error, message, error_dict): return handle_requests_error(error, message, code) -def handle_requests_error(error, message=None, code=None): +def handle_requests_error( + error: requests.RequestException, + message: Optional[str] = None, + code: Optional[str] = None, +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given requests error. This method is agnostic of the remote service that produced the error, whether it is a GCP @@ -235,9 +270,14 @@ def handle_requests_error(error, message=None, code=None): message = str(error) err_type = _error_code_to_exception_type(code) - return err_type(message=message, cause=error, http_response=error.response) + return err_type(message, error, error.response) -def _handle_func_httpx(error: httpx.HTTPError, message, error_dict) -> exceptions.FirebaseError: + +def _handle_func_httpx( + error: httpx.HTTPError, + message: str, + error_dict: dict[str, Any], +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given GCP error. Args: @@ -252,7 +292,11 @@ def _handle_func_httpx(error: httpx.HTTPError, message, error_dict) -> exception return handle_httpx_error(error, message, code) -def handle_httpx_error(error: httpx.HTTPError, message=None, code=None) -> exceptions.FirebaseError: +def handle_httpx_error( + error: Exception, + message: Optional[str] = None, + code: Optional[str] = None, +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given httpx error. This method is agnostic of the remote service that produced the error, whether it is a GCP @@ -286,26 +330,34 @@ def handle_httpx_error(error: httpx.HTTPError, message=None, code=None) -> excep message = str(error) err_type = _error_code_to_exception_type(code) - return err_type(message=message, cause=error, http_response=error.response) + return err_type(message, error, error.response) return exceptions.UnknownError( message=f'Unknown error while making a remote service call: {error}', cause=error) -def _http_status_to_error_code(status): + +def _http_status_to_error_code(status: int) -> str: """Maps an HTTP status to a platform error code.""" return _HTTP_STATUS_TO_ERROR_CODE.get(status, exceptions.UNKNOWN) -def _rpc_code_to_error_code(rpc_code): + +def _rpc_code_to_error_code(rpc_code: int) -> str: """Maps an RPC code to a platform error code.""" return _RPC_CODE_TO_ERROR_CODE.get(rpc_code, exceptions.UNKNOWN) -def _error_code_to_exception_type(code): + +def _error_code_to_exception_type( + code: str, +) -> Callable[ + [str, Optional[Exception], Optional[Union[httpx.Response, requests.Response]]], + exceptions.FirebaseError +]: """Maps a platform error code to an exception type.""" return _ERROR_CODE_TO_EXCEPTION_TYPE.get(code, exceptions.UnknownError) -def _parse_platform_error(content, status_code): +def _parse_platform_error(content: str, status_code: int) -> tuple[dict[str, Any], str]: """Parses an HTTP error response from a Google Cloud Platform API and extracts the error code and message fields. @@ -316,15 +368,15 @@ def _parse_platform_error(content, status_code): Returns: tuple: A tuple containing error code and message. """ - data = {} + data: dict[str, Any] = {} try: parsed_body = json.loads(content) if isinstance(parsed_body, dict): - data = parsed_body + data = cast(dict[str, Any], parsed_body) except ValueError: pass - error_dict = data.get('error', {}) + error_dict: dict[str, Any] = data.get('error', {}) msg = error_dict.get('message') if not msg: msg = f'Unexpected HTTP response with status: {status_code}; body: {content}' @@ -340,9 +392,9 @@ class EmulatorAdminCredentials(google.auth.credentials.Credentials): This is used instead of user-supplied credentials or ADC. It will silently do nothing when asked to refresh credentials. """ - def __init__(self): + def __init__(self) -> None: google.auth.credentials.Credentials.__init__(self) self.token = 'owner' - def refresh(self, request): + def refresh(self, request: google.auth.transport.Request) -> None: pass diff --git a/firebase_admin/app_check.py b/firebase_admin/app_check.py index 40d857f4e..dfc011bcf 100644 --- a/firebase_admin/app_check.py +++ b/firebase_admin/app_check.py @@ -14,18 +14,27 @@ """Firebase App Check module.""" -from typing import Any, Dict +from typing import Any, Optional, cast + import jwt -from jwt import PyJWKClient, ExpiredSignatureError, InvalidTokenError, DecodeError -from jwt import InvalidAudienceError, InvalidIssuerError, InvalidSignatureError + +import firebase_admin from firebase_admin import _utils +__all__ = ('verify_token',) + + _APP_CHECK_ATTRIBUTE = '_app_check' -def _get_app_check_service(app) -> Any: + +def _get_app_check_service(app: Optional[firebase_admin.App]) -> '_AppCheckService': return _utils.get_app_service(app, _APP_CHECK_ATTRIBUTE, _AppCheckService) -def verify_token(token: str, app=None) -> Dict[str, Any]: + +def verify_token( + token: str, + app: Optional[firebase_admin.App] = None, +) -> dict[str, Any]: """Verifies a Firebase App Check token. Args: @@ -42,20 +51,18 @@ def verify_token(token: str, app=None) -> Dict[str, Any]: """ return _get_app_check_service(app).verify_token(token) + class _AppCheckService: """Service class that implements Firebase App Check functionality.""" _APP_CHECK_ISSUER = 'https://firebaseappcheck.googleapis.com/' _JWKS_URL = 'https://firebaseappcheck.googleapis.com/v1/jwks' - _project_id = None - _scoped_project_id = None - _jwks_client = None _APP_CHECK_HEADERS = { 'x-goog-api-client': _utils.get_metrics_header(), } - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: # Validate and store the project_id to validate the JWT claims self._project_id = app.project_id if not self._project_id: @@ -64,13 +71,12 @@ def __init__(self, app): 'service. Either set the projectId option, use service ' 'account credentials, or set the ' 'GOOGLE_CLOUD_PROJECT environment variable.') - self._scoped_project_id = 'projects/' + app.project_id + self._scoped_project_id = 'projects/' + self._project_id # Default lifespan is 300 seconds (5 minutes) so we change it to 21600 seconds (6 hours). - self._jwks_client = PyJWKClient( + self._jwks_client = jwt.PyJWKClient( self._JWKS_URL, lifespan=21600, headers=self._APP_CHECK_HEADERS) - - def verify_token(self, token: str) -> Dict[str, Any]: + def verify_token(self, token: str) -> dict[str, Any]: """Verifies a Firebase App Check token.""" _Validators.check_string("app check token", token) @@ -81,7 +87,7 @@ def verify_token(self, token: str) -> Dict[str, Any]: signing_key = self._jwks_client.get_signing_key_from_jwt(token) self._has_valid_token_headers(jwt.get_unverified_header(token)) verified_claims = self._decode_and_verify(token, signing_key.key) - except (InvalidTokenError, DecodeError) as exception: + except (jwt.InvalidTokenError, jwt.DecodeError) as exception: raise ValueError( f'Verifying App Check token failed. Error: {exception}' ) from exception @@ -89,7 +95,7 @@ def verify_token(self, token: str) -> Dict[str, Any]: verified_claims['app_id'] = verified_claims.get('sub') return verified_claims - def _has_valid_token_headers(self, headers: Any) -> None: + def _has_valid_token_headers(self, headers: dict[str, Any]) -> None: """Checks whether the token has valid headers for App Check.""" # Ensure the token's header has type JWT if headers.get('typ') != 'JWT': @@ -102,9 +108,9 @@ def _has_valid_token_headers(self, headers: Any) -> None: f'Expected RS256 but got {algorithm}.' ) - def _decode_and_verify(self, token: str, signing_key: str): + def _decode_and_verify(self, token: str, signing_key: str) -> dict[str, Any]: """Decodes and verifies the token from App Check.""" - payload = {} + payload: dict[str, Any] = {} try: payload = jwt.decode( token, @@ -112,25 +118,25 @@ def _decode_and_verify(self, token: str, signing_key: str): algorithms=["RS256"], audience=self._scoped_project_id ) - except InvalidSignatureError as exception: + except jwt.InvalidSignatureError as exception: raise ValueError( 'The provided App Check token has an invalid signature.' ) from exception - except InvalidAudienceError as exception: + except jwt.InvalidAudienceError as exception: raise ValueError( 'The provided App Check token has an incorrect "aud" (audience) claim. ' f'Expected payload to include {self._scoped_project_id}.' ) from exception - except InvalidIssuerError as exception: + except jwt.InvalidIssuerError as exception: raise ValueError( 'The provided App Check token has an incorrect "iss" (issuer) claim. ' f'Expected claim to include {self._APP_CHECK_ISSUER}' ) from exception - except ExpiredSignatureError as exception: + except jwt.ExpiredSignatureError as exception: raise ValueError( 'The provided App Check token has expired.' ) from exception - except InvalidTokenError as exception: + except jwt.InvalidTokenError as exception: raise ValueError( f'Decoding App Check token failed. Error: {exception}' ) from exception @@ -138,7 +144,7 @@ def _decode_and_verify(self, token: str, signing_key: str): audience = payload.get('aud') if not isinstance(audience, list) or self._scoped_project_id not in audience: raise ValueError('Firebase App Check token has incorrect "aud" (audience) claim.') - if not payload.get('iss').startswith(self._APP_CHECK_ISSUER): + if not cast(str, payload['iss']).startswith(self._APP_CHECK_ISSUER): raise ValueError('Token does not contain the correct "iss" (issuer).') _Validators.check_string( 'The provided App Check token "sub" (subject) claim', @@ -146,6 +152,7 @@ def _decode_and_verify(self, token: str, signing_key: str): return payload + class _Validators: """A collection of data validation utilities. @@ -153,7 +160,7 @@ class _Validators: """ @classmethod - def check_string(cls, label: str, value: Any): + def check_string(cls, label: str, value: Any) -> None: """Checks if the given value is a string.""" if value is None: raise ValueError(f'{label} "{value}" must be a non-empty string.') diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index ced143112..dfe9244cc 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -19,6 +19,11 @@ creating and managing user accounts in Firebase projects. """ +import datetime +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Optional, Union + +import firebase_admin from firebase_admin import _auth_client from firebase_admin import _auth_providers from firebase_admin import _auth_utils @@ -28,11 +33,10 @@ from firebase_admin import _user_mgt from firebase_admin import _utils +if TYPE_CHECKING: + from _typeshed import ConvertibleToInt -_AUTH_ATTRIBUTE = '_auth' - - -__all__ = [ +__all__ = ( 'ActionCodeSettings', 'CertificateFetchError', 'Client', @@ -107,7 +111,9 @@ 'update_user', 'verify_id_token', 'verify_session_cookie', -] +) + +_AUTH_ATTRIBUTE = '_auth' ActionCodeSettings = _user_mgt.ActionCodeSettings CertificateFetchError = _token_gen.CertificateFetchError @@ -156,7 +162,7 @@ ProviderIdentifier = _user_identifier.ProviderIdentifier -def _get_client(app): +def _get_client(app: Optional[firebase_admin.App]) -> Client: """Returns a client instance for an App. If the App already has a client associated with it, simply returns @@ -175,7 +181,11 @@ def _get_client(app): return _utils.get_app_service(app, _AUTH_ATTRIBUTE, Client) -def create_custom_token(uid, developer_claims=None, app=None): +def create_custom_token( + uid: str, + developer_claims: Optional[dict[str, Any]] = None, + app: Optional[firebase_admin.App] = None, +) -> bytes: """Builds and signs a Firebase custom auth token. Args: @@ -195,7 +205,12 @@ def create_custom_token(uid, developer_claims=None, app=None): return client.create_custom_token(uid, developer_claims) -def verify_id_token(id_token, app=None, check_revoked=False, clock_skew_seconds=0): +def verify_id_token( + id_token: Union[bytes, str], + app: Optional[firebase_admin.App] = None, + check_revoked: bool = False, + clock_skew_seconds: int = 0, +) -> dict[str, Any]: """Verifies the signature and data for the provided JWT. Accepts a signed token string, verifies that it is current, and issued @@ -226,7 +241,11 @@ def verify_id_token(id_token, app=None, check_revoked=False, clock_skew_seconds= id_token, check_revoked=check_revoked, clock_skew_seconds=clock_skew_seconds) -def create_session_cookie(id_token, expires_in, app=None): +def create_session_cookie( + id_token: Union[bytes, str], + expires_in: Union[datetime.timedelta, int], + app: Optional[firebase_admin.App] = None, +) -> str: """Creates a new Firebase session cookie from the given ID token and options. The returned JWT can be set as a server-side session cookie with a custom cookie policy. @@ -249,7 +268,12 @@ def create_session_cookie(id_token, expires_in, app=None): return client._token_generator.create_session_cookie(id_token, expires_in) -def verify_session_cookie(session_cookie, check_revoked=False, app=None, clock_skew_seconds=0): +def verify_session_cookie( + session_cookie: Union[bytes, str], + check_revoked: bool = False, + app: Optional[firebase_admin.App] = None, + clock_skew_seconds: int = 0, +) -> dict[str, Any]: """Verifies a Firebase session cookie. Accepts a session cookie string, verifies that it is current, and issued @@ -285,7 +309,7 @@ def verify_session_cookie(session_cookie, check_revoked=False, app=None, clock_s return verified_claims -def revoke_refresh_tokens(uid, app=None): +def revoke_refresh_tokens(uid: str, app: Optional[firebase_admin.App] = None) -> None: """Revokes all refresh tokens for an existing user. This function updates the user's ``tokens_valid_after_timestamp`` to the current UTC @@ -309,7 +333,7 @@ def revoke_refresh_tokens(uid, app=None): client.revoke_refresh_tokens(uid) -def get_user(uid, app=None): +def get_user(uid: str, app: Optional[firebase_admin.App] = None) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified user ID. Args: @@ -328,7 +352,10 @@ def get_user(uid, app=None): return client.get_user(uid=uid) -def get_user_by_email(email, app=None): +def get_user_by_email( + email: str, + app: Optional[firebase_admin.App] = None, +) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified user email. Args: @@ -347,7 +374,10 @@ def get_user_by_email(email, app=None): return client.get_user_by_email(email=email) -def get_user_by_phone_number(phone_number, app=None): +def get_user_by_phone_number( + phone_number: str, + app: Optional[firebase_admin.App] = None, +) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified phone number. Args: @@ -366,7 +396,10 @@ def get_user_by_phone_number(phone_number, app=None): return client.get_user_by_phone_number(phone_number=phone_number) -def get_users(identifiers, app=None): +def get_users( + identifiers: Sequence[_user_identifier.UserIdentifier], + app: Optional[firebase_admin.App] = None, +) -> _user_mgt.GetUsersResult: """Gets the user data corresponding to the specified identifiers. There are no ordering guarantees; in particular, the nth entry in the @@ -394,7 +427,11 @@ def get_users(identifiers, app=None): return client.get_users(identifiers) -def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, app=None): +def list_users( + page_token: Optional[str] = None, + max_results: int = _user_mgt.MAX_LIST_USERS_RESULTS, + app: Optional[firebase_admin.App] = None, +) -> _user_mgt.ListUsersPage: """Retrieves a page of user accounts from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -420,7 +457,18 @@ def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, ap return client.list_users(page_token=page_token, max_results=max_results) -def create_user(**kwargs): # pylint: disable=differing-param-doc +def create_user( + uid: Optional[str] = None, + display_name: Optional[str] = None, + email: Optional[str] = None, + email_verified: Optional[bool] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + password: Optional[str] = None, + disabled: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, + **kwargs: Any, +) -> _user_mgt.UserRecord: # pylint: disable=differing-param-doc """Creates a new user account with the specified properties. Args: @@ -445,12 +493,28 @@ def create_user(**kwargs): # pylint: disable=differing-param-doc ValueError: If the specified user properties are invalid. FirebaseError: If an error occurs while creating the user account. """ - app = kwargs.pop('app', None) client = _get_client(app) - return client.create_user(**kwargs) - - -def update_user(uid, **kwargs): # pylint: disable=differing-param-doc + return client.create_user(uid=uid, display_name=display_name, email=email, + email_verified=email_verified, phone_number=phone_number, photo_url=photo_url, + password=password, disabled=disabled, **kwargs) + + +def update_user( + uid: str, + *, + display_name: Optional[str] = None, + email: Optional[str] = None, + email_verified: Optional[bool] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + password: Optional[str] = None, + disabled: Optional[bool] = None, + custom_claims: Optional[Union[dict[str, Any], str]] = None, + valid_since: Optional['ConvertibleToInt'] = None, + providers_to_delete: Optional[list[str]] = None, + app: Optional[firebase_admin.App] = None, + **kwargs: Any, +) -> _user_mgt.UserRecord: """Updates an existing user account with the specified properties. Args: @@ -473,6 +537,8 @@ def update_user(uid, **kwargs): # pylint: disable=differing-param-doc user account (optional). To remove all custom claims, pass ``auth.DELETE_ATTRIBUTE``. valid_since: An integer signifying the seconds since the epoch (optional). This field is set by ``revoke_refresh_tokens`` and it is discouraged to set this field directly. + providers_to_delete: The list of provider IDs to unlink, + eg: 'google.com', 'password', etc. app: An App instance (optional). Returns: @@ -482,12 +548,18 @@ def update_user(uid, **kwargs): # pylint: disable=differing-param-doc ValueError: If the specified user ID or properties are invalid. FirebaseError: If an error occurs while updating the user account. """ - app = kwargs.pop('app', None) client = _get_client(app) - return client.update_user(uid, **kwargs) + return client.update_user(uid, display_name=display_name, email=email, + phone_number=phone_number, photo_url=photo_url, password=password, disabled=disabled, + email_verified=email_verified, valid_since=valid_since, custom_claims=custom_claims, + providers_to_delete=providers_to_delete, **kwargs) -def set_custom_user_claims(uid, custom_claims, app=None): +def set_custom_user_claims( + uid: str, + custom_claims: Optional[Union[dict[str, Any], str]], + app: Optional[firebase_admin.App] = None, +) -> None: """Sets additional claims on an existing user account. Custom claims set via this function can be used to define user roles and privilege levels. @@ -511,7 +583,7 @@ def set_custom_user_claims(uid, custom_claims, app=None): client.set_custom_user_claims(uid, custom_claims=custom_claims) -def delete_user(uid, app=None): +def delete_user(uid: str, app: Optional[firebase_admin.App] = None) -> None: """Deletes the user identified by the specified user ID. Args: @@ -526,7 +598,10 @@ def delete_user(uid, app=None): client.delete_user(uid) -def delete_users(uids, app=None): +def delete_users( + uids: Sequence[str], + app: Optional[firebase_admin.App] = None, +) -> _user_mgt.DeleteUsersResult: """Deletes the users specified by the given identifiers. Deleting a non-existing user does not generate an error (the method is @@ -553,7 +628,11 @@ def delete_users(uids, app=None): return client.delete_users(uids) -def import_users(users, hash_alg=None, app=None): +def import_users( + users: Sequence[_user_import.ImportUserRecord], + hash_alg: Optional[_user_import.UserImportHash] = None, + app: Optional[firebase_admin.App] = None, +) -> _user_import.UserImportResult: """Imports the specified list of users into Firebase Auth. At most 1000 users can be imported at a time. This operation is optimized for bulk imports and @@ -579,7 +658,11 @@ def import_users(users, hash_alg=None, app=None): return client.import_users(users, hash_alg) -def generate_password_reset_link(email, action_code_settings=None, app=None): +def generate_password_reset_link( + email: Optional[str], + action_code_settings: Optional[_user_mgt.ActionCodeSettings] = None, + app: Optional[firebase_admin.App] = None, +) -> str: """Generates the out-of-band email action link for password reset flows for the specified email address. @@ -600,7 +683,11 @@ def generate_password_reset_link(email, action_code_settings=None, app=None): return client.generate_password_reset_link(email, action_code_settings=action_code_settings) -def generate_email_verification_link(email, action_code_settings=None, app=None): +def generate_email_verification_link( + email: Optional[str], + action_code_settings: Optional[_user_mgt.ActionCodeSettings] = None, + app: Optional[firebase_admin.App] = None, +) -> str: """Generates the out-of-band email action link for email verification flows for the specified email address. @@ -622,7 +709,11 @@ def generate_email_verification_link(email, action_code_settings=None, app=None) email, action_code_settings=action_code_settings) -def generate_sign_in_with_email_link(email, action_code_settings, app=None): +def generate_sign_in_with_email_link( + email: Optional[str], + action_code_settings: Optional[_user_mgt.ActionCodeSettings], + app: Optional[firebase_admin.App] = None, +) -> str: """Generates the out-of-band email action link for email link sign-in flows, using the action code settings provided. @@ -645,7 +736,10 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): email, action_code_settings=action_code_settings) -def get_oidc_provider_config(provider_id, app=None): +def get_oidc_provider_config( + provider_id: str, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers.OIDCProviderConfig: """Returns the ``OIDCProviderConfig`` with the given ID. Args: @@ -663,9 +757,18 @@ def get_oidc_provider_config(provider_id, app=None): client = _get_client(app) return client.get_oidc_provider_config(provider_id) + def create_oidc_provider_config( - provider_id, client_id, issuer, display_name=None, enabled=None, client_secret=None, - id_token_response_type=None, code_response_type=None, app=None): + provider_id: str, + client_id: str, + issuer: str, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + client_secret: Optional[str] = None, + id_token_response_type: Optional[bool] = None, + code_response_type: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers.OIDCProviderConfig: """Creates a new OIDC provider config from the given parameters. OIDC provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -705,8 +808,16 @@ def create_oidc_provider_config( def update_oidc_provider_config( - provider_id, client_id=None, issuer=None, display_name=None, enabled=None, - client_secret=None, id_token_response_type=None, code_response_type=None, app=None): + provider_id: str, + client_id: Optional[str] = None, + issuer: Optional[str] = None, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + client_secret: Optional[str] = None, + id_token_response_type: Optional[bool] = None, + code_response_type: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers.OIDCProviderConfig: """Updates an existing OIDC provider config with the given parameters. Args: @@ -717,16 +828,16 @@ def update_oidc_provider_config( Pass ``auth.DELETE_ATTRIBUTE`` to delete the current display name. enabled: A boolean indicating whether the provider configuration is enabled or disabled (optional). - app: An App instance (optional). client_secret: A string which sets the client secret for the new provider. This is required for the code flow. + id_token_response_type: A boolean which sets whether to enable the ID token response flow + for the new provider. By default, this is enabled if no response type is specified. + Having both the code and ID token response flows is currently not supported. code_response_type: A boolean which sets whether to enable the code response flow for the new provider. By default, this is not enabled if no response type is specified. A client secret must be set for this response type. Having both the code and ID token response flows is currently not supported. - id_token_response_type: A boolean which sets whether to enable the ID token response flow - for the new provider. By default, this is enabled if no response type is specified. - Having both the code and ID token response flows is currently not supported. + app: An App instance (optional). Returns: OIDCProviderConfig: The updated OIDC provider config instance. @@ -742,7 +853,10 @@ def update_oidc_provider_config( code_response_type=code_response_type) -def delete_oidc_provider_config(provider_id, app=None): +def delete_oidc_provider_config( + provider_id: str, + app: Optional[firebase_admin.App] = None, +) -> None: """Deletes the ``OIDCProviderConfig`` with the given ID. Args: @@ -759,7 +873,10 @@ def delete_oidc_provider_config(provider_id, app=None): def list_oidc_provider_configs( - page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + page_token: Optional[str] = None, + max_results: int = _auth_providers.MAX_LIST_CONFIGS_RESULTS, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers._ListOIDCProviderConfigsPage: """Retrieves a page of OIDC provider configs from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -786,7 +903,10 @@ def list_oidc_provider_configs( return client.list_oidc_provider_configs(page_token, max_results) -def get_saml_provider_config(provider_id, app=None): +def get_saml_provider_config( + provider_id: str, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers.SAMLProviderConfig: """Returns the ``SAMLProviderConfig`` with the given ID. Args: @@ -806,8 +926,16 @@ def get_saml_provider_config(provider_id, app=None): def create_saml_provider_config( - provider_id, idp_entity_id, sso_url, x509_certificates, rp_entity_id, callback_url, - display_name=None, enabled=None, app=None): + provider_id: str, + idp_entity_id: str, + sso_url: str, + x509_certificates: list[str], + rp_entity_id: str, + callback_url: str, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers.SAMLProviderConfig: """Creates a new SAML provider config from the given parameters. SAML provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -848,8 +976,16 @@ def create_saml_provider_config( def update_saml_provider_config( - provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, - rp_entity_id=None, callback_url=None, display_name=None, enabled=None, app=None): + provider_id: str, + idp_entity_id: Optional[str] = None, + sso_url: Optional[str] = None, + x509_certificates: Optional[list[str]] = None, + rp_entity_id: Optional[str] = None, + callback_url: Optional[str] = None, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers.SAMLProviderConfig: """Updates an existing SAML provider config with the given parameters. Args: @@ -880,7 +1016,10 @@ def update_saml_provider_config( callback_url=callback_url, display_name=display_name, enabled=enabled) -def delete_saml_provider_config(provider_id, app=None): +def delete_saml_provider_config( + provider_id: str, + app: Optional[firebase_admin.App] = None, +) -> None: """Deletes the ``SAMLProviderConfig`` with the given ID. Args: @@ -897,7 +1036,10 @@ def delete_saml_provider_config(provider_id, app=None): def list_saml_provider_configs( - page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + page_token: Optional[str] = None, + max_results: int = _auth_providers.MAX_LIST_CONFIGS_RESULTS, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers._ListSAMLProviderConfigsPage: """Retrieves a page of SAML provider configs from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 7117b71a9..9af328f52 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -13,16 +13,30 @@ # limitations under the License. """Firebase credentials module.""" -import collections +import datetime import json import pathlib +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast +from typing_extensions import TypeGuard import google.auth + from google.auth.credentials import Credentials as GoogleAuthCredentials from google.auth.transport import requests +from google.auth import crypt from google.oauth2 import credentials from google.oauth2 import service_account +if TYPE_CHECKING: + from _typeshed import StrPath + +__all__ = ( + 'AccessTokenInfo', + 'ApplicationDefault', + 'Base', + 'Certificate', + 'RefreshToken', +) _request = requests.Request() _scopes = [ @@ -34,18 +48,21 @@ 'https://www.googleapis.com/auth/userinfo.email' ] -AccessTokenInfo = collections.namedtuple('AccessTokenInfo', ['access_token', 'expiry']) -"""Data included in an OAuth2 access token. -Contains the access token string and the expiry time. The expirty time is exposed as a -``datetime`` value. -""" +class AccessTokenInfo(NamedTuple): + """Data included in an OAuth2 access token. + + Contains the access token string and the expiry time. The expirty time is exposed as a + ``datetime`` value. + """ + access_token: Any + expiry: Optional[datetime.datetime] class Base: """Provides OAuth2 access tokens for accessing Firebase services.""" - def get_access_token(self): + def get_access_token(self) -> AccessTokenInfo: """Fetches a Google OAuth2 access token using this credential instance. Returns: @@ -55,30 +72,31 @@ def get_access_token(self): google_cred.refresh(_request) return AccessTokenInfo(google_cred.token, google_cred.expiry) - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the Google credential instance used for authentication.""" raise NotImplementedError -class _ExternalCredentials(Base): + +class _ExternalCredentials(Base): # pyright: ignore[reportUnusedClass] """A wrapper for google.auth.credentials.Credentials typed credential instances""" - def __init__(self, credential: GoogleAuthCredentials): - super().__init__() + def __init__(self, credential: GoogleAuthCredentials) -> None: self._g_credential = credential - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the underlying Google Credential Returns: google.auth.credentials.Credentials: A Google Auth credential instance.""" return self._g_credential + class Certificate(Base): """A credential initialized from a JSON certificate keyfile.""" _CREDENTIAL_TYPE = 'service_account' - def __init__(self, cert): + def __init__(self, cert: Union['StrPath', dict[str, Any]]) -> None: """Initializes a credential from a Google service account certificate. Service account certificates can be downloaded as JSON files from the Firebase console. @@ -92,7 +110,6 @@ def __init__(self, cert): IOError: If the specified certificate file doesn't exist or cannot be read. ValueError: If the specified certificate is invalid. """ - super().__init__() if _is_file_path(cert): with open(cert, encoding='utf-8') as json_file: json_data = json.load(json_file) @@ -115,18 +132,18 @@ def __init__(self, cert): f'Failed to initialize a certificate credential. Caused by: "{error}"') from error @property - def project_id(self): + def project_id(self) -> Optional[str]: return self._g_credential.project_id @property - def signer(self): + def signer(self) -> crypt.Signer: return self._g_credential.signer @property - def service_account_email(self): + def service_account_email(self) -> str: return self._g_credential.service_account_email - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the underlying Google credential. Returns: @@ -137,16 +154,17 @@ def get_credential(self): class ApplicationDefault(Base): """A Google Application Default credential.""" - def __init__(self): + def __init__(self) -> None: """Creates an instance that will use Application Default credentials. The credentials will be lazily initialized when get_credential() or project_id() is called. See those methods for possible errors raised. """ - super().__init__() - self._g_credential = None # Will be lazily-loaded via _load_credential(). + # Will be lazily-loaded via _load_credential(). + self._g_credential: Optional[GoogleAuthCredentials] = None + self._project_id: Optional[str] - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the underlying Google credential. Raises: @@ -155,10 +173,10 @@ def get_credential(self): Returns: google.auth.credentials.Credentials: A Google Auth credential instance.""" self._load_credential() - return self._g_credential + return cast(GoogleAuthCredentials, self._g_credential) @property - def project_id(self): + def project_id(self) -> Optional[str]: """Returns the project_id from the underlying Google credential. Raises: @@ -169,16 +187,17 @@ def project_id(self): self._load_credential() return self._project_id - def _load_credential(self): + def _load_credential(self) -> None: if not self._g_credential: self._g_credential, self._project_id = google.auth.default(scopes=_scopes) + class RefreshToken(Base): """A credential initialized from an existing refresh token.""" _CREDENTIAL_TYPE = 'authorized_user' - def __init__(self, refresh_token): + def __init__(self, refresh_token: Union['StrPath', dict[str, Any]]) -> None: """Initializes a credential from a refresh token JSON file. The JSON must consist of client_id, client_secret and refresh_token fields. Refresh @@ -194,7 +213,6 @@ def __init__(self, refresh_token): IOError: If the specified file doesn't exist or cannot be read. ValueError: If the refresh token configuration is invalid. """ - super().__init__() if _is_file_path(refresh_token): with open(refresh_token, encoding='utf-8') as json_file: json_data = json.load(json_file) @@ -212,18 +230,18 @@ def __init__(self, refresh_token): self._g_credential = credentials.Credentials.from_authorized_user_info(json_data, _scopes) @property - def client_id(self): + def client_id(self) -> Optional[str]: return self._g_credential.client_id @property - def client_secret(self): + def client_secret(self) -> Optional[str]: return self._g_credential.client_secret @property - def refresh_token(self): + def refresh_token(self) -> Optional[str]: return self._g_credential.refresh_token - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the underlying Google credential. Returns: @@ -231,7 +249,7 @@ def get_credential(self): return self._g_credential -def _is_file_path(path): +def _is_file_path(path: Any) -> TypeGuard['StrPath']: try: pathlib.Path(path) return True diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 800cbf8e3..de9cb520b 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -25,8 +25,21 @@ import os import sys import threading +from collections.abc import Callable +from typing import ( + Any, + Generic, + Literal, + NamedTuple, + Optional, + Union, + cast, + overload, +) +from typing_extensions import Self, TypeVar from urllib import parse +import google.auth.credentials import requests import firebase_admin @@ -35,6 +48,19 @@ from firebase_admin import _sseclient from firebase_admin import _utils +__all__ = ( + 'EmulatorConfig', + 'Event', + 'ListenerRegistration', + 'Query', + 'Reference', + 'TransactionAbortedError', + 'reference', +) + +_K = TypeVar('_K', default=Any) +_V = TypeVar('_V', default=Any) +_JsonT = TypeVar('_JsonT', bound='_Json', default='_Json') _DB_ATTRIBUTE = '_database' _INVALID_PATH_CHARACTERS = '[].?#$' @@ -45,9 +71,19 @@ ) _TRANSACTION_MAX_RETRIES = 25 _EMULATOR_HOST_ENV_VAR = 'FIREBASE_DATABASE_EMULATOR_HOST' - - -def reference(path='/', app=None, url=None): +_Json = Optional[Union[ + dict[str, '_Json'], + list['_Json'], + str, + float, +]] + + +def reference( + path: str = '/', + app: Optional[firebase_admin.App] = None, + url: Optional[str] = None, +) -> 'Reference': """Returns a database ``Reference`` representing the node at the specified path. If no path is specified, this function returns a ``Reference`` that represents the database @@ -71,7 +107,8 @@ def reference(path='/', app=None, url=None): client = service.get_client(url) return Reference(client=client, path=path) -def _parse_path(path): + +def _parse_path(path: Any) -> list[str]: """Parses a path string into a set of segments.""" if not isinstance(path, str): raise ValueError(f'Invalid path: "{path}". Path must be a string.') @@ -83,7 +120,7 @@ def _parse_path(path): class Event: """Represents a realtime update event received from the database.""" - def __init__(self, sse_event): + def __init__(self, sse_event: _sseclient.Event) -> None: self._sse_event = sse_event self._data = json.loads(sse_event.data) @@ -98,7 +135,7 @@ def path(self): return self._data['path'] @property - def event_type(self): + def event_type(self) -> str: """Event type string (put, patch).""" return self._sse_event.event_type @@ -106,7 +143,11 @@ def event_type(self): class ListenerRegistration: """Represents the addition of an event listener to a database reference.""" - def __init__(self, callback, sse): + def __init__( + self, + callback: Callable[[Event], None], + sse: _sseclient.SSEClient, + ) -> None: """Initializes a new listener with given parameters. This is an internal API. Use the ``db.Reference.listen()`` method to start a @@ -121,14 +162,14 @@ def __init__(self, callback, sse): self._thread = threading.Thread(target=self._start_listen) self._thread.start() - def _start_listen(self): + def _start_listen(self) -> None: # iterate the sse client's generator for sse_event in self._sse: # only inject data events if sse_event: self._callback(Event(sse_event)) - def close(self): + def close(self) -> None: """Stops the event listener represented by this registration This closes the SSE HTTP connection, and joins the background thread. @@ -140,36 +181,59 @@ def close(self): class Reference: """Reference represents a node in the Firebase realtime database.""" - def __init__(self, **kwargs): + @overload + def __init__( + self, + *, + segments: list[str], + client: Optional['_Client'] = None, + **kwargs: Any, + ) -> None: ... + @overload + def __init__( + self, + *, + path: str, + client: Optional['_Client'] = None, + **kwargs: Any, + ) -> None: ... + def __init__( + self, + *, + path: Optional[str] = None, + segments: Optional[list[str]] = None, + client: Optional['_Client'] = None, + **kwargs: Any, + ) -> None: """Creates a new Reference using the provided parameters. This method is for internal use only. Use db.reference() to obtain an instance of Reference. """ - self._client = kwargs.get('client') - if 'segments' in kwargs: - self._segments = kwargs.get('segments') + self._client = client + if segments is not None: + self._segments = segments else: - self._segments = _parse_path(kwargs.get('path')) + self._segments = _parse_path(path) self._pathurl = '/' + '/'.join(self._segments) @property - def key(self): + def key(self) -> Optional[str]: if self._segments: return self._segments[-1] return None @property - def path(self): + def path(self) -> str: return self._pathurl @property - def parent(self): + def parent(self) -> Optional['Reference']: if self._segments: return Reference(client=self._client, segments=self._segments[:-1]) return None - def child(self, path): + def child(self, path: Optional[str]) -> 'Reference': """Returns a Reference to the specified child node. The path may point to an immediate child of the current Reference, or a deeply nested @@ -191,7 +255,23 @@ def child(self, path): full_path = self._pathurl + '/' + path return Reference(client=self._client, path=full_path) - def get(self, etag=False, shallow=False): + @overload + def get( # pyright: ignore[reportOverlappingOverload] + self, + etag: Literal[True], + shallow: bool = False, + ) -> tuple[_Json, str]: ... + @overload + def get( + self, + etag: bool = False, + shallow: bool = False, + ) -> _Json: ... + def get( + self, + etag: bool = False, + shallow: bool = False, + ) -> Union[tuple[_Json, str], _Json]: """Returns the value, and optionally the ETag, at the current location of the database. Args: @@ -214,12 +294,12 @@ def get(self, etag=False, shallow=False): raise ValueError('etag and shallow cannot both be set to True.') headers, data = self._client.headers_and_body( 'get', self._add_suffix(), headers={'X-Firebase-ETag' : 'true'}) - return data, headers.get('ETag') + return data, cast(str, headers.get('ETag')) params = 'shallow=true' if shallow else None return self._client.body('get', self._add_suffix(), params=params) - def get_if_changed(self, etag): + def get_if_changed(self, etag: str) -> tuple[bool, Optional[Any], Optional[str]]: """Gets data in this location only if the specified ETag does not match. Args: @@ -245,7 +325,7 @@ def get_if_changed(self, etag): return True, resp.json(), resp.headers.get('ETag') - def set(self, value): + def set(self, value: _Json) -> None: """Sets the data at this location to the given value. The value must be JSON-serializable and not None. @@ -262,7 +342,11 @@ def set(self, value): raise ValueError('Value must not be None.') self._client.request('put', self._add_suffix(), json=value, params='print=silent') - def set_if_unchanged(self, expected_etag, value): + def set_if_unchanged( + self, + expected_etag: str, + value: _JsonT + ) -> tuple[bool, _JsonT, str]: """Conditonally sets the data at this location to the given value. Sets the data at this location to the given value only if ``expected_etag`` is same as the @@ -290,7 +374,7 @@ def set_if_unchanged(self, expected_etag, value): try: headers = self._client.headers( 'put', self._add_suffix(), json=value, headers={'if-match': expected_etag}) - return True, value, headers.get('ETag') + return True, value, cast(str, headers.get('ETag')) except exceptions.FailedPreconditionError as error: http_response = error.http_response if http_response is not None and 'ETag' in http_response.headers: @@ -300,7 +384,7 @@ def set_if_unchanged(self, expected_etag, value): raise error - def push(self, value=''): + def push(self, value: _Json = '') -> 'Reference': """Creates a new child node. The optional value argument can be used to provide an initial value for the child node. If @@ -320,10 +404,10 @@ def push(self, value=''): if value is None: raise ValueError('Value must not be None.') output = self._client.body('post', self._add_suffix(), json=value) - push_id = output.get('name') + push_id = cast(Optional[str], output.get('name')) return self.child(push_id) - def update(self, value): + def update(self, value: _Json) -> None: """Updates the specified child keys of this Reference to the provided values. Args: @@ -339,7 +423,7 @@ def update(self, value): raise ValueError('Dictionary must not contain None keys.') self._client.request('patch', self._add_suffix(), json=value, params='print=silent') - def delete(self): + def delete(self) -> None: """Deletes this node from the database. Raises: @@ -347,7 +431,7 @@ def delete(self): """ self._client.request('delete', self._add_suffix()) - def listen(self, callback): + def listen(self, callback: Callable[[Event], None]) -> ListenerRegistration: """Registers the ``callback`` function to receive realtime updates. The specified callback function will get invoked with ``db.Event`` objects for each @@ -373,7 +457,7 @@ def listen(self, callback): """ return self._listen_with_session(callback) - def transaction(self, transaction_update): + def transaction(self, transaction_update: Callable[[_Json], _Json]) -> _Json: """Atomically modifies the data at this location. Unlike a normal ``set()``, which just overwrites the data regardless of its previous state, @@ -416,7 +500,7 @@ def transaction(self, transaction_update): raise TransactionAbortedError('Transaction aborted after failed retries.') - def order_by_child(self, path): + def order_by_child(self, path: str) -> 'Query': """Returns a Query that orders data by child values. Returned Query can be used to set additional parameters, and execute complex database @@ -435,7 +519,7 @@ def order_by_child(self, path): raise ValueError(f'Illegal child path: {path}') return Query(order_by=path, client=self._client, pathurl=self._add_suffix()) - def order_by_key(self): + def order_by_key(self) -> 'Query': """Creates a Query that orderes data by key. Returned Query can be used to set additional parameters, and execute complex database @@ -446,7 +530,7 @@ def order_by_key(self): """ return Query(order_by='$key', client=self._client, pathurl=self._add_suffix()) - def order_by_value(self): + def order_by_value(self) -> 'Query': """Creates a Query that orderes data by value. Returned Query can be used to set additional parameters, and execute complex database @@ -457,16 +541,20 @@ def order_by_value(self): """ return Query(order_by='$value', client=self._client, pathurl=self._add_suffix()) - def _add_suffix(self, suffix='.json'): + def _add_suffix(self, suffix: str = '.json') -> str: return self._pathurl + suffix - def _listen_with_session(self, callback, session=None): + def _listen_with_session( + self, + callback: Callable[[Event], None], + session: Optional[requests.Session] = None, + ) -> ListenerRegistration: url = self._client.base_url + self._add_suffix() if not session: session = self._client.create_listener_session() try: - sse = _sseclient.SSEClient(url, session, **{"params": self._client.params}) + sse = _sseclient.SSEClient(url, session, params=self._client.params) return ListenerRegistration(callback, sse) except requests.exceptions.RequestException as error: raise _Client.handle_rtdb_error(error) @@ -485,8 +573,7 @@ class Query: OrderedDict. """ - def __init__(self, **kwargs): - order_by = kwargs.pop('order_by') + def __init__(self, *, client: '_Client', order_by: str, pathurl: str, **kwargs: Any) -> None: if not order_by or not isinstance(order_by, str): raise ValueError('order_by field must be a non-empty string') if order_by not in _RESERVED_FILTERS: @@ -495,14 +582,14 @@ def __init__(self, **kwargs): f'Invalid path argument: "{order_by}". Child path must not start with "/"') segments = _parse_path(order_by) order_by = '/'.join(segments) - self._client = kwargs.pop('client') - self._pathurl = kwargs.pop('pathurl') + self._client = client + self._pathurl = pathurl self._order_by = order_by - self._params = {'orderBy' : json.dumps(order_by)} + self._params: dict[str, Any] = {'orderBy' : json.dumps(order_by)} if kwargs: raise ValueError(f'Unexpected keyword arguments: {kwargs}') - def limit_to_first(self, limit): + def limit_to_first(self, limit: int) -> Self: """Creates a query with limit, and anchors it to the start of the window. Args: @@ -521,7 +608,7 @@ def limit_to_first(self, limit): self._params['limitToFirst'] = limit return self - def limit_to_last(self, limit): + def limit_to_last(self, limit: int) -> Self: """Creates a query with limit, and anchors it to the end of the window. Args: @@ -540,7 +627,7 @@ def limit_to_last(self, limit): self._params['limitToLast'] = limit return self - def start_at(self, start): + def start_at(self, start: _Json) -> Self: """Sets the lower bound for a range query. The Query will only return child nodes with a value greater than or equal to the specified @@ -560,7 +647,7 @@ def start_at(self, start): self._params['startAt'] = json.dumps(start) return self - def end_at(self, end): + def end_at(self, end: _Json) -> Self: """Sets the upper bound for a range query. The Query will only return child nodes with a value less than or equal to the specified @@ -580,7 +667,7 @@ def end_at(self, end): self._params['endAt'] = json.dumps(end) return self - def equal_to(self, value): + def equal_to(self, value: _Json) -> Self: """Sets an equals constraint on the Query. The Query will only return child nodes whose value is equal to the specified value. @@ -600,13 +687,13 @@ def equal_to(self, value): return self @property - def _querystr(self): - params = [] + def _querystr(self) -> str: + params: list[str] = [] for key in sorted(self._params): params.append(f'{key}={self._params[key]}') return '&'.join(params) - def get(self): + def get(self) -> Union[dict[str, _Json], list[_Json]]: """Executes this Query and returns the results. The results will be returned as a sorted list or an OrderedDict. @@ -626,32 +713,40 @@ def get(self): class TransactionAbortedError(exceptions.AbortedError): """A transaction was aborted aftr exceeding the maximum number of retries.""" - def __init__(self, message): - exceptions.AbortedError.__init__(self, message) + def __init__(self, message: str) -> None: + super().__init__(message) -class _Sorter: +class _Sorter(Generic[_K, _V]): """Helper class for sorting query results.""" - def __init__(self, results, order_by): + @overload + def __init__(self, results: dict[_K, _V], order_by: str) -> None: ... + @overload + def __init__( + self: '_Sorter[int, _V]', # pyright: ignore[reportInvalidTypeVarUse] + results: list[_V], + order_by: str, + ) -> None: ... + def __init__(self, results: Union[dict[_K, _V], list[_V]], order_by: str) -> None: if isinstance(results, dict): self.dict_input = True entries = [_SortEntry(k, v, order_by) for k, v in results.items()] elif isinstance(results, list): self.dict_input = False - entries = [_SortEntry(k, v, order_by) for k, v in enumerate(results)] + entries = [_SortEntry(cast(_K, k), v, order_by) for k, v in enumerate(results)] else: raise ValueError(f'Sorting not supported for "{type(results)}" object.') self.sort_entries = sorted(entries) - def get(self): + def get(self) -> Union[collections.OrderedDict[_K, _V], list[_V]]: if self.dict_input: return collections.OrderedDict([(e.key, e.value) for e in self.sort_entries]) return [e.value for e in self.sort_entries] -class _SortEntry: +class _SortEntry(Generic[_K, _V]): """A wrapper that is capable of sorting items in a dictionary.""" _type_none = 0 @@ -661,7 +756,7 @@ class _SortEntry: _type_string = 4 _type_object = 5 - def __init__(self, key, value, order_by): + def __init__(self, key: _K, value: _V, order_by: str) -> None: self._key = key self._value = value if order_by in ('$key', '$priority'): @@ -673,23 +768,23 @@ def __init__(self, key, value, order_by): self._index_type = _SortEntry._get_index_type(self._index) @property - def key(self): + def key(self) -> _K: return self._key @property - def index(self): + def index(self) -> Optional[Any]: return self._index @property - def index_type(self): + def index_type(self) -> int: return self._index_type @property - def value(self): + def value(self) -> _V: return self._value @classmethod - def _get_index_type(cls, index): + def _get_index_type(cls, index: Any) -> int: """Assigns an integer code to the type of the index. The index type determines how differently typed values are sorted. This ordering is based @@ -709,17 +804,18 @@ def _get_index_type(cls, index): return cls._type_object @classmethod - def _extract_child(cls, value, path): + def _extract_child(cls, value: Any, path: str) -> Optional[Any]: segments = path.split('/') current = value for segment in segments: if isinstance(current, dict): + current = cast(dict[str, Any], current) current = current.get(segment) else: return None return current - def _compare(self, other): + def _compare(self, other: '_SortEntry') -> Literal[-1, 0, 1]: """Compares two _SortEntry instances. If the indices have the same numeric or string type, compare them directly. Ties are @@ -734,39 +830,44 @@ def _compare(self, other): else: self_key, other_key = self.key, other.key - if self_key < other_key: + if self_key < other_key: # pyright: ignore[reportOperatorIssue] return -1 - if self_key > other_key: + if self_key > other_key: # pyright: ignore[reportOperatorIssue] return 1 return 0 - def __lt__(self, other): + def __lt__(self, other: '_SortEntry') -> bool: return self._compare(other) < 0 - def __le__(self, other): + def __le__(self, other: '_SortEntry') -> bool: return self._compare(other) <= 0 - def __gt__(self, other): + def __gt__(self, other: '_SortEntry') -> bool: return self._compare(other) > 0 - def __ge__(self, other): + def __ge__(self, other: '_SortEntry') -> bool: return self._compare(other) >= 0 - def __eq__(self, other): + def __eq__(self, other: '_SortEntry') -> bool: # pyright: ignore[reportIncompatibleMethodOverride] return self._compare(other) == 0 +class EmulatorConfig(NamedTuple): + base_url: str + namespace: str + + class _DatabaseService: """Service that maintains a collection of database clients.""" _DEFAULT_AUTH_OVERRIDE = '_admin_' - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: self._credential = app.credential db_url = app.options.get('databaseURL') if db_url: - self._db_url = db_url + self._db_url: Optional[str] = db_url else: self._db_url = None @@ -776,7 +877,7 @@ def __init__(self, app): else: self._auth_override = None self._timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) - self._clients = {} + self._clients: dict[tuple[str, str], _Client] = {} emulator_host = os.environ.get(_EMULATOR_HOST_ENV_VAR) if emulator_host: @@ -788,7 +889,7 @@ def __init__(self, app): else: self._emulator_host = None - def get_client(self, db_url=None): + def get_client(self, db_url: Optional[str] = None) -> '_Client': """Creates a client based on the db_url. Clients may be cached.""" if db_url is None: db_url = self._db_url @@ -813,7 +914,6 @@ def get_client(self, db_url=None): base_url = f'https://{parsed_url.netloc}' params = {} - if self._auth_override: params['auth_variable_override'] = self._auth_override @@ -823,9 +923,8 @@ def get_client(self, db_url=None): self._clients[client_cache_key] = client return self._clients[client_cache_key] - def _get_emulator_config(self, parsed_url): + def _get_emulator_config(self, parsed_url: parse.ParseResult) -> Optional[EmulatorConfig]: """Checks whether the SDK should connect to the RTDB emulator.""" - EmulatorConfig = collections.namedtuple('EmulatorConfig', ['base_url', 'namespace']) if parsed_url.scheme != 'https': # Emulator mode enabled by passing http URL via AppOptions base_url, namespace = _DatabaseService._parse_emulator_url(parsed_url) @@ -839,7 +938,7 @@ def _get_emulator_config(self, parsed_url): return None @classmethod - def _parse_emulator_url(cls, parsed_url): + def _parse_emulator_url(cls, parsed_url: parse.ParseResult) -> tuple[str, str]: """Parses emulator URL like http://localhost:8080/?ns=foo-bar""" query_ns = parse.parse_qs(parsed_url.query).get('ns') if parsed_url.scheme != 'http' or (not query_ns or len(query_ns) != 1 or not query_ns[0]): @@ -852,9 +951,10 @@ def _parse_emulator_url(cls, parsed_url): return base_url, namespace @classmethod - def _get_auth_override(cls, app): + def _get_auth_override(cls, app: firebase_admin.App) -> Optional[Union[dict[str, Any], str]]: """Gets and validates the database auth override to be used.""" - auth_override = app.options.get('databaseAuthVariableOverride', cls._DEFAULT_AUTH_OVERRIDE) + auth_override = cast(Optional[str], app.options.get( + 'databaseAuthVariableOverride', cls._DEFAULT_AUTH_OVERRIDE)) if auth_override == cls._DEFAULT_AUTH_OVERRIDE or auth_override is None: return auth_override if not isinstance(auth_override, dict): @@ -864,7 +964,7 @@ def _get_auth_override(cls, app): return auth_override - def close(self): + def close(self) -> None: for value in self._clients.values(): value.close() self._clients = {} @@ -877,7 +977,13 @@ class _Client(_http_client.JsonHttpClient): marshalling and unmarshalling of JSON data. """ - def __init__(self, credential, base_url, timeout, params=None): + def __init__( + self, + credential: Optional[google.auth.credentials.Credentials], + base_url: str, + timeout: int, + params: Optional[dict[str, Any]] = None, + ) -> None: """Creates a new _Client from the given parameters. This exists primarily to enable testing. For regular use, obtain _Client instances by @@ -897,7 +1003,7 @@ def __init__(self, credential, base_url, timeout, params=None): self.credential = credential self.params = params if params else {} - def request(self, method, url, **kwargs): + def request(self, method: str, url: str, **kwargs: Any) -> requests.Response: """Makes an HTTP call using the Python requests library. Extends the request() method of the parent JsonHttpClient class. Handles default @@ -929,11 +1035,11 @@ def request(self, method, url, **kwargs): except requests.exceptions.RequestException as error: raise _Client.handle_rtdb_error(error) - def create_listener_session(self): + def create_listener_session(self) -> _sseclient.KeepAuthSession: return _sseclient.KeepAuthSession(self.credential) @classmethod - def handle_rtdb_error(cls, error): + def handle_rtdb_error(cls, error: requests.RequestException) -> exceptions.FirebaseError: """Converts an error encountered while calling RTDB into a FirebaseError.""" if error.response is None: return _utils.handle_requests_error(error) @@ -942,7 +1048,7 @@ def handle_rtdb_error(cls, error): return _utils.handle_requests_error(error, message=message) @classmethod - def _extract_error_message(cls, response): + def _extract_error_message(cls, response: requests.Response) -> str: """Extracts an error message from an error response. If the server has sent a JSON response with an 'error' field, which is the typical @@ -953,7 +1059,7 @@ def _extract_error_message(cls, response): message = None try: # RTDB error format: {"error": "text message"} - data = response.json() + data: dict[str, str] = response.json() if isinstance(data, dict): message = data.get('error') except ValueError: diff --git a/firebase_admin/exceptions.py b/firebase_admin/exceptions.py index 947f36806..00143a117 100644 --- a/firebase_admin/exceptions.py +++ b/firebase_admin/exceptions.py @@ -31,6 +31,46 @@ subtype error handlers. """ +from typing import Optional, Union + +import httpx +import requests + +__all__ = ( + 'ABORTED', + 'ALREADY_EXISTS', + 'CANCELLED', + 'CONFLICT', + 'DATA_LOSS', + 'DEADLINE_EXCEEDED', + 'FAILED_PRECONDITION', + 'INTERNAL', + 'INVALID_ARGUMENT', + 'NOT_FOUND', + 'OUT_OF_RANGE', + 'PERMISSION_DENIED', + 'RESOURCE_EXHAUSTED', + 'UNAUTHENTICATED', + 'UNAVAILABLE', + 'UNKNOWN', + 'AbortedError', + 'AlreadyExistsError', + 'CancelledError', + 'ConflictError', + 'DataLossError', + 'DeadlineExceededError', + 'FailedPreconditionError', + 'FirebaseError', + 'InternalError', + 'InvalidArgumentError', + 'NotFoundError', + 'OutOfRangeError', + 'PermissionDeniedError', + 'ResourceExhaustedError', + 'UnauthenticatedError', + 'UnavailableError', + 'UnknownError', +) #: Error code for ``InvalidArgumentError`` type. INVALID_ARGUMENT = 'INVALID_ARGUMENT' @@ -95,52 +135,78 @@ class FirebaseError(Exception): this object. """ - def __init__(self, code, message, cause=None, http_response=None): - Exception.__init__(self, message) + def __init__( + self, + code: str, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message) self._code = code self._cause = cause self._http_response = http_response @property - def code(self): + def code(self) -> str: return self._code @property - def cause(self): + def cause(self) -> Optional[Exception]: return self._cause @property - def http_response(self): + def http_response(self) -> Optional[Union[httpx.Response, requests.Response]]: return self._http_response class InvalidArgumentError(FirebaseError): """Client specified an invalid argument.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, INVALID_ARGUMENT, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(INVALID_ARGUMENT, message, cause, http_response) class FailedPreconditionError(FirebaseError): """Request can not be executed in the current system state, such as deleting a non-empty directory.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, FAILED_PRECONDITION, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(FAILED_PRECONDITION, message, cause, http_response) class OutOfRangeError(FirebaseError): """Client specified an invalid range.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, OUT_OF_RANGE, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(OUT_OF_RANGE, message, cause, http_response) class UnauthenticatedError(FirebaseError): """Request not authenticated due to missing, invalid, or expired OAuth token.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, UNAUTHENTICATED, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(UNAUTHENTICATED, message, cause, http_response) class PermissionDeniedError(FirebaseError): @@ -150,79 +216,134 @@ class PermissionDeniedError(FirebaseError): have permission, or the API has not been enabled for the client project. """ - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, PERMISSION_DENIED, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(PERMISSION_DENIED, message, cause, http_response) class NotFoundError(FirebaseError): """A specified resource is not found, or the request is rejected by undisclosed reasons, such as whitelisting.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, NOT_FOUND, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(NOT_FOUND, message, cause, http_response) class ConflictError(FirebaseError): """Concurrency conflict, such as read-modify-write conflict.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, CONFLICT, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(CONFLICT, message, cause, http_response) class AbortedError(FirebaseError): """Concurrency conflict, such as read-modify-write conflict.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, ABORTED, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(ABORTED, message, cause, http_response) class AlreadyExistsError(FirebaseError): """The resource that a client tried to create already exists.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, ALREADY_EXISTS, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(ALREADY_EXISTS, message, cause, http_response) class ResourceExhaustedError(FirebaseError): """Either out of resource quota or reaching rate limiting.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, RESOURCE_EXHAUSTED, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(RESOURCE_EXHAUSTED, message, cause, http_response) class CancelledError(FirebaseError): """Request cancelled by the client.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, CANCELLED, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(CANCELLED, message, cause, http_response) class DataLossError(FirebaseError): """Unrecoverable data loss or data corruption.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, DATA_LOSS, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(DATA_LOSS, message, cause, http_response) class UnknownError(FirebaseError): """Unknown server error.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, UNKNOWN, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(UNKNOWN, message, cause, http_response) class InternalError(FirebaseError): """Internal server error.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, INTERNAL, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(INTERNAL, message, cause, http_response) class UnavailableError(FirebaseError): """Service unavailable. Typically the server is down.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, UNAVAILABLE, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(UNAVAILABLE, message, cause, http_response) class DeadlineExceededError(FirebaseError): @@ -233,5 +354,10 @@ class DeadlineExceededError(FirebaseError): request) and the request did not finish within the deadline. """ - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, DEADLINE_EXCEEDED, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(DEADLINE_EXCEEDED, message, cause, http_response) diff --git a/firebase_admin/firestore.py b/firebase_admin/firestore.py index 52ea90671..496afe237 100644 --- a/firebase_admin/firestore.py +++ b/firebase_admin/firestore.py @@ -18,27 +18,28 @@ Firebase apps. This requires the ``google-cloud-firestore`` Python module. """ -from __future__ import annotations -from typing import Optional, Dict -from firebase_admin import App +from typing import Optional + +import firebase_admin from firebase_admin import _utils try: - from google.cloud import firestore + import google.cloud.firestore + # firestore defines __all__ for safe import * + from google.cloud.firestore import * # pyright: ignore[reportWildcardImportFromLibrary] from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE - existing = globals().keys() - for key, value in firestore.__dict__.items(): - if not key.startswith('_') and key not in existing: - globals()[key] = value except ImportError as error: raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure ' 'to install the "google-cloud-firestore" module.') from error +__all__ = ['client'] +__all__.extend(google.cloud.firestore.__all__) # pyright: ignore[reportUnsupportedDunderAll] + _FIRESTORE_ATTRIBUTE = '_firestore' -def client(app: Optional[App] = None, database_id: Optional[str] = None) -> firestore.Client: +def client(app: Optional[firebase_admin.App] = None, database_id: Optional[str] = None) -> Client: """Returns a client that can be used to interact with Google Cloud Firestore. Args: @@ -68,11 +69,11 @@ def client(app: Optional[App] = None, database_id: Optional[str] = None) -> fire class _FirestoreService: """Service that maintains a collection of firestore clients.""" - def __init__(self, app: App) -> None: - self._app: App = app - self._clients: Dict[str, firestore.Client] = {} + def __init__(self, app: firebase_admin.App) -> None: + self._app = app + self._clients: dict[str, Client] = {} - def get_client(self, database_id: Optional[str]) -> firestore.Client: + def get_client(self, database_id: Optional[str]) -> Client: """Creates a client based on the database_id. These clients are cached.""" database_id = database_id or DEFAULT_DATABASE if database_id not in self._clients: @@ -85,7 +86,7 @@ def get_client(self, database_id: Optional[str]) -> firestore.Client: 'or use service account credentials. Alternatively, set the ' 'GOOGLE_CLOUD_PROJECT environment variable.') - fs_client = firestore.Client( + fs_client = Client( credentials=credentials, project=project, database=database_id) self._clients[database_id] = fs_client diff --git a/firebase_admin/firestore_async.py b/firebase_admin/firestore_async.py index 4a197e9df..71694adb9 100644 --- a/firebase_admin/firestore_async.py +++ b/firebase_admin/firestore_async.py @@ -18,27 +18,31 @@ associated with Firebase apps. This requires the ``google-cloud-firestore`` Python module. """ -from __future__ import annotations -from typing import Optional, Dict -from firebase_admin import App +from typing import Optional + +import firebase_admin from firebase_admin import _utils try: - from google.cloud import firestore + import google.cloud.firestore + # firestore defines __all__ for safe import * + from google.cloud.firestore import * # pyright: ignore[reportWildcardImportFromLibrary] from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE - existing = globals().keys() - for key, value in firestore.__dict__.items(): - if not key.startswith('_') and key not in existing: - globals()[key] = value except ImportError as error: raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure ' 'to install the "google-cloud-firestore" module.') from error +__all__ = ['client'] +__all__.extend(google.cloud.firestore.__all__) # pyright: ignore[reportUnsupportedDunderAll] + -_FIRESTORE_ASYNC_ATTRIBUTE: str = '_firestore_async' +_FIRESTORE_ASYNC_ATTRIBUTE = '_firestore_async' -def client(app: Optional[App] = None, database_id: Optional[str] = None) -> firestore.AsyncClient: +def client( + app: Optional[firebase_admin.App] = None, + database_id: Optional[str] = None, +) -> AsyncClient: """Returns an async client that can be used to interact with Google Cloud Firestore. Args: @@ -68,11 +72,11 @@ def client(app: Optional[App] = None, database_id: Optional[str] = None) -> fire class _FirestoreAsyncService: """Service that maintains a collection of firestore async clients.""" - def __init__(self, app: App) -> None: - self._app: App = app - self._clients: Dict[str, firestore.AsyncClient] = {} + def __init__(self, app: firebase_admin.App) -> None: + self._app = app + self._clients: dict[str, AsyncClient] = {} - def get_client(self, database_id: Optional[str]) -> firestore.AsyncClient: + def get_client(self, database_id: Optional[str]) -> AsyncClient: """Creates an async client based on the database_id. These clients are cached.""" database_id = database_id or DEFAULT_DATABASE if database_id not in self._clients: @@ -85,7 +89,7 @@ def get_client(self, database_id: Optional[str]) -> firestore.AsyncClient: 'or use service account credentials. Alternatively, set the ' 'GOOGLE_CLOUD_PROJECT environment variable.') - fs_client = firestore.AsyncClient( + fs_client = AsyncClient( credentials=credentials, project=project, database=database_id) self._clients[database_id] = fs_client diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py index 86eea557a..473d6985e 100644 --- a/firebase_admin/functions.py +++ b/firebase_admin/functions.py @@ -14,27 +14,32 @@ """Firebase Functions module.""" -from __future__ import annotations -from datetime import datetime, timedelta -from urllib import parse -import re +import base64 +import dataclasses +import datetime import json -from base64 import b64encode -from typing import Any, Optional, Dict -from dataclasses import dataclass -from google.auth.compute_engine import Credentials as ComputeEngineCredentials +import re +from urllib import parse +from typing import Any, Optional, cast +from typing_extensions import TypeGuard + import requests +from google.auth.credentials import Credentials as GoogleAuthCredentials +from google.auth.compute_engine import Credentials as ComputeEngineCredentials + import firebase_admin -from firebase_admin import App from firebase_admin import _http_client from firebase_admin import _utils +from firebase_admin import exceptions _FUNCTIONS_ATTRIBUTE = '_functions' __all__ = [ + 'Resource', + 'Task', 'TaskOptions', - + 'TaskQueue', 'task_queue', ] @@ -54,14 +59,14 @@ # Default canonical location ID of the task queue. _DEFAULT_LOCATION = 'us-central1' -def _get_functions_service(app) -> _FunctionsService: +def _get_functions_service(app: Optional[firebase_admin.App]) -> '_FunctionsService': return _utils.get_app_service(app, _FUNCTIONS_ATTRIBUTE, _FunctionsService) def task_queue( - function_name: str, - extension_id: Optional[str] = None, - app: Optional[App] = None - ) -> TaskQueue: + function_name: str, + extension_id: Optional[str] = None, + app: Optional[firebase_admin.App] = None, +) -> 'TaskQueue': """Creates a reference to a TaskQueue for a given function name. The function name can be either: @@ -89,9 +94,10 @@ def task_queue( """ return _get_functions_service(app).task_queue(function_name, extension_id) + class _FunctionsService: """Service class that implements Firebase Functions functionality.""" - def __init__(self, app: App): + def __init__(self, app: firebase_admin.App) -> None: self._project_id = app.project_id if not self._project_id: raise ValueError( @@ -102,28 +108,27 @@ def __init__(self, app: App): self._credential = app.credential.get_credential() self._http_client = _http_client.JsonHttpClient(credential=self._credential) - def task_queue(self, function_name: str, extension_id: Optional[str] = None) -> TaskQueue: + def task_queue(self, function_name: str, extension_id: Optional[str] = None) -> 'TaskQueue': """Creates a TaskQueue instance.""" return TaskQueue( function_name, extension_id, self._project_id, self._credential, self._http_client) @classmethod - def handle_functions_error(cls, error: Any): + def handle_functions_error(cls, error: requests.RequestException) -> exceptions.FirebaseError: """Handles errors received from the Cloud Functions API.""" - return _utils.handle_platform_error_from_requests(error) + class TaskQueue: """TaskQueue class that implements Firebase Cloud Tasks Queues functionality.""" def __init__( - self, - function_name: str, - extension_id: Optional[str], - project_id, - credential, - http_client - ) -> None: - + self, + function_name: str, + extension_id: Optional[str], + project_id: Optional[str], + credential: GoogleAuthCredentials, + http_client: _http_client.HttpClient[dict[str, Any]], + ) -> None: # Validate function_name _Validators.check_non_empty_string('function_name', function_name) @@ -144,8 +149,7 @@ def __init__( _Validators.check_non_empty_string('extension_id', self._extension_id) self._resource.resource_id = f'ext-{self._extension_id}-{self._resource.resource_id}' - - def enqueue(self, task_data: Any, opts: Optional[TaskOptions] = None) -> str: + def enqueue(self, task_data: Any, opts: Optional['TaskOptions'] = None) -> str: """Creates a task and adds it to the queue. Tasks cannot be updated after creation. This action requires `cloudtasks.tasks.create` IAM permission on the service account. @@ -172,7 +176,7 @@ def enqueue(self, task_data: Any, opts: Optional[TaskOptions] = None) -> str: headers=_FUNCTIONS_HEADERS, json={'task': task_payload.__dict__} ) - task_name = resp.get('name', None) + task_name = cast(str, resp['name']) task_resource = \ self._parse_resource_name(task_name, f'queues/{self._resource.resource_id}/tasks') return task_resource.resource_id @@ -203,8 +207,7 @@ def delete(self, task_id: str) -> None: except requests.exceptions.RequestException as error: raise _FunctionsService.handle_functions_error(error) - - def _parse_resource_name(self, resource_name: str, resource_id_key: str) -> Resource: + def _parse_resource_name(self, resource_name: str, resource_id_key: str) -> 'Resource': """Parses a full or partial resource path into a ``Resource``.""" if '/' not in resource_name: return Resource(resource_id=resource_name) @@ -215,7 +218,7 @@ def _parse_resource_name(self, resource_name: str, resource_id_key: str) -> Reso raise ValueError('Invalid resource name format.') return Resource(project_id=match[2], location_id=match[3], resource_id=match[4]) - def _get_url(self, resource: Resource, url_format: str) -> str: + def _get_url(self, resource: 'Resource', url_format: str) -> str: """Generates url path from a ``Resource`` and url format string.""" return url_format.format( project_id=resource.project_id, @@ -223,18 +226,18 @@ def _get_url(self, resource: Resource, url_format: str) -> str: resource_id=resource.resource_id) def _validate_task_options( - self, - data: Any, - resource: Resource, - opts: Optional[TaskOptions] = None - ) -> Task: + self, + data: dict[str, Any], + resource: 'Resource', + opts: Optional['TaskOptions'] = None, + ) -> 'Task': """Validate and create a Task from optional ``TaskOptions``.""" task_http_request = { 'url': '', 'oidc_token': { 'service_account_email': '' }, - 'body': b64encode(json.dumps(data).encode()).decode(), + 'body': base64.b64encode(json.dumps(data).encode()).decode(), 'headers': { 'Content-Type': 'application/json', } @@ -248,14 +251,15 @@ def _validate_task_options( raise ValueError( 'Both sechdule_delay_seconds and schedule_time cannot be set at the same time.') if opts.schedule_time is not None and opts.schedule_delay_seconds is None: - if not isinstance(opts.schedule_time, datetime): + if not isinstance(opts.schedule_time, datetime.datetime): raise ValueError('schedule_time should be UTC datetime.') task.schedule_time = opts.schedule_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') if opts.schedule_delay_seconds is not None and opts.schedule_time is None: if not isinstance(opts.schedule_delay_seconds, int) \ or opts.schedule_delay_seconds < 0: raise ValueError('schedule_delay_seconds should be positive int.') - schedule_time = datetime.utcnow() + timedelta(seconds=opts.schedule_delay_seconds) + schedule_time = datetime.datetime.now(datetime.timezone.utc) + \ + datetime.timedelta(seconds=opts.schedule_delay_seconds) task.schedule_time = schedule_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') if opts.dispatch_deadline_seconds is not None: if not isinstance(opts.dispatch_deadline_seconds, int) \ @@ -279,7 +283,12 @@ def _validate_task_options( task.http_request['url'] = opts.uri return task - def _update_task_payload(self, task: Task, resource: Resource, extension_id: str) -> Task: + def _update_task_payload( + self, + task: 'Task', + resource: 'Resource', + extension_id: Optional[str], + ) -> 'Task': """Prepares task to be sent with credentials.""" # Get function url from task or generate from resources if not _Validators.is_non_empty_string(task.http_request['url']): @@ -289,21 +298,22 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str if _Validators.is_non_empty_string(extension_id) and \ isinstance(self._credential, ComputeEngineCredentials): - id_token = self._credential.token + id_token = cast(str, self._credential.token) task.http_request['headers'] = \ {**task.http_request['headers'], 'Authorization': f'Bearer ${id_token}'} # Delete oidc token del task.http_request['oidc_token'] else: + # possible issue: _credential needs more specific annotation task.http_request['oidc_token'] = \ - {'service_account_email': self._credential.service_account_email} + {'service_account_email': self._credential.service_account_email} # pyright: ignore[reportAttributeAccessIssue] return task class _Validators: """A collection of data validation utilities.""" @classmethod - def check_non_empty_string(cls, label: str, value: Any): + def check_non_empty_string(cls, label: str, value: Any) -> None: """Checks if given value is a non-empty string and throws error if not.""" if not isinstance(value, str): raise ValueError(f'{label} "{value}" must be a string.') @@ -311,14 +321,14 @@ def check_non_empty_string(cls, label: str, value: Any): raise ValueError(f'{label} "{value}" must be a non-empty string.') @classmethod - def is_non_empty_string(cls, value: Any): + def is_non_empty_string(cls, value: Any) -> TypeGuard[str]: """Checks if given value is a non-empty string and returns bool.""" if not isinstance(value, str) or value == '': return False return True @classmethod - def is_task_id(cls, task_id: Any): + def is_task_id(cls, task_id: str) -> bool: """Checks if given value is a valid task id.""" reg = '^[A-Za-z0-9_-]+$' if re.match(reg, task_id) is not None and len(task_id) <= 500: @@ -326,7 +336,7 @@ def is_task_id(cls, task_id: Any): return False @classmethod - def is_url(cls, url: Any): + def is_url(cls, url: Any) -> TypeGuard[str]: """Checks if given value is a valid url.""" if not isinstance(url, str): return False @@ -339,7 +349,7 @@ def is_url(cls, url: Any): return False -@dataclass +@dataclasses.dataclass class TaskOptions: """Task Options that can be applied to a Task. @@ -398,13 +408,14 @@ class TaskOptions: http URL. """ schedule_delay_seconds: Optional[int] = None - schedule_time: Optional[datetime] = None + schedule_time: Optional[datetime.datetime] = None dispatch_deadline_seconds: Optional[int] = None task_id: Optional[str] = None - headers: Optional[Dict[str, str]] = None + headers: Optional[dict[str, str]] = None uri: Optional[str] = None -@dataclass + +@dataclasses.dataclass class Task: """Contains the relevant fields for enqueueing tasks that trigger Cloud Functions. @@ -418,13 +429,13 @@ class Task: schedule_time: The time when the task is scheduled to be attempted or retried. dispatch_deadline: The deadline for requests sent to the worker. """ - http_request: Dict[str, Optional[str | dict]] + http_request: dict[str, Any] name: Optional[str] = None schedule_time: Optional[str] = None dispatch_deadline: Optional[str] = None -@dataclass +@dataclasses.dataclass class Resource: """Contains the parsed address of a resource. diff --git a/firebase_admin/instance_id.py b/firebase_admin/instance_id.py index 812daf40b..8c57eb1dc 100644 --- a/firebase_admin/instance_id.py +++ b/firebase_admin/instance_id.py @@ -16,22 +16,25 @@ This module enables deleting instance IDs associated with Firebase projects. """ +from typing import Optional import requests +import firebase_admin from firebase_admin import _http_client from firebase_admin import _utils +__all__ = ('delete_instance_id',) _IID_SERVICE_URL = 'https://console.firebase.google.com/v1/' _IID_ATTRIBUTE = '_iid' -def _get_iid_service(app): +def _get_iid_service(app: Optional[firebase_admin.App]) -> '_InstanceIdService': return _utils.get_app_service(app, _IID_ATTRIBUTE, _InstanceIdService) -def delete_instance_id(instance_id, app=None): +def delete_instance_id(instance_id: str, app: Optional[firebase_admin.App] = None) -> None: """Deletes the specified instance ID and the associated data from Firebase. Note that Google Analytics for Firebase uses its own form of Instance ID to @@ -55,7 +58,7 @@ def delete_instance_id(instance_id, app=None): class _InstanceIdService: """Provides methods for interacting with the remote instance ID service.""" - error_codes = { + error_codes: dict[int, str] = { 400: 'Malformed instance ID argument.', 401: 'Request not authorized.', 403: 'Project does not match instance ID or the client does not have ' @@ -67,7 +70,7 @@ class _InstanceIdService: 503: 'Backend servers are over capacity. Try again later.' } - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -78,7 +81,7 @@ def __init__(self, app): self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), base_url=_IID_SERVICE_URL) - def delete_instance_id(self, instance_id): + def delete_instance_id(self, instance_id: str) -> None: if not isinstance(instance_id, str) or not instance_id: raise ValueError('Instance ID must be a non-empty string.') path = f'project/{self._project_id}/instanceId/{instance_id}' @@ -88,7 +91,7 @@ def delete_instance_id(self, instance_id): msg = self._extract_message(instance_id, error) raise _utils.handle_requests_error(error, msg) - def _extract_message(self, instance_id, error): + def _extract_message(self, instance_id: str, error: requests.RequestException) -> Optional[str]: if error.response is None: return None status = error.response.status_code diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 749044436..4a9c4503c 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -14,14 +14,15 @@ """Firebase Cloud Messaging module.""" -from __future__ import annotations -from typing import Any, Callable, Dict, List, Optional, cast import concurrent.futures import json -import asyncio import logging -import requests +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Optional, Union, cast + +import asyncio import httpx +import requests import firebase_admin from firebase_admin import ( @@ -29,13 +30,11 @@ _messaging_encoder, _messaging_utils, _utils, - exceptions, - App + exceptions ) -logger = logging.getLogger(__name__) - -_MESSAGING_ATTRIBUTE = '_messaging' +if TYPE_CHECKING: + import httplib2 __all__ = [ @@ -75,6 +74,10 @@ 'unsubscribe_from_topic', ] +logger = logging.getLogger(__name__) + +_MESSAGING_ATTRIBUTE = '_messaging' + AndroidConfig = _messaging_utils.AndroidConfig AndroidFCMOptions = _messaging_utils.AndroidFCMOptions @@ -101,10 +104,11 @@ UnregisteredError = _messaging_utils.UnregisteredError -def _get_messaging_service(app: Optional[App]) -> _MessagingService: +def _get_messaging_service(app: Optional[firebase_admin.App]) -> '_MessagingService': return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService) -def send(message: Message, dry_run: bool = False, app: Optional[App] = None) -> str: + +def send(message: Message, dry_run: bool = False, app: Optional[firebase_admin.App] = None) -> str: """Sends the given message via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -124,11 +128,12 @@ def send(message: Message, dry_run: bool = False, app: Optional[App] = None) -> """ return _get_messaging_service(app).send(message, dry_run) + def send_each( - messages: List[Message], - dry_run: bool = False, - app: Optional[App] = None - ) -> BatchResponse: + messages: list[Message], + dry_run: bool = False, + app: Optional[firebase_admin.App] = None, +) -> 'BatchResponse': """Sends each message in the given list via Firebase Cloud Messaging. If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -148,11 +153,12 @@ def send_each( """ return _get_messaging_service(app).send_each(messages, dry_run) + async def send_each_async( - messages: List[Message], - dry_run: bool = False, - app: Optional[App] = None - ) -> BatchResponse: + messages: list[Message], + dry_run: bool = False, + app: Optional[firebase_admin.App] = None, +) -> 'BatchResponse': """Sends each message in the given list asynchronously via Firebase Cloud Messaging. If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -172,11 +178,12 @@ async def send_each_async( """ return await _get_messaging_service(app).send_each_async(messages, dry_run) + async def send_each_for_multicast_async( - multicast_message: MulticastMessage, - dry_run: bool = False, - app: Optional[App] = None - ) -> BatchResponse: + multicast_message: MulticastMessage, + dry_run: bool = False, + app: Optional[firebase_admin.App] = None, +) -> 'BatchResponse': """Sends the given mutlicast message to each token asynchronously via Firebase Cloud Messaging (FCM). @@ -208,7 +215,12 @@ async def send_each_for_multicast_async( ) for token in multicast_message.tokens] return await _get_messaging_service(app).send_each_async(messages, dry_run) -def send_each_for_multicast(multicast_message, dry_run=False, app=None): + +def send_each_for_multicast( + multicast_message: MulticastMessage, + dry_run: bool = False, + app: Optional[firebase_admin.App] = None, +) -> 'BatchResponse': """Sends the given mutlicast message to each token via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -239,7 +251,12 @@ def send_each_for_multicast(multicast_message, dry_run=False, app=None): ) for token in multicast_message.tokens] return _get_messaging_service(app).send_each(messages, dry_run) -def subscribe_to_topic(tokens, topic, app=None): + +def subscribe_to_topic( + tokens: Union[list[str], str], + topic: str, + app: Optional[firebase_admin.App] = None, +) -> 'TopicManagementResponse': """Subscribes a list of registration tokens to an FCM topic. Args: @@ -258,7 +275,12 @@ def subscribe_to_topic(tokens, topic, app=None): return _get_messaging_service(app).make_topic_management_request( tokens, topic, 'iid/v1:batchAdd') -def unsubscribe_from_topic(tokens, topic, app=None): + +def unsubscribe_from_topic( + tokens: Union[list[str], str], + topic: str, + app: Optional[firebase_admin.App] = None, +) -> 'TopicManagementResponse': """Unsubscribes a list of registration tokens from an FCM topic. Args: @@ -281,17 +303,17 @@ def unsubscribe_from_topic(tokens, topic, app=None): class ErrorInfo: """An error encountered when performing a topic management operation.""" - def __init__(self, index, reason): + def __init__(self, index: int, reason: str) -> None: self._index = index self._reason = reason @property - def index(self): + def index(self) -> int: """Index of the registration token to which this error is related to.""" return self._index @property - def reason(self): + def reason(self) -> str: """String describing the nature of the error.""" return self._reason @@ -299,12 +321,12 @@ def reason(self): class TopicManagementResponse: """The response received from a topic management operation.""" - def __init__(self, resp): + def __init__(self, resp: dict[str, Any]) -> None: if not isinstance(resp, dict) or 'results' not in resp: raise ValueError(f'Unexpected topic management response: {resp}.') self._success_count = 0 self._failure_count = 0 - self._errors = [] + self._errors: list[ErrorInfo] = [] for index, result in enumerate(resp['results']): if 'error' in result: self._failure_count += 1 @@ -313,17 +335,17 @@ def __init__(self, resp): self._success_count += 1 @property - def success_count(self): + def success_count(self) -> int: """Number of tokens that were successfully subscribed or unsubscribed.""" return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: """Number of tokens that could not be subscribed or unsubscribed due to errors.""" return self._failure_count @property - def errors(self): + def errors(self) ->list[ErrorInfo]: """A list of ``messaging.ErrorInfo`` objects (possibly empty).""" return self._errors @@ -331,12 +353,12 @@ def errors(self): class BatchResponse: """The response received from a batch request to the FCM API.""" - def __init__(self, responses: List[SendResponse]) -> None: + def __init__(self, responses: list['SendResponse']) -> None: self._responses = responses self._success_count = sum(1 for resp in responses if resp.success) @property - def responses(self) -> List[SendResponse]: + def responses(self) -> list['SendResponse']: """A list of ``messaging.SendResponse`` objects (possibly empty).""" return self._responses @@ -352,27 +374,32 @@ def failure_count(self) -> int: class SendResponse: """The response received from an individual batched request to the FCM API.""" - def __init__(self, resp, exception): + def __init__( + self, + resp: Optional[dict[str, Any]], + exception: Optional[exceptions.FirebaseError], + ) -> None: self._exception = exception - self._message_id = None + self._message_id: Optional[str] = None if resp: self._message_id = resp.get('name', None) @property - def message_id(self): + def message_id(self) -> Optional[str]: """A message ID string that uniquely identifies the message.""" return self._message_id @property - def success(self): + def success(self) -> bool: """A boolean indicating if the request was successful.""" return self._message_id is not None and not self._exception @property - def exception(self): + def exception(self) -> Optional[exceptions.FirebaseError]: """A ``FirebaseError`` if an error occurs while sending the message to the FCM service.""" return self._exception + class _MessagingService: """Service class that implements Firebase Cloud Messaging (FCM) functionality.""" @@ -390,7 +417,7 @@ class _MessagingService: 'UNREGISTERED': UnregisteredError, } - def __init__(self, app: App) -> None: + def __init__(self, app: firebase_admin.App) -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -409,7 +436,7 @@ def __init__(self, app: App) -> None: credential=self._credential, timeout=timeout) @classmethod - def encode_message(cls, message): + def encode_message(cls, message: Message) -> dict[str, Any]: if not isinstance(message, Message): raise ValueError('Message must be an instance of messaging.Message class.') return cls.JSON_ENCODER.default(message) @@ -428,14 +455,14 @@ def send(self, message: Message, dry_run: bool = False) -> str: raise self._handle_fcm_error(error) return cast(str, resp['name']) - def send_each(self, messages: List[Message], dry_run: bool = False) -> BatchResponse: + def send_each(self, messages: list[Message], dry_run: bool = False) -> BatchResponse: """Sends the given messages to FCM via the FCM v1 API.""" if not isinstance(messages, list): raise ValueError('messages must be a list of messaging.Message instances.') if len(messages) > 500: raise ValueError('messages must not contain more than 500 elements.') - def send_data(data): + def send_data(data: dict[str, Any]) -> SendResponse: try: resp = self._client.body( 'post', @@ -456,14 +483,14 @@ def send_data(data): message=f'Unknown error while making remote service calls: {error}', cause=error) - async def send_each_async(self, messages: List[Message], dry_run: bool = True) -> BatchResponse: + async def send_each_async(self, messages: list[Message], dry_run: bool = True) -> BatchResponse: """Sends the given messages to FCM via the FCM v1 API.""" if not isinstance(messages, list): raise ValueError('messages must be a list of messaging.Message instances.') if len(messages) > 500: raise ValueError('messages must not contain more than 500 elements.') - async def send_data(data): + async def send_data(data: dict[str, Any]) -> SendResponse: try: resp = await self._async_client.request( 'post', @@ -486,7 +513,12 @@ async def send_data(data): message=f'Unknown error while making remote service calls: {error}', cause=error) - def make_topic_management_request(self, tokens, topic, operation): + def make_topic_management_request( + self, + tokens: Union[list[str], str], + topic: str, + operation: str, + ) -> TopicManagementResponse: """Invokes the IID service for topic management functionality.""" if isinstance(tokens, str): tokens = [tokens] @@ -516,18 +548,18 @@ def make_topic_management_request(self, tokens, topic, operation): raise self._handle_iid_error(error) return TopicManagementResponse(resp) - def _message_data(self, message, dry_run): - data = {'message': _MessagingService.encode_message(message)} + def _message_data(self, message: Message, dry_run: bool) -> dict[str, Any]: + data: dict[str, Any] = {'message': _MessagingService.encode_message(message)} if dry_run: data['validate_only'] = True return data - def _postproc(self, _, body): + def _postproc(self, _: 'httplib2.Response', body: bytes) -> Any: """Handle response from batch API request.""" # This only gets called for 2xx responses. return json.loads(body.decode()) - def _handle_fcm_error(self, error): + def _handle_fcm_error(self, error: requests.RequestException) -> exceptions.FirebaseError: """Handles errors received from the FCM API.""" return _utils.handle_platform_error_from_requests( error, _MessagingService._build_fcm_error_requests) @@ -537,12 +569,12 @@ def _handle_fcm_httpx_error(self, error: httpx.HTTPError) -> exceptions.Firebase return _utils.handle_platform_error_from_httpx( error, _MessagingService._build_fcm_error_httpx) - def _handle_iid_error(self, error): + def _handle_iid_error(self, error: requests.RequestException) -> exceptions.FirebaseError: """Handles errors received from the Instance ID API.""" if error.response is None: raise _utils.handle_requests_error(error) - data = {} + data: dict[str, Any] = {} try: parsed_body = error.response.json() if isinstance(parsed_body, dict): @@ -567,41 +599,45 @@ def close(self) -> None: asyncio.run(self._async_client.aclose()) @classmethod - def _build_fcm_error_requests(cls, error, message, error_dict): + def _build_fcm_error_requests( + cls, + error: requests.RequestException, + message: str, + error_dict: dict[str, Any], + ) -> Optional[exceptions.FirebaseError]: """Parses an error response from the FCM API and creates a FCM-specific exception if appropriate.""" exc_type = cls._build_fcm_error(error_dict) - # pylint: disable=not-callable - return exc_type(message, cause=error, http_response=error.response) if exc_type else None + return exc_type(message, error, error.response) if exc_type else None @classmethod def _build_fcm_error_httpx( - cls, - error: httpx.HTTPError, - message: str, - error_dict: Optional[Dict[str, Any]] - ) -> Optional[exceptions.FirebaseError]: + cls, + error: httpx.HTTPError, + message: str, + error_dict: Optional[dict[str, Any]], + ) -> Optional[exceptions.FirebaseError]: """Parses a httpx error response from the FCM API and creates a FCM-specific exception if appropriate.""" exc_type = cls._build_fcm_error(error_dict) if isinstance(error, httpx.HTTPStatusError): - # pylint: disable=not-callable - return exc_type( - message, cause=error, http_response=error.response) if exc_type else None - # pylint: disable=not-callable - return exc_type(message, cause=error) if exc_type else None + return exc_type(message, error, error.response) if exc_type else None + return exc_type(message, error, None) if exc_type else None @classmethod def _build_fcm_error( - cls, - error_dict: Optional[Dict[str, Any]] - ) -> Optional[Callable[..., exceptions.FirebaseError]]: + cls, + error_dict: Optional[dict[str, Any]], + ) -> Optional[Callable[ + [str, Optional[Exception], Optional[Union[httpx.Response, requests.Response]]], + exceptions.FirebaseError + ]]: """Parses an error response to determine the appropriate FCM-specific error type.""" if not error_dict: return None fcm_code = None for detail in error_dict.get('details', []): if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': - fcm_code = detail.get('errorCode') + fcm_code = cast(str, detail.get('errorCode')) break return _MessagingService.FCM_ERROR_TYPES.get(fcm_code) if fcm_code else None diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 3a77dd05f..38b2f69af 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -18,12 +18,13 @@ deleting, publishing and unpublishing Firebase ML models. """ - import datetime +import os import re import time -import os -from urllib import parse +import urllib.parse +from collections.abc import Callable, Iterator +from typing import TYPE_CHECKING, Any, Optional, Union, cast import requests @@ -32,19 +33,46 @@ from firebase_admin import _utils from firebase_admin import exceptions -# pylint: disable=import-error,no-member -try: +if TYPE_CHECKING: + import tensorflow as tf + from _typeshed import Incomplete from firebase_admin import storage - _GCS_ENABLED = True -except ImportError: - _GCS_ENABLED = False -# pylint: disable=import-error,no-member -try: - import tensorflow as tf - _TF_ENABLED = True -except ImportError: - _TF_ENABLED = False + _GCS_ENABLED: bool + _TF_ENABLED: bool +else: + Incomplete = Any + + # pylint: disable=import-error,no-member + try: + from firebase_admin import storage + _GCS_ENABLED = True + except ImportError: + _GCS_ENABLED = False + + # pylint: disable=import-error,no-member + try: + import tensorflow as tf + _TF_ENABLED = True + except ImportError: + _TF_ENABLED = False + +__all__ = ( + 'ListModelsPage', + 'Model', + 'ModelFormat', + 'TFLiteFormat', + 'TFLiteGCSModelSource', + 'TFLiteModelSource', + 'create_model', + 'delete_model', + 'get_model', + 'list_models', + 'publish_model', + 'unpublish_model', + 'update_model', +) + _ML_ATTRIBUTE = '_ml' _MAX_PAGE_SIZE = 100 @@ -59,7 +87,7 @@ r'^projects/(?P[a-z0-9-]{6,30})/operations/[^/]+$') -def _get_ml_service(app): +def _get_ml_service(app: Optional[firebase_admin.App]) -> '_MLService': """ Returns an _MLService instance for an App. Args: @@ -74,7 +102,7 @@ def _get_ml_service(app): return _utils.get_app_service(app, _ML_ATTRIBUTE, _MLService) -def create_model(model, app=None): +def create_model(model: 'Model', app: Optional[firebase_admin.App] = None) -> 'Model': """Creates a model in the current Firebase project. Args: @@ -88,7 +116,7 @@ def create_model(model, app=None): return Model.from_dict(ml_service.create_model(model), app=app) -def update_model(model, app=None): +def update_model(model: 'Model', app: Optional[firebase_admin.App] = None) -> 'Model': """Updates a model's metadata or model file. Args: @@ -102,7 +130,7 @@ def update_model(model, app=None): return Model.from_dict(ml_service.update_model(model), app=app) -def publish_model(model_id, app=None): +def publish_model(model_id: str, app: Optional[firebase_admin.App] = None) -> 'Model': """Publishes a Firebase ML model. A published model can be downloaded to client apps. @@ -118,7 +146,7 @@ def publish_model(model_id, app=None): return Model.from_dict(ml_service.set_published(model_id, publish=True), app=app) -def unpublish_model(model_id, app=None): +def unpublish_model(model_id: str, app: Optional[firebase_admin.App] = None) -> 'Model': """Unpublishes a Firebase ML model. Args: @@ -132,7 +160,7 @@ def unpublish_model(model_id, app=None): return Model.from_dict(ml_service.set_published(model_id, publish=False), app=app) -def get_model(model_id, app=None): +def get_model(model_id: str, app: Optional[firebase_admin.App] = None) -> 'Model': """Gets the model specified by the given ID. Args: @@ -146,7 +174,12 @@ def get_model(model_id, app=None): return Model.from_dict(ml_service.get_model(model_id), app=app) -def list_models(list_filter=None, page_size=None, page_token=None, app=None): +def list_models( + list_filter: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + app: Optional[firebase_admin.App] = None, +) -> 'ListModelsPage': """Lists the current project's models. Args: @@ -165,7 +198,7 @@ def list_models(list_filter=None, page_size=None, page_token=None, app=None): ml_service.list_models, list_filter, page_size, page_token, app=app) -def delete_model(model_id, app=None): +def delete_model(model_id: str, app: Optional[firebase_admin.App] = None) -> None: """Deletes a model from the current project. Args: @@ -184,10 +217,15 @@ class Model: tags: Optional list of strings associated with your model. Can be used in list queries. model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. """ - def __init__(self, display_name=None, tags=None, model_format=None): - self._app = None # Only needed for wait_for_unlo - self._data = {} - self._model_format = None + def __init__( + self, + display_name: Optional[str] = None, + tags: Optional[list[str]] = None, + model_format: Optional['ModelFormat'] = None, + ) -> None: + self._app: Optional[firebase_admin.App] = None # Only needed for wait_for_unlo + self._data: dict[str, Any] = {} + self._model_format: Optional[ModelFormat] = None if display_name is not None: self.display_name = display_name @@ -197,7 +235,11 @@ def __init__(self, display_name=None, tags=None, model_format=None): self.model_format = model_format @classmethod - def from_dict(cls, data, app=None): + def from_dict( + cls, + data: dict[str, Any], + app: Optional[firebase_admin.App] = None, + ) -> 'Model': """Create an instance of the object from a dict.""" data_copy = dict(data) tflite_format = None @@ -210,22 +252,22 @@ def from_dict(cls, data, app=None): model._app = app # pylint: disable=protected-access return model - def _update_from_dict(self, data): + def _update_from_dict(self, data: dict[str, Any]) -> None: copy = Model.from_dict(data) self.model_format = copy.model_format self._data = copy._data # pylint: disable=protected-access - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): # pylint: disable=protected-access return self._data == other._data and self._model_format == other._model_format return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @property - def model_id(self): + def model_id(self) -> Optional[str]: """The model's ID, unique to the project.""" if not self._data.get('name'): return None @@ -233,74 +275,72 @@ def model_id(self): return model_id @property - def display_name(self): + def display_name(self) -> Optional[str]: """The model's display name, used to refer to the model in code and in the Firebase console.""" return self._data.get('displayName') @display_name.setter - def display_name(self, display_name): + def display_name(self, display_name: str) -> None: self._data['displayName'] = _validate_display_name(display_name) - return self @staticmethod - def _convert_to_millis(date_string): + def _convert_to_millis(date_string: Optional[str]) -> Optional[int]: if not date_string: return None format_str = '%Y-%m-%dT%H:%M:%S.%fZ' - epoch = datetime.datetime.utcfromtimestamp(0) - datetime_object = datetime.datetime.strptime(date_string, format_str) + epoch = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) + datetime_object = datetime.datetime.strptime(date_string, format_str).replace(tzinfo=datetime.timezone.utc) millis = int((datetime_object - epoch).total_seconds() * 1000) return millis @property - def create_time(self): + def create_time(self) -> Optional[int]: """The time the model was created.""" return Model._convert_to_millis(self._data.get('createTime', None)) @property - def update_time(self): + def update_time(self) -> Optional[int]: """The time the model was last updated.""" return Model._convert_to_millis(self._data.get('updateTime', None)) @property - def validation_error(self): + def validation_error(self) -> Optional[str]: """Validation error message.""" return self._data.get('state', {}).get('validationError', {}).get('message') @property - def published(self): + def published(self) -> bool: """True if the model is published and available for clients to download.""" return bool(self._data.get('state', {}).get('published')) @property - def etag(self): + def etag(self) -> Optional[Incomplete]: """The entity tag (ETag) of the model resource.""" return self._data.get('etag') @property - def model_hash(self): + def model_hash(self) -> Optional[Incomplete]: """SHA256 hash of the model binary.""" return self._data.get('modelHash') @property - def tags(self): + def tags(self) -> Optional[list[str]]: """Tag strings, used for filtering query results.""" return self._data.get('tags') @tags.setter - def tags(self, tags): + def tags(self, tags: list[str]) -> None: self._data['tags'] = _validate_tags(tags) - return self @property - def locked(self): + def locked(self) -> bool: """True if the Model object is locked by an active operation.""" return bool(self._data.get('activeOperations') and - len(self._data.get('activeOperations')) > 0) + len(self._data['activeOperations']) > 0) - def wait_for_unlocked(self, max_time_seconds=None): + def wait_for_unlocked(self, max_time_seconds: Optional[float] = None) -> None: """Waits for the model to be unlocked. (All active operations complete) Args: @@ -313,7 +353,7 @@ def wait_for_unlocked(self, max_time_seconds=None): if not self.locked: return ml_service = _get_ml_service(self._app) - op_name = self._data.get('activeOperations')[0].get('name') + op_name = self._data['activeOperations'][0].get('name') model_dict = ml_service.handle_operation( ml_service.get_operation(op_name), wait_for_operation=True, @@ -321,19 +361,18 @@ def wait_for_unlocked(self, max_time_seconds=None): self._update_from_dict(model_dict) @property - def model_format(self): + def model_format(self) -> Optional['ModelFormat']: """The model's ``ModelFormat`` object, which represents the model's format and storage location.""" return self._model_format @model_format.setter - def model_format(self, model_format): + def model_format(self, model_format: Optional['ModelFormat']) -> None: if model_format is not None: _validate_model_format(model_format) self._model_format = model_format #Can be None - return self - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> dict[str, Any]: """Returns a serializable representation of the object.""" copy = dict(self._data) if self._model_format: @@ -343,7 +382,7 @@ def as_dict(self, for_upload=False): class ModelFormat: """Abstract base class representing a Model Format such as TFLite.""" - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> dict[str, Any]: """Returns a serializable representation of the object.""" raise NotImplementedError @@ -354,32 +393,32 @@ class TFLiteFormat(ModelFormat): Args: model_source: A TFLiteModelSource sub class. Specifies the details of the model source. """ - def __init__(self, model_source=None): - self._data = {} - self._model_source = None + def __init__(self, model_source: Optional['TFLiteModelSource'] = None) -> None: + self._data: dict[str, Any] = {} + self._model_source: Optional[TFLiteModelSource] = None if model_source is not None: self.model_source = model_source @classmethod - def from_dict(cls, data): + def from_dict(cls, data: dict[str, Any]) -> 'TFLiteFormat': """Create an instance of the object from a dict.""" data_copy = dict(data) tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy)) tflite_format._data = data_copy # pylint: disable=protected-access return tflite_format - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): # pylint: disable=protected-access return self._data == other._data and self._model_source == other._model_source return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @staticmethod - def _init_model_source(data): + def _init_model_source(data: dict[str, Any]) -> Optional['TFLiteModelSource']: """Initialize the ML model source.""" gcs_tflite_uri = data.pop('gcsTfliteUri', None) if gcs_tflite_uri: @@ -387,23 +426,23 @@ def _init_model_source(data): return None @property - def model_source(self): + def model_source(self) -> Optional['TFLiteModelSource']: """The TF Lite model's location.""" return self._model_source @model_source.setter - def model_source(self, model_source): + def model_source(self, model_source: Optional['TFLiteModelSource']) -> None: if model_source is not None: if not isinstance(model_source, TFLiteModelSource): raise TypeError('Model source must be a TFLiteModelSource object.') self._model_source = model_source # Can be None @property - def size_bytes(self): + def size_bytes(self) -> Optional[Incomplete]: """The size in bytes of the TF Lite model.""" return self._data.get('sizeBytes') - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> dict[str, Any]: """Returns a serializable representation of the object.""" copy = dict(self._data) if self._model_source: @@ -413,7 +452,7 @@ def as_dict(self, for_upload=False): class TFLiteModelSource: """Abstract base class representing a model source for TFLite format models.""" - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> dict[str, Any]: """Returns a serializable representation of the object.""" raise NotImplementedError @@ -425,13 +464,13 @@ class _CloudStorageClient: BLOB_NAME = 'Firebase/ML/Models/{0}' @staticmethod - def _assert_gcs_enabled(): + def _assert_gcs_enabled() -> None: if not _GCS_ENABLED: raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' 'to install the "google-cloud-storage" module.') @staticmethod - def _parse_gcs_tflite_uri(uri): + def _parse_gcs_tflite_uri(uri: str) -> tuple[str, str]: # GCS Bucket naming rules are complex. The regex is not comprehensive. # See https://cloud.google.com/storage/docs/naming for full details. matcher = _GCS_TFLITE_URI_PATTERN.match(uri) @@ -440,10 +479,13 @@ def _parse_gcs_tflite_uri(uri): return matcher.group('bucket_name'), matcher.group('blob_name') @staticmethod - def upload(bucket_name, model_file_name, app): + def upload( + bucket_name: Optional[str], + model_file_name: Union[str, os.PathLike[str]], + app: Optional[firebase_admin.App], + ) -> str: """Upload a model file to the specified Storage bucket.""" _CloudStorageClient._assert_gcs_enabled() - file_name = os.path.basename(model_file_name) bucket = storage.bucket(bucket_name, app=app) blob_name = _CloudStorageClient.BLOB_NAME.format(file_name) @@ -452,7 +494,7 @@ def upload(bucket_name, model_file_name, app): return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name) @staticmethod - def sign_uri(gcs_tflite_uri, app): + def sign_uri(gcs_tflite_uri: str, app: Optional[firebase_admin.App]) -> str: """Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri.""" _CloudStorageClient._assert_gcs_enabled() bucket_name, blob_name = _CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) @@ -470,20 +512,29 @@ class TFLiteGCSModelSource(TFLiteModelSource): _STORAGE_CLIENT = _CloudStorageClient() - def __init__(self, gcs_tflite_uri, app=None): + def __init__( + self, + gcs_tflite_uri: str, + app: Optional[firebase_admin.App] = None, + ) -> None: self._app = app self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @classmethod - def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): + def from_tflite_model_file( + cls, + model_file_name: Union[str, os.PathLike[str]], + bucket_name: Optional[str] = None, + app: Optional[firebase_admin.App] = None, + ) -> 'TFLiteGCSModelSource': """Uploads the model file to an existing Google Cloud Storage bucket. Args: @@ -502,7 +553,7 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) @staticmethod - def _assert_tf_enabled(): + def _assert_tf_enabled() -> None: if not _TF_ENABLED: raise ImportError('Failed to import the tensorflow library for Python. Make sure ' 'to install the tensorflow module.') @@ -511,13 +562,13 @@ def _assert_tf_enabled(): f'Expected tensorflow version 1.x or 2.x, but found {tf.version.VERSION}') @staticmethod - def _tf_convert_from_saved_model(saved_model_dir): + def _tf_convert_from_saved_model(saved_model_dir: Incomplete) -> Incomplete: # Same for both v1.x and v2.x converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) return converter.convert() @staticmethod - def _tf_convert_from_keras_model(keras_model): + def _tf_convert_from_keras_model(keras_model: Incomplete) -> Incomplete: """Converts the given Keras model into a TF Lite model.""" # Version 1.x conversion function takes a model file. Version 2.x takes the model itself. if tf.version.VERSION.startswith('1.'): @@ -530,8 +581,13 @@ def _tf_convert_from_keras_model(keras_model): return converter.convert() @classmethod - def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', - bucket_name=None, app=None): + def from_saved_model( + cls, + saved_model_dir: Incomplete, + model_file_name: Union[str, os.PathLike[str]] = 'firebase_ml_model.tflite', + bucket_name: Optional[str] = None, + app: Optional[firebase_admin.App] = None, + ) -> 'TFLiteGCSModelSource': """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. Args: @@ -554,8 +610,13 @@ def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tf return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @classmethod - def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite', - bucket_name=None, app=None): + def from_keras_model( + cls, + keras_model: os.PathLike[str], + model_file_name: str = 'firebase_ml_model.tflite', + bucket_name: Optional[str] = None, + app: Optional[firebase_admin.App] = None, + ) -> 'TFLiteGCSModelSource': """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. Args: @@ -578,25 +639,26 @@ def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @property - def gcs_tflite_uri(self): + def gcs_tflite_uri(self) -> str: """URI of the model file in Cloud Storage.""" return self._gcs_tflite_uri @gcs_tflite_uri.setter - def gcs_tflite_uri(self, gcs_tflite_uri): + def gcs_tflite_uri(self, gcs_tflite_uri: str) -> None: self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) - def _get_signed_gcs_tflite_uri(self): + def _get_signed_gcs_tflite_uri(self) -> str: """Signs the GCS uri, so the model file can be uploaded to Firebase ML and verified.""" return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app) - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> dict[str, Any]: """Returns a serializable representation of the object.""" if for_upload: return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()} return {'gcsTfliteUri': self._gcs_tflite_uri} + class ListModelsPage: """Represents a page of models in a Firebase project. @@ -605,7 +667,17 @@ class ListModelsPage: ``iterate_all()`` can be used to iterate through all the models in the Firebase project starting from this page. """ - def __init__(self, list_models_func, list_filter, page_size, page_token, app): + def __init__( + self, + list_models_func: Callable[ + [Optional[str], Optional[int], Optional[str]], + dict[str, Any] + ], + list_filter: Optional[str], + page_size: Optional[int], + page_token: Optional[str], + app: Optional[firebase_admin.App], + ) -> None: self._list_models_func = list_models_func self._list_filter = list_filter self._page_size = page_size @@ -614,28 +686,32 @@ def __init__(self, list_models_func, list_filter, page_size, page_token, app): self._list_response = list_models_func(list_filter, page_size, page_token) @property - def models(self): + def models(self) -> list[Model]: """A list of Models from this page.""" return [ - Model.from_dict(model, app=self._app) for model in self._list_response.get('models', []) + Model.from_dict(model, app=self._app) + for model in cast( + list[dict[str, Any]], + self._list_response.get('models', []), + ) ] @property - def list_filter(self): + def list_filter(self) -> Optional[str]: """The filter string used to filter the models.""" return self._list_filter @property - def next_page_token(self): + def next_page_token(self) -> str: """Token identifying the next page of results.""" - return self._list_response.get('nextPageToken', '') + return cast(str, self._list_response.get('nextPageToken', '')) @property - def has_next_page(self): + def has_next_page(self) -> bool: """True if more pages are available.""" return bool(self.next_page_token) - def get_next_page(self): + def get_next_page(self) -> Optional['ListModelsPage']: """Retrieves the next page of models if available. Returns: @@ -650,7 +726,7 @@ def get_next_page(self): self._app) return None - def iterate_all(self): + def iterate_all(self) -> '_ModelIterator': """Retrieves an iterator for Models. Returned iterator will iterate through all the models in the Firebase @@ -670,16 +746,16 @@ class _ModelIterator: When the whole page has been traversed, it loads another page. This class never keeps more than one page of entries in memory. """ - def __init__(self, current_page): + def __init__(self, current_page: ListModelsPage) -> None: if not isinstance(current_page, ListModelsPage): raise TypeError('Current page must be a ListModelsPage') self._current_page = current_page - self._index = 0 + self._index: int = 0 - def __next__(self): + def __next__(self) -> Model: if self._index == len(self._current_page.models): if self._current_page.has_next_page: - self._current_page = self._current_page.get_next_page() + self._current_page = cast(ListModelsPage, self._current_page.get_next_page()) self._index = 0 if self._index < len(self._current_page.models): result = self._current_page.models[self._index] @@ -687,11 +763,11 @@ def __next__(self): return result raise StopIteration - def __iter__(self): + def __iter__(self) -> Iterator[Model]: return self -def _validate_and_parse_name(name): +def _validate_and_parse_name(name: Any) -> tuple[str, str]: # The resource name is added automatically from API call responses. # The only way it could be invalid is if someone tries to # create a model from a dictionary manually and does it incorrectly. @@ -701,40 +777,41 @@ def _validate_and_parse_name(name): return matcher.group('project_id'), matcher.group('model_id') -def _validate_model(model, update_mask=None): +def _validate_model(model: Model, update_mask: Optional[str] = None) -> None: if not isinstance(model, Model): raise TypeError('Model must be an ml.Model.') if update_mask is None and not model.display_name: raise ValueError('Model must have a display name.') -def _validate_model_id(model_id): +def _validate_model_id(model_id: str) -> None: if not _MODEL_ID_PATTERN.match(model_id): raise ValueError('Model ID format is invalid.') -def _validate_operation_name(op_name): +def _validate_operation_name(op_name: Any) -> str: if not _OPERATION_NAME_PATTERN.match(op_name): raise ValueError('Operation name format is invalid.') return op_name -def _validate_display_name(display_name): +def _validate_display_name(display_name: Any) -> str: if not _DISPLAY_NAME_PATTERN.match(display_name): raise ValueError('Display name format is invalid.') return display_name -def _validate_tags(tags): +def _validate_tags(tags: Any) -> list[str]: if not isinstance(tags, list) or not \ all(isinstance(tag, str) for tag in tags): raise TypeError('Tags must be a list of strings.') + tags = cast(list[str], tags) if not all(_TAG_PATTERN.match(tag) for tag in tags): raise ValueError('Tag format is invalid.') return tags -def _validate_gcs_tflite_uri(uri): +def _validate_gcs_tflite_uri(uri: str) -> str: # GCS Bucket naming rules are complex. The regex is not comprehensive. # See https://cloud.google.com/storage/docs/naming for full details. if not _GCS_TFLITE_URI_PATTERN.match(uri): @@ -742,19 +819,19 @@ def _validate_gcs_tflite_uri(uri): return uri -def _validate_model_format(model_format): +def _validate_model_format(model_format: Any) -> ModelFormat: if not isinstance(model_format, ModelFormat): raise TypeError('Model format must be a ModelFormat object.') return model_format -def _validate_list_filter(list_filter): +def _validate_list_filter(list_filter: Optional[str]) -> None: if list_filter is not None: if not isinstance(list_filter, str): raise TypeError('List filter must be a string or None.') -def _validate_page_size(page_size): +def _validate_page_size(page_size: Optional[int]) -> None: if page_size is not None: if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck # Specifically type() to disallow boolean which is a subtype of int @@ -764,7 +841,7 @@ def _validate_page_size(page_size): f'Page size must be a positive integer between 1 and {_MAX_PAGE_SIZE}') -def _validate_page_token(page_token): +def _validate_page_token(page_token: Optional[str]) -> None: if page_token is not None: if not isinstance(page_token, str): raise TypeError('Page token must be a string or None.') @@ -778,7 +855,7 @@ class _MLService: POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5 POLL_BASE_WAIT_TIME_SECONDS = 3 - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: self._project_id = app.project_id if not self._project_id: raise ValueError( @@ -797,14 +874,14 @@ def __init__(self, app): headers=ml_headers, base_url=_MLService.OPERATION_URL) - def get_operation(self, op_name): + def get_operation(self, op_name: str) -> dict[str, Any]: _validate_operation_name(op_name) try: return self._operation_client.body('get', url=op_name) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def _exponential_backoff(self, current_attempt, stop_time): + def _exponential_backoff(self, current_attempt: int, stop_time: Optional[datetime.datetime]) -> None: """Sleeps for the appropriate amount of time. Or throws deadline exceeded.""" delay_factor = pow(_MLService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) wait_time_seconds = delay_factor * _MLService.POLL_BASE_WAIT_TIME_SECONDS @@ -816,7 +893,12 @@ def _exponential_backoff(self, current_attempt, stop_time): wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) time.sleep(wait_time_seconds) - def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None): + def handle_operation( + self, + operation: dict[str, Any], + wait_for_operation: bool = False, + max_time_seconds: Optional[float] = None, + ) -> dict[str, Any]: """Handles long running operations. Args: @@ -841,13 +923,14 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds if operation.get('done'): # Operations which are immediately done don't have an operation name if operation.get('response'): - return operation.get('response') + return cast(dict[str, Any], operation['response']) if operation.get('error'): - raise _utils.handle_operation_error(operation.get('error')) + error = cast(dict[str, Any], operation['error']) + raise _utils.handle_operation_error(error) raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') op_name = _validate_operation_name(operation.get('name')) - metadata = operation.get('metadata', {}) + metadata = cast(dict[str, Any], operation.get('metadata', {})) metadata_type = metadata.get('@type', '') if not metadata_type.endswith('ModelOperationMetadata'): raise TypeError('Unknown type of operation metadata.') @@ -865,15 +948,16 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds if operation.get('done'): if operation.get('response'): - return operation.get('response') + return cast(dict[str, Any], operation['response']) if operation.get('error'): - raise _utils.handle_operation_error(operation.get('error')) + error = cast(dict[str, Any], operation['error']) + raise _utils.handle_operation_error(error) # If the operation is not complete or timed out, return a (locked) model instead return get_model(model_id).as_dict() - def create_model(self, model): + def create_model(self, model: Model) -> dict[str, Any]: _validate_model(model) try: return self.handle_operation( @@ -881,7 +965,7 @@ def create_model(self, model): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def update_model(self, model, update_mask=None): + def update_model(self, model: Model, update_mask: Optional[str] = None) -> dict[str, Any]: _validate_model(model, update_mask) path = f'models/{model.model_id}' if update_mask is not None: @@ -892,7 +976,7 @@ def update_model(self, model, update_mask=None): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def set_published(self, model_id, publish): + def set_published(self, model_id: str, publish: bool) -> dict[str, Any]: _validate_model_id(model_id) model_name = f'projects/{self._project_id}/models/{model_id}' model = Model.from_dict({ @@ -903,19 +987,24 @@ def set_published(self, model_id, publish): }) return self.update_model(model, update_mask='state.published') - def get_model(self, model_id): + def get_model(self, model_id: str) -> dict[str, Any]: _validate_model_id(model_id) try: return self._client.body('get', url=f'models/{model_id}') except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def list_models(self, list_filter, page_size, page_token): + def list_models( + self, + list_filter: Optional[str], + page_size: Optional[int], + page_token: Optional[str], + ) -> dict[str, Any]: """ lists Firebase ML models.""" _validate_list_filter(list_filter) _validate_page_size(page_size) _validate_page_token(page_token) - params = {} + params: dict[str, Any] = {} if list_filter: params['filter'] = list_filter if page_size: @@ -924,14 +1013,14 @@ def list_models(self, list_filter, page_size, page_token): params['page_token'] = page_token path = 'models' if params: - param_str = parse.urlencode(sorted(params.items()), True) + param_str = urllib.parse.urlencode(sorted(params.items()), True) path = path + '?' + param_str try: return self._client.body('get', url=path) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def delete_model(self, model_id): + def delete_model(self, model_id: str) -> None: _validate_model_id(model_id) try: self._client.body('delete', url=f'models/{model_id}') diff --git a/firebase_admin/project_management.py b/firebase_admin/project_management.py index 73c100d3a..e952390ae 100644 --- a/firebase_admin/project_management.py +++ b/firebase_admin/project_management.py @@ -20,6 +20,15 @@ import base64 import re import time +from collections.abc import Callable +from typing import ( + Any, + Optional, + NoReturn, + TypeVar, + cast, + overload, +) import requests @@ -28,15 +37,31 @@ from firebase_admin import _http_client from firebase_admin import _utils +__all__ = ( + 'AndroidApp', + 'AndroidAppMetadata', + 'IOSApp', + 'IOSAppMetadata', + 'SHACertificate', + 'android_app', + 'create_android_app', + 'create_ios_app', + 'ios_app', + 'list_android_apps', + 'list_ios_apps', +) + +_T = TypeVar('_T') +_AppMetadataT = TypeVar('_AppMetadataT', bound='_AppMetadata') _PROJECT_MANAGEMENT_ATTRIBUTE = '_project_management' -def _get_project_management_service(app): +def _get_project_management_service(app: Optional[firebase_admin.App]) -> '_ProjectManagementService': return _utils.get_app_service(app, _PROJECT_MANAGEMENT_ATTRIBUTE, _ProjectManagementService) -def android_app(app_id, app=None): +def android_app(app_id: str, app: Optional[firebase_admin.App] = None) -> 'AndroidApp': """Obtains a reference to an Android app in the associated Firebase project. Args: @@ -49,7 +74,7 @@ def android_app(app_id, app=None): return AndroidApp(app_id=app_id, service=_get_project_management_service(app)) -def ios_app(app_id, app=None): +def ios_app(app_id: str, app: Optional[firebase_admin.App] = None) -> 'IOSApp': """Obtains a reference to an iOS app in the associated Firebase project. Args: @@ -62,7 +87,7 @@ def ios_app(app_id, app=None): return IOSApp(app_id=app_id, service=_get_project_management_service(app)) -def list_android_apps(app=None): +def list_android_apps(app: Optional[firebase_admin.App] = None) -> list['AndroidApp']: """Lists all Android apps in the associated Firebase project. Args: @@ -75,7 +100,7 @@ def list_android_apps(app=None): return _get_project_management_service(app).list_android_apps() -def list_ios_apps(app=None): +def list_ios_apps(app: Optional[firebase_admin.App] = None) -> list['IOSApp']: """Lists all iOS apps in the associated Firebase project. Args: @@ -87,7 +112,11 @@ def list_ios_apps(app=None): return _get_project_management_service(app).list_ios_apps() -def create_android_app(package_name, display_name=None, app=None): +def create_android_app( + package_name: str, + display_name: Optional[str] = None, + app: Optional[firebase_admin.App] = None, +) -> 'AndroidApp': """Creates a new Android app in the associated Firebase project. Args: @@ -101,7 +130,11 @@ def create_android_app(package_name, display_name=None, app=None): return _get_project_management_service(app).create_android_app(package_name, display_name) -def create_ios_app(bundle_id, display_name=None, app=None): +def create_ios_app( + bundle_id: str, + display_name: Optional[str] = None, + app: Optional[firebase_admin.App] = None, +) -> 'IOSApp': """Creates a new iOS app in the associated Firebase project. Args: @@ -115,25 +148,29 @@ def create_ios_app(bundle_id, display_name=None, app=None): return _get_project_management_service(app).create_ios_app(bundle_id, display_name) -def _check_is_string_or_none(obj, field_name): +def _check_is_string_or_none(obj: Any, field_name: str) -> Optional[str]: if obj is None or isinstance(obj, str): return obj raise ValueError(f'{field_name} must be a string.') -def _check_is_nonempty_string(obj, field_name): +def _check_is_nonempty_string(obj: Any, field_name: str) -> str: if isinstance(obj, str) and obj: return obj raise ValueError(f'{field_name} must be a non-empty string.') -def _check_is_nonempty_string_or_none(obj, field_name): +def _check_is_nonempty_string_or_none(obj: Any, field_name: str) -> Optional[str]: if obj is None: return None return _check_is_nonempty_string(obj, field_name) -def _check_not_none(obj, field_name): +@overload +def _check_not_none(obj: None, field_name: str) -> NoReturn: ... +@overload +def _check_not_none(obj: _T, field_name: str) -> _T: ... +def _check_not_none(obj: Optional[_T], field_name: str) -> _T: if obj is None: raise ValueError(f'{field_name} cannot be None.') return obj @@ -148,12 +185,12 @@ class AndroidApp: instead of instantiating it directly. """ - def __init__(self, app_id, service): + def __init__(self, app_id: str, service: '_ProjectManagementService') -> None: self._app_id = app_id self._service = service @property - def app_id(self): + def app_id(self) -> str: """Returns the app ID of the Android app to which this instance refers. Note: This method does not make an RPC. @@ -163,7 +200,7 @@ def app_id(self): """ return self._app_id - def get_metadata(self): + def get_metadata(self) -> 'AndroidAppMetadata': """Retrieves detailed information about this Android app. Returns: @@ -175,7 +212,7 @@ def get_metadata(self): """ return self._service.get_android_app_metadata(self._app_id) - def set_display_name(self, new_display_name): + def set_display_name(self, new_display_name: Optional[str]) -> None: """Updates the display name attribute of this Android app to the one given. Args: @@ -188,13 +225,13 @@ def set_display_name(self, new_display_name): FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ - return self._service.set_android_app_display_name(self._app_id, new_display_name) + self._service.set_android_app_display_name(self._app_id, new_display_name) - def get_config(self): + def get_config(self) -> str: """Retrieves the configuration artifact associated with this Android app.""" return self._service.get_android_app_config(self._app_id) - def get_sha_certificates(self): + def get_sha_certificates(self) -> list['SHACertificate']: """Retrieves the entire list of SHA certificates associated with this Android app. Returns: @@ -206,7 +243,7 @@ def get_sha_certificates(self): """ return self._service.get_sha_certificates(self._app_id) - def add_sha_certificate(self, certificate_to_add): + def add_sha_certificate(self, certificate_to_add: 'SHACertificate') -> None: """Adds a SHA certificate to this Android app. Args: @@ -219,9 +256,9 @@ def add_sha_certificate(self, certificate_to_add): FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. (For example, if the certificate_to_add already exists.) """ - return self._service.add_sha_certificate(self._app_id, certificate_to_add) + self._service.add_sha_certificate(self._app_id, certificate_to_add) - def delete_sha_certificate(self, certificate_to_delete): + def delete_sha_certificate(self, certificate_to_delete: 'SHACertificate') -> None: """Removes a SHA certificate from this Android app. Args: @@ -234,7 +271,7 @@ def delete_sha_certificate(self, certificate_to_delete): FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. (For example, if the certificate_to_delete is not found.) """ - return self._service.delete_sha_certificate(certificate_to_delete) + self._service.delete_sha_certificate(certificate_to_delete) class IOSApp: @@ -246,12 +283,12 @@ class IOSApp: instead of instantiating it directly. """ - def __init__(self, app_id, service): + def __init__(self, app_id: str, service: '_ProjectManagementService') -> None: self._app_id = app_id self._service = service @property - def app_id(self): + def app_id(self) -> str: """Returns the app ID of the iOS app to which this instance refers. Note: This method does not make an RPC. @@ -261,7 +298,7 @@ def app_id(self): """ return self._app_id - def get_metadata(self): + def get_metadata(self) -> 'IOSAppMetadata': """Retrieves detailed information about this iOS app. Returns: @@ -273,7 +310,7 @@ def get_metadata(self): """ return self._service.get_ios_app_metadata(self._app_id) - def set_display_name(self, new_display_name): + def set_display_name(self, new_display_name: Optional[str]) -> None: """Updates the display name attribute of this iOS app to the one given. Args: @@ -286,9 +323,9 @@ def set_display_name(self, new_display_name): FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ - return self._service.set_ios_app_display_name(self._app_id, new_display_name) + self._service.set_ios_app_display_name(self._app_id, new_display_name) - def get_config(self): + def get_config(self) -> str: """Retrieves the configuration artifact associated with this iOS app.""" return self._service.get_ios_app_config(self._app_id) @@ -296,7 +333,7 @@ def get_config(self): class _AppMetadata: """Detailed information about a Firebase Android or iOS app.""" - def __init__(self, name, app_id, display_name, project_id): + def __init__(self, name: str, app_id: str, display_name: Optional[str], project_id: str) -> None: # _name is the fully qualified resource name of this Android or iOS app; currently it is not # exposed to client code. self._name = _check_is_nonempty_string(name, 'name') @@ -305,7 +342,7 @@ def __init__(self, name, app_id, display_name, project_id): self._project_id = _check_is_nonempty_string(project_id, 'project_id') @property - def app_id(self): + def app_id(self) -> str: """The globally unique, Firebase-assigned identifier of this Android or iOS app. This ID is unique even across apps of different platforms. @@ -313,18 +350,18 @@ def app_id(self): return self._app_id @property - def display_name(self): + def display_name(self) -> Optional[str]: """The user-assigned display name of this Android or iOS app. Note that the display name can be None if it has never been set by the user.""" return self._display_name @property - def project_id(self): + def project_id(self) -> str: """The permanent, globally unique, user-assigned ID of the parent Firebase project.""" return self._project_id - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return False # pylint: disable=protected-access @@ -336,24 +373,31 @@ def __eq__(self, other): class AndroidAppMetadata(_AppMetadata): """Android-specific information about an Android Firebase app.""" - def __init__(self, package_name, name, app_id, display_name, project_id): + def __init__( + self, + package_name: str, + name: str, + app_id: str, + display_name: Optional[str], + project_id: str, + ) -> None: """Clients should not instantiate this class directly.""" super().__init__(name, app_id, display_name, project_id) self._package_name = _check_is_nonempty_string(package_name, 'package_name') @property - def package_name(self): + def package_name(self) -> str: """The canonical package name of this Android app as it would appear in the Play Store.""" return self._package_name - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return (super().__eq__(other) and self.package_name == other.package_name) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash( (self._name, self.app_id, self.display_name, self.project_id, self.package_name)) @@ -361,23 +405,30 @@ def __hash__(self): class IOSAppMetadata(_AppMetadata): """iOS-specific information about an iOS Firebase app.""" - def __init__(self, bundle_id, name, app_id, display_name, project_id): + def __init__( + self, + bundle_id: str, + name: str, + app_id: str, + display_name: Optional[str], + project_id: str, + ) -> None: """Clients should not instantiate this class directly.""" super().__init__(name, app_id, display_name, project_id) self._bundle_id = _check_is_nonempty_string(bundle_id, 'bundle_id') @property - def bundle_id(self): + def bundle_id(self) -> str: """The canonical bundle ID of this iOS app as it would appear in the iOS AppStore.""" return self._bundle_id - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return super().__eq__(other) and self.bundle_id == other.bundle_id - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash((self._name, self.app_id, self.display_name, self.project_id, self.bundle_id)) @@ -390,7 +441,7 @@ class SHACertificate: _SHA_1_RE = re.compile('^[0-9A-Fa-f]{40}$') _SHA_256_RE = re.compile('^[0-9A-Fa-f]{64}$') - def __init__(self, sha_hash, name=None): + def __init__(self, sha_hash: str, name: Optional[str] = None) -> None: """Creates a new SHACertificate instance. Args: @@ -415,7 +466,7 @@ def __init__(self, sha_hash, name=None): 'The supplied certificate hash is neither a valid SHA-1 nor SHA_256 hash.') @property - def name(self): + def name(self) -> Optional[str]: """Returns the fully qualified resource name of this certificate, if known. Returns: @@ -425,7 +476,7 @@ def name(self): return self._name @property - def sha_hash(self): + def sha_hash(self) -> str: """Returns the certificate hash. Returns: @@ -434,7 +485,7 @@ def sha_hash(self): return self._sha_hash @property - def cert_type(self): + def cert_type(self) -> str: """Returns the type of the SHA certificate encoded in the hash. Returns: @@ -442,16 +493,16 @@ def cert_type(self): """ return self._cert_type - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, SHACertificate): return False return (self.name == other.name and self.sha_hash == other.sha_hash and self.cert_type == other.cert_type) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash((self.name, self.sha_hash, self.cert_type)) @@ -469,7 +520,7 @@ class _ProjectManagementService: IOS_APPS_RESOURCE_NAME = 'iosApps' IOS_APP_IDENTIFIER_NAME = 'bundleId' - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -485,73 +536,83 @@ def __init__(self, app): headers={'X-Client-Version': version_header}, timeout=timeout) - def get_android_app_metadata(self, app_id): + def get_android_app_metadata(self, app_id: str) -> AndroidAppMetadata: return self._get_app_metadata( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.ANDROID_APP_IDENTIFIER_NAME, metadata_class=AndroidAppMetadata, app_id=app_id) - def get_ios_app_metadata(self, app_id): + def get_ios_app_metadata(self, app_id: str) -> IOSAppMetadata: return self._get_app_metadata( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.IOS_APP_IDENTIFIER_NAME, metadata_class=IOSAppMetadata, app_id=app_id) - def _get_app_metadata(self, platform_resource_name, identifier_name, metadata_class, app_id): + def _get_app_metadata( + self, + platform_resource_name: str, + identifier_name: str, + metadata_class: Callable[[str, str, str, Optional[str], str], _AppMetadataT], + app_id: str, + ) -> _AppMetadataT: """Retrieves detailed information about an Android or iOS app.""" _check_is_nonempty_string(app_id, 'app_id') path = f'/v1beta1/projects/-/{platform_resource_name}/{app_id}' response = self._make_request('get', path) return metadata_class( response[identifier_name], - name=response['name'], - app_id=response['appId'], - display_name=response.get('displayName') or None, - project_id=response['projectId']) + response['name'], + response['appId'], + response.get('displayName') or None, + response['projectId']) - def set_android_app_display_name(self, app_id, new_display_name): + def set_android_app_display_name(self, app_id: str, new_display_name: Optional[str]) -> None: self._set_display_name( app_id=app_id, new_display_name=new_display_name, platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME) - def set_ios_app_display_name(self, app_id, new_display_name): + def set_ios_app_display_name(self, app_id: str, new_display_name: Optional[str]) -> None: self._set_display_name( app_id=app_id, new_display_name=new_display_name, platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME) - def _set_display_name(self, app_id, new_display_name, platform_resource_name): + def _set_display_name(self, app_id: str, new_display_name: Optional[str], platform_resource_name: str) -> None: """Sets the display name of an Android or iOS app.""" path = f'/v1beta1/projects/-/{platform_resource_name}/{app_id}?updateMask=displayName' request_body = {'displayName': new_display_name} self._make_request('patch', path, json=request_body) - def list_android_apps(self): + def list_android_apps(self) -> list[AndroidApp]: return self._list_apps( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, app_class=AndroidApp) - def list_ios_apps(self): + def list_ios_apps(self) -> list[IOSApp]: return self._list_apps( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, app_class=IOSApp) - def _list_apps(self, platform_resource_name, app_class): + def _list_apps( + self, + platform_resource_name: str, + app_class: Callable[[str, '_ProjectManagementService'], _T], + ) -> list[_T]: """Lists all the Android or iOS apps within the Firebase project.""" path = ( f'/v1beta1/projects/{self._project_id}/{platform_resource_name}?pageSize=' f'{_ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE}' ) response = self._make_request('get', path) - apps_list = [] + apps_list: list[_T] = [] while True: - apps = response.get('apps') + apps = cast(list[dict[str, Any]], response.get('apps', [])) if not apps: break - apps_list.extend(app_class(app_id=app['appId'], service=self) for app in apps) + apps_list.extend(app_class(app['appId'], self) for app in apps) next_page_token = response.get('nextPageToken') if not next_page_token: break @@ -564,7 +625,7 @@ def _list_apps(self, platform_resource_name, app_class): response = self._make_request('get', path) return apps_list - def create_android_app(self, package_name, display_name=None): + def create_android_app(self, package_name: str, display_name: Optional[str] = None) -> AndroidApp: return self._create_app( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.ANDROID_APP_IDENTIFIER_NAME, @@ -572,7 +633,7 @@ def create_android_app(self, package_name, display_name=None): display_name=display_name, app_class=AndroidApp) - def create_ios_app(self, bundle_id, display_name=None): + def create_ios_app(self, bundle_id: str, display_name: Optional[str] = None) -> IOSApp: return self._create_app( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.IOS_APP_IDENTIFIER_NAME, @@ -581,12 +642,13 @@ def create_ios_app(self, bundle_id, display_name=None): app_class=IOSApp) def _create_app( - self, - platform_resource_name, - identifier_name, - identifier, - display_name, - app_class): + self, + platform_resource_name: str, + identifier_name: str, + identifier: str, + display_name: Optional[str], + app_class: Callable[[str, '_ProjectManagementService'], _T], + ) -> _T: """Creates an Android or iOS app.""" _check_is_string_or_none(display_name, 'display_name') path = f'/v1beta1/projects/{self._project_id}/{platform_resource_name}' @@ -596,9 +658,9 @@ def _create_app( response = self._make_request('post', path, json=request_body) operation_name = response['name'] poll_response = self._poll_app_creation(operation_name) - return app_class(app_id=poll_response['appId'], service=self) + return app_class(poll_response['appId'], self) - def _poll_app_creation(self, operation_name): + def _poll_app_creation(self, operation_name: object) -> dict[str, Any]: """Polls the Long-Running Operation repeatedly until it is done with exponential backoff.""" for current_attempt in range(_ProjectManagementService.MAXIMUM_POLLING_ATTEMPTS): delay_factor = pow( @@ -609,7 +671,7 @@ def _poll_app_creation(self, operation_name): poll_response, http_response = self._body_and_response('get', path) done = poll_response.get('done') if done: - response = poll_response.get('response') + response: Optional[dict[str, Any]] = poll_response.get('response') if response: return response @@ -618,45 +680,55 @@ def _poll_app_creation(self, operation_name): http_response=http_response) raise exceptions.DeadlineExceededError('Polling deadline exceeded.') - def get_android_app_config(self, app_id): + def get_android_app_config(self, app_id: str) -> str: return self._get_app_config( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, app_id=app_id) - def get_ios_app_config(self, app_id): + def get_ios_app_config(self, app_id: str) -> str: return self._get_app_config( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, app_id=app_id) - def _get_app_config(self, platform_resource_name, app_id): + def _get_app_config(self, platform_resource_name: str, app_id: str) -> str: path = f'/v1beta1/projects/-/{platform_resource_name}/{app_id}/config' response = self._make_request('get', path) # In Python 2.7, the base64 module works with strings, while in Python 3, it works with # bytes objects. This line works in both versions. return base64.standard_b64decode(response['configFileContents']).decode(encoding='utf-8') - def get_sha_certificates(self, app_id): + def get_sha_certificates(self, app_id: str) -> list[SHACertificate]: path = f'/v1beta1/projects/-/androidApps/{app_id}/sha' - response = self._make_request('get', path) - cert_list = response.get('certificates') or [] + response: dict[str, Any] = self._make_request('get', path) + cert_list: list[dict[str, Any]] = response.get('certificates') or [] return [SHACertificate(sha_hash=cert['shaHash'], name=cert['name']) for cert in cert_list] - def add_sha_certificate(self, app_id, certificate_to_add): + def add_sha_certificate(self, app_id: str, certificate_to_add: SHACertificate) -> None: path = f'/v1beta1/projects/-/androidApps/{app_id}/sha' sha_hash = _check_not_none(certificate_to_add, 'certificate_to_add').sha_hash cert_type = certificate_to_add.cert_type request_body = {'shaHash': sha_hash, 'certType': cert_type} self._make_request('post', path, json=request_body) - def delete_sha_certificate(self, certificate_to_delete): + def delete_sha_certificate(self, certificate_to_delete: SHACertificate) -> None: name = _check_not_none(certificate_to_delete, 'certificate_to_delete').name path = f'/v1beta1/{name}' self._make_request('delete', path) - def _make_request(self, method, url, json=None): + def _make_request( + self, + method: str, + url: str, + json: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: body, _ = self._body_and_response(method, url, json) return body - def _body_and_response(self, method, url, json=None): + def _body_and_response( + self, + method: str, + url: str, + json: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], requests.Response]: try: return self._client.body_and_response(method=method, url=url, json=json) except requests.exceptions.RequestException as error: diff --git a/firebase_admin/remote_config.py b/firebase_admin/remote_config.py index 880804d3d..b6b4955e8 100644 --- a/firebase_admin/remote_config.py +++ b/firebase_admin/remote_config.py @@ -20,13 +20,32 @@ import json import logging import threading -from typing import Dict, Optional, Literal, Union, Any -from enum import Enum +import enum import re import hashlib +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + import requests -from firebase_admin import App, _http_client, _utils + import firebase_admin +from firebase_admin import _http_client +from firebase_admin import _utils +from firebase_admin import exceptions + +if TYPE_CHECKING: + from _typeshed import ConvertibleToFloat + +__all__ = ( + 'MAX_CONDITION_RECURSION_DEPTH', + 'CustomSignalOperator', + 'PercentConditionOperator', + 'ServerConfig', + 'ServerTemplate', + 'ValueSource', + 'get_server_template', + 'init_server_template', +) # Set up logging (you can customize the level and output) logging.basicConfig(level=logging.INFO) @@ -36,7 +55,8 @@ MAX_CONDITION_RECURSION_DEPTH = 10 ValueSource = Literal['default', 'remote', 'static'] # Define the ValueSource type -class PercentConditionOperator(Enum): + +class PercentConditionOperator(enum.Enum): """Enum representing the available operators for percent conditions. """ LESS_OR_EQUAL = "LESS_OR_EQUAL" @@ -44,7 +64,8 @@ class PercentConditionOperator(Enum): BETWEEN = "BETWEEN" UNKNOWN = "UNKNOWN" -class CustomSignalOperator(Enum): + +class CustomSignalOperator(enum.Enum): """Enum representing the available operators for custom signal conditions. """ STRING_CONTAINS = "STRING_CONTAINS" @@ -65,9 +86,10 @@ class CustomSignalOperator(Enum): SEMANTIC_VERSION_GREATER_EQUAL = "SEMANTIC_VERSION_GREATER_EQUAL" UNKNOWN = "UNKNOWN" + class _ServerTemplateData: """Parses, validates and encapsulates template data and metadata.""" - def __init__(self, template_data): + def __init__(self, template_data: dict[str, Any]) -> None: """Initializes a new ServerTemplateData instance. Args: @@ -82,7 +104,7 @@ def __init__(self, template_data): else: raise ValueError('Remote Config parameters must be a non-null object') else: - self._parameters = {} + self._parameters: dict[str, dict[str, Any]] = {} if 'conditions' in template_data: if template_data['conditions'] is not None: @@ -90,28 +112,28 @@ def __init__(self, template_data): else: raise ValueError('Remote Config conditions must be a non-null object') else: - self._conditions = [] + self._conditions: list[dict[str, Any]] = [] - self._version = '' + self._version: str = '' if 'version' in template_data: self._version = template_data['version'] - self._etag = '' + self._etag: str = '' if 'etag' in template_data and isinstance(template_data['etag'], str): self._etag = template_data['etag'] self._template_data_json = json.dumps(template_data) @property - def parameters(self): + def parameters(self) -> dict[str, dict[str, Any]]: return self._parameters @property - def etag(self): + def etag(self) -> str: return self._etag @property - def version(self): + def version(self) -> str: return self._version @property @@ -119,13 +141,17 @@ def conditions(self): return self._conditions @property - def template_data_json(self): + def template_data_json(self) -> str: return self._template_data_json class ServerTemplate: """Represents a Server Template with implementations for loading and evaluating the template.""" - def __init__(self, app: App = None, default_config: Optional[Dict[str, str]] = None): + def __init__( + self, + app: Optional[firebase_admin.App] = None, + default_config: Optional[dict[str, str]] = None, + ) -> None: """Initializes a ServerTemplate instance. Args: @@ -137,8 +163,8 @@ def __init__(self, app: App = None, default_config: Optional[Dict[str, str]] = N _REMOTE_CONFIG_ATTRIBUTE, _RemoteConfigService) # This gets set when the template is # fetched from RC servers via the load API, or via the set API. - self._cache = None - self._stringified_default_config: Dict[str, str] = {} + self._cache: Optional[_ServerTemplateData] = None + self._stringified_default_config: dict[str, str] = {} self._lock = threading.RLock() # RC stores all remote values as string, but it's more intuitive @@ -148,13 +174,13 @@ def __init__(self, app: App = None, default_config: Optional[Dict[str, str]] = N for key in default_config: self._stringified_default_config[key] = str(default_config[key]) - async def load(self): + async def load(self) -> None: """Fetches the server template and caches the data.""" rc_server_template = await self._rc_service.get_server_template() with self._lock: self._cache = rc_server_template - def evaluate(self, context: Optional[Dict[str, Union[str, int]]] = None) -> 'ServerConfig': + def evaluate(self, context: Optional[dict[str, Union[str, int]]] = None) -> 'ServerConfig': """Evaluates the cached server template to produce a ServerConfig. Args: @@ -170,14 +196,14 @@ def evaluate(self, context: Optional[Dict[str, Union[str, int]]] = None) -> 'Ser raise ValueError("""No Remote Config Server template in cache. Call load() before calling evaluate().""") context = context or {} - config_values = {} + config_values: dict[str, _Value] = {} with self._lock: template_conditions = self._cache.conditions template_parameters = self._cache.parameters # Initializes config Value objects with default values. - if self._stringified_default_config is not None: + if self._stringified_default_config: for key, value in self._stringified_default_config.items(): config_values[key] = _Value('default', value) self._evaluator = _ConditionEvaluator(template_conditions, @@ -185,7 +211,7 @@ def evaluate(self, context: Optional[Dict[str, Union[str, int]]] = None) -> 'Ser config_values) return ServerConfig(config_values=self._evaluator.evaluate()) - def set(self, template_data_json: str): + def set(self, template_data_json: str) -> None: """Updates the cache to store the given template is of type ServerTemplateData. Args: @@ -197,7 +223,7 @@ def set(self, template_data_json: str): with self._lock: self._cache = template_data - def to_json(self): + def to_json(self) -> str: """Provides the server template in a JSON format to be used for initialization later.""" if not self._cache: raise ValueError("""No Remote Config Server template in cache. @@ -209,30 +235,30 @@ def to_json(self): class ServerConfig: """Represents a Remote Config Server Side Config.""" - def __init__(self, config_values): + def __init__(self, config_values: dict[str, '_Value']): self._config_values = config_values # dictionary of param key to values - def get_boolean(self, key): + def get_boolean(self, key: str) -> bool: """Returns the value as a boolean.""" return self._get_value(key).as_boolean() - def get_string(self, key): + def get_string(self, key: str) -> str: """Returns the value as a string.""" return self._get_value(key).as_string() - def get_int(self, key): + def get_int(self, key: str) -> int: """Returns the value as an integer.""" return self._get_value(key).as_int() - def get_float(self, key): + def get_float(self, key: str) -> float: """Returns the value as a float.""" return self._get_value(key).as_float() - def get_value_source(self, key): + def get_value_source(self, key: str) -> ValueSource: """Returns the source of the value.""" return self._get_value(key).get_source() - def _get_value(self, key): + def _get_value(self, key: str) -> "_Value": return self._config_values.get(key, _Value('static')) @@ -240,7 +266,7 @@ class _RemoteConfigService: """Internal class that facilitates sending requests to the Firebase Remote Config backend API. """ - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: """Initialize a JsonHttpClient with necessary inputs. Args: @@ -258,7 +284,7 @@ def __init__(self, app): base_url=remote_config_base_url, headers=rc_headers, timeout=timeout) - async def get_server_template(self): + async def get_server_template(self) -> _ServerTemplateData: """Requests for a server template and converts the response to an instance of ServerTemplateData for storing the template parameters and conditions.""" try: @@ -271,12 +297,12 @@ async def get_server_template(self): template_data['etag'] = headers.get('etag') return _ServerTemplateData(template_data) - def _get_url(self): + def _get_url(self) -> str: """Returns project prefix for url, in the format of /v1/projects/${projectId}""" return f"/v1/projects/{self._project_id}/namespaces/firebase-server/serverRemoteConfig" @classmethod - def _handle_remote_config_error(cls, error: Any): + def _handle_remote_config_error(cls, error: requests.RequestException) -> exceptions.FirebaseError: """Handles errors received from the Cloud Functions API.""" return _utils.handle_platform_error_from_requests(error) @@ -284,13 +310,19 @@ def _handle_remote_config_error(cls, error: Any): class _ConditionEvaluator: """Internal class that facilitates sending requests to the Firebase Remote Config backend API.""" - def __init__(self, conditions, parameters, context, config_values): + def __init__( + self, + conditions: list[dict[str, Any]], + parameters: dict[str, dict[str, Any]], + context: dict[str, Any], + config_values: dict[str, '_Value'], + ) -> None: self._context = context self._conditions = conditions self._parameters = parameters self._config_values = config_values - def evaluate(self): + def evaluate(self) -> dict[str, '_Value']: """Internal function that evaluates the cached server template to produce a ServerConfig""" evaluated_conditions = self.evaluate_conditions(self._conditions, self._context) @@ -298,9 +330,9 @@ def evaluate(self): # Overlays config Value objects derived by evaluating the template. if self._parameters: for key, parameter in self._parameters.items(): - conditional_values = parameter.get('conditionalValues', {}) - default_value = parameter.get('defaultValue', {}) - parameter_value_wrapper = None + conditional_values: dict[str, Any] = parameter.get('conditionalValues', {}) + default_value: dict[str, Any] = parameter.get('defaultValue', {}) + parameter_value_wrapper: Optional[dict[str, Any]] = None # Iterates in order over condition list. If there is a value associated # with a condition, this checks if the condition is true. if evaluated_conditions: @@ -314,6 +346,7 @@ def evaluate(self): continue if parameter_value_wrapper: + # possible issue: Is `None` a valid value for `_Value`? parameter_value = parameter_value_wrapper.get('value') self._config_values[key] = _Value('remote', parameter_value) continue @@ -328,7 +361,11 @@ def evaluate(self): self._config_values[key] = _Value('remote', default_value.get('value')) return self._config_values - def evaluate_conditions(self, conditions, context)-> Dict[str, bool]: + def evaluate_conditions( + self, + conditions: list[dict[str, Any]], + context: dict[str, Any], + )-> dict[str, bool]: """Evaluates a list of conditions and returns a dictionary of results. Args: @@ -338,15 +375,20 @@ def evaluate_conditions(self, conditions, context)-> Dict[str, bool]: Returns: A dictionary that maps condition names to boolean evaluation results. """ - evaluated_conditions = {} + evaluated_conditions: dict[Any, Any] = {} for condition in conditions: + # possible issue: does condition always have `name`? evaluated_conditions[condition.get('name')] = self.evaluate_condition( - condition.get('condition'), context + condition['condition'], context ) return evaluated_conditions - def evaluate_condition(self, condition, context, - nesting_level: int = 0) -> bool: + def evaluate_condition( + self, + condition: dict[str, Any], + context: dict[str, Any], + nesting_level: int = 0, + ) -> bool: """Recursively evaluates a condition. Args: @@ -361,25 +403,28 @@ def evaluate_condition(self, condition, context, logger.warning("Maximum condition recursion depth exceeded.") return False if condition.get('orCondition') is not None: - return self.evaluate_or_condition(condition.get('orCondition'), + return self.evaluate_or_condition(condition['orCondition'], context, nesting_level + 1) if condition.get('andCondition') is not None: - return self.evaluate_and_condition(condition.get('andCondition'), + return self.evaluate_and_condition(condition['andCondition'], context, nesting_level + 1) if condition.get('true') is not None: return True if condition.get('false') is not None: return False if condition.get('percent') is not None: - return self.evaluate_percent_condition(condition.get('percent'), context) + return self.evaluate_percent_condition(condition['percent'], context) if condition.get('customSignal') is not None: - return self.evaluate_custom_signal_condition(condition.get('customSignal'), context) + return self.evaluate_custom_signal_condition(condition['customSignal'], context) logger.warning("Unknown condition type encountered.") return False - def evaluate_or_condition(self, or_condition, - context, - nesting_level: int = 0) -> bool: + def evaluate_or_condition( + self, + or_condition: dict[str, Any], + context: dict[str, Any], + nesting_level: int = 0, + ) -> bool: """Evaluates an OR condition. Args: @@ -390,16 +435,19 @@ def evaluate_or_condition(self, or_condition, Returns: True if any of the subconditions are true, False otherwise. """ - sub_conditions = or_condition.get('conditions') or [] + sub_conditions: list[dict[str, Any]] = or_condition.get('conditions') or [] for sub_condition in sub_conditions: result = self.evaluate_condition(sub_condition, context, nesting_level + 1) if result: return True return False - def evaluate_and_condition(self, and_condition, - context, - nesting_level: int = 0) -> bool: + def evaluate_and_condition( + self, + and_condition: dict[str, Any], + context: dict[str, Any], + nesting_level: int = 0, + ) -> bool: """Evaluates an AND condition. Args: @@ -410,15 +458,18 @@ def evaluate_and_condition(self, and_condition, Returns: True if all of the subconditions are met; False otherwise. """ - sub_conditions = and_condition.get('conditions') or [] + sub_conditions: list[dict[str, Any]] = and_condition.get('conditions') or [] for sub_condition in sub_conditions: result = self.evaluate_condition(sub_condition, context, nesting_level + 1) if not result: return False return True - def evaluate_percent_condition(self, percent_condition, - context) -> bool: + def evaluate_percent_condition( + self, + percent_condition: dict[str, Any], + context: dict[str, Any], + ) -> bool: """Evaluates a percent condition. Args: @@ -462,6 +513,7 @@ def evaluate_percent_condition(self, percent_condition, return norm_percent_lower_bound < instance_micro_percentile <= norm_percent_upper_bound logger.warning("Unknown percent operator: %s", percent_operator) return False + def hash_seeded_randomization_id(self, seeded_randomization_id: str) -> int: """Hashes a seeded randomization ID. @@ -476,8 +528,11 @@ def hash_seeded_randomization_id(self, seeded_randomization_id: str) -> int: hash64 = hash_object.hexdigest() return abs(int(hash64, 16)) - def evaluate_custom_signal_condition(self, custom_signal_condition, - context) -> bool: + def evaluate_custom_signal_condition( + self, + custom_signal_condition: dict[str, Any], + context: dict[str, Any], + ) -> bool: """Evaluates a custom signal condition. Args: @@ -487,18 +542,16 @@ def evaluate_custom_signal_condition(self, custom_signal_condition, Returns: True if the condition is met, False otherwise. """ - custom_signal_operator = custom_signal_condition.get('customSignalOperator') or {} - custom_signal_key = custom_signal_condition.get('customSignalKey') or {} - target_custom_signal_values = ( - custom_signal_condition.get('targetCustomSignalValues') or {}) + custom_signal_operator: Optional[str] = custom_signal_condition.get('customSignalOperator') + custom_signal_key: Optional[str] = custom_signal_condition.get('customSignalKey') + target_custom_signal_values: Optional[list[Any]] = ( + custom_signal_condition.get('targetCustomSignalValues')) - if not all([custom_signal_operator, custom_signal_key, target_custom_signal_values]): + if not (custom_signal_operator and custom_signal_key and target_custom_signal_values): logger.warning("Missing operator, key, or target values for custom signal condition.") return False - if not target_custom_signal_values: - return False - actual_custom_signal_value = context.get(custom_signal_key) or {} + actual_custom_signal_value: Optional[Any] = context.get(custom_signal_key) if not actual_custom_signal_value: logger.debug("Custom signal value not found in context: %s", custom_signal_key) @@ -519,7 +572,7 @@ def evaluate_custom_signal_condition(self, custom_signal_condition, if custom_signal_operator == CustomSignalOperator.STRING_CONTAINS_REGEX.value: return self._compare_strings(target_custom_signal_values, actual_custom_signal_value, - re.search) + lambda pattern, string: bool(re.search(pattern, string))) # For numeric operators only one target value is allowed. if custom_signal_operator == CustomSignalOperator.NUMERIC_LESS_THAN.value: @@ -587,7 +640,12 @@ def evaluate_custom_signal_condition(self, custom_signal_condition, logger.warning("Unknown custom signal operator: %s", custom_signal_operator) return False - def _compare_strings(self, target_values, actual_value, predicate_fn) -> bool: + def _compare_strings( + self, + target_values: list[str], + actual_value: str, + predicate_fn: Callable[[str, str], bool], + ) -> bool: """Compares the actual string value of a signal against a list of target values. Args: @@ -607,7 +665,13 @@ def _compare_strings(self, target_values, actual_value, predicate_fn) -> bool: return True return False - def _compare_numbers(self, custom_signal_key, target_value, actual_value, predicate_fn) -> bool: + def _compare_numbers( + self, + custom_signal_key: str, + target_value: 'ConvertibleToFloat', + actual_value: 'ConvertibleToFloat', + predicate_fn: Callable[[float], bool], + ) -> bool: try: target = float(target_value) actual = float(actual_value) @@ -618,8 +682,13 @@ def _compare_numbers(self, custom_signal_key, target_value, actual_value, predic custom_signal_key) return False - def _compare_semantic_versions(self, custom_signal_key, - target_value, actual_value, predicate_fn) -> bool: + def _compare_semantic_versions( + self, + custom_signal_key: str, + target_value: str, + actual_value: str, + predicate_fn: Callable[[Literal[-1, 0, 1]], bool], + ) -> bool: """Compares the actual semantic version value of a signal against a target value. Calls the predicate function with -1, 0, 1 if actual is less than, equal to, or greater than target. @@ -637,8 +706,13 @@ def _compare_semantic_versions(self, custom_signal_key, return self._compare_versions(custom_signal_key, str(actual_value), str(target_value), predicate_fn) - def _compare_versions(self, custom_signal_key, - sem_version_1, sem_version_2, predicate_fn) -> bool: + def _compare_versions( + self, + custom_signal_key: str, + sem_version_1: str, + sem_version_2: str, + predicate_fn: Callable[[Literal[-1, 0, 1]], bool], + ) -> bool: """Compares two semantic version strings. Args: @@ -671,7 +745,11 @@ def _compare_versions(self, custom_signal_key, custom_signal_key) return False -async def get_server_template(app: App = None, default_config: Optional[Dict[str, str]] = None): + +async def get_server_template( + app: Optional[firebase_admin.App] = None, + default_config: Optional[dict[str, str]] = None, +) -> ServerTemplate: """Initializes a new ServerTemplate instance and fetches the server template. Args: @@ -686,8 +764,12 @@ async def get_server_template(app: App = None, default_config: Optional[Dict[str await template.load() return template -def init_server_template(app: App = None, default_config: Optional[Dict[str, str]] = None, - template_data_json: Optional[str] = None): + +def init_server_template( + app: Optional[firebase_admin.App] = None, + default_config: Optional[dict[str, str]] = None, + template_data_json: Optional[str] = None, +) -> ServerTemplate: """Initializes a new ServerTemplate instance. Args: @@ -705,6 +787,7 @@ def init_server_template(app: App = None, default_config: Optional[Dict[str, str template.set(template_data_json) return template + class _Value: """Represents a value fetched from Remote Config. """ @@ -714,7 +797,7 @@ class _Value: DEFAULT_VALUE_FOR_FLOAT_NUMBER = 0.0 BOOLEAN_TRUTHY_VALUES = ['1', 'true', 't', 'yes', 'y', 'on'] - def __init__(self, source: ValueSource, value: str = DEFAULT_VALUE_FOR_STRING): + def __init__(self, source: ValueSource, value: Any = DEFAULT_VALUE_FOR_STRING) -> None: """Initializes a Value instance. Args: @@ -724,7 +807,7 @@ def __init__(self, source: ValueSource, value: str = DEFAULT_VALUE_FOR_STRING): "remote" indicates the value was defined by config produced by evaluating a template. value: The string value. """ - self.source = source + self.source: ValueSource = source self.value = value def as_string(self) -> str: @@ -739,7 +822,7 @@ def as_boolean(self) -> bool: return self.DEFAULT_VALUE_FOR_BOOLEAN return str(self.value).lower() in self.BOOLEAN_TRUTHY_VALUES - def as_int(self) -> float: + def as_int(self) -> int: """Returns the value as a number.""" if self.source == 'static': return self.DEFAULT_VALUE_FOR_INTEGER diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index d2f004be6..1a3d33c31 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -18,6 +18,8 @@ Firebase apps. This requires the ``google-cloud-storage`` Python module. """ +from typing import Optional + # pylint: disable=import-error,no-name-in-module try: from google.cloud import storage @@ -25,12 +27,16 @@ raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' 'to install the "google-cloud-storage" module.') from exception +from google.auth import credentials + +import firebase_admin from firebase_admin import _utils +__all__ = ('bucket',) _STORAGE_ATTRIBUTE = '_storage' -def bucket(name=None, app=None) -> storage.Bucket: +def bucket(name: Optional[str] = None, app: Optional[firebase_admin.App] = None) -> storage.Bucket: """Returns a handle to a Google Cloud Storage bucket. If the name argument is not provided, uses the 'storageBucket' option specified when @@ -59,20 +65,25 @@ class _StorageClient: 'x-goog-api-client': _utils.get_metrics_header(), } - def __init__(self, credentials, project, default_bucket): + def __init__( + self, + credentials: credentials.Credentials, + project: Optional[str], + default_bucket: Optional[str], + ) -> None: self._client = storage.Client( credentials=credentials, project=project, extra_headers=self.STORAGE_HEADERS) self._default_bucket = default_bucket @classmethod - def from_app(cls, app): + def from_app(cls, app: firebase_admin.App) -> '_StorageClient': credentials = app.credential.get_credential() default_bucket = app.options.get('storageBucket') # Specifying project ID is not required, but providing it when available # significantly speeds up the initialization of the storage client. return _StorageClient(credentials, app.project_id, default_bucket) - def bucket(self, name=None): + def bucket(self, name: Optional[str] = None) -> storage.Bucket: """Returns a handle to the specified Cloud Storage Bucket.""" bucket_name = name if name is not None else self._default_bucket if bucket_name is None: diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py index 9e713d988..e1dc0d8b1 100644 --- a/firebase_admin/tenant_mgt.py +++ b/firebase_admin/tenant_mgt.py @@ -20,6 +20,8 @@ import re import threading +from collections.abc import Callable, Iterator +from typing import Any, Optional, cast import requests @@ -29,12 +31,6 @@ from firebase_admin import _http_client from firebase_admin import _utils - -_TENANT_MGT_ATTRIBUTE = '_tenant_mgt' -_MAX_LIST_TENANTS_RESULTS = 100 -_DISPLAY_NAME_PATTERN = re.compile('^[a-zA-Z][a-zA-Z0-9-]{3,19}$') - - __all__ = [ 'ListTenantsPage', 'Tenant', @@ -49,12 +45,15 @@ 'update_tenant', ] +_TENANT_MGT_ATTRIBUTE = '_tenant_mgt' +_MAX_LIST_TENANTS_RESULTS = 100 +_DISPLAY_NAME_PATTERN = re.compile('^[a-zA-Z][a-zA-Z0-9-]{3,19}$') TenantIdMismatchError = _auth_utils.TenantIdMismatchError TenantNotFoundError = _auth_utils.TenantNotFoundError -def auth_for_tenant(tenant_id, app=None): +def auth_for_tenant(tenant_id: str, app: Optional[firebase_admin.App] = None) -> auth.Client: """Gets an Auth Client instance scoped to the given tenant ID. Args: @@ -71,7 +70,7 @@ def auth_for_tenant(tenant_id, app=None): return tenant_mgt_service.auth_for_tenant(tenant_id) -def get_tenant(tenant_id, app=None): +def get_tenant(tenant_id: str, app: Optional[firebase_admin.App] = None) -> 'Tenant': """Gets the tenant corresponding to the given ``tenant_id``. Args: @@ -91,7 +90,11 @@ def get_tenant(tenant_id, app=None): def create_tenant( - display_name, allow_password_sign_up=None, enable_email_link_sign_in=None, app=None): + display_name: str, + allow_password_sign_up: Optional[bool] = None, + enable_email_link_sign_in: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, +) -> 'Tenant': """Creates a new tenant from the given options. Args: @@ -117,8 +120,12 @@ def create_tenant( def update_tenant( - tenant_id, display_name=None, allow_password_sign_up=None, enable_email_link_sign_in=None, - app=None): + tenant_id: str, + display_name: Optional[str] = None, + allow_password_sign_up: Optional[bool] = None, + enable_email_link_sign_in: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, +) -> 'Tenant': """Updates an existing tenant with the given options. Args: @@ -144,7 +151,7 @@ def update_tenant( enable_email_link_sign_in=enable_email_link_sign_in) -def delete_tenant(tenant_id, app=None): +def delete_tenant(tenant_id: str, app: Optional[firebase_admin.App] = None) -> None: """Deletes the tenant corresponding to the given ``tenant_id``. Args: @@ -160,7 +167,11 @@ def delete_tenant(tenant_id, app=None): tenant_mgt_service.delete_tenant(tenant_id) -def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=None): +def list_tenants( + page_token: Optional[str] = None, + max_results: int = _MAX_LIST_TENANTS_RESULTS, + app: Optional[firebase_admin.App] = None, +) -> 'ListTenantsPage': """Retrieves a page of tenants from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -183,12 +194,12 @@ def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=Non FirebaseError: If an error occurs while retrieving the user accounts. """ tenant_mgt_service = _get_tenant_mgt_service(app) - def download(page_token, max_results): + def download(page_token: Optional[str], max_results: int) -> dict[str, Any]: return tenant_mgt_service.list_tenants(page_token, max_results) return ListTenantsPage(download, page_token, max_results) -def _get_tenant_mgt_service(app): +def _get_tenant_mgt_service(app: Optional[firebase_admin.App]) -> '_TenantManagementService': return _utils.get_app_service(app, _TENANT_MGT_ATTRIBUTE, _TenantManagementService) @@ -203,7 +214,7 @@ class Tenant: such as the display name, tenant identifier and email authentication configuration. """ - def __init__(self, data): + def __init__(self, data: dict[str, Any]) -> None: if not isinstance(data, dict): raise ValueError(f'Invalid data argument in Tenant constructor: {data}') if not 'name' in data: @@ -212,20 +223,20 @@ def __init__(self, data): self._data = data @property - def tenant_id(self): + def tenant_id(self) -> str: name = self._data['name'] return name.split('/')[-1] @property - def display_name(self): + def display_name(self) -> Optional[str]: return self._data.get('displayName') @property - def allow_password_sign_up(self): + def allow_password_sign_up(self) -> bool: return self._data.get('allowPasswordSignup', False) @property - def enable_email_link_sign_in(self): + def enable_email_link_sign_in(self) -> bool: return self._data.get('enableEmailLinkSignin', False) @@ -234,17 +245,17 @@ class _TenantManagementService: TENANT_MGT_URL = 'https://identitytoolkit.googleapis.com/v2' - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: credential = app.credential.get_credential() version_header = f'Python/Admin/{firebase_admin.__version__}' base_url = f'{self.TENANT_MGT_URL}/projects/{app.project_id}' self.app = app self.client = _http_client.JsonHttpClient( credential=credential, base_url=base_url, headers={'X-Client-Version': version_header}) - self.tenant_clients = {} + self.tenant_clients: dict[str, auth.Client] = {} self.lock = threading.RLock() - def auth_for_tenant(self, tenant_id): + def auth_for_tenant(self, tenant_id: str) -> auth.Client: """Gets an Auth Client instance scoped to the given tenant ID.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( @@ -256,9 +267,9 @@ def auth_for_tenant(self, tenant_id): client = auth.Client(self.app, tenant_id=tenant_id) self.tenant_clients[tenant_id] = client - return client + return client - def get_tenant(self, tenant_id): + def get_tenant(self, tenant_id: str) -> Tenant: """Gets the tenant corresponding to the given ``tenant_id``.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( @@ -271,10 +282,14 @@ def get_tenant(self, tenant_id): return Tenant(body) def create_tenant( - self, display_name, allow_password_sign_up=None, enable_email_link_sign_in=None): + self, + display_name: str, + allow_password_sign_up: Optional[bool] = None, + enable_email_link_sign_in: Optional[bool] = None, + ) -> Tenant: """Creates a new tenant from the given parameters.""" - payload = {'displayName': _validate_display_name(display_name)} + payload: dict[str, Any] = {'displayName': _validate_display_name(display_name)} if allow_password_sign_up is not None: payload['allowPasswordSignup'] = _auth_utils.validate_boolean( allow_password_sign_up, 'allowPasswordSignup') @@ -289,13 +304,17 @@ def create_tenant( return Tenant(body) def update_tenant( - self, tenant_id, display_name=None, allow_password_sign_up=None, - enable_email_link_sign_in=None): + self, + tenant_id: str, + display_name: Optional[str] = None, + allow_password_sign_up: Optional[bool] = None, + enable_email_link_sign_in: Optional[bool] = None + ) -> Tenant: """Updates the specified tenant with the given parameters.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError('Tenant ID must be a non-empty string.') - payload = {} + payload: dict[str, Any] = {} if display_name is not None: payload['displayName'] = _validate_display_name(display_name) if allow_password_sign_up is not None: @@ -317,18 +336,22 @@ def update_tenant( raise _auth_utils.handle_auth_backend_error(error) return Tenant(body) - def delete_tenant(self, tenant_id): + def delete_tenant(self, tenant_id: str) -> None: """Deletes the tenant corresponding to the given ``tenant_id``.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( - f'Invalid tenant ID: {tenant_id}. Tenant ID must be a non-empty string.') + 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) try: self.client.request('delete', f'/tenants/{tenant_id}') except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - def list_tenants(self, page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS): + def list_tenants( + self, + page_token: Optional[str] = None, + max_results: int = _MAX_LIST_TENANTS_RESULTS, + ) -> dict[str, Any]: """Retrieves a batch of tenants.""" if page_token is not None: if not isinstance(page_token, str) or not page_token: @@ -340,7 +363,7 @@ def list_tenants(self, page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS): 'Max results must be a positive integer less than or equal to ' f'{_MAX_LIST_TENANTS_RESULTS}.') - payload = {'pageSize': max_results} + payload: dict[str, Any] = {'pageSize': max_results} if page_token: payload['pageToken'] = page_token try: @@ -357,27 +380,32 @@ class ListTenantsPage: through all tenants in the Firebase project starting from this page. """ - def __init__(self, download, page_token, max_results): + def __init__( + self, + download: Callable[[Optional[str], int], dict[str, Any]], + page_token: Optional[str], + max_results: int, + ) -> None: self._download = download self._max_results = max_results self._current = download(page_token, max_results) @property - def tenants(self): + def tenants(self) -> list[Tenant]: """A list of ``ExportedUserRecord`` instances available in this page.""" return [Tenant(data) for data in self._current.get('tenants', [])] @property - def next_page_token(self): + def next_page_token(self) -> str: """Page token string for the next page (empty string indicates no more pages).""" return self._current.get('nextPageToken', '') @property - def has_next_page(self): + def has_next_page(self) -> bool: """A boolean indicating whether more pages are available.""" return bool(self.next_page_token) - def get_next_page(self): + def get_next_page(self) -> Optional['ListTenantsPage']: """Retrieves the next page of tenants, if available. Returns: @@ -408,16 +436,16 @@ class _TenantIterator: of entries in memory. """ - def __init__(self, current_page): + def __init__(self, current_page: ListTenantsPage) -> None: if not current_page: raise ValueError('Current page must not be None.') self._current_page = current_page self._index = 0 - def __next__(self): + def __next__(self) -> Tenant: if self._index == len(self._current_page.tenants): if self._current_page.has_next_page: - self._current_page = self._current_page.get_next_page() + self._current_page = cast(ListTenantsPage, self._current_page.get_next_page()) self._index = 0 if self._index < len(self._current_page.tenants): result = self._current_page.tenants[self._index] @@ -425,11 +453,11 @@ def __next__(self): return result raise StopIteration - def __iter__(self): + def __iter__(self) -> Iterator[Tenant]: return self -def _validate_display_name(display_name): +def _validate_display_name(display_name: Any) -> str: if not isinstance(display_name, str): raise ValueError('Invalid type for displayName') if not _DISPLAY_NAME_PATTERN.search(display_name): diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 000000000..772975727 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,33 @@ +{ + "pythonVersion": "3.9", + "typeCheckingMode": "strict", + + "include": ["firebase_admin"], + + "ignore": [ + "integration", + "snippets", + "tests", + "setup.py", + ], + + // Suppress import cycle errors (using forward references as needed) + "reportImportCycles": "none", + + // Allow dependencies without type annotations or stubs + "reportIncompleteStub": "none", + "reportMissingTypeStubs": "none", + + // Permit usage of private members across modules + "reportPrivateUsage": "none", + + // Allow `isinstance` for type assertions and runtime checks + "reportUnnecessaryIsInstance": "none", + + // Warn when a previously ignored type check is no longer needed + "reportUnnecessaryTypeIgnoreComment": "warning", + "reportMissingParameterType": "warning", + "reportUnknownArgumentType": "warning", + "reportUnknownMemberType": "warning", + "reportUnknownVariableType": "warning" +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3e67d1dd5..25146ba64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,10 +6,14 @@ pytest-localserver >= 0.4.1 pytest-asyncio >= 0.26.0 pytest-mock >= 3.6.1 respx == 0.22.0 +pyright >= 1.1.402 cachecontrol >= 0.12.14 google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' google-cloud-firestore >= 2.19.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.37.1 pyjwt[crypto] >= 2.5.0 -httpx[http2] == 0.28.1 \ No newline at end of file +httpx[http2] == 0.28.1 +typing-extensions >= 4.12.0 +types-requests +types-httplib2 \ No newline at end of file diff --git a/setup.py b/setup.py index 25cf12672..6678ccb76 100644 --- a/setup.py +++ b/setup.py @@ -21,8 +21,7 @@ from setuptools import setup -(major, minor) = (sys.version_info.major, sys.version_info.minor) -if major != 3 or minor < 7: +if sys.version_info < (3, 9): print('firebase_admin requires python >= 3.9', file=sys.stderr) sys.exit(1) @@ -43,6 +42,9 @@ 'google-cloud-storage>=1.37.1', 'pyjwt[crypto] >= 2.5.0', 'httpx[http2] == 0.28.1', + 'typing-extensions >= 4.12.0' + 'types-requests' + 'types-httplib2' ] setup( @@ -72,5 +74,6 @@ 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: 3.13', 'License :: OSI Approved :: Apache Software License', + 'Typing :: Typed', ], ) From cc86db21be3061b788ad63ed1df0b4b83af64f95 Mon Sep 17 00:00:00 2001 From: ViktorSky Date: Tue, 1 Jul 2025 14:28:48 -0300 Subject: [PATCH 06/13] - remove redundant overriding methods - fix some pylint issues --- firebase_admin/_auth_utils.py | 117 -------------------------------- firebase_admin/_sseclient.py | 2 +- firebase_admin/db.py | 8 ++- firebase_admin/ml.py | 9 ++- firebase_admin/remote_config.py | 5 +- firebase_admin/tenant_mgt.py | 2 +- 6 files changed, 19 insertions(+), 124 deletions(-) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index e702ff8f2..9cb7d5774 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -24,13 +24,11 @@ Literal, Optional, Protocol, - Union, cast, overload, ) from urllib import parse -import httpx import requests from typing_extensions import Self, TypeVar @@ -476,28 +474,12 @@ class UidAlreadyExistsError(exceptions.AlreadyExistsError): default_message = 'The user with the provided uid already exists' - def __init__( - self, - message: str, - cause: Optional[Exception], - http_response: Optional[Union[httpx.Response, requests.Response]], - ) -> None: - super().__init__(message, cause, http_response) - class EmailAlreadyExistsError(exceptions.AlreadyExistsError): """The user with the provided email already exists.""" default_message = 'The user with the provided email already exists' - def __init__( - self, - message: str, - cause: Optional[Exception], - http_response: Optional[Union[httpx.Response, requests.Response]], - ) -> None: - super().__init__(message, cause, http_response) - class InsufficientPermissionError(exceptions.PermissionDeniedError): """The credential used to initialize the SDK lacks required permissions.""" @@ -507,169 +489,70 @@ class InsufficientPermissionError(exceptions.PermissionDeniedError): 'https://firebase.google.com/docs/admin/setup for details ' 'on how to initialize the Admin SDK with appropriate permissions') - def __init__( - self, - message: str, - cause: Optional[Exception], - http_response: Optional[Union[httpx.Response, requests.Response]], - ) -> None: - super().__init__(message, cause, http_response) - class InvalidDynamicLinkDomainError(exceptions.InvalidArgumentError): """Dynamic link domain in ActionCodeSettings is not authorized.""" default_message = 'Dynamic link domain specified in ActionCodeSettings is not authorized' - def __init__( - self, - message: str, - cause: Optional[Exception], - http_response: Optional[Union[httpx.Response, requests.Response]], - ) -> None: - super().__init__(message, cause, http_response) - class InvalidIdTokenError(exceptions.InvalidArgumentError): """The provided ID token is not a valid Firebase ID token.""" default_message = 'The provided ID token is invalid' - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class PhoneNumberAlreadyExistsError(exceptions.AlreadyExistsError): """The user with the provided phone number already exists.""" default_message = 'The user with the provided phone number already exists' - def __init__( - self, - message: str, - cause: Optional[Exception], - http_response: Optional[Union[httpx.Response, requests.Response]], - ) -> None: - super().__init__(message, cause, http_response) - class UnexpectedResponseError(exceptions.UnknownError): """Backend service responded with an unexpected or malformed response.""" - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class UserNotFoundError(exceptions.NotFoundError): """No user record found for the specified identifier.""" default_message = 'No user record found for the given identifier' - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class EmailNotFoundError(exceptions.NotFoundError): """No user record found for the specified email.""" default_message = 'No user record found for the given email' - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class TenantNotFoundError(exceptions.NotFoundError): """No tenant found for the specified identifier.""" default_message = 'No tenant found for the given identifier' - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class TenantIdMismatchError(exceptions.InvalidArgumentError): """Missing or invalid tenant ID field in the given JWT.""" - def __init__(self, message: str) -> None: - super().__init__(message) - class ConfigurationNotFoundError(exceptions.NotFoundError): """No auth provider found for the specified identifier.""" default_message = 'No auth provider found for the given identifier' - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class UserDisabledError(exceptions.InvalidArgumentError): """An operation failed due to a user record being disabled.""" default_message = 'The user record is disabled' - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class TooManyAttemptsTryLaterError(exceptions.ResourceExhaustedError): """Rate limited because of too many attempts.""" - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class ResetPasswordExceedLimitError(exceptions.ResourceExhaustedError): """Reset password emails exceeded their limits.""" - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - _CODE_TO_EXC_TYPE = { 'CONFIGURATION_NOT_FOUND': ConfigurationNotFoundError, diff --git a/firebase_admin/_sseclient.py b/firebase_admin/_sseclient.py index ea0d5ac23..796b65d82 100644 --- a/firebase_admin/_sseclient.py +++ b/firebase_admin/_sseclient.py @@ -163,7 +163,7 @@ def __next__(self) -> Optional['Event']: return event def next(self) -> Optional['Event']: - return self.__next__() + return next(self) class Event: diff --git a/firebase_admin/db.py b/firebase_admin/db.py index de9cb520b..8545e914d 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -36,11 +36,12 @@ cast, overload, ) -from typing_extensions import Self, TypeVar from urllib import parse import google.auth.credentials import requests +from typing_extensions import Self, TypeVar + import firebase_admin from firebase_admin import exceptions @@ -849,7 +850,10 @@ def __gt__(self, other: '_SortEntry') -> bool: def __ge__(self, other: '_SortEntry') -> bool: return self._compare(other) >= 0 - def __eq__(self, other: '_SortEntry') -> bool: # pyright: ignore[reportIncompatibleMethodOverride] + def __eq__( # pyright: ignore[reportIncompatibleMethodOverride] + self, + other: '_SortEntry', + ) -> bool: return self._compare(other) == 0 diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 38b2f69af..3e906cd84 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -290,7 +290,8 @@ def _convert_to_millis(date_string: Optional[str]) -> Optional[int]: return None format_str = '%Y-%m-%dT%H:%M:%S.%fZ' epoch = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) - datetime_object = datetime.datetime.strptime(date_string, format_str).replace(tzinfo=datetime.timezone.utc) + datetime_object = datetime.datetime.strptime( + date_string, format_str).replace(tzinfo=datetime.timezone.utc) millis = int((datetime_object - epoch).total_seconds() * 1000) return millis @@ -881,7 +882,11 @@ def get_operation(self, op_name: str) -> dict[str, Any]: except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def _exponential_backoff(self, current_attempt: int, stop_time: Optional[datetime.datetime]) -> None: + def _exponential_backoff( + self, + current_attempt: int, + stop_time: Optional[datetime.datetime], + ) -> None: """Sleeps for the appropriate amount of time. Or throws deadline exceeded.""" delay_factor = pow(_MLService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) wait_time_seconds = delay_factor * _MLService.POLL_BASE_WAIT_TIME_SECONDS diff --git a/firebase_admin/remote_config.py b/firebase_admin/remote_config.py index b6b4955e8..c9af51883 100644 --- a/firebase_admin/remote_config.py +++ b/firebase_admin/remote_config.py @@ -302,7 +302,10 @@ def _get_url(self) -> str: return f"/v1/projects/{self._project_id}/namespaces/firebase-server/serverRemoteConfig" @classmethod - def _handle_remote_config_error(cls, error: requests.RequestException) -> exceptions.FirebaseError: + def _handle_remote_config_error( + cls, + error: requests.RequestException, + ) -> exceptions.FirebaseError: """Handles errors received from the Cloud Functions API.""" return _utils.handle_platform_error_from_requests(error) diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py index e1dc0d8b1..41c21548c 100644 --- a/firebase_admin/tenant_mgt.py +++ b/firebase_admin/tenant_mgt.py @@ -340,7 +340,7 @@ def delete_tenant(self, tenant_id: str) -> None: """Deletes the tenant corresponding to the given ``tenant_id``.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( - 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + f'Invalid tenant ID: {tenant_id}. Tenant ID must be a non-empty string.') try: self.client.request('delete', f'/tenants/{tenant_id}') From 7f6e9f11be8edbbd0a7e877d5105fd26107777bf Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Thu, 3 Jul 2025 10:05:51 -0400 Subject: [PATCH 07/13] chore: Upgraded Google API Core, Cloud Firestore, and Cloud Storage dependencies (#897) * chore: Bump dependencies * fix: Also update setup.py --- requirements.txt | 10 +++++----- setup.py | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3e67d1dd5..ff15072a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,9 +7,9 @@ pytest-asyncio >= 0.26.0 pytest-mock >= 3.6.1 respx == 0.22.0 -cachecontrol >= 0.12.14 -google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' -google-cloud-firestore >= 2.19.0; platform.python_implementation != 'PyPy' -google-cloud-storage >= 1.37.1 -pyjwt[crypto] >= 2.5.0 +cachecontrol >= 0.14.3 +google-api-core[grpc] >= 2.25.1, < 3.0.0dev; platform.python_implementation != 'PyPy' +google-cloud-firestore >= 2.21.0; platform.python_implementation != 'PyPy' +google-cloud-storage >= 3.1.1 +pyjwt[crypto] >= 2.10.1 httpx[http2] == 0.28.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 25cf12672..21e29332e 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ (major, minor) = (sys.version_info.major, sys.version_info.minor) -if major != 3 or minor < 7: +if major != 3 or minor < 9: print('firebase_admin requires python >= 3.9', file=sys.stderr) sys.exit(1) @@ -37,11 +37,11 @@ long_description = ('The Firebase Admin Python SDK enables server-side (backend) Python developers ' 'to integrate Firebase into their services and applications.') install_requires = [ - 'cachecontrol>=0.12.14', - 'google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != "PyPy"', - 'google-cloud-firestore>=2.19.0; platform.python_implementation != "PyPy"', - 'google-cloud-storage>=1.37.1', - 'pyjwt[crypto] >= 2.5.0', + 'cachecontrol>=0.14.3', + 'google-api-core[grpc] >= 2.25.1, < 3.0.0dev; platform.python_implementation != "PyPy"', + 'google-cloud-firestore>=2.21.0; platform.python_implementation != "PyPy"', + 'google-cloud-storage>=3.1.1', + 'pyjwt[crypto] >= 2.10.1', 'httpx[http2] == 0.28.1', ] From 4dfb7399de352fcf2eb9c329c137042e041b5849 Mon Sep 17 00:00:00 2001 From: Jonathan Edey <145066863+jonathanedey@users.noreply.github.com> Date: Wed, 2 Jul 2025 16:04:09 -0400 Subject: [PATCH 08/13] fix(functions): Remove usage of deprecated `datetime.utcnow() and fix flaky unit test` (#896) --- firebase_admin/functions.py | 5 ++-- tests/test_functions.py | 58 +++++++++++++++++++++++-------------- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py index 86eea557a..6db0fbb42 100644 --- a/firebase_admin/functions.py +++ b/firebase_admin/functions.py @@ -15,7 +15,7 @@ """Firebase Functions module.""" from __future__ import annotations -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from urllib import parse import re import json @@ -255,7 +255,8 @@ def _validate_task_options( if not isinstance(opts.schedule_delay_seconds, int) \ or opts.schedule_delay_seconds < 0: raise ValueError('schedule_delay_seconds should be positive int.') - schedule_time = datetime.utcnow() + timedelta(seconds=opts.schedule_delay_seconds) + schedule_time = ( + datetime.now(timezone.utc) + timedelta(seconds=opts.schedule_delay_seconds)) task.schedule_time = schedule_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') if opts.dispatch_deadline_seconds is not None: if not isinstance(opts.dispatch_deadline_seconds, int) \ diff --git a/tests/test_functions.py b/tests/test_functions.py index 1856426d9..52e92c1b2 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -14,7 +14,7 @@ """Test cases for the firebase_admin.functions module.""" -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import json import time import pytest @@ -33,8 +33,6 @@ _CLOUD_TASKS_URL + 'projects/test-project/locations/us-central1/queues/test-function-name/tasks' _DEFAULT_TASK_URL = _CLOUD_TASKS_URL + _DEFAULT_TASK_PATH _DEFAULT_RESPONSE = json.dumps({'name': _DEFAULT_TASK_PATH}) -_ENQUEUE_TIME = datetime.utcnow() -_SCHEDULE_TIME = _ENQUEUE_TIME + timedelta(seconds=100) class TestTaskQueue: @classmethod @@ -185,27 +183,46 @@ def _instrument_functions_service(self, app=None, status=200, payload=_DEFAULT_R testutils.MockAdapter(payload, status, recorder)) return functions_service, recorder - - @pytest.mark.parametrize('task_opts_params', [ - { + def test_task_options_delay_seconds(self): + _, recorder = self._instrument_functions_service() + enqueue_time = datetime.now(timezone.utc) + expected_schedule_time = enqueue_time + timedelta(seconds=100) + task_opts_params = { 'schedule_delay_seconds': 100, 'schedule_time': None, 'dispatch_deadline_seconds': 200, 'task_id': 'test-task-id', 'headers': {'x-test-header': 'test-header-value'}, 'uri': 'https://google.com' - }, - { + } + queue = functions.task_queue('test-function-name') + task_opts = functions.TaskOptions(**task_opts_params) + queue.enqueue(_DEFAULT_DATA, task_opts) + + assert len(recorder) == 1 + task = json.loads(recorder[0].body.decode())['task'] + + task_schedule_time = datetime.fromisoformat(task['schedule_time'].replace('Z', '+00:00')) + delta = abs(task_schedule_time - expected_schedule_time) + assert delta <= timedelta(seconds=1) + + assert task['dispatch_deadline'] == '200s' + assert task['http_request']['headers']['x-test-header'] == 'test-header-value' + assert task['http_request']['url'] in ['http://google.com', 'https://google.com'] + assert task['name'] == _DEFAULT_TASK_PATH + + def test_task_options_utc_time(self): + _, recorder = self._instrument_functions_service() + enqueue_time = datetime.now(timezone.utc) + expected_schedule_time = enqueue_time + timedelta(seconds=100) + task_opts_params = { 'schedule_delay_seconds': None, - 'schedule_time': _SCHEDULE_TIME, + 'schedule_time': expected_schedule_time, 'dispatch_deadline_seconds': 200, 'task_id': 'test-task-id', 'headers': {'x-test-header': 'test-header-value'}, 'uri': 'http://google.com' - }, - ]) - def test_task_options(self, task_opts_params): - _, recorder = self._instrument_functions_service() + } queue = functions.task_queue('test-function-name') task_opts = functions.TaskOptions(**task_opts_params) queue.enqueue(_DEFAULT_DATA, task_opts) @@ -213,19 +230,18 @@ def test_task_options(self, task_opts_params): assert len(recorder) == 1 task = json.loads(recorder[0].body.decode())['task'] - schedule_time = datetime.fromisoformat(task['schedule_time'][:-1]) - delta = abs(schedule_time - _SCHEDULE_TIME) - assert delta <= timedelta(seconds=15) + task_schedule_time = datetime.fromisoformat(task['schedule_time'].replace('Z', '+00:00')) + assert task_schedule_time == expected_schedule_time assert task['dispatch_deadline'] == '200s' assert task['http_request']['headers']['x-test-header'] == 'test-header-value' assert task['http_request']['url'] in ['http://google.com', 'https://google.com'] assert task['name'] == _DEFAULT_TASK_PATH - def test_schedule_set_twice_error(self): _, recorder = self._instrument_functions_service() - opts = functions.TaskOptions(schedule_delay_seconds=100, schedule_time=datetime.utcnow()) + opts = functions.TaskOptions( + schedule_delay_seconds=100, schedule_time=datetime.now(timezone.utc)) queue = functions.task_queue('test-function-name') with pytest.raises(ValueError) as excinfo: queue.enqueue(_DEFAULT_DATA, opts) @@ -236,9 +252,9 @@ def test_schedule_set_twice_error(self): @pytest.mark.parametrize('schedule_time', [ time.time(), - str(datetime.utcnow()), - datetime.utcnow().isoformat(), - datetime.utcnow().isoformat() + 'Z', + str(datetime.now(timezone.utc)), + datetime.now(timezone.utc).isoformat(), + datetime.now(timezone.utc).isoformat() + 'Z', '', ' ' ]) def test_invalid_schedule_time_error(self, schedule_time): From c864d2dfd6e713b0d3ef5d4e8454fea88e2679f5 Mon Sep 17 00:00:00 2001 From: ViktorSky Date: Mon, 30 Jun 2025 15:09:17 -0300 Subject: [PATCH 09/13] add type annotations rebase --- firebase_admin/__init__.py | 84 ++++-- firebase_admin/_auth_client.py | 202 ++++++++++--- firebase_admin/_auth_providers.py | 216 +++++++++----- firebase_admin/_auth_utils.py | 413 ++++++++++++++++++++++----- firebase_admin/_http_client.py | 155 ++++++---- firebase_admin/_messaging_encoder.py | 202 +++++++++---- firebase_admin/_messaging_utils.py | 226 ++++++++++++--- firebase_admin/_retry.py | 55 ++-- firebase_admin/_rfc3339.py | 21 +- firebase_admin/_sseclient.py | 54 ++-- firebase_admin/_token_gen.py | 225 +++++++++++---- firebase_admin/_user_identifier.py | 33 ++- firebase_admin/_user_import.py | 174 ++++++----- firebase_admin/_user_mgt.py | 291 ++++++++++++------- firebase_admin/_utils.py | 116 +++++--- firebase_admin/app_check.py | 55 ++-- firebase_admin/auth.py | 240 ++++++++++++---- firebase_admin/credentials.py | 82 +++--- firebase_admin/db.py | 282 ++++++++++++------ firebase_admin/exceptions.py | 200 ++++++++++--- firebase_admin/firestore.py | 29 +- firebase_admin/firestore_async.py | 34 ++- firebase_admin/functions.py | 120 ++++---- firebase_admin/instance_id.py | 15 +- firebase_admin/messaging.py | 178 +++++++----- firebase_admin/ml.py | 347 +++++++++++++--------- firebase_admin/project_management.py | 248 ++++++++++------ firebase_admin/remote_config.py | 245 ++++++++++------ firebase_admin/storage.py | 19 +- firebase_admin/tenant_mgt.py | 116 +++++--- pyrightconfig.json | 33 +++ requirements.txt | 6 +- setup.py | 7 +- 33 files changed, 3309 insertions(+), 1414 deletions(-) create mode 100644 pyrightconfig.json diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index 8c9f628e5..3d485c831 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -13,27 +13,41 @@ # limitations under the License. """Firebase Admin SDK for Python.""" -import datetime + import json import os import threading +from collections.abc import Callable +from typing import Any, Optional, TypeVar, Union, overload + +import google.auth.credentials +import google.auth.exceptions -from google.auth.credentials import Credentials as GoogleAuthCredentials -from google.auth.exceptions import DefaultCredentialsError from firebase_admin import credentials from firebase_admin.__about__ import __version__ +__all__ = ( + 'App', + 'delete_app', + 'get_app', + 'initialize_app', +) -_apps = {} +_T = TypeVar('_T') + +_apps: dict[str, 'App'] = {} _apps_lock = threading.RLock() -_clock = datetime.datetime.utcnow _DEFAULT_APP_NAME = '[DEFAULT]' _FIREBASE_CONFIG_ENV_VAR = 'FIREBASE_CONFIG' _CONFIG_VALID_KEYS = ['databaseAuthVariableOverride', 'databaseURL', 'httpTimeout', 'projectId', 'storageBucket'] -def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): +def initialize_app( + credential: Optional[Union[credentials.Base, google.auth.credentials.Credentials]] = None, + options: Optional[dict[str, Any]] = None, + name: str = _DEFAULT_APP_NAME, +) -> 'App': """Initializes and returns a new App instance. Creates a new App instance using the specified options @@ -86,7 +100,7 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): 'you call initialize_app().') -def delete_app(app): +def delete_app(app: 'App') -> None: """Gracefully deletes an App instance. Args: @@ -113,7 +127,7 @@ def delete_app(app): 'second argument.') -def get_app(name=_DEFAULT_APP_NAME): +def get_app(name: str = _DEFAULT_APP_NAME) -> 'App': """Retrieves an App instance by name. Args: @@ -147,7 +161,7 @@ def get_app(name=_DEFAULT_APP_NAME): class _AppOptions: """A collection of configuration options for an App.""" - def __init__(self, options): + def __init__(self, options: Optional[dict[str, Any]]) -> None: if options is None: options = self._load_from_environment() @@ -157,11 +171,16 @@ def __init__(self, options): 'Options must be a dictionary.') self._options = options - def get(self, key, default=None): + @overload + def get(self, key: str, default: None = None) -> Optional[Any]: ... + # possible issue: needs return Any | _T ? + @overload + def get(self, key: str, default: _T) -> _T: ... + def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: """Returns the option identified by the provided key.""" return self._options.get(key, default) - def _load_from_environment(self): + def _load_from_environment(self) -> dict[str, Any]: """Invoked when no options are passed to __init__, loads options from FIREBASE_CONFIG. If the value of the FIREBASE_CONFIG environment variable starts with "{" an attempt is made @@ -194,7 +213,12 @@ class App: common to all Firebase APIs. """ - def __init__(self, name, credential, options): + def __init__( + self, + name: str, + credential: Union[credentials.Base, google.auth.credentials.Credentials], + options: Optional[dict[str, Any]], + ) -> None: """Constructs a new App using the provided name and options. Args: @@ -211,7 +235,7 @@ def __init__(self, name, credential, options): 'non-empty string.') self._name = name - if isinstance(credential, GoogleAuthCredentials): + if isinstance(credential, google.auth.credentials.Credentials): self._credential = credentials._ExternalCredentials(credential) # pylint: disable=protected-access elif isinstance(credential, credentials.Base): self._credential = credential @@ -220,37 +244,38 @@ def __init__(self, name, credential, options): 'with a valid credential instance.') self._options = _AppOptions(options) self._lock = threading.RLock() - self._services = {} + self._services: Optional[dict[str, Any]] = {} App._validate_project_id(self._options.get('projectId')) - self._project_id_initialized = False + self._project_id_initialized: bool = False @classmethod - def _validate_project_id(cls, project_id): + def _validate_project_id(cls, project_id: Optional[Any]) -> Optional[str]: if project_id is not None and not isinstance(project_id, str): raise ValueError( f'Invalid project ID: "{project_id}". project ID must be a string.') + return project_id @property - def name(self): + def name(self) -> str: return self._name @property - def credential(self): + def credential(self) -> credentials.Base: return self._credential @property - def options(self): + def options(self) -> _AppOptions: return self._options @property - def project_id(self): + def project_id(self) -> Optional[str]: if not self._project_id_initialized: self._project_id = self._lookup_project_id() self._project_id_initialized = True return self._project_id - def _lookup_project_id(self): + def _lookup_project_id(self) -> Optional[str]: """Looks up the Firebase project ID associated with an App. If a ``projectId`` is specified in app options, it is returned. Then tries to @@ -264,8 +289,8 @@ def _lookup_project_id(self): project_id = self._options.get('projectId') if not project_id: try: - project_id = self._credential.project_id - except (AttributeError, DefaultCredentialsError): + project_id = getattr(self._credential, 'project_id') + except (AttributeError, google.auth.exceptions.DefaultCredentialsError): pass if not project_id: project_id = os.environ.get('GOOGLE_CLOUD_PROJECT', @@ -273,7 +298,7 @@ def _lookup_project_id(self): App._validate_project_id(self._options.get('projectId')) return project_id - def _get_service(self, name, initializer): + def _get_service(self, name: str, initializer: Callable[['App'], _T]) -> _T: """Returns the service instance identified by the given name. Services are functional entities exposed by the Admin SDK (e.g. auth, database). Each @@ -303,7 +328,7 @@ def _get_service(self, name, initializer): self._services[name] = initializer(self) return self._services[name] - def _cleanup(self): + def _cleanup(self) -> None: """Cleans up any services associated with this App. Checks whether each service contains a close() method, and calls it if available. @@ -311,7 +336,8 @@ def _cleanup(self): any services started by the App. """ with self._lock: - for service in self._services.values(): - if hasattr(service, 'close') and hasattr(service.close, '__call__'): - service.close() - self._services = None + if self._services is not None: + for service in self._services.values(): + if hasattr(service, 'close') and hasattr(service.close, '__call__'): + service.close() + self._services = None diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 74261fa37..170c05851 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -15,6 +15,8 @@ """Firebase auth client sub module.""" import time +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, Optional, Union import firebase_admin from firebase_admin import _auth_providers @@ -25,12 +27,18 @@ from firebase_admin import _user_import from firebase_admin import _user_mgt from firebase_admin import _utils +from firebase_admin import exceptions + +if TYPE_CHECKING: + from _typeshed import ConvertibleToInt + +__all__ = ('Client',) class Client: """Firebase Authentication client scoped to a specific tenant.""" - def __init__(self, app, tenant_id=None): + def __init__(self, app: firebase_admin.App, tenant_id: Optional[str] = None) -> None: if not app.project_id: raise ValueError("""A project ID is required to access the auth service. 1. Use a service account credential, or @@ -41,7 +49,7 @@ def __init__(self, app, tenant_id=None): version_header = f'Python/Admin/{firebase_admin.__version__}' timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) # Non-default endpoint URLs for emulator support are set in this dict later. - endpoint_urls = {} + endpoint_urls: dict[str, str] = {} self.emulated = False # If an emulator is present, check that the given value matches the expected format and set @@ -70,11 +78,15 @@ def __init__(self, app, tenant_id=None): http_client, app.project_id, tenant_id, url_override=endpoint_urls.get('v2')) @property - def tenant_id(self): + def tenant_id(self) -> Optional[str]: """Tenant ID associated with this client.""" return self._tenant_id - def create_custom_token(self, uid, developer_claims=None): + def create_custom_token( + self, + uid: str, + developer_claims: Optional[dict[str, Any]] = None, + ) -> bytes: """Builds and signs a Firebase custom auth token. Args: @@ -92,7 +104,12 @@ def create_custom_token(self, uid, developer_claims=None): return self._token_generator.create_custom_token( uid, developer_claims, tenant_id=self.tenant_id) - def verify_id_token(self, id_token, check_revoked=False, clock_skew_seconds=0): + def verify_id_token( + self, + id_token: Union[bytes, str], + check_revoked: bool = False, + clock_skew_seconds: int = 0, + ) -> dict[str, Any]: """Verifies the signature and data for the provided JWT. Accepts a signed token string, verifies that it is current, was issued @@ -139,7 +156,7 @@ def verify_id_token(self, id_token, check_revoked=False, clock_skew_seconds=0): verified_claims, _token_gen.RevokedIdTokenError, 'ID token') return verified_claims - def revoke_refresh_tokens(self, uid): + def revoke_refresh_tokens(self, uid: str) -> None: """Revokes all refresh tokens for an existing user. This method updates the user's ``tokens_valid_after_timestamp`` to the current UTC @@ -160,7 +177,7 @@ def revoke_refresh_tokens(self, uid): """ self._user_manager.update_user(uid, valid_since=int(time.time())) - def get_user(self, uid): + def get_user(self, uid: str) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified user ID. Args: @@ -177,7 +194,7 @@ def get_user(self, uid): response = self._user_manager.get_user(uid=uid) return _user_mgt.UserRecord(response) - def get_user_by_email(self, email): + def get_user_by_email(self, email: str) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified user email. Args: @@ -194,7 +211,7 @@ def get_user_by_email(self, email): response = self._user_manager.get_user(email=email) return _user_mgt.UserRecord(response) - def get_user_by_phone_number(self, phone_number): + def get_user_by_phone_number(self, phone_number: str) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified phone number. Args: @@ -211,7 +228,7 @@ def get_user_by_phone_number(self, phone_number): response = self._user_manager.get_user(phone_number=phone_number) return _user_mgt.UserRecord(response) - def get_users(self, identifiers): + def get_users(self, identifiers: 'Sequence[_user_identifier.UserIdentifier]') -> _user_mgt.GetUsersResult: """Gets the user data corresponding to the specified identifiers. There are no ordering guarantees; in particular, the nth entry in the @@ -236,7 +253,7 @@ def get_users(self, identifiers): """ response = self._user_manager.get_users(identifiers=identifiers) - def _matches(identifier, user_record): + def _matches(identifier: _user_identifier.UserIdentifier, user_record: _user_mgt.UserRecord) -> bool: if isinstance(identifier, _user_identifier.UidIdentifier): return identifier.uid == user_record.uid if isinstance(identifier, _user_identifier.EmailIdentifier): @@ -252,7 +269,10 @@ def _matches(identifier, user_record): ), False) raise TypeError(f"Unexpected type: {type(identifier)}") - def _is_user_found(identifier, user_records): + def _is_user_found( + identifier: _user_identifier.UserIdentifier, + user_records: list[_user_mgt.UserRecord], + ) -> bool: return any(_matches(identifier, user_record) for user_record in user_records) users = [_user_mgt.UserRecord(user) for user in response] @@ -261,7 +281,11 @@ def _is_user_found(identifier, user_records): return _user_mgt.GetUsersResult(users=users, not_found=not_found) - def list_users(self, page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS): + def list_users( + self, + page_token: Optional[str] = None, + max_results: int = _user_mgt.MAX_LIST_USERS_RESULTS, + ) -> _user_mgt.ListUsersPage: """Retrieves a page of user accounts from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -283,11 +307,23 @@ def list_users(self, page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESUL ValueError: If max_results or page_token are invalid. FirebaseError: If an error occurs while retrieving the user accounts. """ - def download(page_token, max_results): + def download(page_token: Optional[str], max_results: int) -> dict[str, Any]: return self._user_manager.list_users(page_token, max_results) return _user_mgt.ListUsersPage(download, page_token, max_results) - def create_user(self, **kwargs): # pylint: disable=differing-param-doc + def create_user( + self, + *, + uid: Optional[str] = None, + display_name: Optional[str] = None, + email: Optional[str] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + password: Optional[str] = None, + disabled: Optional[bool] = None, + email_verified: Optional[bool] = None, + **kwargs: Any, + ) -> _user_mgt.UserRecord: """Creates a new user account with the specified properties. Args: @@ -311,10 +347,27 @@ def create_user(self, **kwargs): # pylint: disable=differing-param-doc ValueError: If the specified user properties are invalid. FirebaseError: If an error occurs while creating the user account. """ - uid = self._user_manager.create_user(**kwargs) + uid = self._user_manager.create_user(uid=uid, display_name=display_name, email=email, + phone_number=phone_number, photo_url=photo_url, password=password, disabled=disabled, + email_verified=email_verified, **kwargs) return self.get_user(uid=uid) - def update_user(self, uid, **kwargs): # pylint: disable=differing-param-doc + def update_user( + self, + uid: str, + *, + display_name: Optional[str] = None, + email: Optional[str] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + password: Optional[str] = None, + disabled: Optional[bool] = None, + email_verified: Optional[bool] = None, + valid_since: Optional['ConvertibleToInt'] = None, + custom_claims: Optional[Union[dict[str, Any], str]] = None, + providers_to_delete: Optional[list[str]] = None, + **kwargs: Any, + ) -> _user_mgt.UserRecord: """Updates an existing user account with the specified properties. Args: @@ -349,10 +402,16 @@ def update_user(self, uid, **kwargs): # pylint: disable=differing-param-doc ValueError: If the specified user ID or properties are invalid. FirebaseError: If an error occurs while updating the user account. """ - self._user_manager.update_user(uid, **kwargs) + self._user_manager.update_user(uid, display_name=display_name, email=email, phone_number=phone_number, + photo_url=photo_url, password=password, disabled=disabled, email_verified=email_verified, + valid_since=valid_since, custom_claims=custom_claims, providers_to_delete=providers_to_delete, **kwargs) return self.get_user(uid=uid) - def set_custom_user_claims(self, uid, custom_claims): + def set_custom_user_claims( + self, + uid: str, + custom_claims: Optional[Union[dict[str, Any], str]], + ) -> None: """Sets additional claims on an existing user account. Custom claims set via this function can be used to define user roles and privilege levels. @@ -375,7 +434,7 @@ def set_custom_user_claims(self, uid, custom_claims): custom_claims = _user_mgt.DELETE_ATTRIBUTE self._user_manager.update_user(uid, custom_claims=custom_claims) - def delete_user(self, uid): + def delete_user(self, uid: str) -> None: """Deletes the user identified by the specified user ID. Args: @@ -387,7 +446,7 @@ def delete_user(self, uid): """ self._user_manager.delete_user(uid) - def delete_users(self, uids): + def delete_users(self, uids: Sequence[str]) -> _user_mgt.DeleteUsersResult: """Deletes the users specified by the given identifiers. Deleting a non-existing user does not generate an error (the method is @@ -414,7 +473,11 @@ def delete_users(self, uids): result = self._user_manager.delete_users(uids, force_delete=True) return _user_mgt.DeleteUsersResult(result, len(uids)) - def import_users(self, users, hash_alg=None): + def import_users( + self, + users: Sequence[_user_import.ImportUserRecord], + hash_alg: Optional[_user_import.UserImportHash] = None, + ) -> _user_import.UserImportResult: """Imports the specified list of users into Firebase Auth. At most 1000 users can be imported at a time. This operation is optimized for bulk imports @@ -438,7 +501,11 @@ def import_users(self, users, hash_alg=None): result = self._user_manager.import_users(users, hash_alg) return _user_import.UserImportResult(result, len(users)) - def generate_password_reset_link(self, email, action_code_settings=None): + def generate_password_reset_link( + self, + email: Optional[str], + action_code_settings: Optional[_user_mgt.ActionCodeSettings] = None, + ) -> str: """Generates the out-of-band email action link for password reset flows for the specified email address. @@ -459,7 +526,11 @@ def generate_password_reset_link(self, email, action_code_settings=None): return self._user_manager.generate_email_action_link( 'PASSWORD_RESET', email, action_code_settings=action_code_settings) - def generate_email_verification_link(self, email, action_code_settings=None): + def generate_email_verification_link( + self, + email: Optional[str], + action_code_settings: Optional[_user_mgt.ActionCodeSettings] = None, + ) -> str: """Generates the out-of-band email action link for email verification flows for the specified email address. @@ -480,7 +551,11 @@ def generate_email_verification_link(self, email, action_code_settings=None): return self._user_manager.generate_email_action_link( 'VERIFY_EMAIL', email, action_code_settings=action_code_settings) - def generate_sign_in_with_email_link(self, email, action_code_settings): + def generate_sign_in_with_email_link( + self, + email: Optional[str], + action_code_settings: Optional[_user_mgt.ActionCodeSettings], + ) -> str: """Generates the out-of-band email action link for email link sign-in flows, using the action code settings provided. @@ -500,7 +575,7 @@ def generate_sign_in_with_email_link(self, email, action_code_settings): return self._user_manager.generate_email_action_link( 'EMAIL_SIGNIN', email, action_code_settings=action_code_settings) - def get_oidc_provider_config(self, provider_id): + def get_oidc_provider_config(self, provider_id: str) -> _auth_providers.OIDCProviderConfig: """Returns the ``OIDCProviderConfig`` with the given ID. Args: @@ -517,8 +592,16 @@ def get_oidc_provider_config(self, provider_id): return self._provider_manager.get_oidc_provider_config(provider_id) def create_oidc_provider_config( - self, provider_id, client_id, issuer, display_name=None, enabled=None, - client_secret=None, id_token_response_type=None, code_response_type=None): + self, + provider_id: str, + client_id: str, + issuer: str, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + client_secret: Optional[str] = None, + id_token_response_type: Optional[bool] = None, + code_response_type: Optional[bool] = None, + ) -> _auth_providers.OIDCProviderConfig: """Creates a new OIDC provider config from the given parameters. OIDC provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -556,8 +639,16 @@ def create_oidc_provider_config( id_token_response_type=id_token_response_type, code_response_type=code_response_type) def update_oidc_provider_config( - self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None, - client_secret=None, id_token_response_type=None, code_response_type=None): + self, + provider_id: str, + client_id: Optional[str] = None, + issuer: Optional[str] = None, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + client_secret: Optional[str] = None, + id_token_response_type: Optional[bool] = None, + code_response_type: Optional[bool] = None, + ) -> _auth_providers.OIDCProviderConfig: """Updates an existing OIDC provider config with the given parameters. Args: @@ -591,7 +682,7 @@ def update_oidc_provider_config( enabled=enabled, client_secret=client_secret, id_token_response_type=id_token_response_type, code_response_type=code_response_type) - def delete_oidc_provider_config(self, provider_id): + def delete_oidc_provider_config(self, provider_id: str) -> None: """Deletes the ``OIDCProviderConfig`` with the given ID. Args: @@ -605,7 +696,10 @@ def delete_oidc_provider_config(self, provider_id): self._provider_manager.delete_oidc_provider_config(provider_id) def list_oidc_provider_configs( - self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + self, + page_token: Optional[str] = None, + max_results: int = _auth_providers.MAX_LIST_CONFIGS_RESULTS, + ) -> _auth_providers._ListOIDCProviderConfigsPage: """Retrieves a page of OIDC provider configs from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -629,7 +723,7 @@ def list_oidc_provider_configs( """ return self._provider_manager.list_oidc_provider_configs(page_token, max_results) - def get_saml_provider_config(self, provider_id): + def get_saml_provider_config(self, provider_id: str) -> _auth_providers.SAMLProviderConfig: """Returns the ``SAMLProviderConfig`` with the given ID. Args: @@ -646,8 +740,16 @@ def get_saml_provider_config(self, provider_id): return self._provider_manager.get_saml_provider_config(provider_id) def create_saml_provider_config( - self, provider_id, idp_entity_id, sso_url, x509_certificates, rp_entity_id, - callback_url, display_name=None, enabled=None): + self, + provider_id: str, + idp_entity_id: str, + sso_url: str, + x509_certificates: list[str], + rp_entity_id: str, + callback_url: str, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + ) -> _auth_providers.SAMLProviderConfig: """Creates a new SAML provider config from the given parameters. SAML provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -686,8 +788,16 @@ def create_saml_provider_config( callback_url=callback_url, display_name=display_name, enabled=enabled) def update_saml_provider_config( - self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, - rp_entity_id=None, callback_url=None, display_name=None, enabled=None): + self, + provider_id: str, + idp_entity_id: Optional[str] = None, + sso_url: Optional[str] = None, + x509_certificates: Optional[list[str]] = None, + rp_entity_id: Optional[str] = None, + callback_url: Optional[str] = None, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + ) -> _auth_providers.SAMLProviderConfig: """Updates an existing SAML provider config with the given parameters. Args: @@ -715,7 +825,7 @@ def update_saml_provider_config( x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, callback_url=callback_url, display_name=display_name, enabled=enabled) - def delete_saml_provider_config(self, provider_id): + def delete_saml_provider_config(self, provider_id: str) -> None: """Deletes the ``SAMLProviderConfig`` with the given ID. Args: @@ -729,7 +839,10 @@ def delete_saml_provider_config(self, provider_id): self._provider_manager.delete_saml_provider_config(provider_id) def list_saml_provider_configs( - self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + self, + page_token: Optional[str] = None, + max_results: int = _auth_providers.MAX_LIST_CONFIGS_RESULTS, + ) -> _auth_providers._ListSAMLProviderConfigsPage: """Retrieves a page of SAML provider configs from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -753,9 +866,14 @@ def list_saml_provider_configs( """ return self._provider_manager.list_saml_provider_configs(page_token, max_results) - def _check_jwt_revoked_or_disabled(self, verified_claims, exc_type, label): - user = self.get_user(verified_claims.get('uid')) + def _check_jwt_revoked_or_disabled( + self, + verified_claims: dict[str, Any], + exc_type: Callable[[str], exceptions.FirebaseError], + label: str, + ) -> None: + user = self.get_user(verified_claims['uid']) if user.disabled: raise _auth_utils.UserDisabledError('The user record is disabled.') - if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: + if verified_claims['iat'] * 1000 < user.tokens_valid_after_timestamp: raise exc_type(f'The Firebase {label} has been revoked.') diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index cc7949526..9c1653e53 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -14,13 +14,28 @@ """Firebase auth providers management sub module.""" +from collections.abc import Callable +from typing import Any, Generic, Optional, cast +from typing_extensions import Self, TypeVar from urllib import parse import requests from firebase_admin import _auth_utils +from firebase_admin import _http_client from firebase_admin import _user_mgt +__all__ = ( + 'MAX_LIST_CONFIGS_RESULTS', + 'ListProviderConfigsPage', + 'OIDCProviderConfig', + 'ProviderConfig', + 'ProviderConfigClient', + 'SAMLProviderConfig', +) + +_ProviderConfigT = TypeVar('_ProviderConfigT', bound='ProviderConfig', default='ProviderConfig') + MAX_LIST_CONFIGS_RESULTS = 100 @@ -28,20 +43,20 @@ class ProviderConfig: """Parent type for all authentication provider config types.""" - def __init__(self, data): + def __init__(self, data: dict[str, Any]) -> None: self._data = data @property - def provider_id(self): - name = self._data['name'] + def provider_id(self) -> str: + name = cast(str, self._data['name']) return name.split('/')[-1] @property - def display_name(self): + def display_name(self) -> Optional[str]: return self._data.get('displayName') @property - def enabled(self): + def enabled(self) -> bool: return self._data.get('enabled', False) @@ -80,55 +95,60 @@ class SAMLProviderConfig(ProviderConfig): @property def idp_entity_id(self): - return self._data.get('idpConfig', {})['idpEntityId'] + return self._data['idpConfig']['idpEntityId'] @property def sso_url(self): - return self._data.get('idpConfig', {})['ssoUrl'] + return self._data['idpConfig']['ssoUrl'] @property def x509_certificates(self): - certs = self._data.get('idpConfig', {})['idpCertificates'] + certs = self._data['idpConfig']['idpCertificates'] return [c['x509Certificate'] for c in certs] @property def callback_url(self): - return self._data.get('spConfig', {})['callbackUri'] + return self._data['spConfig']['callbackUri'] @property def rp_entity_id(self): - return self._data.get('spConfig', {})['spEntityId'] + return self._data['spConfig']['spEntityId'] -class ListProviderConfigsPage: - """Represents a page of AuthProviderConfig instances retrieved from a Firebase project. +class ListProviderConfigsPage(Generic[_ProviderConfigT]): + """Represents a page of ProviderConfig instances retrieved from a Firebase project. Provides methods for traversing the provider configs included in this page, as well as retrieving subsequent pages. The iterator returned by ``iterate_all()`` can be used to iterate through all provider configs in the Firebase project starting from this page. """ - def __init__(self, download, page_token, max_results): + def __init__( + self, + download: Callable[[Optional[str], int], dict[str, Any]], + page_token: Optional[str], + max_results: int, + ) -> None: self._download = download self._max_results = max_results self._current = download(page_token, max_results) @property - def provider_configs(self): - """A list of ``AuthProviderConfig`` instances available in this page.""" + def provider_configs(self) -> list[_ProviderConfigT]: + """A list of ``ProviderConfig`` instances available in this page.""" raise NotImplementedError @property - def next_page_token(self): + def next_page_token(self) -> str: """Page token string for the next page (empty string indicates no more pages).""" return self._current.get('nextPageToken', '') @property - def has_next_page(self): + def has_next_page(self) -> bool: """A boolean indicating whether more pages are available.""" return bool(self.next_page_token) - def get_next_page(self): + def get_next_page(self) -> Optional[Self]: """Retrieves the next page of provider configs, if available. Returns: @@ -139,7 +159,7 @@ def get_next_page(self): return self.__class__(self._download, self.next_page_token, self._max_results) return None - def iterate_all(self): + def iterate_all(self) -> '_ProviderConfigIterator[_ProviderConfigT]': """Retrieves an iterator for provider configs. Returned iterator will iterate through all the provider configs in the Firebase project @@ -147,30 +167,39 @@ def iterate_all(self): in memory at a time. Returns: - iterator: An iterator of AuthProviderConfig instances. + iterator: An iterator of ProviderConfig instances. """ return _ProviderConfigIterator(self) -class _ListOIDCProviderConfigsPage(ListProviderConfigsPage): - +class _ListOIDCProviderConfigsPage(ListProviderConfigsPage[OIDCProviderConfig]): @property - def provider_configs(self): - return [OIDCProviderConfig(data) for data in self._current.get('oauthIdpConfigs', [])] + def provider_configs(self) -> list[OIDCProviderConfig]: + return [ + OIDCProviderConfig(data) + for data in cast( + list[dict[str, Any]], + self._current.get('oauthIdpConfigs', []), + ) + ] -class _ListSAMLProviderConfigsPage(ListProviderConfigsPage): - +class _ListSAMLProviderConfigsPage(ListProviderConfigsPage[SAMLProviderConfig]): @property - def provider_configs(self): - return [SAMLProviderConfig(data) for data in self._current.get('inboundSamlConfigs', [])] - + def provider_configs(self) -> list[SAMLProviderConfig]: + return [ + SAMLProviderConfig(data) + for data in cast( + list[dict[str, Any]], + self._current.get('inboundSamlConfigs', []), + ) + ] -class _ProviderConfigIterator(_auth_utils.PageIterator): +class _ProviderConfigIterator(_auth_utils.PageIterator[ListProviderConfigsPage[_ProviderConfigT]]): @property - def items(self): - return self._current_page.provider_configs + def items(self) -> list[_ProviderConfigT]: + return self._current_page.provider_configs if self._current_page else [] class ProviderConfigClient: @@ -178,24 +207,38 @@ class ProviderConfigClient: PROVIDER_CONFIG_URL = 'https://identitytoolkit.googleapis.com/v2' - def __init__(self, http_client, project_id, tenant_id=None, url_override=None): + def __init__( + self, + http_client: _http_client.HttpClient[dict[str, Any]], + project_id: str, + tenant_id: Optional[str] = None, + url_override: Optional[str] = None, + ) -> None: self.http_client = http_client url_prefix = url_override or self.PROVIDER_CONFIG_URL self.base_url = f'{url_prefix}/projects/{project_id}' if tenant_id: self.base_url += f'/tenants/{tenant_id}' - def get_oidc_provider_config(self, provider_id): + def get_oidc_provider_config(self, provider_id: str) -> OIDCProviderConfig: _validate_oidc_provider_id(provider_id) body = self._make_request('get', f'/oauthIdpConfigs/{provider_id}') return OIDCProviderConfig(body) def create_oidc_provider_config( - self, provider_id, client_id, issuer, display_name=None, enabled=None, - client_secret=None, id_token_response_type=None, code_response_type=None): + self, + provider_id: str, + client_id: str, + issuer: str, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + client_secret: Optional[str] = None, + id_token_response_type: Optional[bool] = None, + code_response_type: Optional[bool] = None, + ) -> OIDCProviderConfig: """Creates a new OIDC provider config from the given parameters.""" _validate_oidc_provider_id(provider_id) - req = { + req: dict[str, Any] = { 'clientId': _validate_non_empty_string(client_id, 'client_id'), 'issuer': _validate_url(issuer, 'issuer'), } @@ -204,7 +247,7 @@ def create_oidc_provider_config( if enabled is not None: req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') - response_type = {} + response_type: dict[str, Any] = {} if id_token_response_type is False and code_response_type is False: raise ValueError('At least one response type must be returned.') if id_token_response_type is not None: @@ -223,12 +266,19 @@ def create_oidc_provider_config( return OIDCProviderConfig(body) def update_oidc_provider_config( - self, provider_id, client_id=None, issuer=None, display_name=None, - enabled=None, client_secret=None, id_token_response_type=None, - code_response_type=None): + self, + provider_id: str, + client_id: Optional[str] = None, + issuer: Optional[str] = None, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + client_secret: Optional[str] = None, + id_token_response_type: Optional[bool] = None, + code_response_type: Optional[bool] = None, + ) -> OIDCProviderConfig: """Updates an existing OIDC provider config with the given parameters.""" _validate_oidc_provider_id(provider_id) - req = {} + req: dict[str, Any] = {} if display_name is not None: if display_name == _user_mgt.DELETE_ATTRIBUTE: req['displayName'] = None @@ -264,28 +314,44 @@ def update_oidc_provider_config( body = self._make_request('patch', url, json=req, params=params) return OIDCProviderConfig(body) - def delete_oidc_provider_config(self, provider_id): + def delete_oidc_provider_config(self, provider_id: str) -> None: _validate_oidc_provider_id(provider_id) self._make_request('delete', f'/oauthIdpConfigs/{provider_id}') - def list_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def list_oidc_provider_configs( + self, + page_token: Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> _ListOIDCProviderConfigsPage: return _ListOIDCProviderConfigsPage( self._fetch_oidc_provider_configs, page_token, max_results) - def _fetch_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def _fetch_oidc_provider_configs( + self, + page_token: Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> dict[str, Any]: return self._fetch_provider_configs('/oauthIdpConfigs', page_token, max_results) - def get_saml_provider_config(self, provider_id): + def get_saml_provider_config(self, provider_id: str) -> SAMLProviderConfig: _validate_saml_provider_id(provider_id) body = self._make_request('get', f'/inboundSamlConfigs/{provider_id}') return SAMLProviderConfig(body) def create_saml_provider_config( - self, provider_id, idp_entity_id, sso_url, x509_certificates, - rp_entity_id, callback_url, display_name=None, enabled=None): + self, + provider_id: str, + idp_entity_id: str, + sso_url: str, + x509_certificates: list[str], + rp_entity_id: str, + callback_url: str, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + ) -> SAMLProviderConfig: """Creates a new SAML provider config from the given parameters.""" _validate_saml_provider_id(provider_id) - req = { + req: dict[str, Any] = { 'idpConfig': { 'idpEntityId': _validate_non_empty_string(idp_entity_id, 'idp_entity_id'), 'ssoUrl': _validate_url(sso_url, 'sso_url'), @@ -306,11 +372,19 @@ def create_saml_provider_config( return SAMLProviderConfig(body) def update_saml_provider_config( - self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, - rp_entity_id=None, callback_url=None, display_name=None, enabled=None): + self, + provider_id: str, + idp_entity_id: Optional[str] = None, + sso_url: Optional[str] = None, + x509_certificates: Optional[list[str]]=None, + rp_entity_id: Optional[str] = None, + callback_url: Optional[str] = None, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + ) -> SAMLProviderConfig: """Updates an existing SAML provider config with the given parameters.""" _validate_saml_provider_id(provider_id) - idp_config = {} + idp_config: dict[str, Any] = {} if idp_entity_id is not None: idp_config['idpEntityId'] = _validate_non_empty_string(idp_entity_id, 'idp_entity_id') if sso_url is not None: @@ -318,13 +392,13 @@ def update_saml_provider_config( if x509_certificates is not None: idp_config['idpCertificates'] = _validate_x509_certificates(x509_certificates) - sp_config = {} + sp_config: dict[str, Any] = {} if rp_entity_id is not None: sp_config['spEntityId'] = _validate_non_empty_string(rp_entity_id, 'rp_entity_id') if callback_url is not None: sp_config['callbackUri'] = _validate_url(callback_url, 'callback_url') - req = {} + req: dict[str, Any] = {} if display_name is not None: if display_name == _user_mgt.DELETE_ATTRIBUTE: req['displayName'] = None @@ -346,18 +420,31 @@ def update_saml_provider_config( body = self._make_request('patch', url, json=req, params=params) return SAMLProviderConfig(body) - def delete_saml_provider_config(self, provider_id): + def delete_saml_provider_config(self, provider_id: str) -> None: _validate_saml_provider_id(provider_id) self._make_request('delete', f'/inboundSamlConfigs/{provider_id}') - def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def list_saml_provider_configs( + self, + page_token: Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> _ListSAMLProviderConfigsPage: return _ListSAMLProviderConfigsPage( self._fetch_saml_provider_configs, page_token, max_results) - def _fetch_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def _fetch_saml_provider_configs( + self, + page_token: Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> dict[str, Any]: return self._fetch_provider_configs('/inboundSamlConfigs', page_token, max_results) - def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def _fetch_provider_configs( + self, + path: str, + page_token: Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> dict[str, Any]: """Fetches a page of auth provider configs""" if page_token is not None: if not isinstance(page_token, str) or not page_token: @@ -374,7 +461,7 @@ def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CO params += f'&pageToken={page_token}' return self._make_request('get', path, params=params) - def _make_request(self, method, path, **kwargs): + def _make_request(self, method: str, path: str, **kwargs: Any) -> dict[str, Any]: url = f'{self.base_url}{path}' try: return self.http_client.body(method, url, **kwargs) @@ -382,7 +469,7 @@ def _make_request(self, method, path, **kwargs): raise _auth_utils.handle_auth_backend_error(error) -def _validate_oidc_provider_id(provider_id): +def _validate_oidc_provider_id(provider_id: Any) -> str: if not isinstance(provider_id, str): raise ValueError( f'Invalid OIDC provider ID: {provider_id}. Provider ID must be a non-empty string.') @@ -391,7 +478,7 @@ def _validate_oidc_provider_id(provider_id): return provider_id -def _validate_saml_provider_id(provider_id): +def _validate_saml_provider_id(provider_id: Any) -> str: if not isinstance(provider_id, str): raise ValueError( f'Invalid SAML provider ID: {provider_id}. Provider ID must be a non-empty string.') @@ -400,7 +487,7 @@ def _validate_saml_provider_id(provider_id): return provider_id -def _validate_non_empty_string(value, label): +def _validate_non_empty_string(value: Any, label: str) -> str: """Validates that the given value is a non-empty string.""" if not isinstance(value, str): raise ValueError(f'Invalid type for {label}: {value}.') @@ -409,7 +496,7 @@ def _validate_non_empty_string(value, label): return value -def _validate_url(url, label): +def _validate_url(url: Any, label: str) -> str: """Validates that the given value is a well-formed URL string.""" if not isinstance(url, str) or not url: raise ValueError( @@ -423,9 +510,10 @@ def _validate_url(url, label): raise ValueError(f'Malformed {label}: "{url}".') from exception -def _validate_x509_certificates(x509_certificates): +def _validate_x509_certificates(x509_certificates: Any) -> list[dict[str, str]]: if not isinstance(x509_certificates, list) or not x509_certificates: raise ValueError('x509_certificates must be a non-empty list.') + x509_certificates = cast(list[Any], x509_certificates) if not all(isinstance(cert, str) and cert for cert in x509_certificates): raise ValueError('x509_certificates must only contain non-empty strings.') return [{'x509Certificate': cert} for cert in x509_certificates] diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 60d411822..e702ff8f2 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -17,22 +17,97 @@ import json import os import re +from collections.abc import Callable, Iterator, Sequence +from typing import ( + Any, + Generic, + Literal, + Optional, + Protocol, + Union, + cast, + overload, +) from urllib import parse +import httpx +import requests +from typing_extensions import Self, TypeVar + from firebase_admin import exceptions from firebase_admin import _utils +__all__ = ( + 'EMULATOR_HOST_ENV_VAR', + 'MAX_CLAIMS_PAYLOAD_SIZE', + 'RESERVED_CLAIMS', + 'VALID_EMAIL_ACTION_TYPES', + 'ConfigurationNotFoundError', + 'EmailAlreadyExistsError', + 'EmailNotFoundError', + 'InsufficientPermissionError', + 'InvalidDynamicLinkDomainError', + 'InvalidIdTokenError', + 'PhoneNumberAlreadyExistsError', + 'ResetPasswordExceedLimitError', + 'TenantNotFoundError', + 'TenantIdMismatchError', + 'TooManyAttemptsTryLaterError', + 'UidAlreadyExistsError', + 'UnexpectedResponseError', + 'UserDisabledError', + 'UserNotFoundError', + 'PageIterator', + 'build_update_mask', + 'get_emulator_host', + 'handle_auth_backend_error', + 'is_emulated', + 'validate_action_type', + 'validate_boolean', + 'validate_bytes', + 'validate_custom_claims', + 'validate_display_name', + 'validate_email', + 'validate_int', + 'validate_password', + 'validate_phone', + 'validate_photo_url', + 'validate_provider_id', + 'validate_provider_ids', + 'validate_provider_uid', + 'validate_string', + 'validate_timestamp', + 'validate_uid', +) + +_PageT = TypeVar('_PageT', bound='_Page') +_ErrorT = TypeVar( + '_ErrorT', bound=exceptions.FirebaseError, default=exceptions.FirebaseError +) + +_EmailActionType = Literal[ + 'VERIFY_EMAIL', + 'EMAIL_SIGNIN', + 'PASSWORD_RESET', +] EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' MAX_CLAIMS_PAYLOAD_SIZE = 1000 -RESERVED_CLAIMS = set([ +RESERVED_CLAIMS = { 'acr', 'amr', 'at_hash', 'aud', 'auth_time', 'azp', 'cnf', 'c_hash', 'exp', 'iat', 'iss', 'jti', 'nbf', 'nonce', 'sub', 'firebase', -]) -VALID_EMAIL_ACTION_TYPES = set(['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET']) +} +VALID_EMAIL_ACTION_TYPES = {'VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET'} + + +class _Page(Protocol): + @property + def has_next_page(self) -> bool: ... + + def get_next_page(self) -> Optional[Self]: ... -class PageIterator: +class PageIterator(Generic[_PageT]): """An iterator that allows iterating over a sequence of items, one at a time. This implementation loads a page of items into memory, and iterates on them. When the whole @@ -40,21 +115,21 @@ class PageIterator: of entries in memory. """ - def __init__(self, current_page): + def __init__(self, current_page: _PageT) -> None: if not current_page: raise ValueError('Current page must not be None.') - self._current_page = current_page - self._iter = None + self._current_page: Optional[_PageT] = current_page + self._iter: Optional[Iterator[_PageT]] = None - def __next__(self): + def __next__(self) -> _PageT: if self._iter is None: self._iter = iter(self.items) try: return next(self._iter) except StopIteration: - if self._current_page.has_next_page: + if self._current_page and self._current_page.has_next_page: self._current_page = self._current_page.get_next_page() self._iter = iter(self.items) @@ -62,15 +137,15 @@ def __next__(self): raise - def __iter__(self): + def __iter__(self) -> Iterator[_PageT]: return self @property - def items(self): + def items(self) -> Sequence[Any]: raise NotImplementedError -def get_emulator_host(): +def get_emulator_host() -> str: emulator_host = os.getenv(EMULATOR_HOST_ENV_VAR, '') if emulator_host and '//' in emulator_host: raise ValueError( @@ -79,11 +154,15 @@ def get_emulator_host(): return emulator_host -def is_emulated(): +def is_emulated() -> bool: return get_emulator_host() != '' -def validate_uid(uid, required=False): +@overload +def validate_uid(uid: Optional[Any], required: Literal[True]) -> str: ... +@overload +def validate_uid(uid: Optional[Any], required: bool = False) -> Optional[str]: ... +def validate_uid(uid: Optional[Any], required: bool = False) -> Optional[str]: if uid is None and not required: return None if not isinstance(uid, str) or not uid or len(uid) > 128: @@ -92,7 +171,12 @@ def validate_uid(uid, required=False): 'characters.') return uid -def validate_email(email, required=False): + +@overload +def validate_email(email: Optional[Any], required: Literal[True]) -> str: ... +@overload +def validate_email(email: Optional[Any], required: bool = False) -> Optional[str]: ... +def validate_email(email: Optional[Any], required: bool = False) -> Optional[str]: if email is None and not required: return None if not isinstance(email, str) or not email: @@ -103,7 +187,12 @@ def validate_email(email, required=False): raise ValueError(f'Malformed email address string: "{email}".') return email -def validate_phone(phone, required=False): + +@overload +def validate_phone(phone: Optional[Any], required: Literal[True]) -> str: ... +@overload +def validate_phone(phone: Optional[Any], required: bool = False) -> Optional[str]: ... +def validate_phone(phone: Optional[Any], required: bool = False) -> Optional[str]: """Validates the specified phone number. Phone number vlidation is very lax here. Backend will enforce E.164 spec compliance, and @@ -121,7 +210,14 @@ def validate_phone(phone, required=False): 'compliant identifier.') return phone -def validate_password(password, required=False): + +@overload +def validate_password(password: Optional[Any], required: Literal[True]) -> str: ... +@overload +def validate_password( + password: Optional[Any], required: bool = False +) -> Optional[str]: ... +def validate_password(password: Optional[Any], required: bool = False) -> Optional[str]: if password is None and not required: return None if not isinstance(password, str) or len(password) < 6: @@ -129,14 +225,36 @@ def validate_password(password, required=False): 'Invalid password string. Password must be a string at least 6 characters long.') return password -def validate_bytes(value, label, required=False): + +@overload +def validate_bytes( + value: Optional[Any], label: Any, required: Literal[True] +) -> bytes: ... +@overload +def validate_bytes( + value: Optional[Any], label: Any, required: bool = False +) -> Optional[bytes]: ... +def validate_bytes( + value: Optional[Any], label: Any, required: bool = False +) -> Optional[bytes]: if value is None and not required: return None if not isinstance(value, bytes) or not value: raise ValueError(f'{label} must be a non-empty byte sequence.') return value -def validate_display_name(display_name, required=False): + +@overload +def validate_display_name( + display_name: Optional[Any], required: Literal[True] +) -> str: ... +@overload +def validate_display_name( + display_name: Optional[Any], required: bool = False +) -> Optional[str]: ... +def validate_display_name( + display_name: Optional[Any], required: bool = False +) -> Optional[str]: if display_name is None and not required: return None if not isinstance(display_name, str) or not display_name: @@ -145,7 +263,18 @@ def validate_display_name(display_name, required=False): 'string.') return display_name -def validate_provider_id(provider_id, required=True): + +@overload +def validate_provider_id( + provider_id: Optional[Any], required: Literal[True] +) -> str: ... +@overload +def validate_provider_id( + provider_id: Optional[Any], required: bool = True +) -> Optional[str]: ... +def validate_provider_id( + provider_id: Optional[Any], required: bool = True +) -> Optional[str]: if provider_id is None and not required: return None if not isinstance(provider_id, str) or not provider_id: @@ -153,7 +282,18 @@ def validate_provider_id(provider_id, required=True): f'Invalid provider ID: "{provider_id}". Provider ID must be a non-empty string.') return provider_id -def validate_provider_uid(provider_uid, required=True): + +@overload +def validate_provider_uid( + provider_uid: Optional[Any], required: Literal[True] = True +) -> str: ... +@overload +def validate_provider_uid( + provider_uid: Optional[Any], required: bool = True +) -> Optional[str]: ... +def validate_provider_uid( + provider_uid: Optional[Any], required: bool = True +) -> Optional[str]: if provider_uid is None and not required: return None if not isinstance(provider_uid, str) or not provider_uid: @@ -161,7 +301,16 @@ def validate_provider_uid(provider_uid, required=True): f'Invalid provider UID: "{provider_uid}". Provider UID must be a non-empty string.') return provider_uid -def validate_photo_url(photo_url, required=False): + +@overload +def validate_photo_url(photo_url: Optional[Any], required: Literal[True]) -> str: ... +@overload +def validate_photo_url( + photo_url: Optional[Any], required: bool = False +) -> Optional[str]: ... +def validate_photo_url( + photo_url: Optional[Any], required: bool = False +) -> Optional[str]: """Parses and validates the given URL string.""" if photo_url is None and not required: return None @@ -176,14 +325,31 @@ def validate_photo_url(photo_url, required=False): except Exception as err: raise ValueError(f'Malformed photo URL: "{photo_url}".') from err -def validate_timestamp(timestamp, label, required=False): + +@overload +def validate_timestamp( + timestamp: Optional[Any], + label: Any, + required: Literal[True], +) -> int: ... +@overload +def validate_timestamp( + timestamp: Optional[Any], + label: Any, + required: bool = False, +) -> Optional[int]: ... +def validate_timestamp( + timestamp: Optional[Any], + label: Any, + required: bool = False, +) -> Optional[int]: """Validates the given timestamp value. Timestamps must be positive integers.""" if timestamp is None and not required: return None if isinstance(timestamp, bool): raise ValueError('Boolean value specified as timestamp.') try: - timestamp_int = int(timestamp) + timestamp_int = int(timestamp) # pyright: ignore[reportArgumentType] except TypeError as err: raise ValueError(f'Invalid type for timestamp value: {timestamp}.') from err if timestamp_int != timestamp: @@ -192,7 +358,13 @@ def validate_timestamp(timestamp, label, required=False): raise ValueError(f'{label} timestamp must be a positive interger.') return timestamp_int -def validate_int(value, label, low=None, high=None): + +def validate_int( + value: Any, + label: Any, + low: Optional[int] = None, + high: Optional[int] = None, +) -> int: """Validates that the given value represents an integer. There are several ways to represent an integer in Python (e.g. 2, 2L, 2.0). This method allows @@ -215,19 +387,28 @@ def validate_int(value, label, low=None, high=None): raise ValueError(f'{label} must not be larger than {high}.') return val_int -def validate_string(value, label): + +def validate_string(value: Any, label: Any) -> str: """Validates that the given value is a string.""" if not isinstance(value, str): raise ValueError(f'Invalid type for {label}: {value}.') return value -def validate_boolean(value, label): + +def validate_boolean(value: Any, label: Any) -> bool: """Validates that the given value is a boolean.""" if not isinstance(value, bool): raise ValueError(f'Invalid type for {label}: {value}.') return value -def validate_custom_claims(custom_claims, required=False): + +@overload +def validate_custom_claims(custom_claims: Any, required: Literal[True]) -> str: ... +@overload +def validate_custom_claims( + custom_claims: Any, required: bool = False +) -> Optional[str]: ... +def validate_custom_claims(custom_claims: Any, required: bool = False) -> Optional[str]: """Validates the specified custom claims. Custom claims must be specified as a JSON string. The string must not exceed 1000 @@ -255,14 +436,18 @@ def validate_custom_claims(custom_claims, required=False): f'Claim "{invalid_claims.pop()}" is reserved, and must not be set.') return claims_str -def validate_action_type(action_type): + +def validate_action_type( + action_type: Any, +) -> Literal['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET']: if action_type not in VALID_EMAIL_ACTION_TYPES: raise ValueError( f'Invalid action type provided action_type: {action_type}. Valid values are ' f'{", ".join(VALID_EMAIL_ACTION_TYPES)}') return action_type -def validate_provider_ids(provider_ids, required=False): + +def validate_provider_ids(provider_ids: Any, required: bool = False) -> list[str]: if not provider_ids: if required: raise ValueError('Invalid provider IDs. Provider ids should be provided') @@ -271,9 +456,10 @@ def validate_provider_ids(provider_ids, required=False): validate_provider_id(provider_id, True) return provider_ids -def build_update_mask(params): + +def build_update_mask(params: dict[str, Any]) -> list[str]: """Creates an update mask list from the given dictionary.""" - mask = [] + mask: list[str] = [] for key, value in params.items(): if isinstance(value, dict): child_mask = build_update_mask(value) @@ -290,8 +476,13 @@ class UidAlreadyExistsError(exceptions.AlreadyExistsError): default_message = 'The user with the provided uid already exists' - def __init__(self, message, cause, http_response): - exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception], + http_response: Optional[Union[httpx.Response, requests.Response]], + ) -> None: + super().__init__(message, cause, http_response) class EmailAlreadyExistsError(exceptions.AlreadyExistsError): @@ -299,8 +490,13 @@ class EmailAlreadyExistsError(exceptions.AlreadyExistsError): default_message = 'The user with the provided email already exists' - def __init__(self, message, cause, http_response): - exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception], + http_response: Optional[Union[httpx.Response, requests.Response]], + ) -> None: + super().__init__(message, cause, http_response) class InsufficientPermissionError(exceptions.PermissionDeniedError): @@ -311,8 +507,13 @@ class InsufficientPermissionError(exceptions.PermissionDeniedError): 'https://firebase.google.com/docs/admin/setup for details ' 'on how to initialize the Admin SDK with appropriate permissions') - def __init__(self, message, cause, http_response): - exceptions.PermissionDeniedError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception], + http_response: Optional[Union[httpx.Response, requests.Response]], + ) -> None: + super().__init__(message, cause, http_response) class InvalidDynamicLinkDomainError(exceptions.InvalidArgumentError): @@ -320,8 +521,13 @@ class InvalidDynamicLinkDomainError(exceptions.InvalidArgumentError): default_message = 'Dynamic link domain specified in ActionCodeSettings is not authorized' - def __init__(self, message, cause, http_response): - exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception], + http_response: Optional[Union[httpx.Response, requests.Response]], + ) -> None: + super().__init__(message, cause, http_response) class InvalidIdTokenError(exceptions.InvalidArgumentError): @@ -329,8 +535,13 @@ class InvalidIdTokenError(exceptions.InvalidArgumentError): default_message = 'The provided ID token is invalid' - def __init__(self, message, cause=None, http_response=None): - exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class PhoneNumberAlreadyExistsError(exceptions.AlreadyExistsError): @@ -338,15 +549,25 @@ class PhoneNumberAlreadyExistsError(exceptions.AlreadyExistsError): default_message = 'The user with the provided phone number already exists' - def __init__(self, message, cause, http_response): - exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception], + http_response: Optional[Union[httpx.Response, requests.Response]], + ) -> None: + super().__init__(message, cause, http_response) class UnexpectedResponseError(exceptions.UnknownError): """Backend service responded with an unexpected or malformed response.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.UnknownError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class UserNotFoundError(exceptions.NotFoundError): @@ -354,8 +575,13 @@ class UserNotFoundError(exceptions.NotFoundError): default_message = 'No user record found for the given identifier' - def __init__(self, message, cause=None, http_response=None): - exceptions.NotFoundError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class EmailNotFoundError(exceptions.NotFoundError): @@ -363,8 +589,13 @@ class EmailNotFoundError(exceptions.NotFoundError): default_message = 'No user record found for the given email' - def __init__(self, message, cause=None, http_response=None): - exceptions.NotFoundError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class TenantNotFoundError(exceptions.NotFoundError): @@ -372,15 +603,20 @@ class TenantNotFoundError(exceptions.NotFoundError): default_message = 'No tenant found for the given identifier' - def __init__(self, message, cause=None, http_response=None): - exceptions.NotFoundError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class TenantIdMismatchError(exceptions.InvalidArgumentError): """Missing or invalid tenant ID field in the given JWT.""" - def __init__(self, message): - exceptions.InvalidArgumentError.__init__(self, message) + def __init__(self, message: str) -> None: + super().__init__(message) class ConfigurationNotFoundError(exceptions.NotFoundError): @@ -388,8 +624,13 @@ class ConfigurationNotFoundError(exceptions.NotFoundError): default_message = 'No auth provider found for the given identifier' - def __init__(self, message, cause=None, http_response=None): - exceptions.NotFoundError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class UserDisabledError(exceptions.InvalidArgumentError): @@ -397,22 +638,37 @@ class UserDisabledError(exceptions.InvalidArgumentError): default_message = 'The user record is disabled' - def __init__(self, message, cause=None, http_response=None): - exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class TooManyAttemptsTryLaterError(exceptions.ResourceExhaustedError): """Rate limited because of too many attempts.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class ResetPasswordExceedLimitError(exceptions.ResourceExhaustedError): """Reset password emails exceeded their limits.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) _CODE_TO_EXC_TYPE = { @@ -432,7 +688,7 @@ def __init__(self, message, cause=None, http_response=None): } -def handle_auth_backend_error(error): +def handle_auth_backend_error(error: requests.RequestException) -> exceptions.FirebaseError: """Converts a requests error received from the Firebase Auth service into a FirebaseError.""" if error.response is None: return _utils.handle_requests_error(error) @@ -450,19 +706,26 @@ def handle_auth_backend_error(error): return exc_type(msg, cause=error, http_response=error.response) -def _parse_error_body(response): +def _parse_error_body( + response: requests.Response, +) -> tuple[Optional[str], Optional[str]]: """Parses the given error response to extract Auth error code and message.""" - error_dict = {} + parsed_body = None try: parsed_body = response.json() - if isinstance(parsed_body, dict): - error_dict = parsed_body.get('error', {}) except ValueError: pass + if not isinstance(parsed_body, dict): + return None, None + # Auth error response format: {"error": {"message": "AUTH_ERROR_CODE: Optional text"}} - code = error_dict.get('message') if isinstance(error_dict, dict) else None - custom_message = None + parsed_body = cast(dict[str, Any], parsed_body) + error_dict = parsed_body.get('error', {}) + if not isinstance(error_dict, dict): + return None, None + error_dict = cast(dict[str, str], error_dict) + code, custom_message = error_dict.get('message'), None if code: separator = code.find(':') if separator != -1: @@ -472,8 +735,14 @@ def _parse_error_body(response): return code, custom_message -def _build_error_message(code, exc_type, custom_message): - default_message = exc_type.default_message if ( - exc_type and hasattr(exc_type, 'default_message')) else 'Error while calling Auth service' +def _build_error_message( + code: str, + exc_type: Optional[Callable[ + [str, Optional[Exception], Optional[requests.Response]], + exceptions.FirebaseError + ]], + custom_message: Optional[str], +) -> str: + default_message = getattr(exc_type, 'default_message', 'Error while calling Auth service') ext = f' {custom_message}' if custom_message else '' return f'{default_message} ({code}).{ext}' diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index 6d2582291..0ecc69cb5 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -17,25 +17,48 @@ This module provides utilities for making HTTP calls using the requests library. """ -from __future__ import annotations import logging -from typing import Any, Dict, Generator, Optional, Tuple, Union +from collections.abc import Generator, Iterable +from typing import TYPE_CHECKING, Any, Generic, Optional, Union + import httpx +import google.auth.transport.requests +import google.auth.credentials import requests.adapters -from requests.packages.urllib3.util import retry # pylint: disable=import-error -from google.auth import credentials -from google.auth import transport -from google.auth.transport import requests as google_auth_requests +import requests.structures +import typing_extensions +from firebase_admin import _retry from firebase_admin import _utils -from firebase_admin._retry import HttpxRetry, HttpxRetryTransport + +if TYPE_CHECKING: + from urllib3.util import retry + from _typeshed import SupportsKeysAndGetItem +else: + from requests.packages.urllib3.util import retry # pylint: disable=import-error + +__all__ = ( + 'DEFAULT_HTTPX_RETRY_CONFIG', + 'DEFAULT_RETRY_CONFIG', + 'DEFAULT_TIMEOUT_SECONDS', + 'METRICS_HEADERS', + 'GoogleAuthCredentialFlow', + 'HttpClient', + 'HttpxAsyncClient', + 'JsonHttpClient', +) logger = logging.getLogger(__name__) +_T = typing_extensions.TypeVar('_T', default=Any) + +_ANY_METHOD: dict[str, Any] = {} + if hasattr(retry.Retry.DEFAULT, 'allowed_methods'): - _ANY_METHOD = {'allowed_methods': None} + _ANY_METHOD['allowed_methods'] = None else: - _ANY_METHOD = {'method_whitelist': None} + _ANY_METHOD['method_whitelist'] = None + # Default retry configuration: Retries once on low-level connection and socket read errors. # Retries up to 4 times on HTTP 500 and 503 errors, with exponential backoff. Returns the # last response upon exhausting all retries. @@ -43,17 +66,16 @@ connect=1, read=1, status=4, status_forcelist=[500, 503], raise_on_status=False, backoff_factor=0.5, **_ANY_METHOD) -DEFAULT_HTTPX_RETRY_CONFIG = HttpxRetry( +DEFAULT_HTTPX_RETRY_CONFIG = _retry.HttpxRetry( max_retries=4, status_forcelist=[500, 503], backoff_factor=0.5) - DEFAULT_TIMEOUT_SECONDS = 120 METRICS_HEADERS = { 'x-goog-api-client': _utils.get_metrics_header(), } -class HttpClient: +class HttpClient(Generic[_T]): """Base HTTP client used to make HTTP calls. HttpClient maintains an HTTP session, and handles request authentication and retries if @@ -61,8 +83,17 @@ class HttpClient: """ def __init__( - self, credential=None, session=None, base_url='', headers=None, - retries=DEFAULT_RETRY_CONFIG, timeout=DEFAULT_TIMEOUT_SECONDS): + self, + credential: Optional[google.auth.credentials.Credentials] = None, + session: Optional[requests.Session] = None, + base_url: str = '', + headers: Optional[Union[ + 'SupportsKeysAndGetItem[str, Union[bytes, str]]', + Iterable[tuple[str, Union[bytes, str]]], + ]] = None, + retries: retry.Retry = DEFAULT_RETRY_CONFIG, + timeout: int = DEFAULT_TIMEOUT_SECONDS, + ) -> None: """Creates a new HttpClient instance from the provided arguments. If a credential is provided, initializes a new HTTP session authorized with it. If neither @@ -79,8 +110,9 @@ def __init__( timeout: HTTP timeout in seconds. Defaults to 120 seconds when not specified. Set to None to disable timeouts (optional). """ + self._session: Optional[requests.Session] if credential: - self._session = transport.requests.AuthorizedSession(credential) + self._session = google.auth.transport.requests.AuthorizedSession(credential) elif session: self._session = session else: @@ -95,21 +127,21 @@ def __init__( self._timeout = timeout @property - def session(self): + def session(self) -> Optional[requests.Session]: return self._session @property - def base_url(self): + def base_url(self) -> str: return self._base_url @property - def timeout(self): + def timeout(self) -> int: return self._timeout - def parse_body(self, resp): + def parse_body(self, resp: requests.Response) -> _T: raise NotImplementedError - def request(self, method, url, **kwargs): + def request(self, method: str, url: str, **kwargs: Any) -> requests.Response: """Makes an HTTP call using the Python requests library. This is the sole entry point to the requests library. All other helper methods in this @@ -132,51 +164,58 @@ class call this method to send HTTP requests out. Refer to if 'timeout' not in kwargs: kwargs['timeout'] = self.timeout kwargs.setdefault('headers', {}).update(METRICS_HEADERS) + # possible issue: _session can be None resp = self._session.request(method, self.base_url + url, **kwargs) resp.raise_for_status() return resp - def headers(self, method, url, **kwargs): + def headers(self, method: str, url: str, **kwargs: Any) -> 'requests.structures.CaseInsensitiveDict[str]': resp = self.request(method, url, **kwargs) return resp.headers - def body_and_response(self, method, url, **kwargs): + def body_and_response(self, method: str, url: str, **kwargs: Any) -> tuple[_T, requests.Response]: resp = self.request(method, url, **kwargs) return self.parse_body(resp), resp - def body(self, method, url, **kwargs): + def body(self, method: str, url: str, **kwargs: Any) -> _T: resp = self.request(method, url, **kwargs) return self.parse_body(resp) - def headers_and_body(self, method, url, **kwargs): + def headers_and_body( + self, + method: str, + url: str, + **kwargs: Any, + ) -> tuple[requests.structures.CaseInsensitiveDict[str], _T]: resp = self.request(method, url, **kwargs) return resp.headers, self.parse_body(resp) - def close(self): - self._session.close() - self._session = None + def close(self) -> None: + if self._session is not None: + self._session.close() + self._session = None -class JsonHttpClient(HttpClient): - """An HTTP client that parses response messages as JSON.""" - def __init__(self, **kwargs): - HttpClient.__init__(self, **kwargs) +class JsonHttpClient(HttpClient[dict[str, Any]]): + """An HTTP client that parses response messages as JSON.""" - def parse_body(self, resp): + def parse_body(self, resp: requests.Response) -> dict[str, Any]: return resp.json() + + class GoogleAuthCredentialFlow(httpx.Auth): """Google Auth Credential Auth Flow""" - def __init__(self, credential: credentials.Credentials): + def __init__(self, credential: google.auth.credentials.Credentials) -> None: self._credential = credential self._max_refresh_attempts = 2 self._refresh_status_codes = (401,) def apply_auth_headers( - self, - request: httpx.Request, - auth_request: google_auth_requests.Request - ) -> None: + self, + request: httpx.Request, + auth_request: google.auth.transport.requests.Request, + ) -> None: """A helper function that refreshes credentials if needed and mutates the request headers to contain access token and any other Google Auth headers.""" @@ -194,7 +233,7 @@ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Re _credential_refresh_attempt = 0 # Create a Google auth request object to be used for refreshing credentials - auth_request = google_auth_requests.Request() + auth_request = google.auth.transport.requests.Request() while True: # Copy original headers for each attempt @@ -237,20 +276,22 @@ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Re break # The last yielded response is automatically returned by httpx's auth flow. -class HttpxAsyncClient(): + +class HttpxAsyncClient: """Async HTTP client used to make HTTP/2 calls using HTTPX. HttpxAsyncClient maintains an async HTTPX client, handles request authentication, and retries if necessary. """ + def __init__( - self, - credential: Optional[credentials.Credentials] = None, - base_url: str = '', - headers: Optional[Union[httpx.Headers, Dict[str, str]]] = None, - retry_config: HttpxRetry = DEFAULT_HTTPX_RETRY_CONFIG, - timeout: int = DEFAULT_TIMEOUT_SECONDS, - http2: bool = True + self, + credential: Optional[google.auth.credentials.Credentials] = None, + base_url: str = '', + headers: Optional[Union[httpx.Headers, dict[str, str]]] = None, + retry_config: _retry.HttpxRetry = DEFAULT_HTTPX_RETRY_CONFIG, + timeout: int = DEFAULT_TIMEOUT_SECONDS, + http2: bool = True, ) -> None: """Creates a new HttpxAsyncClient instance from the provided arguments. @@ -274,8 +315,8 @@ def __init__( # Only set up retries on urls starting with 'http://' and 'https://' self._mounts = { - 'http://': HttpxRetryTransport(retry=self._retry_config, http2=http2), - 'https://': HttpxRetryTransport(retry=self._retry_config, http2=http2) + 'http://': _retry.HttpxRetryTransport(retry=self._retry_config, http2=http2), + 'https://': _retry.HttpxRetryTransport(retry=self._retry_config, http2=http2) } if credential: @@ -295,15 +336,15 @@ def __init__( ) @property - def base_url(self): + def base_url(self) -> str: return self._base_url @property - def timeout(self): + def timeout(self) -> int: return self._timeout @property - def async_client(self): + def async_client(self) -> httpx.AsyncClient: return self._async_client async def request(self, method: str, url: str, **kwargs: Any) -> httpx.Response: @@ -337,7 +378,11 @@ async def headers(self, method: str, url: str, **kwargs: Any) -> httpx.Headers: return resp.headers async def body_and_response( - self, method: str, url: str, **kwargs: Any) -> Tuple[Any, httpx.Response]: + self, + method: str, + url: str, + **kwargs: Any, + ) -> tuple[Any, httpx.Response]: resp = await self.request(method, url, **kwargs) return self.parse_body(resp), resp @@ -346,7 +391,11 @@ async def body(self, method: str, url: str, **kwargs: Any) -> Any: return self.parse_body(resp) async def headers_and_body( - self, method: str, url: str, **kwargs: Any) -> Tuple[httpx.Headers, Any]: + self, + method: str, + url: str, + **kwargs: Any, + ) -> tuple[httpx.Headers, Any]: resp = await self.request(method, url, **kwargs) return resp.headers, self.parse_body(resp) diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index 960a6d742..068015b56 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -19,9 +19,19 @@ import math import numbers import re +from typing import Any, Optional, TypeVar, Union, cast from firebase_admin import _messaging_utils +_K = TypeVar('_K') +_V = TypeVar('_V') + +__all__ = ( + 'Message', + 'MessageEncoder', + 'MulticastMessage', +) + class Message: """A message that can be sent via Firebase Cloud Messaging. @@ -35,7 +45,7 @@ class Message: notification: An instance of ``messaging.Notification`` (optional). android: An instance of ``messaging.AndroidConfig`` (optional). webpush: An instance of ``messaging.WebpushConfig`` (optional). - apns: An instance of ``messaging.ApnsConfig`` (optional). + apns: An instance of ``messaging.APNSConfig`` (optional). fcm_options: An instance of ``messaging.FCMOptions`` (optional). token: The registration token of the device to which the message should be sent (optional). topic: Name of the FCM topic to which the message should be sent (optional). Topic name @@ -43,8 +53,18 @@ class Message: condition: The FCM condition to which the message should be sent (optional). """ - def __init__(self, data=None, notification=None, android=None, webpush=None, apns=None, - fcm_options=None, token=None, topic=None, condition=None): + def __init__( + self, + data: Optional[dict[str, str]] = None, + notification: Optional[_messaging_utils.Notification] = None, + android: Optional[_messaging_utils.AndroidConfig] = None, + webpush: Optional[_messaging_utils.WebpushConfig] = None, + apns: Optional[_messaging_utils.APNSConfig] = None, + fcm_options: Optional[_messaging_utils.FCMOptions] = None, + token: Optional[str] = None, + topic: Optional[str] = None, + condition: Optional[str] = None, + ) -> None: self.data = data self.notification = notification self.android = android @@ -55,7 +75,7 @@ def __init__(self, data=None, notification=None, android=None, webpush=None, apn self.topic = topic self.condition = condition - def __str__(self): + def __str__(self) -> str: return json.dumps(self, cls=MessageEncoder, sort_keys=True) @@ -69,11 +89,19 @@ class MulticastMessage: notification: An instance of ``messaging.Notification`` (optional). android: An instance of ``messaging.AndroidConfig`` (optional). webpush: An instance of ``messaging.WebpushConfig`` (optional). - apns: An instance of ``messaging.ApnsConfig`` (optional). + apns: An instance of ``messaging.APNSConfig`` (optional). fcm_options: An instance of ``messaging.FCMOptions`` (optional). """ - def __init__(self, tokens, data=None, notification=None, android=None, webpush=None, apns=None, - fcm_options=None): + def __init__( + self, + tokens: list[str], + data: Optional[dict[str, str]] = None, + notification: Optional[_messaging_utils.Notification] = None, + android: Optional[_messaging_utils.AndroidConfig] = None, + webpush: Optional[_messaging_utils.WebpushConfig] = None, + apns: Optional[_messaging_utils.APNSConfig] = None, + fcm_options: Optional[_messaging_utils.FCMOptions] = None, + ) -> None: _Validators.check_string_list('MulticastMessage.tokens', tokens) if len(tokens) > 500: raise ValueError('MulticastMessage.tokens must not contain more than 500 tokens.') @@ -93,7 +121,7 @@ class _Validators: """ @classmethod - def check_string(cls, label, value, non_empty=False): + def check_string(cls, label: str, value: Any, non_empty: bool = False) -> Optional[str]: """Checks if the given value is a string.""" if value is None: return None @@ -106,7 +134,7 @@ def check_string(cls, label, value, non_empty=False): return value @classmethod - def check_number(cls, label, value): + def check_number(cls, label: str, value: Any) -> Optional[numbers.Number]: if value is None: return None if not isinstance(value, numbers.Number): @@ -114,12 +142,17 @@ def check_number(cls, label, value): return value @classmethod - def check_string_dict(cls, label, value): + def check_string_dict( + cls, + label: str, + value: Optional[Any], + ) -> Optional[dict[str, str]]: """Checks if the given value is a dictionary comprised only of string keys and values.""" if value is None or value == {}: return None if not isinstance(value, dict): raise ValueError(f'{label} must be a dictionary.') + value = cast(dict[Any, Any], value) non_str = [k for k in value if not isinstance(k, str)] if non_str: raise ValueError(f'{label} must not contain non-string keys.') @@ -129,31 +162,41 @@ def check_string_dict(cls, label, value): return value @classmethod - def check_string_list(cls, label, value): + def check_string_list( + cls, + label: str, + value: Optional[Any], + ) -> Optional[list[str]]: """Checks if the given value is a list comprised only of strings.""" if value is None or value == []: return None if not isinstance(value, list): raise ValueError(f'{label} must be a list of strings.') + value = cast(list[Any], value) non_str = [k for k in value if not isinstance(k, str)] if non_str: raise ValueError(f'{label} must not contain non-string values.') return value @classmethod - def check_number_list(cls, label, value): + def check_number_list( + cls, + label: str, + value: Optional[Any], + ) -> Optional[list[numbers.Number]]: """Checks if the given value is a list comprised only of numbers.""" if value is None or value == []: return None if not isinstance(value, list): raise ValueError(f'{label} must be a list of numbers.') + value = cast(list[Any], value) non_number = [k for k in value if not isinstance(k, numbers.Number)] if non_number: raise ValueError(f'{label} must not contain non-number values.') return value @classmethod - def check_analytics_label(cls, label, value): + def check_analytics_label(cls, label: str, value: Optional[Any]) -> Optional[str]: """Checks if the given value is a valid analytics label.""" value = _Validators.check_string(label, value) if value is not None and not re.match(r'^[a-zA-Z0-9-_.~%]{1,50}$', value): @@ -161,7 +204,7 @@ def check_analytics_label(cls, label, value): return value @classmethod - def check_boolean(cls, label, value): + def check_boolean(cls, label: str, value: Optional[Any]) -> Optional[bool]: """Checks if the given value is boolean.""" if value is None: return None @@ -170,7 +213,7 @@ def check_boolean(cls, label, value): return value @classmethod - def check_datetime(cls, label, value): + def check_datetime(cls, label: str, value: Optional[Any]) -> Optional[datetime.datetime]: """Checks if the given value is a datetime.""" if value is None: return None @@ -183,17 +226,20 @@ class MessageEncoder(json.JSONEncoder): """A custom ``JSONEncoder`` implementation for serializing Message instances into JSON.""" @classmethod - def remove_null_values(cls, dict_value): - return {k: v for k, v in dict_value.items() if v not in [None, [], {}]} + def remove_null_values(cls, dict_value: dict[_K, Optional[_V]]) -> dict[_K, _V]: + return {k: cast(_V, v) for k, v in dict_value.items() if v not in [None, [], {}]} @classmethod - def encode_android(cls, android): + def encode_android( + cls, + android: Optional[_messaging_utils.AndroidConfig], + ) -> Optional[dict[str, Any]]: """Encodes an ``AndroidConfig`` instance into JSON.""" if android is None: return None if not isinstance(android, _messaging_utils.AndroidConfig): raise ValueError('Message.android must be an instance of AndroidConfig class.') - result = { + result: dict[str, Any] = { 'collapse_key': _Validators.check_string( 'AndroidConfig.collapse_key', android.collapse_key), 'data': _Validators.check_string_dict( @@ -215,7 +261,10 @@ def encode_android(cls, android): return result @classmethod - def encode_android_fcm_options(cls, fcm_options): + def encode_android_fcm_options( + cls, + fcm_options: Optional[_messaging_utils.AndroidFCMOptions], + ) -> Optional[dict[str, str]]: """Encodes an ``AndroidFCMOptions`` instance into JSON.""" if fcm_options is None: return None @@ -230,12 +279,12 @@ def encode_android_fcm_options(cls, fcm_options): return result @classmethod - def encode_ttl(cls, ttl): + def encode_ttl(cls, ttl: Optional[Union[numbers.Real, datetime.timedelta]]) -> Optional[str]: """Encodes an ``AndroidConfig`` ``TTL`` duration into a string.""" if ttl is None: return None - if isinstance(ttl, numbers.Number): - ttl = datetime.timedelta(seconds=ttl) + if isinstance(ttl, numbers.Real): + ttl = datetime.timedelta(seconds=float(ttl)) if not isinstance(ttl, datetime.timedelta): raise ValueError('AndroidConfig.ttl must be a duration in seconds or an instance of ' 'datetime.timedelta.') @@ -249,12 +298,16 @@ def encode_ttl(cls, ttl): return f'{seconds}s' @classmethod - def encode_milliseconds(cls, label, msec): + def encode_milliseconds( + cls, + label: str, + msec: Optional[Union[numbers.Real, datetime.timedelta]], + ) -> Optional[str]: """Encodes a duration in milliseconds into a string.""" if msec is None: return None - if isinstance(msec, numbers.Number): - msec = datetime.timedelta(milliseconds=msec) + if isinstance(msec, numbers.Real): + msec = datetime.timedelta(milliseconds=float(msec)) if not isinstance(msec, datetime.timedelta): raise ValueError( f'{label} must be a duration in milliseconds or an instance of datetime.timedelta.') @@ -268,14 +321,17 @@ def encode_milliseconds(cls, label, msec): return f'{seconds}s' @classmethod - def encode_android_notification(cls, notification): + def encode_android_notification( + cls, + notification: Optional[_messaging_utils.AndroidNotification], + ) -> Optional[dict[str, Any]]: """Encodes an ``AndroidNotification`` instance into JSON.""" if notification is None: return None if not isinstance(notification, _messaging_utils.AndroidNotification): raise ValueError('AndroidConfig.notification must be an instance of ' 'AndroidNotification class.') - result = { + result: dict[str, Any] = { 'body': _Validators.check_string( 'AndroidNotification.body', notification.body), 'body_loc_args': _Validators.check_string_list( @@ -324,7 +380,7 @@ def encode_android_notification(cls, notification): 'AndroidNotification.proxy', notification.proxy, non_empty=True) } result = cls.remove_null_values(result) - color = result.get('color') + color: Optional[str] = result.get('color') if color and not re.match(r'^#[0-9a-fA-F]{6}$', color): raise ValueError( 'AndroidNotification.color must be in the form #RRGGBB.') @@ -335,7 +391,7 @@ def encode_android_notification(cls, notification): raise ValueError( 'AndroidNotification.title_loc_key is required when specifying title_loc_args.') - event_time = result.get('event_time') + event_time: Optional[datetime.datetime] = result.get('event_time') if event_time: # if the datetime instance is not naive (tzinfo is present), convert to UTC # otherwise (tzinfo is None) assume the datetime instance is already in UTC @@ -357,9 +413,9 @@ def encode_android_notification(cls, notification): 'AndroidNotification.visibility must be "private", "public" or "secret".') result['visibility'] = visibility.upper() - vibrate_timings_millis = result.get('vibrate_timings') + vibrate_timings_millis: Optional[list[Any]] = result.get('vibrate_timings') if vibrate_timings_millis: - vibrate_timing_strings = [] + vibrate_timing_strings: list[Optional[str]] = [] for msec in vibrate_timings_millis: formated_string = cls.encode_milliseconds( 'AndroidNotification.vibrate_timings_millis', msec) @@ -375,14 +431,17 @@ def encode_android_notification(cls, notification): return result @classmethod - def encode_light_settings(cls, light_settings): + def encode_light_settings( + cls, + light_settings: Optional[_messaging_utils.LightSettings], + ) -> Optional[dict[str, Any]]: """Encodes a ``LightSettings`` instance into JSON.""" if light_settings is None: return None if not isinstance(light_settings, _messaging_utils.LightSettings): raise ValueError( 'AndroidNotification.light_settings must be an instance of LightSettings class.') - result = { + result: dict[str, Any] = { 'color': _Validators.check_string( 'LightSettings.color', light_settings.color, non_empty=True), 'light_on_duration': cls.encode_milliseconds( @@ -416,7 +475,10 @@ def encode_light_settings(cls, light_settings): return result @classmethod - def encode_webpush(cls, webpush): + def encode_webpush( + cls, + webpush: Optional[_messaging_utils.WebpushConfig], + ) -> Optional[dict[str, Any]]: """Encodes a ``WebpushConfig`` instance into JSON.""" if webpush is None: return None @@ -433,14 +495,17 @@ def encode_webpush(cls, webpush): return cls.remove_null_values(result) @classmethod - def encode_webpush_notification(cls, notification): + def encode_webpush_notification( + cls, + notification: Optional[_messaging_utils.WebpushNotification], + ) -> Optional[dict[str, Any]]: """Encodes a ``WebpushNotification`` instance into JSON.""" if notification is None: return None if not isinstance(notification, _messaging_utils.WebpushNotification): raise ValueError('WebpushConfig.notification must be an instance of ' 'WebpushNotification class.') - result = { + result: dict[str, Any] = { 'actions': cls.encode_webpush_notification_actions(notification.actions), 'badge': _Validators.check_string( 'WebpushNotification.badge', notification.badge), @@ -480,14 +545,17 @@ def encode_webpush_notification(cls, notification): return cls.remove_null_values(result) @classmethod - def encode_webpush_notification_actions(cls, actions): + def encode_webpush_notification_actions( + cls, + actions: Optional[list[_messaging_utils.WebpushNotificationAction]], + ) -> Optional[list[dict[str, str]]]: """Encodes a list of ``WebpushNotificationActions`` into JSON.""" if actions is None: return None if not isinstance(actions, list): raise ValueError('WebpushConfig.notification.actions must be a list of ' 'WebpushNotificationAction instances.') - results = [] + results: list[dict[str, str]] = [] for action in actions: if not isinstance(action, _messaging_utils.WebpushNotificationAction): raise ValueError('WebpushConfig.notification.actions must be a list of ' @@ -504,7 +572,10 @@ def encode_webpush_notification_actions(cls, actions): return results @classmethod - def encode_webpush_fcm_options(cls, options): + def encode_webpush_fcm_options( + cls, + options: Optional[_messaging_utils.WebpushFCMOptions], + ) -> Optional[dict[str, str]]: """Encodes a ``WebpushFCMOptions`` instance into JSON.""" if options is None: return None @@ -518,13 +589,16 @@ def encode_webpush_fcm_options(cls, options): return result @classmethod - def encode_apns(cls, apns): + def encode_apns( + cls, + apns: Optional[_messaging_utils.APNSConfig], + ) -> Optional[dict[str, Any]]: """Encodes an ``APNSConfig`` instance into JSON.""" if apns is None: return None if not isinstance(apns, _messaging_utils.APNSConfig): raise ValueError('Message.apns must be an instance of APNSConfig class.') - result = { + result: dict[str, Any] = { 'headers': _Validators.check_string_dict( 'APNSConfig.headers', apns.headers), 'payload': cls.encode_apns_payload(apns.payload), @@ -535,13 +609,16 @@ def encode_apns(cls, apns): return cls.remove_null_values(result) @classmethod - def encode_apns_payload(cls, payload): + def encode_apns_payload( + cls, + payload: Optional[_messaging_utils.APNSPayload], + ) -> Optional[dict[str, Any]]: """Encodes an ``APNSPayload`` instance into JSON.""" if payload is None: return None if not isinstance(payload, _messaging_utils.APNSPayload): raise ValueError('APNSConfig.payload must be an instance of APNSPayload class.') - result = { + result: dict[str, Any] = { 'aps': cls.encode_aps(payload.aps) } for key, value in payload.custom_data.items(): @@ -549,7 +626,10 @@ def encode_apns_payload(cls, payload): return cls.remove_null_values(result) @classmethod - def encode_apns_fcm_options(cls, fcm_options): + def encode_apns_fcm_options( + cls, + fcm_options: Optional[_messaging_utils.APNSFCMOptions], + ) -> Optional[dict[str, str]]: """Encodes an ``APNSFCMOptions`` instance into JSON.""" if fcm_options is None: return None @@ -564,11 +644,11 @@ def encode_apns_fcm_options(cls, fcm_options): return result @classmethod - def encode_aps(cls, aps): + def encode_aps(cls, aps: _messaging_utils.Aps) -> dict[str, Any]: """Encodes an ``Aps`` instance into JSON.""" if not isinstance(aps, _messaging_utils.Aps): raise ValueError('APNSPayload.aps must be an instance of Aps class.') - result = { + result: dict[str, Any] = { 'alert': cls.encode_aps_alert(aps.alert), 'badge': _Validators.check_number('Aps.badge', aps.badge), 'sound': cls.encode_aps_sound(aps.sound), @@ -585,12 +665,15 @@ def encode_aps(cls, aps): for key, val in aps.custom_data.items(): _Validators.check_string('Aps.custom_data key', key) if key in result: - raise ValueError(f'Multiple specifications for {key} in Aps.') + raise ValueError('Multiple specifications for {0} in Aps.'.format(key)) result[key] = val return cls.remove_null_values(result) @classmethod - def encode_aps_sound(cls, sound): + def encode_aps_sound( + cls, + sound: Optional[Union[str, _messaging_utils.CriticalSound]], + ) -> Optional[Union[str, dict[str, Any]]]: """Encodes an APNs sound configuration into JSON.""" if sound is None: return None @@ -599,7 +682,7 @@ def encode_aps_sound(cls, sound): if not isinstance(sound, _messaging_utils.CriticalSound): raise ValueError( 'Aps.sound must be a non-empty string or an instance of CriticalSound class.') - result = { + result: dict[str, Any] = { 'name': _Validators.check_string('CriticalSound.name', sound.name, non_empty=True), 'volume': _Validators.check_number('CriticalSound.volume', sound.volume), } @@ -613,7 +696,10 @@ def encode_aps_sound(cls, sound): return cls.remove_null_values(result) @classmethod - def encode_aps_alert(cls, alert): + def encode_aps_alert( + cls, + alert: Optional[Union[_messaging_utils.ApsAlert, str]], + ) -> Optional[Union[str, dict[str, Any]]]: """Encodes an ``ApsAlert`` instance into JSON.""" if alert is None: return None @@ -655,7 +741,10 @@ def encode_aps_alert(cls, alert): return cls.remove_null_values(result) @classmethod - def encode_notification(cls, notification): + def encode_notification( + cls, + notification: Optional[_messaging_utils.Notification], + ) -> Optional[dict[str, str]]: """Encodes a ``Notification`` instance into JSON.""" if notification is None: return None @@ -669,7 +758,7 @@ def encode_notification(cls, notification): return cls.remove_null_values(result) @classmethod - def sanitize_topic_name(cls, topic): + def sanitize_topic_name(cls, topic: Optional[str]) -> Optional[str]: """Removes the /topics/ prefix from the topic name, if present.""" if not topic: return None @@ -681,7 +770,7 @@ def sanitize_topic_name(cls, topic): raise ValueError('Malformed topic name.') return topic - def default(self, o): # pylint: disable=method-hidden + def default(self, o: Any) -> dict[str, Any]: # pylint: disable=method-hidden if not isinstance(o, Message): return json.JSONEncoder.default(self, o) result = { @@ -704,7 +793,10 @@ def default(self, o): # pylint: disable=method-hidden return result @classmethod - def encode_fcm_options(cls, fcm_options): + def encode_fcm_options( + cls, + fcm_options: Optional[_messaging_utils.FCMOptions], + ) -> Optional[dict[str, str]]: """Encodes an ``FCMOptions`` instance into JSON.""" if fcm_options is None: return None diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index 8fd720701..3b942a31e 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -14,8 +14,43 @@ """Types and utilities used by the messaging (FCM) module.""" +import datetime +import numbers +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + +import httpx +import requests + from firebase_admin import exceptions +if TYPE_CHECKING: + from _typeshed import Incomplete +else: + Incomplete = Any + +__all__ = ( + 'APNSConfig', + 'APNSFCMOptions', + 'APNSPayload', + 'AndroidConfig', + 'AndroidFCMOptions', + 'AndroidNotification', + 'Aps', + 'ApsAlert', + 'CriticalSound', + 'FCMOptions', + 'LightSettings', + 'Notification', + 'QuotaExceededError', + 'SenderIdMismatchError', + 'ThirdPartyAuthError', + 'UnregisteredError', + 'WebpushConfig', + 'WebpushFCMOptions', + 'WebpushNotification', + 'WebpushNotificationAction', +) + class Notification: """A notification that can be included in a message. @@ -26,7 +61,12 @@ class Notification: image: Image url of the notification (optional) """ - def __init__(self, title=None, body=None, image=None): + def __init__( + self, + title: Optional[str] = None, + body: Optional[str] = None, + image: Optional[str] = None, + ) -> None: self.title = title self.body = body self.image = image @@ -53,8 +93,17 @@ class AndroidConfig: the app while the device is in direct boot mode (optional). """ - def __init__(self, collapse_key=None, priority=None, ttl=None, restricted_package_name=None, - data=None, notification=None, fcm_options=None, direct_boot_ok=None): + def __init__( + self, + collapse_key: Optional[str] = None, + priority: Optional[Literal['high', 'normal']] = None, + ttl: Optional[Union[numbers.Real, datetime.timedelta]] = None, + restricted_package_name: Optional[str] = None, + data: Optional[dict[str, str]] = None, + notification: Optional['AndroidNotification'] = None, + fcm_options: Optional['AndroidFCMOptions'] = None, + direct_boot_ok: Optional[bool] = None, + ) -> None: self.collapse_key = collapse_key self.priority = priority self.ttl = ttl @@ -153,13 +202,35 @@ class AndroidNotification: """ - def __init__(self, title=None, body=None, icon=None, color=None, sound=None, tag=None, - click_action=None, body_loc_key=None, body_loc_args=None, title_loc_key=None, - title_loc_args=None, channel_id=None, image=None, ticker=None, sticky=None, - event_timestamp=None, local_only=None, priority=None, vibrate_timings_millis=None, - default_vibrate_timings=None, default_sound=None, light_settings=None, - default_light_settings=None, visibility=None, notification_count=None, - proxy=None): + def __init__( + self, + title: Optional[str] = None, + body: Optional[str] = None, + icon: Optional[str] = None, + color: Optional[str] = None, + sound: Optional[str] = None, + tag: Optional[str] = None, + click_action: Optional[Incomplete] = None, + body_loc_key: Optional[str] = None, + body_loc_args: Optional[list[str]] = None, + title_loc_key: Optional[str] = None, + title_loc_args: Optional[list[str]] = None, + channel_id: Optional[Incomplete] = None, + image: Optional[str] = None, + ticker: Optional[Incomplete] = None, + sticky: Optional[bool] = None, + event_timestamp: Optional[datetime.datetime] = None, + local_only: Optional[Incomplete] = None, + priority: Optional[Literal['default', 'min', 'low', 'high', 'max', 'normal']] = None, + vibrate_timings_millis: Optional[float] = None, + default_vibrate_timings: Optional[bool] = None, + default_sound: Optional[bool] = None, + light_settings: Optional['LightSettings'] = None, + default_light_settings: Optional[bool] = None, + visibility: Optional[Literal['private', 'public', 'secret']] = None, + notification_count: Optional[int] = None, + proxy: Optional[Literal['allow', 'deny']] = None, + ) -> None: self.title = title self.body = body self.icon = icon @@ -199,8 +270,12 @@ class LightSettings: light_off_duration_millis: Along with ``light_on_duration``, defines the blink rate of LED flashes. """ - def __init__(self, color, light_on_duration_millis, - light_off_duration_millis): + def __init__( + self, + color: str, + light_on_duration_millis: Union[numbers.Real, datetime.timedelta], + light_off_duration_millis: Union[numbers.Real, datetime.timedelta], + ) -> None: self.color = color self.light_on_duration_millis = light_on_duration_millis self.light_off_duration_millis = light_off_duration_millis @@ -214,7 +289,7 @@ class AndroidFCMOptions: (optional). """ - def __init__(self, analytics_label=None): + def __init__(self, analytics_label: Optional[Incomplete] = None) -> None: self.analytics_label = analytics_label @@ -233,7 +308,13 @@ class WebpushConfig: .. _Webpush Specification: https://tools.ietf.org/html/rfc8030#section-5 """ - def __init__(self, headers=None, data=None, notification=None, fcm_options=None): + def __init__( + self, + headers: Optional[dict[str, str]] = None, + data: Optional[dict[str, str]] = None, + notification: Optional['WebpushNotification'] = None, + fcm_options: Optional['WebpushFCMOptions'] = None, + ) -> None: self.headers = headers self.data = data self.notification = notification @@ -249,7 +330,7 @@ class WebpushNotificationAction: icon: Icon URL for the action (optional). """ - def __init__(self, action, title, icon=None): + def __init__(self, action: str, title: str, icon: Optional[str] = None) -> None: self.action = action self.title = title self.icon = icon @@ -290,10 +371,25 @@ class WebpushNotification: /notification/Notification """ - def __init__(self, title=None, body=None, icon=None, actions=None, badge=None, data=None, - direction=None, image=None, language=None, renotify=None, - require_interaction=None, silent=None, tag=None, timestamp_millis=None, - vibrate=None, custom_data=None): + def __init__( + self, + title: Optional[str] = None, + body: Optional[str] = None, + icon: Optional[str] = None, + actions: Optional[list[WebpushNotificationAction]] = None, + badge: Optional[str] = None, + data: Optional[Any] = None, + direction: Optional[Literal['auto', 'ltr', 'rtl']] = None, + image: Optional[str] = None, + language: Optional[str] = None, + renotify: Optional[bool] = None, + require_interaction: Optional[bool] = None, + silent: Optional[bool] = None, + tag: Optional[str] = None, + timestamp_millis: Optional[int] = None, + vibrate: Optional[list[int]] = None, + custom_data: Optional[dict[str, Any]] = None, + ) -> None: self.title = title self.body = body self.icon = icon @@ -320,7 +416,7 @@ class WebpushFCMOptions: (optional). """ - def __init__(self, link=None): + def __init__(self, link: Optional[str] = None) -> None: self.link = link @@ -340,7 +436,13 @@ class APNSConfig: /NetworkingInternet/Conceptual/RemoteNotificationsPG/CommunicatingwithAPNs.html """ - def __init__(self, headers=None, payload=None, fcm_options=None, live_activity_token=None): + def __init__( + self, + headers: Optional[dict[str, str]] = None, + payload: Optional['APNSPayload'] = None, + fcm_options: Optional['APNSFCMOptions'] = None, + live_activity_token: Optional[str] = None, + ) -> None: self.headers = headers self.payload = payload self.fcm_options = fcm_options @@ -356,7 +458,7 @@ class APNSPayload: (optional). """ - def __init__(self, aps, **kwargs): + def __init__(self, aps: 'Aps', **kwargs: Any) -> None: self.aps = aps self.custom_data = kwargs @@ -379,8 +481,17 @@ class Aps: (optional). """ - def __init__(self, alert=None, badge=None, sound=None, content_available=None, category=None, - thread_id=None, mutable_content=None, custom_data=None): + def __init__( + self, + alert: Optional[Union['ApsAlert', str]] = None, + badge: Optional[float] = None, # should it be int? + sound: Optional[Union[str, 'CriticalSound']] = None, + content_available: Optional[bool] = None, + category: Optional[str] = None, + thread_id: Optional[str] = None, + mutable_content: Optional[bool] = None, + custom_data: Optional[dict[str, Any]] = None, + ) -> None: self.alert = alert self.badge = badge self.sound = sound @@ -404,7 +515,12 @@ class CriticalSound: and 1.0 (full volume) (optional). """ - def __init__(self, name, critical=None, volume=None): + def __init__( + self, + name: str, + critical: Optional[bool] = None, + volume: Optional[float] = None, + ) -> None: self.name = name self.critical = critical self.volume = volume @@ -434,9 +550,19 @@ class ApsAlert: (optional) """ - def __init__(self, title=None, subtitle=None, body=None, loc_key=None, loc_args=None, - title_loc_key=None, title_loc_args=None, action_loc_key=None, launch_image=None, - custom_data=None): + def __init__( + self, + title: Optional[str] = None, + subtitle: Optional[str] = None, + body: Optional[str] = None, + loc_key: Optional[str] = None, + loc_args: Optional[list[str]] = None, + title_loc_key: Optional[str] = None, + title_loc_args: Optional[list[str]] = None, + action_loc_key: Optional[str] = None, + launch_image: Optional[str] = None, + custom_data: Optional[dict[str, Any]] = None, + ) -> None: self.title = title self.subtitle = subtitle self.body = body @@ -459,7 +585,11 @@ class APNSFCMOptions: (optional). """ - def __init__(self, analytics_label=None, image=None): + def __init__( + self, + analytics_label: Optional[Incomplete] = None, + image: Optional[str] = None, + ) -> None: self.analytics_label = analytics_label self.image = image @@ -471,29 +601,44 @@ class FCMOptions: analytics_label: contains additional options to use across all platforms (optional). """ - def __init__(self, analytics_label=None): + def __init__(self, analytics_label: Optional[Incomplete] = None) -> None: self.analytics_label = analytics_label class ThirdPartyAuthError(exceptions.UnauthenticatedError): """APNs certificate or web push auth key was invalid or missing.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.UnauthenticatedError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class QuotaExceededError(exceptions.ResourceExhaustedError): """Sending limit exceeded for the message target.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class SenderIdMismatchError(exceptions.PermissionDeniedError): """The authenticated sender ID is different from the sender ID for the registration token.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.PermissionDeniedError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) class UnregisteredError(exceptions.NotFoundError): @@ -501,5 +646,10 @@ class UnregisteredError(exceptions.NotFoundError): This usually means that the token used is no longer valid and a new one must be used.""" - def __init__(self, message, cause=None, http_response=None): - exceptions.NotFoundError.__init__(self, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message, cause, http_response) diff --git a/firebase_admin/_retry.py b/firebase_admin/_retry.py index efd90a743..84c27ccf1 100644 --- a/firebase_admin/_retry.py +++ b/firebase_admin/_retry.py @@ -17,17 +17,22 @@ This module provides utilities for adding retry logic to HTTPX requests """ -from __future__ import annotations import copy import email.utils import random import re import time -from typing import Any, Callable, List, Optional, Tuple, Coroutine import logging +from collections.abc import Callable, Coroutine +from typing import Any, Optional + import asyncio +from typing_extensions import Self + import httpx +__all__ = ('HttpxRetry', 'HttpxRetryTransport') + logger = logging.getLogger(__name__) @@ -40,18 +45,18 @@ class HttpxRetry: DEFAULT_BACKOFF_MAX = 120 def __init__( - self, - max_retries: int = 10, - status_forcelist: Optional[List[int]] = None, - backoff_factor: float = 0, - backoff_max: float = DEFAULT_BACKOFF_MAX, - backoff_jitter: float = 0, - history: Optional[List[Tuple[ - httpx.Request, - Optional[httpx.Response], - Optional[Exception] - ]]] = None, - respect_retry_after_header: bool = False, + self, + max_retries: int = 10, + status_forcelist: Optional[list[int]] = None, + backoff_factor: float = 0, + backoff_max: float = DEFAULT_BACKOFF_MAX, + backoff_jitter: float = 0, + history: Optional[list[tuple[ + httpx.Request, + Optional[httpx.Response], + Optional[Exception] + ]]] = None, + respect_retry_after_header: bool = False, ) -> None: self.retries_left = max_retries self.status_forcelist = status_forcelist @@ -64,7 +69,7 @@ def __init__( self.history = [] self.respect_retry_after_header = respect_retry_after_header - def copy(self) -> HttpxRetry: + def copy(self) -> Self: """Creates a deep copy of this instance.""" return copy.deepcopy(self) @@ -89,7 +94,7 @@ def is_exhausted(self) -> bool: return self.retries_left < 0 # Identical implementation of `urllib3.Retry.parse_retry_after()` - def _parse_retry_after(self, retry_after_header: str) -> float | None: + def _parse_retry_after(self, retry_after_header: str) -> Optional[float]: """Parses Retry-After string into a float with unit seconds.""" seconds: float # Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4 @@ -107,7 +112,7 @@ def _parse_retry_after(self, retry_after_header: str) -> float | None: return seconds - def get_retry_after(self, response: httpx.Response) -> float | None: + def get_retry_after(self, response: httpx.Response) -> Optional[float]: """Determine the Retry-After time needed before sending the next request.""" retry_after_header = response.headers.get('Retry-After', None) if retry_after_header: @@ -115,7 +120,7 @@ def get_retry_after(self, response: httpx.Response) -> float | None: return self._parse_retry_after(retry_after_header) return None - def get_backoff_time(self): + def get_backoff_time(self) -> float: """Determine the backoff time needed before sending the next request.""" # attempt_count is the number of previous request attempts attempt_count = len(self.history) @@ -147,10 +152,10 @@ async def sleep(self, response: httpx.Response) -> None: await self.sleep_for_backoff() def increment( - self, - request: httpx.Request, - response: Optional[httpx.Response] = None, - error: Optional[Exception] = None + self, + request: httpx.Request, + response: Optional[httpx.Response] = None, + error: Optional[Exception] = None, ) -> None: """Update the retry state based on request attempt.""" self.retries_left -= 1 @@ -177,9 +182,9 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: request, self._wrapped_transport.handle_async_request) async def _dispatch_with_retry( - self, - request: httpx.Request, - dispatch_method: Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]] + self, + request: httpx.Request, + dispatch_method: Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]], ) -> httpx.Response: """Sends a request with retry logic using a provided dispatch method.""" # This request config is used across all requests that use this transport and therefore diff --git a/firebase_admin/_rfc3339.py b/firebase_admin/_rfc3339.py index 8489bdcb9..7132911e4 100644 --- a/firebase_admin/_rfc3339.py +++ b/firebase_admin/_rfc3339.py @@ -14,10 +14,13 @@ """Parse RFC3339 date strings""" -from datetime import datetime, timezone +import datetime import re -def parse_to_epoch(datestr): +__all__ = ('parse_to_epoch',) + + +def parse_to_epoch(datestr: str) -> float: """Parse an RFC3339 date string and return the number of seconds since the epoch (as a float). @@ -37,7 +40,7 @@ def parse_to_epoch(datestr): return _parse_to_datetime(datestr).timestamp() -def _parse_to_datetime(datestr): +def _parse_to_datetime(datestr: str) -> datetime.datetime: """Parse an RFC3339 date string and return a python datetime instance. Args: @@ -55,16 +58,16 @@ def _parse_to_datetime(datestr): # This format is the one we actually expect to occur from our backend. The # others are only present because the spec says we *should* accept them. try: - return datetime.strptime( + return datetime.datetime.strptime( datestr_modified, '%Y-%m-%dT%H:%M:%S.%fZ' - ).replace(tzinfo=timezone.utc) + ).replace(tzinfo=datetime.timezone.utc) except ValueError: pass try: - return datetime.strptime( + return datetime.datetime.strptime( datestr_modified, '%Y-%m-%dT%H:%M:%SZ' - ).replace(tzinfo=timezone.utc) + ).replace(tzinfo=datetime.timezone.utc) except ValueError: pass @@ -75,12 +78,12 @@ def _parse_to_datetime(datestr): datestr_modified = re.sub(r'(\d\d):(\d\d)$', r'\1\2', datestr_modified) try: - return datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S.%f%z') + return datetime.datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S.%f%z') except ValueError: pass try: - return datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S%z') + return datetime.datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S%z') except ValueError: pass diff --git a/firebase_admin/_sseclient.py b/firebase_admin/_sseclient.py index 3372fe5f2..ea0d5ac23 100644 --- a/firebase_admin/_sseclient.py +++ b/firebase_admin/_sseclient.py @@ -20,58 +20,67 @@ import re import time import warnings +from collections.abc import Iterator +from typing import Any, Optional +from typing_extensions import Self -from google.auth import transport +import google.auth.credentials +import google.auth.transport.requests import requests +__all__ = ( + 'Event', + 'KeepAuthSession', + 'SSEClient', +) # Technically, we should support streams that mix line endings. This regex, # however, assumes that a system will provide consistent line endings. end_of_field = re.compile(r'\r\n\r\n|\r\r|\n\n') -class KeepAuthSession(transport.requests.AuthorizedSession): +class KeepAuthSession(google.auth.transport.requests.AuthorizedSession): """A session that does not drop authentication on redirects between domains.""" - def __init__(self, credential): + def __init__(self, credential: Optional[google.auth.credentials.Credentials]) -> None: super().__init__(credential) - def rebuild_auth(self, prepared_request, response): + def rebuild_auth(self, prepared_request: requests.PreparedRequest, response: requests.Response) -> None: pass class _EventBuffer: """A helper class for buffering and parsing raw SSE data.""" - def __init__(self): - self._buffer = [] + def __init__(self) -> None: + self._buffer: list[str] = [] self._tail = '' - def append(self, char): + def append(self, char: str) -> None: self._buffer.append(char) self._tail += char self._tail = self._tail[-4:] - def truncate(self): + def truncate(self) -> None: head, sep, _ = self.buffer_string.rpartition('\n') rem = head + sep self._buffer = list(rem) self._tail = rem[-4:] @property - def is_end_of_field(self): + def is_end_of_field(self) -> bool: last_two_chars = self._tail[-2:] return last_two_chars == '\n\n' or last_two_chars == '\r\r' or self._tail == '\r\n\r\n' @property - def buffer_string(self): + def buffer_string(self) -> str: return ''.join(self._buffer) class SSEClient: """SSE client implementation.""" - def __init__(self, url, session, retry=3000, **kwargs): + def __init__(self, url: str, session: requests.Session, retry: int = 3000, **kwargs: Any) -> None: """Initializes the SSEClient. Args: @@ -85,7 +94,7 @@ def __init__(self, url, session, retry=3000, **kwargs): self.retry = retry self.requests_kwargs = kwargs self.should_connect = True - self.last_id = None + self.last_id: Optional[str] = None self.buf = '' # Keep data here as it streams in headers = self.requests_kwargs.get('headers', {}) @@ -96,13 +105,13 @@ def __init__(self, url, session, retry=3000, **kwargs): self.requests_kwargs['headers'] = headers self._connect() - def close(self): + def close(self) -> None: """Closes the SSEClient instance.""" self.should_connect = False self.retry = 0 self.resp.close() - def _connect(self): + def _connect(self) -> None: """Connects to the server using requests.""" if self.should_connect: if self.last_id: @@ -113,10 +122,10 @@ def _connect(self): else: raise StopIteration() - def __iter__(self): + def __iter__(self) -> Iterator[Optional['Event']]: return self - def __next__(self): + def __next__(self) -> Optional['Event']: if not re.search(end_of_field, self.buf): temp_buffer = _EventBuffer() while not temp_buffer.is_end_of_field: @@ -153,20 +162,29 @@ def __next__(self): self.last_id = event.event_id return event + def next(self) -> Optional['Event']: + return self.__next__() + class Event: """Event represents the events fired by SSE.""" sse_line_pattern = re.compile('(?P[^:]*):?( ?(?P.*))?') - def __init__(self, data='', event_type='message', event_id=None, retry=None): + def __init__( + self, + data: str = '', + event_type: str = 'message', + event_id: Optional[str] = None, + retry: Optional[int] = None, + ) -> None: self.data = data self.event_type = event_type self.event_id = event_id self.retry = retry @classmethod - def parse(cls, raw): + def parse(cls, raw: str) -> Self: """Given a possibly-multiline string representing an SSE message, parses it and returns an Event object. diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 1607ef0ba..22d2aed7e 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -16,21 +16,55 @@ import datetime import time +from collections.abc import Callable, Mapping +from typing import TYPE_CHECKING, Any, Optional, Union, cast import cachecontrol import requests from google.auth import credentials from google.auth import iam from google.auth import jwt -from google.auth import transport +import google.auth.transport.requests +import google.auth.crypt import google.auth.exceptions import google.oauth2.id_token import google.oauth2.service_account +import firebase_admin from firebase_admin import exceptions from firebase_admin import _auth_utils from firebase_admin import _http_client +if TYPE_CHECKING: + from _typeshed import Incomplete +else: + Incomplete = Any + +__all__ = ( + 'ALGORITHM_NONE', + 'ALGORITHM_RS256', + 'AUTH_EMULATOR_EMAIL', + 'COOKIE_CERT_URI', + 'COOKIE_ISSUER_PREFIX', + 'FIREBASE_AUDIENCE', + 'ID_TOKEN_CERT_URI', + 'ID_TOKEN_ISSUER_PREFIX', + 'MAX_SESSION_COOKIE_DURATION_SECONDS', + 'MAX_TOKEN_LIFETIME_SECONDS', + 'METADATA_SERVICE_URL', + 'MIN_SESSION_COOKIE_DURATION_SECONDS', + 'RESERVED_CLAIMS', + 'CertificateFetchRequest', + 'TokenGenerator', + 'TokenSignError', + 'TokenVerifier', + 'CertificateFetchError', + 'ExpiredIdTokenError', + 'ExpiredSessionCookieError', + 'InvalidSessionCookieError', + 'RevokedIdTokenError', + 'RevokedSessionCookieError', +) # ID token constants ID_TOKEN_ISSUER_PREFIX = 'https://securetoken.google.com/' @@ -61,19 +95,26 @@ class _EmulatedSigner(google.auth.crypt.Signer): - key_id = None + @property + def key_id(self) -> Optional[str]: + return None - def __init__(self): + def __init__(self) -> None: pass - def sign(self, message): + def sign(self, message: Union[str, bytes]) -> bytes: return b'' class _SigningProvider: """Stores a reference to a google.auth.crypto.Signer.""" - def __init__(self, signer, signer_email, alg=ALGORITHM_RS256): + def __init__( + self, + signer: google.auth.crypt.Signer, + signer_email: Optional[str], + alg: str = ALGORITHM_RS256, + ) -> None: self._signer = signer self._signer_email = signer_email self._alg = alg @@ -87,20 +128,28 @@ def signer_email(self): return self._signer_email @property - def alg(self): + def alg(self) -> str: return self._alg @classmethod - def from_credential(cls, google_cred): + def from_credential( + cls, + google_cred: Union[google.oauth2.service_account.Credentials, credentials.Signing] + ) -> '_SigningProvider': return _SigningProvider(google_cred.signer, google_cred.signer_email) @classmethod - def from_iam(cls, request, google_cred, service_account): + def from_iam( + cls, + request: google.auth.transport.Request, + google_cred: credentials.Credentials, + service_account: str, + ) -> '_SigningProvider': signer = iam.Signer(request, google_cred, service_account) return _SigningProvider(signer, service_account) @classmethod - def for_emulator(cls): + def for_emulator(cls) -> '_SigningProvider': return _SigningProvider(_EmulatedSigner(), AUTH_EMULATOR_EMAIL, ALGORITHM_NONE) @@ -109,15 +158,20 @@ class TokenGenerator: ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' - def __init__(self, app, http_client, url_override=None): + def __init__( + self, + app: firebase_admin.App, + http_client: _http_client.HttpClient[dict[str, Any]], + url_override: Optional[str] = None, + ) -> None: self.app = app self.http_client = http_client - self.request = transport.requests.Request() + self.request = google.auth.transport.requests.Request() url_prefix = url_override or self.ID_TOOLKIT_URL self.base_url = f'{url_prefix}/projects/{app.project_id}' - self._signing_provider = None + self._signing_provider: Optional[_SigningProvider] = None - def _init_signing_provider(self): + def _init_signing_provider(self) -> _SigningProvider: """Initializes a signing provider by following the go/firebase-admin-sign protocol.""" if _auth_utils.is_emulated(): return _SigningProvider.for_emulator() @@ -143,11 +197,11 @@ def _init_signing_provider(self): if resp.status != 200: raise ValueError( f'Failed to contact the local metadata service: {resp.data.decode()}.') - service_account = resp.data.decode() + service_account = cast(str, resp.data.decode()) return _SigningProvider.from_iam(self.request, google_cred, service_account) @property - def signing_provider(self): + def signing_provider(self) -> _SigningProvider: """Initializes and returns the SigningProvider instance to be used.""" if not self._signing_provider: try: @@ -161,7 +215,12 @@ def signing_provider(self): 'details on creating custom tokens.') from error return self._signing_provider - def create_custom_token(self, uid, developer_claims=None, tenant_id=None): + def create_custom_token( + self, + uid: str, + developer_claims: Optional[dict[str, Any]] = None, + tenant_id: Optional[str] = None, + ) -> bytes: """Builds and signs a Firebase custom auth token.""" if developer_claims is not None: if not isinstance(developer_claims, dict): @@ -184,7 +243,7 @@ def create_custom_token(self, uid, developer_claims=None, tenant_id=None): signing_provider = self.signing_provider now = int(time.time()) - payload = { + payload: dict[str, Any] = { 'iss': signing_provider.signer_email, 'sub': signing_provider.signer_email, 'aud': FIREBASE_AUDIENCE, @@ -206,7 +265,11 @@ def create_custom_token(self, uid, developer_claims=None, tenant_id=None): raise TokenSignError(msg, error) from error - def create_session_cookie(self, id_token, expires_in): + def create_session_cookie( + self, + id_token: Union[bytes, str], + expires_in: Union[datetime.timedelta, int], + ) -> str: """Creates a session cookie from the provided ID token.""" id_token = id_token.decode('utf-8') if isinstance(id_token, bytes) else id_token if not isinstance(id_token, str) or not id_token: @@ -238,38 +301,47 @@ def create_session_cookie(self, id_token, expires_in): if not body or not body.get('sessionCookie'): raise _auth_utils.UnexpectedResponseError( 'Failed to create session cookie.', http_response=http_resp) - return body.get('sessionCookie') + return cast(str, body['sessionCookie']) -class CertificateFetchRequest(transport.Request): +class CertificateFetchRequest(google.auth.transport.Request): """A google-auth transport that supports HTTP cache-control. Also injects a timeout to each outgoing HTTP request. """ - def __init__(self, timeout_seconds=None): + def __init__(self, timeout_seconds: Optional[float] = None) -> None: self._session = cachecontrol.CacheControl(requests.Session()) - self._delegate = transport.requests.Request(self.session) + self._delegate = google.auth.transport.requests.Request(self.session) self._timeout_seconds = timeout_seconds @property - def session(self): + def session(self) -> requests.Session: return self._session @property - def timeout_seconds(self): + def timeout_seconds(self) -> Optional[float]: return self._timeout_seconds - def __call__(self, url, method='GET', body=None, headers=None, timeout=None, **kwargs): + def __call__( + self, + url: str, + method: str = 'GET', + body: Optional[Incomplete] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + **kwargs: Incomplete, + ) -> google.auth.transport.Response: timeout = timeout or self.timeout_seconds return self._delegate( - url, method=method, body=body, headers=headers, timeout=timeout, **kwargs) + url, method=method, body=body, headers=headers, + timeout=timeout, **kwargs) # pyright: ignore[reportArgumentType] class TokenVerifier: """Verifies ID tokens and session cookies.""" - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self.request = CertificateFetchRequest(timeout) self.id_token_verifier = _JWTVerifier( @@ -289,31 +361,56 @@ def __init__(self, app): invalid_token_error=InvalidSessionCookieError, expired_token_error=ExpiredSessionCookieError) - def verify_id_token(self, id_token, clock_skew_seconds=0): + def verify_id_token( + self, + id_token: Union[bytes, str], + clock_skew_seconds: int = 0, + ) -> dict[str, Any]: return self.id_token_verifier.verify(id_token, self.request, clock_skew_seconds) - def verify_session_cookie(self, cookie, clock_skew_seconds=0): + def verify_session_cookie( + self, + cookie: Union[bytes, str], + clock_skew_seconds: int = 0, + ) -> dict[str, Any]: return self.cookie_verifier.verify(cookie, self.request, clock_skew_seconds) class _JWTVerifier: """Verifies Firebase JWTs (ID tokens or session cookies).""" - def __init__(self, **kwargs): - self.project_id = kwargs.pop('project_id') - self.short_name = kwargs.pop('short_name') - self.operation = kwargs.pop('operation') - self.url = kwargs.pop('doc_url') - self.cert_url = kwargs.pop('cert_url') - self.issuer = kwargs.pop('issuer') + def __init__( + self, + *, + project_id: Optional[str], + short_name: str, + operation: str, + doc_url: str, + cert_url: str, + issuer: str, + invalid_token_error: Callable[[str, Optional[Exception]], exceptions.FirebaseError], + expired_token_error: Callable[[str, Optional[Exception]], exceptions.FirebaseError], + **kwargs: Any, + ) -> None: + self.project_id = project_id + self.short_name = short_name + self.operation = operation + self.url = doc_url + self.cert_url = cert_url + self.issuer = issuer if self.short_name[0].lower() in 'aeiou': self.articled_short_name = f'an {self.short_name}' else: self.articled_short_name = f'a {self.short_name}' - self._invalid_token_error = kwargs.pop('invalid_token_error') - self._expired_token_error = kwargs.pop('expired_token_error') - - def verify(self, token, request, clock_skew_seconds=0): + self._invalid_token_error = invalid_token_error + self._expired_token_error = expired_token_error + + def verify( + self, + token: Union[bytes, str], + request: google.auth.transport.Request, + clock_skew_seconds: int = 0, + ) -> dict[str, Any]: """Verifies the signature and data for the provided JWT.""" token = token.encode('utf-8') if isinstance(token, str) else token if not isinstance(token, bytes) or not token: @@ -387,7 +484,7 @@ def verify(self, token, request, clock_skew_seconds=0): f'characters. {verify_id_token_msg}') if error_message: - raise self._invalid_token_error(error_message) + raise self._invalid_token_error(error_message, None) try: if emulated: @@ -399,68 +496,72 @@ def verify(self, token, request, clock_skew_seconds=0): audience=self.project_id, certs_url=self.cert_url, clock_skew_in_seconds=clock_skew_seconds) + verified_claims = cast(dict[str, Any], verified_claims) verified_claims['uid'] = verified_claims['sub'] return verified_claims except google.auth.exceptions.TransportError as error: raise CertificateFetchError(str(error), cause=error) from error except ValueError as error: if 'Token expired' in str(error): - raise self._expired_token_error(str(error), cause=error) - raise self._invalid_token_error(str(error), cause=error) + raise self._expired_token_error(str(error), error) + raise self._invalid_token_error(str(error), error) - def _decode_unverified(self, token): + def _decode_unverified( + self, + token: Union[bytes, str], + ) -> tuple[dict[str, str], dict[str, Any]]: try: - header = jwt.decode_header(token) - payload = jwt.decode(token, verify=False) - return header, payload + header = cast(Mapping[str, str], jwt.decode_header(token)) + payload = cast(Mapping[str, Any], jwt.decode(token, verify=False)) + return dict(header), dict(payload) except ValueError as error: - raise self._invalid_token_error(str(error), cause=error) + raise self._invalid_token_error(str(error), error) class TokenSignError(exceptions.UnknownError): """Unexpected error while signing a Firebase custom token.""" - def __init__(self, message, cause): - exceptions.UnknownError.__init__(self, message, cause) + def __init__(self, message: str, cause: Optional[Exception]) -> None: + super().__init__(message, cause) class CertificateFetchError(exceptions.UnknownError): """Failed to fetch some public key certificates required to verify a token.""" - def __init__(self, message, cause): - exceptions.UnknownError.__init__(self, message, cause) + def __init__(self, message: str, cause: Optional[Exception]) -> None: + super().__init__(message, cause) class ExpiredIdTokenError(_auth_utils.InvalidIdTokenError): """The provided ID token is expired.""" - def __init__(self, message, cause): - _auth_utils.InvalidIdTokenError.__init__(self, message, cause) + def __init__(self, message: str, cause: Optional[Exception]) -> None: + super().__init__(message, cause) class RevokedIdTokenError(_auth_utils.InvalidIdTokenError): """The provided ID token has been revoked.""" - def __init__(self, message): - _auth_utils.InvalidIdTokenError.__init__(self, message) + def __init__(self, message: str) -> None: + super().__init__(message) class InvalidSessionCookieError(exceptions.InvalidArgumentError): """The provided string is not a valid Firebase session cookie.""" - def __init__(self, message, cause=None): - exceptions.InvalidArgumentError.__init__(self, message, cause) + def __init__(self, message: str, cause: Optional[Exception] = None) -> None: + super().__init__(message, cause) class ExpiredSessionCookieError(InvalidSessionCookieError): """The provided session cookie is expired.""" - def __init__(self, message, cause): - InvalidSessionCookieError.__init__(self, message, cause) + def __init__(self, message: str, cause: Optional[Exception]) -> None: + super().__init__(message, cause) class RevokedSessionCookieError(InvalidSessionCookieError): """The provided session cookie has been revoked.""" - def __init__(self, message): - InvalidSessionCookieError.__init__(self, message) + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/firebase_admin/_user_identifier.py b/firebase_admin/_user_identifier.py index 85a224e0b..37ac388b7 100644 --- a/firebase_admin/_user_identifier.py +++ b/firebase_admin/_user_identifier.py @@ -16,6 +16,15 @@ from firebase_admin import _auth_utils +__all__ = ( + 'EmailIdentifier', + 'PhoneIdentifier', + 'ProviderIdentifier', + 'UidIdentifier', + 'UserIdentifier', +) + + class UserIdentifier: """Identifies a user to be looked up.""" @@ -26,7 +35,7 @@ class UidIdentifier(UserIdentifier): See ``auth.get_user()``. """ - def __init__(self, uid): + def __init__(self, uid: str) -> None: """Constructs a new `UidIdentifier` object. Args: @@ -35,7 +44,7 @@ def __init__(self, uid): self._uid = _auth_utils.validate_uid(uid, required=True) @property - def uid(self): + def uid(self) -> str: return self._uid @@ -45,7 +54,7 @@ class EmailIdentifier(UserIdentifier): See ``auth.get_user()``. """ - def __init__(self, email): + def __init__(self, email: str) -> None: """Constructs a new `EmailIdentifier` object. Args: @@ -54,7 +63,7 @@ def __init__(self, email): self._email = _auth_utils.validate_email(email, required=True) @property - def email(self): + def email(self) -> str: return self._email @@ -64,7 +73,7 @@ class PhoneIdentifier(UserIdentifier): See ``auth.get_user()``. """ - def __init__(self, phone_number): + def __init__(self, phone_number: str) -> None: """Constructs a new `PhoneIdentifier` object. Args: @@ -73,7 +82,7 @@ def __init__(self, phone_number): self._phone_number = _auth_utils.validate_phone(phone_number, required=True) @property - def phone_number(self): + def phone_number(self) -> str: return self._phone_number @@ -83,21 +92,21 @@ class ProviderIdentifier(UserIdentifier): See ``auth.get_user()``. """ - def __init__(self, provider_id, provider_uid): + def __init__(self, provider_id: str, provider_uid: str) -> None: """Constructs a new `ProviderIdentifier` object. -   Args: -     provider_id: A provider ID string. -     provider_uid: A provider UID string. + Args: + provider_id: A provider ID string. + provider_uid: A provider UID string. """ self._provider_id = _auth_utils.validate_provider_id(provider_id, required=True) self._provider_uid = _auth_utils.validate_provider_uid( provider_uid, required=True) @property - def provider_id(self): + def provider_id(self) -> str: return self._provider_id @property - def provider_uid(self): + def provider_uid(self) -> str: return self._provider_uid diff --git a/firebase_admin/_user_import.py b/firebase_admin/_user_import.py index 7c7a9e70b..b2ef1421e 100644 --- a/firebase_admin/_user_import.py +++ b/firebase_admin/_user_import.py @@ -16,11 +16,22 @@ import base64 import json +from typing import Any, Optional, cast from firebase_admin import _auth_utils +from firebase_admin import _user_mgt +__all__ = ( + 'ErrorInfo', + 'ImportUserRecord', + 'UserImportHash', + 'UserImportResult', + 'UserProvider', + 'b64_encode', +) -def b64_encode(bytes_value): + +def b64_encode(bytes_value: bytes) -> str: return base64.urlsafe_b64encode(bytes_value).decode() @@ -39,7 +50,14 @@ class UserProvider: photo_url: User's photo URL (optional). """ - def __init__(self, uid, provider_id, email=None, display_name=None, photo_url=None): + def __init__( + self, + uid: str, + provider_id: str, + email: Optional[str] = None, + display_name: Optional[str] = None, + photo_url: Optional[str] = None, + ) -> None: self.uid = uid self.provider_id = provider_id self.email = email @@ -47,46 +65,46 @@ def __init__(self, uid, provider_id, email=None, display_name=None, photo_url=No self.photo_url = photo_url @property - def uid(self): + def uid(self) -> str: return self._uid @uid.setter - def uid(self, uid): + def uid(self, uid: str) -> None: self._uid = _auth_utils.validate_uid(uid, required=True) @property - def provider_id(self): + def provider_id(self) -> str: return self._provider_id @provider_id.setter - def provider_id(self, provider_id): + def provider_id(self, provider_id: str) -> None: self._provider_id = _auth_utils.validate_provider_id(provider_id, required=True) @property - def email(self): + def email(self) -> Optional[str]: return self._email @email.setter - def email(self, email): + def email(self, email: Optional[str]) -> None: self._email = _auth_utils.validate_email(email) @property - def display_name(self): + def display_name(self) -> Optional[str]: return self._display_name @display_name.setter - def display_name(self, display_name): + def display_name(self, display_name: Optional[str]) -> None: self._display_name = _auth_utils.validate_display_name(display_name) @property - def photo_url(self): + def photo_url(self) -> Optional[str]: return self._photo_url @photo_url.setter - def photo_url(self, photo_url): + def photo_url(self, photo_url: Optional[str]): self._photo_url = _auth_utils.validate_photo_url(photo_url) - def to_dict(self): + def to_dict(self) -> dict[str, str]: payload = { 'rawId': self.uid, 'providerId': self.provider_id, @@ -123,9 +141,21 @@ class ImportUserRecord: ValueError: If provided arguments are invalid. """ - def __init__(self, uid, email=None, email_verified=None, display_name=None, phone_number=None, - photo_url=None, disabled=None, user_metadata=None, provider_data=None, - custom_claims=None, password_hash=None, password_salt=None): + def __init__( + self, + uid: str, + email: Optional[str] = None, + email_verified: Optional[bool] = None, + display_name: Optional[str] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + disabled: Optional[bool] = None, + user_metadata: Optional['_user_mgt.UserMetadata'] = None, + provider_data: Optional[list[UserProvider]] = None, + custom_claims: Optional[dict[str, Any]] = None, + password_hash: Optional[bytes] = None, + password_salt: Optional[bytes] = None, + ) -> None: self.uid = uid self.email = email self.display_name = display_name @@ -140,67 +170,67 @@ def __init__(self, uid, email=None, email_verified=None, display_name=None, phon self.custom_claims = custom_claims @property - def uid(self): + def uid(self) -> str: return self._uid @uid.setter - def uid(self, uid): + def uid(self, uid: str) -> None: self._uid = _auth_utils.validate_uid(uid, required=True) @property - def email(self): + def email(self) -> Optional[str]: return self._email @email.setter - def email(self, email): + def email(self, email: Optional[str]) -> None: self._email = _auth_utils.validate_email(email) @property - def display_name(self): + def display_name(self) -> Optional[str]: return self._display_name @display_name.setter - def display_name(self, display_name): + def display_name(self, display_name: Optional[str]) -> None: self._display_name = _auth_utils.validate_display_name(display_name) @property - def phone_number(self): + def phone_number(self) -> Optional[str]: return self._phone_number @phone_number.setter - def phone_number(self, phone_number): + def phone_number(self, phone_number: Optional[str]) -> None: self._phone_number = _auth_utils.validate_phone(phone_number) @property - def photo_url(self): + def photo_url(self) -> Optional[str]: return self._photo_url @photo_url.setter - def photo_url(self, photo_url): + def photo_url(self, photo_url: Optional[str]) -> None: self._photo_url = _auth_utils.validate_photo_url(photo_url) @property - def password_hash(self): + def password_hash(self) -> Optional[bytes]: return self._password_hash @password_hash.setter - def password_hash(self, password_hash): + def password_hash(self, password_hash: Optional[bytes]) -> None: self._password_hash = _auth_utils.validate_bytes(password_hash, 'password_hash') @property - def password_salt(self): + def password_salt(self) -> Optional[bytes]: return self._password_salt @password_salt.setter - def password_salt(self, password_salt): + def password_salt(self, password_salt: Optional[bytes]) -> None: self._password_salt = _auth_utils.validate_bytes(password_salt, 'password_salt') @property - def user_metadata(self): + def user_metadata(self) -> Optional['_user_mgt.UserMetadata']: return self._user_metadata @user_metadata.setter - def user_metadata(self, user_metadata): + def user_metadata(self, user_metadata: Optional['_user_mgt.UserMetadata']) -> None: created_at = user_metadata.creation_timestamp if user_metadata is not None else None last_login_at = user_metadata.last_sign_in_timestamp if user_metadata is not None else None self._created_at = _auth_utils.validate_timestamp(created_at, 'creation_timestamp') @@ -209,11 +239,11 @@ def user_metadata(self, user_metadata): self._user_metadata = user_metadata @property - def provider_data(self): + def provider_data(self) -> Optional[list[UserProvider]]: return self._provider_data @provider_data.setter - def provider_data(self, provider_data): + def provider_data(self, provider_data: Optional[list[UserProvider]]) -> None: if provider_data is not None: try: if any(not isinstance(p, UserProvider) for p in provider_data): @@ -223,19 +253,19 @@ def provider_data(self, provider_data): self._provider_data = provider_data @property - def custom_claims(self): + def custom_claims(self) -> Optional[dict[str, Any]]: return self._custom_claims @custom_claims.setter - def custom_claims(self, custom_claims): + def custom_claims(self, custom_claims: Optional[dict[str, Any]]) -> None: json_claims = json.dumps(custom_claims) if isinstance( custom_claims, dict) else custom_claims self._custom_claims_str = _auth_utils.validate_custom_claims(json_claims) self._custom_claims = custom_claims - def to_dict(self): + def to_dict(self) -> dict[str, Any]: """Returns a dict representation of the user. For internal use only.""" - payload = { + payload: dict[str, Any] = { 'localId': self.uid, 'email': self.email, 'displayName': self.display_name, @@ -265,25 +295,25 @@ class UserImportHash: .. _documentation: https://firebase.google.com/docs/auth/admin/import-users """ - def __init__(self, name, data=None): + def __init__(self, name: str, data: Optional[dict[str, Any]] = None) -> None: self._name = name self._data = data - def to_dict(self): - payload = {'hashAlgorithm': self._name} + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = {'hashAlgorithm': self._name} if self._data: payload.update(self._data) return payload @classmethod - def _hmac(cls, name, key): + def _hmac(cls, name: str, key: bytes) -> 'UserImportHash': data = { 'signerKey': b64_encode(_auth_utils.validate_bytes(key, 'key', required=True)) } return UserImportHash(name, data) @classmethod - def hmac_sha512(cls, key): + def hmac_sha512(cls, key: bytes) -> 'UserImportHash': """Creates a new HMAC SHA512 algorithm instance. Args: @@ -295,7 +325,7 @@ def hmac_sha512(cls, key): return cls._hmac('HMAC_SHA512', key) @classmethod - def hmac_sha256(cls, key): + def hmac_sha256(cls, key: bytes) -> 'UserImportHash': """Creates a new HMAC SHA256 algorithm instance. Args: @@ -307,7 +337,7 @@ def hmac_sha256(cls, key): return cls._hmac('HMAC_SHA256', key) @classmethod - def hmac_sha1(cls, key): + def hmac_sha1(cls, key: bytes) -> 'UserImportHash': """Creates a new HMAC SHA1 algorithm instance. Args: @@ -319,7 +349,7 @@ def hmac_sha1(cls, key): return cls._hmac('HMAC_SHA1', key) @classmethod - def hmac_md5(cls, key): + def hmac_md5(cls, key: bytes) -> 'UserImportHash': """Creates a new HMAC MD5 algorithm instance. Args: @@ -331,7 +361,7 @@ def hmac_md5(cls, key): return cls._hmac('HMAC_MD5', key) @classmethod - def md5(cls, rounds): + def md5(cls, rounds: int) -> 'UserImportHash': """Creates a new MD5 algorithm instance. Args: @@ -345,7 +375,7 @@ def md5(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 8192)}) @classmethod - def sha1(cls, rounds): + def sha1(cls, rounds: int) -> 'UserImportHash': """Creates a new SHA1 algorithm instance. Args: @@ -359,7 +389,7 @@ def sha1(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8192)}) @classmethod - def sha256(cls, rounds): + def sha256(cls, rounds: int) -> 'UserImportHash': """Creates a new SHA256 algorithm instance. Args: @@ -373,7 +403,7 @@ def sha256(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8192)}) @classmethod - def sha512(cls, rounds): + def sha512(cls, rounds: int) -> 'UserImportHash': """Creates a new SHA512 algorithm instance. Args: @@ -387,7 +417,7 @@ def sha512(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8192)}) @classmethod - def pbkdf_sha1(cls, rounds): + def pbkdf_sha1(cls, rounds: int) -> 'UserImportHash': """Creates a new PBKDF SHA1 algorithm instance. Args: @@ -401,7 +431,7 @@ def pbkdf_sha1(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 120000)}) @classmethod - def pbkdf2_sha256(cls, rounds): + def pbkdf2_sha256(cls, rounds: int) -> 'UserImportHash': """Creates a new PBKDF2 SHA256 algorithm instance. Args: @@ -415,7 +445,13 @@ def pbkdf2_sha256(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 120000)}) @classmethod - def scrypt(cls, key, rounds, memory_cost, salt_separator=None): + def scrypt( + cls, + key: bytes, + rounds: int, + memory_cost: int, + salt_separator: Optional[bytes] = None, + ) -> 'UserImportHash': """Creates a new Scrypt algorithm instance. This is the modified Scrypt algorithm used by Firebase Auth. See ``standard_scrypt()`` @@ -430,18 +466,18 @@ def scrypt(cls, key, rounds, memory_cost, salt_separator=None): Returns: UserImportHash: A new ``UserImportHash``. """ - data = { + data: dict[str, Any] = { 'signerKey': b64_encode(_auth_utils.validate_bytes(key, 'key', required=True)), 'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8), 'memoryCost': _auth_utils.validate_int(memory_cost, 'memory_cost', 1, 14), } if salt_separator: data['saltSeparator'] = b64_encode(_auth_utils.validate_bytes( - salt_separator, 'salt_separator')) + salt_separator, 'salt_separator', True)) return UserImportHash('SCRYPT', data) @classmethod - def bcrypt(cls): + def bcrypt(cls) -> 'UserImportHash': """Creates a new Bcrypt algorithm instance. Returns: @@ -450,7 +486,13 @@ def bcrypt(cls): return UserImportHash('BCRYPT') @classmethod - def standard_scrypt(cls, memory_cost, parallelization, block_size, derived_key_length): + def standard_scrypt( + cls, + memory_cost: int, + parallelization: int, + block_size: int, + derived_key_length: int, + ) -> 'UserImportHash': """Creates a new standard Scrypt algorithm instance. Args: @@ -479,16 +521,16 @@ class ErrorInfo: # it's home in _user_import.py). It's now also used by bulk deletion of # users. Move this to a more common location. - def __init__(self, error): - self._index = error['index'] - self._reason = error['message'] + def __init__(self, error: dict[str, Any]) -> None: + self._index = cast(int, error['index']) + self._reason = cast(str, error['message']) @property - def index(self): + def index(self) -> int: return self._index @property - def reason(self): + def reason(self) -> str: return self._reason @@ -498,23 +540,23 @@ class UserImportResult: See ``auth.import_users()`` API for more details. """ - def __init__(self, result, total): + def __init__(self, result: dict[str, Any], total: int) -> None: errors = result.get('error', []) self._success_count = total - len(errors) self._failure_count = len(errors) self._errors = [ErrorInfo(err) for err in errors] @property - def success_count(self): + def success_count(self) -> int: """Returns the number of users successfully imported.""" return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: """Returns the number of users that failed to be imported.""" return self._failure_count @property - def errors(self): + def errors(self) -> list[ErrorInfo]: """Returns a list of ``auth.ErrorInfo`` instances describing the errors encountered.""" return self._errors diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 9a75b7a2e..9a26c9773 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -15,18 +15,42 @@ """Firebase user management sub module.""" import base64 -from collections import defaultdict +import collections import json +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from urllib import parse import requests from firebase_admin import _auth_utils +from firebase_admin import _http_client from firebase_admin import _rfc3339 from firebase_admin import _user_identifier from firebase_admin import _user_import -from firebase_admin._user_import import ErrorInfo +if TYPE_CHECKING: + from _typeshed import ConvertibleToInt + +__all__ = ( + 'B64_REDACTED', + 'DELETE_ATTRIBUTE', + 'MAX_IMPORT_USERS_SIZE', + 'MAX_LIST_USERS_RESULTS', + 'ActionCodeSettings', + 'BatchDeleteAccountsResponse', + 'DeleteUsersResult', + 'ExportedUserRecord', + 'GetUsersResult', + 'ListUsersPage', + 'ProviderUserInfo', + 'Sentinel', + 'UserInfo', + 'UserManager', + 'UserMetadata', + 'UserRecord', + 'encode_action_code_settings', +) MAX_LIST_USERS_RESULTS = 1000 MAX_IMPORT_USERS_SIZE = 1000 @@ -34,19 +58,22 @@ class Sentinel: - - def __init__(self, description): + def __init__(self, description: str) -> None: self.description = description -DELETE_ATTRIBUTE = Sentinel('Value used to delete an attribute from a user profile') +DELETE_ATTRIBUTE: Any = Sentinel('Value used to delete an attribute from a user profile') class UserMetadata: """Contains additional metadata associated with a user account.""" - def __init__(self, creation_timestamp=None, last_sign_in_timestamp=None, - last_refresh_timestamp=None): + def __init__( + self, + creation_timestamp: Optional[Any] = None, + last_sign_in_timestamp: Optional[Any] = None, + last_refresh_timestamp: Optional[Any] = None, + ) -> None: self._creation_timestamp = _auth_utils.validate_timestamp( creation_timestamp, 'creation_timestamp') self._last_sign_in_timestamp = _auth_utils.validate_timestamp( @@ -55,7 +82,7 @@ def __init__(self, creation_timestamp=None, last_sign_in_timestamp=None, last_refresh_timestamp, 'last_refresh_timestamp') @property - def creation_timestamp(self): + def creation_timestamp(self) -> Optional[int]: """ Creation timestamp in milliseconds since the epoch. Returns: @@ -64,7 +91,7 @@ def creation_timestamp(self): return self._creation_timestamp @property - def last_sign_in_timestamp(self): + def last_sign_in_timestamp(self) -> Optional[int]: """ Last sign in timestamp in milliseconds since the epoch. Returns: @@ -73,7 +100,7 @@ def last_sign_in_timestamp(self): return self._last_sign_in_timestamp @property - def last_refresh_timestamp(self): + def last_refresh_timestamp(self) -> Optional[int]: """The time at which the user was last active (ID token refreshed). Returns: @@ -90,32 +117,32 @@ class UserInfo: """ @property - def uid(self): + def uid(self) -> str: """Returns the user ID of this user.""" raise NotImplementedError @property - def display_name(self): + def display_name(self) -> Optional[str]: """Returns the display name of this user.""" raise NotImplementedError @property - def email(self): + def email(self) -> Optional[str]: """Returns the email address associated with this user.""" raise NotImplementedError @property - def phone_number(self): + def phone_number(self) -> Optional[str]: """Returns the phone number associated with this user.""" raise NotImplementedError @property - def photo_url(self): + def photo_url(self) -> Optional[str]: """Returns the photo URL of this user.""" raise NotImplementedError @property - def provider_id(self): + def provider_id(self) -> str: """Returns the ID of the identity provider. This can be a short domain name (e.g. google.com), or the identity of an OpenID @@ -127,8 +154,7 @@ def provider_id(self): class UserRecord(UserInfo): """Contains metadata associated with a Firebase user account.""" - def __init__(self, data): - super().__init__() + def __init__(self, data: dict[str, Any]) -> None: if not isinstance(data, dict): raise ValueError(f'Invalid data argument: {data}. Must be a dictionary.') if not data.get('localId'): @@ -136,16 +162,16 @@ def __init__(self, data): self._data = data @property - def uid(self): + def uid(self) -> str: """Returns the user ID of this user. Returns: string: A user ID string. This value is never None or empty. """ - return self._data.get('localId') + return self._data['localId'] @property - def display_name(self): + def display_name(self) -> Optional[str]: """Returns the display name of this user. Returns: @@ -154,7 +180,7 @@ def display_name(self): return self._data.get('displayName') @property - def email(self): + def email(self) -> Optional[str]: """Returns the email address associated with this user. Returns: @@ -163,7 +189,7 @@ def email(self): return self._data.get('email') @property - def phone_number(self): + def phone_number(self) -> Optional[str]: """Returns the phone number associated with this user. Returns: @@ -172,7 +198,7 @@ def phone_number(self): return self._data.get('phoneNumber') @property - def photo_url(self): + def photo_url(self) -> Optional[str]: """Returns the photo URL of this user. Returns: @@ -181,7 +207,7 @@ def photo_url(self): return self._data.get('photoUrl') @property - def provider_id(self): + def provider_id(self) -> str: """Returns the provider ID of this user. Returns: @@ -190,7 +216,7 @@ def provider_id(self): return 'firebase' @property - def email_verified(self): + def email_verified(self) -> bool: """Returns whether the email address of this user has been verified. Returns: @@ -199,7 +225,7 @@ def email_verified(self): return bool(self._data.get('emailVerified')) @property - def disabled(self): + def disabled(self) -> bool: """Returns whether this user account is disabled. Returns: @@ -208,7 +234,7 @@ def disabled(self): return bool(self._data.get('disabled')) @property - def tokens_valid_after_timestamp(self): + def tokens_valid_after_timestamp(self) -> int: """Returns the time, in milliseconds since the epoch, before which tokens are invalid. Note: this is truncated to 1 second accuracy. @@ -223,16 +249,17 @@ def tokens_valid_after_timestamp(self): return 0 @property - def user_metadata(self): + def user_metadata(self) -> UserMetadata: """Returns additional metadata associated with this user. Returns: UserMetadata: A UserMetadata instance. Does not return None. """ - def _int_or_none(key): + def _int_or_none(key: str) -> Optional[int]: if key in self._data: return int(self._data[key]) return None + last_refresh_at_millis = None last_refresh_at_rfc3339 = self._data.get('lastRefreshAt', None) if last_refresh_at_rfc3339: @@ -241,7 +268,7 @@ def _int_or_none(key): _int_or_none('createdAt'), _int_or_none('lastLoginAt'), last_refresh_at_millis) @property - def provider_data(self): + def provider_data(self) -> list['ProviderUserInfo']: """Returns a list of UserInfo instances. Each object represents an identity from an identity provider that is linked to this user. @@ -253,7 +280,7 @@ def provider_data(self): return [ProviderUserInfo(entry) for entry in providers] @property - def custom_claims(self): + def custom_claims(self) -> Optional[dict[str, Any]]: """Returns any custom claims set on this user account. Returns: @@ -267,7 +294,7 @@ def custom_claims(self): return None @property - def tenant_id(self): + def tenant_id(self) -> Optional[str]: """Returns the tenant ID of this user. Returns: @@ -280,7 +307,7 @@ class ExportedUserRecord(UserRecord): """Contains metadata associated with a user including password hash and salt.""" @property - def password_hash(self): + def password_hash(self) -> Optional[str]: """The user's password hash as a base64-encoded string. If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this @@ -299,7 +326,7 @@ def password_hash(self): return password_hash @property - def password_salt(self): + def password_salt(self) -> Optional[str]: """The user's password salt as a base64-encoded string. If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this @@ -314,7 +341,11 @@ def password_salt(self): class GetUsersResult: """Represents the result of the ``auth.get_users()`` API.""" - def __init__(self, users, not_found): + def __init__( + self, + users: list[UserRecord], + not_found: list[_user_identifier.UserIdentifier], + ) -> None: """Constructs a `GetUsersResult` object. Args: @@ -325,7 +356,7 @@ def __init__(self, users, not_found): self._not_found = not_found @property - def users(self): + def users(self) -> list[UserRecord]: """Set of `UserRecord` instances, corresponding to the set of users that were requested. Only users that were found are listed here. The result set is unordered. @@ -333,7 +364,7 @@ def users(self): return self._users @property - def not_found(self): + def not_found(self) -> list[_user_identifier.UserIdentifier]: """Set of `UserIdentifier` instances that were requested, but not found. """ @@ -348,27 +379,38 @@ class ListUsersPage: through all users in the Firebase project starting from this page. """ - def __init__(self, download, page_token, max_results): + def __init__( + self, + download: Callable[[Optional[str], int], dict[str, Any]], + page_token: Optional[str], + max_results: int, + ) -> None: self._download = download self._max_results = max_results self._current = download(page_token, max_results) @property - def users(self): + def users(self) -> list[ExportedUserRecord]: """A list of ``ExportedUserRecord`` instances available in this page.""" - return [ExportedUserRecord(user) for user in self._current.get('users', [])] + return [ + ExportedUserRecord(user) + for user in cast( + list[dict[str, Any]], + self._current.get('users', []), + ) + ] @property - def next_page_token(self): + def next_page_token(self) -> str: """Page token string for the next page (empty string indicates no more pages).""" - return self._current.get('nextPageToken', '') + return cast(str, self._current.get('nextPageToken', '')) @property - def has_next_page(self): + def has_next_page(self) -> bool: """A boolean indicating whether more pages are available.""" return bool(self.next_page_token) - def get_next_page(self): + def get_next_page(self) -> Optional['ListUsersPage']: """Retrieves the next page of user accounts, if available. Returns: @@ -378,7 +420,7 @@ def get_next_page(self): return ListUsersPage(self._download, self.next_page_token, self._max_results) return None - def iterate_all(self): + def iterate_all(self) -> '_UserIterator': """Retrieves an iterator for user accounts. Returned iterator will iterate through all the user accounts in the Firebase project @@ -394,7 +436,7 @@ def iterate_all(self): class DeleteUsersResult: """Represents the result of the ``auth.delete_users()`` API.""" - def __init__(self, result, total): + def __init__(self, result: 'BatchDeleteAccountsResponse', total: int) -> None: """Constructs a `DeleteUsersResult` object. Args: @@ -408,7 +450,7 @@ def __init__(self, result, total): self._errors = errors @property - def success_count(self): + def success_count(self) -> int: """Returns the number of users that were deleted successfully (possibly zero). @@ -418,14 +460,14 @@ def success_count(self): return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: """Returns the number of users that failed to be deleted (possibly zero). """ return self._failure_count @property - def errors(self): + def errors(self) -> list[_user_import.ErrorInfo]: """A list of `auth.ErrorInfo` instances describing the errors that were encountered during the deletion. Length of this list is equal to `failure_count`. @@ -436,7 +478,7 @@ def errors(self): class BatchDeleteAccountsResponse: """Represents the results of a `delete_users()` call.""" - def __init__(self, errors=None): + def __init__(self, errors: Optional[list[dict[str, Any]]] = None) -> None: """Constructs a `BatchDeleteAccountsResponse` instance, corresponding to the JSON representing the `BatchDeleteAccountsResponse` proto. @@ -445,14 +487,13 @@ def __init__(self, errors=None): `ErrorInfo` instance as returned by the server. `None` implies an empty list. """ - self.errors = [ErrorInfo(err) for err in errors] if errors else [] + self.errors = [_user_import.ErrorInfo(err) for err in errors] if errors else [] class ProviderUserInfo(UserInfo): """Contains metadata regarding how a user is known by a particular identity provider.""" - def __init__(self, data): - super().__init__() + def __init__(self, data: dict[str, Any]) -> None: if not isinstance(data, dict): raise ValueError(f'Invalid data argument: {data}. Must be a dictionary.') if not data.get('rawId'): @@ -460,28 +501,29 @@ def __init__(self, data): self._data = data @property - def uid(self): - return self._data.get('rawId') + def uid(self) -> str: + return self._data['rawId'] @property - def display_name(self): + def display_name(self) -> Optional[str]: return self._data.get('displayName') @property - def email(self): + def email(self) -> Optional[str]: return self._data.get('email') @property - def phone_number(self): + def phone_number(self) -> Optional[str]: return self._data.get('phoneNumber') @property - def photo_url(self): + def photo_url(self) -> Optional[str]: return self._data.get('photoUrl') @property - def provider_id(self): - return self._data.get('providerId') + def provider_id(self) -> str: + # possible issue: can providerId be `None`? + return self._data.get('providerId') # pyright: ignore[reportReturnType] class ActionCodeSettings: @@ -489,8 +531,16 @@ class ActionCodeSettings: Used when invoking the email action link generation APIs. """ - def __init__(self, url, handle_code_in_app=None, dynamic_link_domain=None, ios_bundle_id=None, - android_package_name=None, android_install_app=None, android_minimum_version=None): + def __init__( + self, + url: str, + handle_code_in_app: Optional[bool] = None, + dynamic_link_domain: Optional[str] = None, + ios_bundle_id: Optional[str] = None, + android_package_name: Optional[str] = None, + android_install_app: Optional[bool] = None, + android_minimum_version: Optional[str] = None, + ) -> None: self.url = url self.handle_code_in_app = handle_code_in_app self.dynamic_link_domain = dynamic_link_domain @@ -500,7 +550,7 @@ def __init__(self, url, handle_code_in_app=None, dynamic_link_domain=None, ios_b self.android_minimum_version = android_minimum_version -def encode_action_code_settings(settings): +def encode_action_code_settings(settings: ActionCodeSettings) -> dict[str, Any]: """ Validates the provided action code settings for email link generation and populates the REST api parameters. @@ -508,7 +558,7 @@ def encode_action_code_settings(settings): returns - dict of parameters to be passed for link gereration. """ - parameters = {} + parameters: dict[str, Any] = {} # url if not settings.url: raise ValueError("Dynamic action links url is mandatory") @@ -574,14 +624,20 @@ class UserManager: ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' - def __init__(self, http_client, project_id, tenant_id=None, url_override=None): + def __init__( + self, + http_client: _http_client.HttpClient[dict[str, Any]], + project_id: str, + tenant_id: Optional[str] = None, + url_override: Optional[str] = None, + ) -> None: self.http_client = http_client url_prefix = url_override or self.ID_TOOLKIT_URL self.base_url = f'{url_prefix}/projects/{project_id}' if tenant_id: self.base_url += f'/tenants/{tenant_id}' - def get_user(self, **kwargs): + def get_user(self, **kwargs: Any) -> dict[str, Any]: """Gets the user data corresponding to the provided key.""" if 'uid' in kwargs: key, key_type = kwargs.pop('uid'), 'user ID' @@ -600,9 +656,12 @@ def get_user(self, **kwargs): raise _auth_utils.UserNotFoundError( f'No user record found for the provided {key_type}: {key}.', http_response=http_resp) - return body['users'][0] + return cast(list[dict[str, Any]], body['users'])[0] - def get_users(self, identifiers): + def get_users( + self, + identifiers: Sequence[_user_identifier.UserIdentifier], + ) -> list[dict[str, Any]]: """Looks up multiple users by their identifiers (uid, email, etc.) Args: @@ -624,7 +683,7 @@ def get_users(self, identifiers): if len(identifiers) > 100: raise ValueError('`identifiers` parameter must have <= 100 entries.') - payload = defaultdict(list) + payload: dict[str, list[Any]] = collections.defaultdict(list) for identifier in identifiers: if isinstance(identifier, _user_identifier.UidIdentifier): payload['localId'].append(identifier.uid) @@ -646,9 +705,13 @@ def get_users(self, identifiers): if not http_resp.ok: raise _auth_utils.UnexpectedResponseError( 'Failed to get users.', http_response=http_resp) - return body.get('users', []) + return cast(list[dict[str, Any]], body.get('users', [])) - def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): + def list_users( + self, + page_token: Optional[str] = None, + max_results: int = MAX_LIST_USERS_RESULTS, + ) -> dict[str, Any]: """Retrieves a batch of users.""" if page_token is not None: if not isinstance(page_token, str) or not page_token: @@ -659,14 +722,23 @@ def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): raise ValueError( f'Max results must be a positive integer less than {MAX_LIST_USERS_RESULTS}.') - payload = {'maxResults': max_results} + payload: dict[str, Any] = {'maxResults': max_results} if page_token: payload['nextPageToken'] = page_token body, _ = self._make_request('get', '/accounts:batchGet', params=payload) return body - def create_user(self, uid=None, display_name=None, email=None, phone_number=None, - photo_url=None, password=None, disabled=None, email_verified=None): + def create_user( + self, + uid: Optional[str] = None, + display_name: Optional[str] = None, + email: Optional[str] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + password: Optional[str] = None, + disabled: Optional[bool] = None, + email_verified: Optional[bool] = None, + ) -> str: """Creates a new user account with the specified properties.""" payload = { 'localId': _auth_utils.validate_uid(uid), @@ -683,13 +755,24 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None if not body or not body.get('localId'): raise _auth_utils.UnexpectedResponseError( 'Failed to create new user.', http_response=http_resp) - return body.get('localId') - - def update_user(self, uid, display_name=None, email=None, phone_number=None, - photo_url=None, password=None, disabled=None, email_verified=None, - valid_since=None, custom_claims=None, providers_to_delete=None): + return cast(str, body['localId']) + + def update_user( + self, + uid: str, + display_name: Optional[str] = None, + email: Optional[str] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + password: Optional[str] = None, + disabled: Optional[bool] = None, + email_verified: Optional[bool] = None, + valid_since: Optional['ConvertibleToInt'] = None, + custom_claims: Optional[Union[dict[str, Any], str]] = None, + providers_to_delete: Optional[list[str]] = None, + ) -> str: """Updates an existing user account with the specified properties""" - payload = { + payload: dict[str, Any] = { 'localId': _auth_utils.validate_uid(uid, required=True), 'email': _auth_utils.validate_email(email), 'password': _auth_utils.validate_password(password), @@ -698,7 +781,7 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, 'disableUser': bool(disabled) if disabled is not None else None, } - remove = [] + remove: list[str] = [] remove_provider = _auth_utils.validate_provider_ids(providers_to_delete) if display_name is not None: if display_name is DELETE_ATTRIBUTE: @@ -734,9 +817,9 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, if not body or not body.get('localId'): raise _auth_utils.UnexpectedResponseError( f'Failed to update user: {uid}.', http_response=http_resp) - return body.get('localId') + return cast(str, body['localId']) - def delete_user(self, uid): + def delete_user(self, uid: str) -> None: """Deletes the user identified by the specified user ID.""" _auth_utils.validate_uid(uid, required=True) body, http_resp = self._make_request('post', '/accounts:delete', json={'localId' : uid}) @@ -744,7 +827,7 @@ def delete_user(self, uid): raise _auth_utils.UnexpectedResponseError( f'Failed to delete user: {uid}.', http_response=http_resp) - def delete_users(self, uids, force_delete=False): + def delete_users(self, uids: Sequence[str], force_delete: bool = False) -> BatchDeleteAccountsResponse: """Deletes the users identified by the specified user ids. Args: @@ -778,9 +861,14 @@ def delete_users(self, uids, force_delete=False): raise _auth_utils.UnexpectedResponseError( 'Unexpected response from server while attempting to delete users.', http_response=http_resp) - return BatchDeleteAccountsResponse(body.get('errors', [])) - - def import_users(self, users, hash_alg=None): + return BatchDeleteAccountsResponse(cast(list[dict[str, Any]], + body.get('errors', []))) + + def import_users( + self, + users: Sequence[_user_import.ImportUserRecord], + hash_alg: Optional[_user_import.UserImportHash] = None, + ) -> dict[str, Any]: """Imports the given list of users to Firebase Auth.""" try: if not users or len(users) > MAX_IMPORT_USERS_SIZE: @@ -803,7 +891,12 @@ def import_users(self, users, hash_alg=None): 'Failed to import users.', http_response=http_resp) return body - def generate_email_action_link(self, action_type, email, action_code_settings=None): + def generate_email_action_link( + self, + action_type: Literal['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET'], + email: Optional[str], + action_code_settings: Optional[ActionCodeSettings] = None, + ) -> str: """Fetches the email action links for types Args: @@ -833,9 +926,14 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No if not body or not body.get('oobLink'): raise _auth_utils.UnexpectedResponseError( 'Failed to generate email action link.', http_response=http_resp) - return body.get('oobLink') - - def _make_request(self, method, path, **kwargs): + return cast(str, body['oobLink']) + + def _make_request( + self, + method: str, + path: str, + **kwargs: Any, + ) -> tuple[dict[str, Any], requests.Response]: url = f'{self.base_url}{path}' try: return self.http_client.body_and_response(method, url, **kwargs) @@ -843,8 +941,7 @@ def _make_request(self, method, path, **kwargs): raise _auth_utils.handle_auth_backend_error(error) -class _UserIterator(_auth_utils.PageIterator): - +class _UserIterator(_auth_utils.PageIterator[ListUsersPage]): @property - def items(self): - return self._current_page.users + def items(self) -> list[ExportedUserRecord]: + return self._current_page.users if self._current_page else [] diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index d0aca884b..2d9e82aa5 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -15,16 +15,30 @@ """Internal utilities common to all modules.""" import json +from collections.abc import Callable +from typing import Any, Optional, TypeVar, Union, cast from platform import python_version -from typing import Callable, Optional -import google.auth -import requests +import google.auth.credentials +import google.auth.transport import httpx +import requests import firebase_admin from firebase_admin import exceptions +__all__ = ( + 'EmulatorAdminCredentials', + 'get_app_service', + 'get_metrics_header', + 'handle_httpx_error', + 'handle_operation_error', + 'handle_platform_error_from_httpx', + 'handle_platform_error_from_requests', + 'handle_requests_error', +) + +_T = TypeVar('_T') _ERROR_CODE_TO_EXCEPTION_TYPE = { exceptions.INVALID_ARGUMENT: exceptions.InvalidArgumentError, @@ -46,7 +60,7 @@ } -_HTTP_STATUS_TO_ERROR_CODE = { +_HTTP_STATUS_TO_ERROR_CODE: dict[int, str] = { 400: exceptions.INVALID_ARGUMENT, 401: exceptions.UNAUTHENTICATED, 403: exceptions.PERMISSION_DENIED, @@ -60,7 +74,7 @@ # See https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto -_RPC_CODE_TO_ERROR_CODE = { +_RPC_CODE_TO_ERROR_CODE: dict[int, str] = { 1: exceptions.CANCELLED, 2: exceptions.UNKNOWN, 3: exceptions.INVALID_ARGUMENT, @@ -78,10 +92,12 @@ 16: exceptions.UNAUTHENTICATED, } -def get_metrics_header(): + +def get_metrics_header() -> str: return f'gl-python/{python_version()} fire-admin/{firebase_admin.__version__}' -def _get_initialized_app(app): + +def _get_initialized_app(app: Optional[firebase_admin.App]) -> firebase_admin.App: """Returns a reference to an initialized App instance.""" if app is None: return firebase_admin.get_app() @@ -98,13 +114,22 @@ def _get_initialized_app(app): f'"{type(app)}".') - -def get_app_service(app, name, initializer): +def get_app_service( + app: Optional[firebase_admin.App], + name: str, + initializer: Callable[[firebase_admin.App], _T], +) -> _T: app = _get_initialized_app(app) return app._get_service(name, initializer) # pylint: disable=protected-access -def handle_platform_error_from_requests(error, handle_func=None): +def handle_platform_error_from_requests( + error: requests.RequestException, + handle_func: Optional[Callable[ + [requests.RequestException, str, dict[str, Any]], + Optional[exceptions.FirebaseError], + ]] = None, +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given requests error. This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. @@ -131,9 +156,10 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) + def handle_platform_error_from_httpx( - error: httpx.HTTPError, - handle_func: Optional[Callable[..., Optional[exceptions.FirebaseError]]] = None + error: httpx.HTTPError, + handle_func: Optional[Callable[..., Optional[exceptions.FirebaseError]]] = None, ) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given httpx error. @@ -162,7 +188,7 @@ def handle_platform_error_from_httpx( return handle_httpx_error(error) -def handle_operation_error(error): +def handle_operation_error(error: Union[dict[str, Any], Exception]) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given operation error. Args: @@ -173,17 +199,22 @@ def handle_operation_error(error): """ if not isinstance(error, dict): return exceptions.UnknownError( - message=f'Unknown error while making a remote service call: {error}', + message='Unknown error while making a remote service call: {0}'.format(error), cause=error) - rpc_code = error.get('code') - message = error.get('message') + rpc_code = error.get('code', 0) + # possible issue: needs be str | None ? + message = cast(str, error.get('message')) error_code = _rpc_code_to_error_code(rpc_code) err_type = _error_code_to_exception_type(error_code) - return err_type(message=message) + return err_type(message, None, None) -def _handle_func_requests(error, message, error_dict): +def _handle_func_requests( + error: requests.RequestException, + message: str, + error_dict: dict[str, Any], +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given GCP error. Args: @@ -198,7 +229,11 @@ def _handle_func_requests(error, message, error_dict): return handle_requests_error(error, message, code) -def handle_requests_error(error, message=None, code=None): +def handle_requests_error( + error: requests.RequestException, + message: Optional[str] = None, + code: Optional[str] = None, +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given requests error. This method is agnostic of the remote service that produced the error, whether it is a GCP @@ -235,9 +270,14 @@ def handle_requests_error(error, message=None, code=None): message = str(error) err_type = _error_code_to_exception_type(code) - return err_type(message=message, cause=error, http_response=error.response) + return err_type(message, error, error.response) -def _handle_func_httpx(error: httpx.HTTPError, message, error_dict) -> exceptions.FirebaseError: + +def _handle_func_httpx( + error: httpx.HTTPError, + message: str, + error_dict: dict[str, Any], +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given GCP error. Args: @@ -252,7 +292,11 @@ def _handle_func_httpx(error: httpx.HTTPError, message, error_dict) -> exception return handle_httpx_error(error, message, code) -def handle_httpx_error(error: httpx.HTTPError, message=None, code=None) -> exceptions.FirebaseError: +def handle_httpx_error( + error: Exception, + message: Optional[str] = None, + code: Optional[str] = None, +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given httpx error. This method is agnostic of the remote service that produced the error, whether it is a GCP @@ -286,26 +330,34 @@ def handle_httpx_error(error: httpx.HTTPError, message=None, code=None) -> excep message = str(error) err_type = _error_code_to_exception_type(code) - return err_type(message=message, cause=error, http_response=error.response) + return err_type(message, error, error.response) return exceptions.UnknownError( message=f'Unknown error while making a remote service call: {error}', cause=error) -def _http_status_to_error_code(status): + +def _http_status_to_error_code(status: int) -> str: """Maps an HTTP status to a platform error code.""" return _HTTP_STATUS_TO_ERROR_CODE.get(status, exceptions.UNKNOWN) -def _rpc_code_to_error_code(rpc_code): + +def _rpc_code_to_error_code(rpc_code: int) -> str: """Maps an RPC code to a platform error code.""" return _RPC_CODE_TO_ERROR_CODE.get(rpc_code, exceptions.UNKNOWN) -def _error_code_to_exception_type(code): + +def _error_code_to_exception_type( + code: str, +) -> Callable[ + [str, Optional[Exception], Optional[Union[httpx.Response, requests.Response]]], + exceptions.FirebaseError +]: """Maps a platform error code to an exception type.""" return _ERROR_CODE_TO_EXCEPTION_TYPE.get(code, exceptions.UnknownError) -def _parse_platform_error(content, status_code): +def _parse_platform_error(content: str, status_code: int) -> tuple[dict[str, Any], str]: """Parses an HTTP error response from a Google Cloud Platform API and extracts the error code and message fields. @@ -316,15 +368,15 @@ def _parse_platform_error(content, status_code): Returns: tuple: A tuple containing error code and message. """ - data = {} + data: dict[str, Any] = {} try: parsed_body = json.loads(content) if isinstance(parsed_body, dict): - data = parsed_body + data = cast(dict[str, Any], parsed_body) except ValueError: pass - error_dict = data.get('error', {}) + error_dict: dict[str, Any] = data.get('error', {}) msg = error_dict.get('message') if not msg: msg = f'Unexpected HTTP response with status: {status_code}; body: {content}' @@ -340,9 +392,9 @@ class EmulatorAdminCredentials(google.auth.credentials.Credentials): This is used instead of user-supplied credentials or ADC. It will silently do nothing when asked to refresh credentials. """ - def __init__(self): + def __init__(self) -> None: google.auth.credentials.Credentials.__init__(self) self.token = 'owner' - def refresh(self, request): + def refresh(self, request: google.auth.transport.Request) -> None: pass diff --git a/firebase_admin/app_check.py b/firebase_admin/app_check.py index 40d857f4e..dfc011bcf 100644 --- a/firebase_admin/app_check.py +++ b/firebase_admin/app_check.py @@ -14,18 +14,27 @@ """Firebase App Check module.""" -from typing import Any, Dict +from typing import Any, Optional, cast + import jwt -from jwt import PyJWKClient, ExpiredSignatureError, InvalidTokenError, DecodeError -from jwt import InvalidAudienceError, InvalidIssuerError, InvalidSignatureError + +import firebase_admin from firebase_admin import _utils +__all__ = ('verify_token',) + + _APP_CHECK_ATTRIBUTE = '_app_check' -def _get_app_check_service(app) -> Any: + +def _get_app_check_service(app: Optional[firebase_admin.App]) -> '_AppCheckService': return _utils.get_app_service(app, _APP_CHECK_ATTRIBUTE, _AppCheckService) -def verify_token(token: str, app=None) -> Dict[str, Any]: + +def verify_token( + token: str, + app: Optional[firebase_admin.App] = None, +) -> dict[str, Any]: """Verifies a Firebase App Check token. Args: @@ -42,20 +51,18 @@ def verify_token(token: str, app=None) -> Dict[str, Any]: """ return _get_app_check_service(app).verify_token(token) + class _AppCheckService: """Service class that implements Firebase App Check functionality.""" _APP_CHECK_ISSUER = 'https://firebaseappcheck.googleapis.com/' _JWKS_URL = 'https://firebaseappcheck.googleapis.com/v1/jwks' - _project_id = None - _scoped_project_id = None - _jwks_client = None _APP_CHECK_HEADERS = { 'x-goog-api-client': _utils.get_metrics_header(), } - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: # Validate and store the project_id to validate the JWT claims self._project_id = app.project_id if not self._project_id: @@ -64,13 +71,12 @@ def __init__(self, app): 'service. Either set the projectId option, use service ' 'account credentials, or set the ' 'GOOGLE_CLOUD_PROJECT environment variable.') - self._scoped_project_id = 'projects/' + app.project_id + self._scoped_project_id = 'projects/' + self._project_id # Default lifespan is 300 seconds (5 minutes) so we change it to 21600 seconds (6 hours). - self._jwks_client = PyJWKClient( + self._jwks_client = jwt.PyJWKClient( self._JWKS_URL, lifespan=21600, headers=self._APP_CHECK_HEADERS) - - def verify_token(self, token: str) -> Dict[str, Any]: + def verify_token(self, token: str) -> dict[str, Any]: """Verifies a Firebase App Check token.""" _Validators.check_string("app check token", token) @@ -81,7 +87,7 @@ def verify_token(self, token: str) -> Dict[str, Any]: signing_key = self._jwks_client.get_signing_key_from_jwt(token) self._has_valid_token_headers(jwt.get_unverified_header(token)) verified_claims = self._decode_and_verify(token, signing_key.key) - except (InvalidTokenError, DecodeError) as exception: + except (jwt.InvalidTokenError, jwt.DecodeError) as exception: raise ValueError( f'Verifying App Check token failed. Error: {exception}' ) from exception @@ -89,7 +95,7 @@ def verify_token(self, token: str) -> Dict[str, Any]: verified_claims['app_id'] = verified_claims.get('sub') return verified_claims - def _has_valid_token_headers(self, headers: Any) -> None: + def _has_valid_token_headers(self, headers: dict[str, Any]) -> None: """Checks whether the token has valid headers for App Check.""" # Ensure the token's header has type JWT if headers.get('typ') != 'JWT': @@ -102,9 +108,9 @@ def _has_valid_token_headers(self, headers: Any) -> None: f'Expected RS256 but got {algorithm}.' ) - def _decode_and_verify(self, token: str, signing_key: str): + def _decode_and_verify(self, token: str, signing_key: str) -> dict[str, Any]: """Decodes and verifies the token from App Check.""" - payload = {} + payload: dict[str, Any] = {} try: payload = jwt.decode( token, @@ -112,25 +118,25 @@ def _decode_and_verify(self, token: str, signing_key: str): algorithms=["RS256"], audience=self._scoped_project_id ) - except InvalidSignatureError as exception: + except jwt.InvalidSignatureError as exception: raise ValueError( 'The provided App Check token has an invalid signature.' ) from exception - except InvalidAudienceError as exception: + except jwt.InvalidAudienceError as exception: raise ValueError( 'The provided App Check token has an incorrect "aud" (audience) claim. ' f'Expected payload to include {self._scoped_project_id}.' ) from exception - except InvalidIssuerError as exception: + except jwt.InvalidIssuerError as exception: raise ValueError( 'The provided App Check token has an incorrect "iss" (issuer) claim. ' f'Expected claim to include {self._APP_CHECK_ISSUER}' ) from exception - except ExpiredSignatureError as exception: + except jwt.ExpiredSignatureError as exception: raise ValueError( 'The provided App Check token has expired.' ) from exception - except InvalidTokenError as exception: + except jwt.InvalidTokenError as exception: raise ValueError( f'Decoding App Check token failed. Error: {exception}' ) from exception @@ -138,7 +144,7 @@ def _decode_and_verify(self, token: str, signing_key: str): audience = payload.get('aud') if not isinstance(audience, list) or self._scoped_project_id not in audience: raise ValueError('Firebase App Check token has incorrect "aud" (audience) claim.') - if not payload.get('iss').startswith(self._APP_CHECK_ISSUER): + if not cast(str, payload['iss']).startswith(self._APP_CHECK_ISSUER): raise ValueError('Token does not contain the correct "iss" (issuer).') _Validators.check_string( 'The provided App Check token "sub" (subject) claim', @@ -146,6 +152,7 @@ def _decode_and_verify(self, token: str, signing_key: str): return payload + class _Validators: """A collection of data validation utilities. @@ -153,7 +160,7 @@ class _Validators: """ @classmethod - def check_string(cls, label: str, value: Any): + def check_string(cls, label: str, value: Any) -> None: """Checks if the given value is a string.""" if value is None: raise ValueError(f'{label} "{value}" must be a non-empty string.') diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index ced143112..dfe9244cc 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -19,6 +19,11 @@ creating and managing user accounts in Firebase projects. """ +import datetime +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Optional, Union + +import firebase_admin from firebase_admin import _auth_client from firebase_admin import _auth_providers from firebase_admin import _auth_utils @@ -28,11 +33,10 @@ from firebase_admin import _user_mgt from firebase_admin import _utils +if TYPE_CHECKING: + from _typeshed import ConvertibleToInt -_AUTH_ATTRIBUTE = '_auth' - - -__all__ = [ +__all__ = ( 'ActionCodeSettings', 'CertificateFetchError', 'Client', @@ -107,7 +111,9 @@ 'update_user', 'verify_id_token', 'verify_session_cookie', -] +) + +_AUTH_ATTRIBUTE = '_auth' ActionCodeSettings = _user_mgt.ActionCodeSettings CertificateFetchError = _token_gen.CertificateFetchError @@ -156,7 +162,7 @@ ProviderIdentifier = _user_identifier.ProviderIdentifier -def _get_client(app): +def _get_client(app: Optional[firebase_admin.App]) -> Client: """Returns a client instance for an App. If the App already has a client associated with it, simply returns @@ -175,7 +181,11 @@ def _get_client(app): return _utils.get_app_service(app, _AUTH_ATTRIBUTE, Client) -def create_custom_token(uid, developer_claims=None, app=None): +def create_custom_token( + uid: str, + developer_claims: Optional[dict[str, Any]] = None, + app: Optional[firebase_admin.App] = None, +) -> bytes: """Builds and signs a Firebase custom auth token. Args: @@ -195,7 +205,12 @@ def create_custom_token(uid, developer_claims=None, app=None): return client.create_custom_token(uid, developer_claims) -def verify_id_token(id_token, app=None, check_revoked=False, clock_skew_seconds=0): +def verify_id_token( + id_token: Union[bytes, str], + app: Optional[firebase_admin.App] = None, + check_revoked: bool = False, + clock_skew_seconds: int = 0, +) -> dict[str, Any]: """Verifies the signature and data for the provided JWT. Accepts a signed token string, verifies that it is current, and issued @@ -226,7 +241,11 @@ def verify_id_token(id_token, app=None, check_revoked=False, clock_skew_seconds= id_token, check_revoked=check_revoked, clock_skew_seconds=clock_skew_seconds) -def create_session_cookie(id_token, expires_in, app=None): +def create_session_cookie( + id_token: Union[bytes, str], + expires_in: Union[datetime.timedelta, int], + app: Optional[firebase_admin.App] = None, +) -> str: """Creates a new Firebase session cookie from the given ID token and options. The returned JWT can be set as a server-side session cookie with a custom cookie policy. @@ -249,7 +268,12 @@ def create_session_cookie(id_token, expires_in, app=None): return client._token_generator.create_session_cookie(id_token, expires_in) -def verify_session_cookie(session_cookie, check_revoked=False, app=None, clock_skew_seconds=0): +def verify_session_cookie( + session_cookie: Union[bytes, str], + check_revoked: bool = False, + app: Optional[firebase_admin.App] = None, + clock_skew_seconds: int = 0, +) -> dict[str, Any]: """Verifies a Firebase session cookie. Accepts a session cookie string, verifies that it is current, and issued @@ -285,7 +309,7 @@ def verify_session_cookie(session_cookie, check_revoked=False, app=None, clock_s return verified_claims -def revoke_refresh_tokens(uid, app=None): +def revoke_refresh_tokens(uid: str, app: Optional[firebase_admin.App] = None) -> None: """Revokes all refresh tokens for an existing user. This function updates the user's ``tokens_valid_after_timestamp`` to the current UTC @@ -309,7 +333,7 @@ def revoke_refresh_tokens(uid, app=None): client.revoke_refresh_tokens(uid) -def get_user(uid, app=None): +def get_user(uid: str, app: Optional[firebase_admin.App] = None) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified user ID. Args: @@ -328,7 +352,10 @@ def get_user(uid, app=None): return client.get_user(uid=uid) -def get_user_by_email(email, app=None): +def get_user_by_email( + email: str, + app: Optional[firebase_admin.App] = None, +) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified user email. Args: @@ -347,7 +374,10 @@ def get_user_by_email(email, app=None): return client.get_user_by_email(email=email) -def get_user_by_phone_number(phone_number, app=None): +def get_user_by_phone_number( + phone_number: str, + app: Optional[firebase_admin.App] = None, +) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified phone number. Args: @@ -366,7 +396,10 @@ def get_user_by_phone_number(phone_number, app=None): return client.get_user_by_phone_number(phone_number=phone_number) -def get_users(identifiers, app=None): +def get_users( + identifiers: Sequence[_user_identifier.UserIdentifier], + app: Optional[firebase_admin.App] = None, +) -> _user_mgt.GetUsersResult: """Gets the user data corresponding to the specified identifiers. There are no ordering guarantees; in particular, the nth entry in the @@ -394,7 +427,11 @@ def get_users(identifiers, app=None): return client.get_users(identifiers) -def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, app=None): +def list_users( + page_token: Optional[str] = None, + max_results: int = _user_mgt.MAX_LIST_USERS_RESULTS, + app: Optional[firebase_admin.App] = None, +) -> _user_mgt.ListUsersPage: """Retrieves a page of user accounts from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -420,7 +457,18 @@ def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, ap return client.list_users(page_token=page_token, max_results=max_results) -def create_user(**kwargs): # pylint: disable=differing-param-doc +def create_user( + uid: Optional[str] = None, + display_name: Optional[str] = None, + email: Optional[str] = None, + email_verified: Optional[bool] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + password: Optional[str] = None, + disabled: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, + **kwargs: Any, +) -> _user_mgt.UserRecord: # pylint: disable=differing-param-doc """Creates a new user account with the specified properties. Args: @@ -445,12 +493,28 @@ def create_user(**kwargs): # pylint: disable=differing-param-doc ValueError: If the specified user properties are invalid. FirebaseError: If an error occurs while creating the user account. """ - app = kwargs.pop('app', None) client = _get_client(app) - return client.create_user(**kwargs) - - -def update_user(uid, **kwargs): # pylint: disable=differing-param-doc + return client.create_user(uid=uid, display_name=display_name, email=email, + email_verified=email_verified, phone_number=phone_number, photo_url=photo_url, + password=password, disabled=disabled, **kwargs) + + +def update_user( + uid: str, + *, + display_name: Optional[str] = None, + email: Optional[str] = None, + email_verified: Optional[bool] = None, + phone_number: Optional[str] = None, + photo_url: Optional[str] = None, + password: Optional[str] = None, + disabled: Optional[bool] = None, + custom_claims: Optional[Union[dict[str, Any], str]] = None, + valid_since: Optional['ConvertibleToInt'] = None, + providers_to_delete: Optional[list[str]] = None, + app: Optional[firebase_admin.App] = None, + **kwargs: Any, +) -> _user_mgt.UserRecord: """Updates an existing user account with the specified properties. Args: @@ -473,6 +537,8 @@ def update_user(uid, **kwargs): # pylint: disable=differing-param-doc user account (optional). To remove all custom claims, pass ``auth.DELETE_ATTRIBUTE``. valid_since: An integer signifying the seconds since the epoch (optional). This field is set by ``revoke_refresh_tokens`` and it is discouraged to set this field directly. + providers_to_delete: The list of provider IDs to unlink, + eg: 'google.com', 'password', etc. app: An App instance (optional). Returns: @@ -482,12 +548,18 @@ def update_user(uid, **kwargs): # pylint: disable=differing-param-doc ValueError: If the specified user ID or properties are invalid. FirebaseError: If an error occurs while updating the user account. """ - app = kwargs.pop('app', None) client = _get_client(app) - return client.update_user(uid, **kwargs) + return client.update_user(uid, display_name=display_name, email=email, + phone_number=phone_number, photo_url=photo_url, password=password, disabled=disabled, + email_verified=email_verified, valid_since=valid_since, custom_claims=custom_claims, + providers_to_delete=providers_to_delete, **kwargs) -def set_custom_user_claims(uid, custom_claims, app=None): +def set_custom_user_claims( + uid: str, + custom_claims: Optional[Union[dict[str, Any], str]], + app: Optional[firebase_admin.App] = None, +) -> None: """Sets additional claims on an existing user account. Custom claims set via this function can be used to define user roles and privilege levels. @@ -511,7 +583,7 @@ def set_custom_user_claims(uid, custom_claims, app=None): client.set_custom_user_claims(uid, custom_claims=custom_claims) -def delete_user(uid, app=None): +def delete_user(uid: str, app: Optional[firebase_admin.App] = None) -> None: """Deletes the user identified by the specified user ID. Args: @@ -526,7 +598,10 @@ def delete_user(uid, app=None): client.delete_user(uid) -def delete_users(uids, app=None): +def delete_users( + uids: Sequence[str], + app: Optional[firebase_admin.App] = None, +) -> _user_mgt.DeleteUsersResult: """Deletes the users specified by the given identifiers. Deleting a non-existing user does not generate an error (the method is @@ -553,7 +628,11 @@ def delete_users(uids, app=None): return client.delete_users(uids) -def import_users(users, hash_alg=None, app=None): +def import_users( + users: Sequence[_user_import.ImportUserRecord], + hash_alg: Optional[_user_import.UserImportHash] = None, + app: Optional[firebase_admin.App] = None, +) -> _user_import.UserImportResult: """Imports the specified list of users into Firebase Auth. At most 1000 users can be imported at a time. This operation is optimized for bulk imports and @@ -579,7 +658,11 @@ def import_users(users, hash_alg=None, app=None): return client.import_users(users, hash_alg) -def generate_password_reset_link(email, action_code_settings=None, app=None): +def generate_password_reset_link( + email: Optional[str], + action_code_settings: Optional[_user_mgt.ActionCodeSettings] = None, + app: Optional[firebase_admin.App] = None, +) -> str: """Generates the out-of-band email action link for password reset flows for the specified email address. @@ -600,7 +683,11 @@ def generate_password_reset_link(email, action_code_settings=None, app=None): return client.generate_password_reset_link(email, action_code_settings=action_code_settings) -def generate_email_verification_link(email, action_code_settings=None, app=None): +def generate_email_verification_link( + email: Optional[str], + action_code_settings: Optional[_user_mgt.ActionCodeSettings] = None, + app: Optional[firebase_admin.App] = None, +) -> str: """Generates the out-of-band email action link for email verification flows for the specified email address. @@ -622,7 +709,11 @@ def generate_email_verification_link(email, action_code_settings=None, app=None) email, action_code_settings=action_code_settings) -def generate_sign_in_with_email_link(email, action_code_settings, app=None): +def generate_sign_in_with_email_link( + email: Optional[str], + action_code_settings: Optional[_user_mgt.ActionCodeSettings], + app: Optional[firebase_admin.App] = None, +) -> str: """Generates the out-of-band email action link for email link sign-in flows, using the action code settings provided. @@ -645,7 +736,10 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): email, action_code_settings=action_code_settings) -def get_oidc_provider_config(provider_id, app=None): +def get_oidc_provider_config( + provider_id: str, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers.OIDCProviderConfig: """Returns the ``OIDCProviderConfig`` with the given ID. Args: @@ -663,9 +757,18 @@ def get_oidc_provider_config(provider_id, app=None): client = _get_client(app) return client.get_oidc_provider_config(provider_id) + def create_oidc_provider_config( - provider_id, client_id, issuer, display_name=None, enabled=None, client_secret=None, - id_token_response_type=None, code_response_type=None, app=None): + provider_id: str, + client_id: str, + issuer: str, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + client_secret: Optional[str] = None, + id_token_response_type: Optional[bool] = None, + code_response_type: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers.OIDCProviderConfig: """Creates a new OIDC provider config from the given parameters. OIDC provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -705,8 +808,16 @@ def create_oidc_provider_config( def update_oidc_provider_config( - provider_id, client_id=None, issuer=None, display_name=None, enabled=None, - client_secret=None, id_token_response_type=None, code_response_type=None, app=None): + provider_id: str, + client_id: Optional[str] = None, + issuer: Optional[str] = None, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + client_secret: Optional[str] = None, + id_token_response_type: Optional[bool] = None, + code_response_type: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers.OIDCProviderConfig: """Updates an existing OIDC provider config with the given parameters. Args: @@ -717,16 +828,16 @@ def update_oidc_provider_config( Pass ``auth.DELETE_ATTRIBUTE`` to delete the current display name. enabled: A boolean indicating whether the provider configuration is enabled or disabled (optional). - app: An App instance (optional). client_secret: A string which sets the client secret for the new provider. This is required for the code flow. + id_token_response_type: A boolean which sets whether to enable the ID token response flow + for the new provider. By default, this is enabled if no response type is specified. + Having both the code and ID token response flows is currently not supported. code_response_type: A boolean which sets whether to enable the code response flow for the new provider. By default, this is not enabled if no response type is specified. A client secret must be set for this response type. Having both the code and ID token response flows is currently not supported. - id_token_response_type: A boolean which sets whether to enable the ID token response flow - for the new provider. By default, this is enabled if no response type is specified. - Having both the code and ID token response flows is currently not supported. + app: An App instance (optional). Returns: OIDCProviderConfig: The updated OIDC provider config instance. @@ -742,7 +853,10 @@ def update_oidc_provider_config( code_response_type=code_response_type) -def delete_oidc_provider_config(provider_id, app=None): +def delete_oidc_provider_config( + provider_id: str, + app: Optional[firebase_admin.App] = None, +) -> None: """Deletes the ``OIDCProviderConfig`` with the given ID. Args: @@ -759,7 +873,10 @@ def delete_oidc_provider_config(provider_id, app=None): def list_oidc_provider_configs( - page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + page_token: Optional[str] = None, + max_results: int = _auth_providers.MAX_LIST_CONFIGS_RESULTS, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers._ListOIDCProviderConfigsPage: """Retrieves a page of OIDC provider configs from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -786,7 +903,10 @@ def list_oidc_provider_configs( return client.list_oidc_provider_configs(page_token, max_results) -def get_saml_provider_config(provider_id, app=None): +def get_saml_provider_config( + provider_id: str, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers.SAMLProviderConfig: """Returns the ``SAMLProviderConfig`` with the given ID. Args: @@ -806,8 +926,16 @@ def get_saml_provider_config(provider_id, app=None): def create_saml_provider_config( - provider_id, idp_entity_id, sso_url, x509_certificates, rp_entity_id, callback_url, - display_name=None, enabled=None, app=None): + provider_id: str, + idp_entity_id: str, + sso_url: str, + x509_certificates: list[str], + rp_entity_id: str, + callback_url: str, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers.SAMLProviderConfig: """Creates a new SAML provider config from the given parameters. SAML provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -848,8 +976,16 @@ def create_saml_provider_config( def update_saml_provider_config( - provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, - rp_entity_id=None, callback_url=None, display_name=None, enabled=None, app=None): + provider_id: str, + idp_entity_id: Optional[str] = None, + sso_url: Optional[str] = None, + x509_certificates: Optional[list[str]] = None, + rp_entity_id: Optional[str] = None, + callback_url: Optional[str] = None, + display_name: Optional[str] = None, + enabled: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers.SAMLProviderConfig: """Updates an existing SAML provider config with the given parameters. Args: @@ -880,7 +1016,10 @@ def update_saml_provider_config( callback_url=callback_url, display_name=display_name, enabled=enabled) -def delete_saml_provider_config(provider_id, app=None): +def delete_saml_provider_config( + provider_id: str, + app: Optional[firebase_admin.App] = None, +) -> None: """Deletes the ``SAMLProviderConfig`` with the given ID. Args: @@ -897,7 +1036,10 @@ def delete_saml_provider_config(provider_id, app=None): def list_saml_provider_configs( - page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + page_token: Optional[str] = None, + max_results: int = _auth_providers.MAX_LIST_CONFIGS_RESULTS, + app: Optional[firebase_admin.App] = None, +) -> _auth_providers._ListSAMLProviderConfigsPage: """Retrieves a page of SAML provider configs from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 7117b71a9..9af328f52 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -13,16 +13,30 @@ # limitations under the License. """Firebase credentials module.""" -import collections +import datetime import json import pathlib +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast +from typing_extensions import TypeGuard import google.auth + from google.auth.credentials import Credentials as GoogleAuthCredentials from google.auth.transport import requests +from google.auth import crypt from google.oauth2 import credentials from google.oauth2 import service_account +if TYPE_CHECKING: + from _typeshed import StrPath + +__all__ = ( + 'AccessTokenInfo', + 'ApplicationDefault', + 'Base', + 'Certificate', + 'RefreshToken', +) _request = requests.Request() _scopes = [ @@ -34,18 +48,21 @@ 'https://www.googleapis.com/auth/userinfo.email' ] -AccessTokenInfo = collections.namedtuple('AccessTokenInfo', ['access_token', 'expiry']) -"""Data included in an OAuth2 access token. -Contains the access token string and the expiry time. The expirty time is exposed as a -``datetime`` value. -""" +class AccessTokenInfo(NamedTuple): + """Data included in an OAuth2 access token. + + Contains the access token string and the expiry time. The expirty time is exposed as a + ``datetime`` value. + """ + access_token: Any + expiry: Optional[datetime.datetime] class Base: """Provides OAuth2 access tokens for accessing Firebase services.""" - def get_access_token(self): + def get_access_token(self) -> AccessTokenInfo: """Fetches a Google OAuth2 access token using this credential instance. Returns: @@ -55,30 +72,31 @@ def get_access_token(self): google_cred.refresh(_request) return AccessTokenInfo(google_cred.token, google_cred.expiry) - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the Google credential instance used for authentication.""" raise NotImplementedError -class _ExternalCredentials(Base): + +class _ExternalCredentials(Base): # pyright: ignore[reportUnusedClass] """A wrapper for google.auth.credentials.Credentials typed credential instances""" - def __init__(self, credential: GoogleAuthCredentials): - super().__init__() + def __init__(self, credential: GoogleAuthCredentials) -> None: self._g_credential = credential - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the underlying Google Credential Returns: google.auth.credentials.Credentials: A Google Auth credential instance.""" return self._g_credential + class Certificate(Base): """A credential initialized from a JSON certificate keyfile.""" _CREDENTIAL_TYPE = 'service_account' - def __init__(self, cert): + def __init__(self, cert: Union['StrPath', dict[str, Any]]) -> None: """Initializes a credential from a Google service account certificate. Service account certificates can be downloaded as JSON files from the Firebase console. @@ -92,7 +110,6 @@ def __init__(self, cert): IOError: If the specified certificate file doesn't exist or cannot be read. ValueError: If the specified certificate is invalid. """ - super().__init__() if _is_file_path(cert): with open(cert, encoding='utf-8') as json_file: json_data = json.load(json_file) @@ -115,18 +132,18 @@ def __init__(self, cert): f'Failed to initialize a certificate credential. Caused by: "{error}"') from error @property - def project_id(self): + def project_id(self) -> Optional[str]: return self._g_credential.project_id @property - def signer(self): + def signer(self) -> crypt.Signer: return self._g_credential.signer @property - def service_account_email(self): + def service_account_email(self) -> str: return self._g_credential.service_account_email - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the underlying Google credential. Returns: @@ -137,16 +154,17 @@ def get_credential(self): class ApplicationDefault(Base): """A Google Application Default credential.""" - def __init__(self): + def __init__(self) -> None: """Creates an instance that will use Application Default credentials. The credentials will be lazily initialized when get_credential() or project_id() is called. See those methods for possible errors raised. """ - super().__init__() - self._g_credential = None # Will be lazily-loaded via _load_credential(). + # Will be lazily-loaded via _load_credential(). + self._g_credential: Optional[GoogleAuthCredentials] = None + self._project_id: Optional[str] - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the underlying Google credential. Raises: @@ -155,10 +173,10 @@ def get_credential(self): Returns: google.auth.credentials.Credentials: A Google Auth credential instance.""" self._load_credential() - return self._g_credential + return cast(GoogleAuthCredentials, self._g_credential) @property - def project_id(self): + def project_id(self) -> Optional[str]: """Returns the project_id from the underlying Google credential. Raises: @@ -169,16 +187,17 @@ def project_id(self): self._load_credential() return self._project_id - def _load_credential(self): + def _load_credential(self) -> None: if not self._g_credential: self._g_credential, self._project_id = google.auth.default(scopes=_scopes) + class RefreshToken(Base): """A credential initialized from an existing refresh token.""" _CREDENTIAL_TYPE = 'authorized_user' - def __init__(self, refresh_token): + def __init__(self, refresh_token: Union['StrPath', dict[str, Any]]) -> None: """Initializes a credential from a refresh token JSON file. The JSON must consist of client_id, client_secret and refresh_token fields. Refresh @@ -194,7 +213,6 @@ def __init__(self, refresh_token): IOError: If the specified file doesn't exist or cannot be read. ValueError: If the refresh token configuration is invalid. """ - super().__init__() if _is_file_path(refresh_token): with open(refresh_token, encoding='utf-8') as json_file: json_data = json.load(json_file) @@ -212,18 +230,18 @@ def __init__(self, refresh_token): self._g_credential = credentials.Credentials.from_authorized_user_info(json_data, _scopes) @property - def client_id(self): + def client_id(self) -> Optional[str]: return self._g_credential.client_id @property - def client_secret(self): + def client_secret(self) -> Optional[str]: return self._g_credential.client_secret @property - def refresh_token(self): + def refresh_token(self) -> Optional[str]: return self._g_credential.refresh_token - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the underlying Google credential. Returns: @@ -231,7 +249,7 @@ def get_credential(self): return self._g_credential -def _is_file_path(path): +def _is_file_path(path: Any) -> TypeGuard['StrPath']: try: pathlib.Path(path) return True diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 800cbf8e3..de9cb520b 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -25,8 +25,21 @@ import os import sys import threading +from collections.abc import Callable +from typing import ( + Any, + Generic, + Literal, + NamedTuple, + Optional, + Union, + cast, + overload, +) +from typing_extensions import Self, TypeVar from urllib import parse +import google.auth.credentials import requests import firebase_admin @@ -35,6 +48,19 @@ from firebase_admin import _sseclient from firebase_admin import _utils +__all__ = ( + 'EmulatorConfig', + 'Event', + 'ListenerRegistration', + 'Query', + 'Reference', + 'TransactionAbortedError', + 'reference', +) + +_K = TypeVar('_K', default=Any) +_V = TypeVar('_V', default=Any) +_JsonT = TypeVar('_JsonT', bound='_Json', default='_Json') _DB_ATTRIBUTE = '_database' _INVALID_PATH_CHARACTERS = '[].?#$' @@ -45,9 +71,19 @@ ) _TRANSACTION_MAX_RETRIES = 25 _EMULATOR_HOST_ENV_VAR = 'FIREBASE_DATABASE_EMULATOR_HOST' - - -def reference(path='/', app=None, url=None): +_Json = Optional[Union[ + dict[str, '_Json'], + list['_Json'], + str, + float, +]] + + +def reference( + path: str = '/', + app: Optional[firebase_admin.App] = None, + url: Optional[str] = None, +) -> 'Reference': """Returns a database ``Reference`` representing the node at the specified path. If no path is specified, this function returns a ``Reference`` that represents the database @@ -71,7 +107,8 @@ def reference(path='/', app=None, url=None): client = service.get_client(url) return Reference(client=client, path=path) -def _parse_path(path): + +def _parse_path(path: Any) -> list[str]: """Parses a path string into a set of segments.""" if not isinstance(path, str): raise ValueError(f'Invalid path: "{path}". Path must be a string.') @@ -83,7 +120,7 @@ def _parse_path(path): class Event: """Represents a realtime update event received from the database.""" - def __init__(self, sse_event): + def __init__(self, sse_event: _sseclient.Event) -> None: self._sse_event = sse_event self._data = json.loads(sse_event.data) @@ -98,7 +135,7 @@ def path(self): return self._data['path'] @property - def event_type(self): + def event_type(self) -> str: """Event type string (put, patch).""" return self._sse_event.event_type @@ -106,7 +143,11 @@ def event_type(self): class ListenerRegistration: """Represents the addition of an event listener to a database reference.""" - def __init__(self, callback, sse): + def __init__( + self, + callback: Callable[[Event], None], + sse: _sseclient.SSEClient, + ) -> None: """Initializes a new listener with given parameters. This is an internal API. Use the ``db.Reference.listen()`` method to start a @@ -121,14 +162,14 @@ def __init__(self, callback, sse): self._thread = threading.Thread(target=self._start_listen) self._thread.start() - def _start_listen(self): + def _start_listen(self) -> None: # iterate the sse client's generator for sse_event in self._sse: # only inject data events if sse_event: self._callback(Event(sse_event)) - def close(self): + def close(self) -> None: """Stops the event listener represented by this registration This closes the SSE HTTP connection, and joins the background thread. @@ -140,36 +181,59 @@ def close(self): class Reference: """Reference represents a node in the Firebase realtime database.""" - def __init__(self, **kwargs): + @overload + def __init__( + self, + *, + segments: list[str], + client: Optional['_Client'] = None, + **kwargs: Any, + ) -> None: ... + @overload + def __init__( + self, + *, + path: str, + client: Optional['_Client'] = None, + **kwargs: Any, + ) -> None: ... + def __init__( + self, + *, + path: Optional[str] = None, + segments: Optional[list[str]] = None, + client: Optional['_Client'] = None, + **kwargs: Any, + ) -> None: """Creates a new Reference using the provided parameters. This method is for internal use only. Use db.reference() to obtain an instance of Reference. """ - self._client = kwargs.get('client') - if 'segments' in kwargs: - self._segments = kwargs.get('segments') + self._client = client + if segments is not None: + self._segments = segments else: - self._segments = _parse_path(kwargs.get('path')) + self._segments = _parse_path(path) self._pathurl = '/' + '/'.join(self._segments) @property - def key(self): + def key(self) -> Optional[str]: if self._segments: return self._segments[-1] return None @property - def path(self): + def path(self) -> str: return self._pathurl @property - def parent(self): + def parent(self) -> Optional['Reference']: if self._segments: return Reference(client=self._client, segments=self._segments[:-1]) return None - def child(self, path): + def child(self, path: Optional[str]) -> 'Reference': """Returns a Reference to the specified child node. The path may point to an immediate child of the current Reference, or a deeply nested @@ -191,7 +255,23 @@ def child(self, path): full_path = self._pathurl + '/' + path return Reference(client=self._client, path=full_path) - def get(self, etag=False, shallow=False): + @overload + def get( # pyright: ignore[reportOverlappingOverload] + self, + etag: Literal[True], + shallow: bool = False, + ) -> tuple[_Json, str]: ... + @overload + def get( + self, + etag: bool = False, + shallow: bool = False, + ) -> _Json: ... + def get( + self, + etag: bool = False, + shallow: bool = False, + ) -> Union[tuple[_Json, str], _Json]: """Returns the value, and optionally the ETag, at the current location of the database. Args: @@ -214,12 +294,12 @@ def get(self, etag=False, shallow=False): raise ValueError('etag and shallow cannot both be set to True.') headers, data = self._client.headers_and_body( 'get', self._add_suffix(), headers={'X-Firebase-ETag' : 'true'}) - return data, headers.get('ETag') + return data, cast(str, headers.get('ETag')) params = 'shallow=true' if shallow else None return self._client.body('get', self._add_suffix(), params=params) - def get_if_changed(self, etag): + def get_if_changed(self, etag: str) -> tuple[bool, Optional[Any], Optional[str]]: """Gets data in this location only if the specified ETag does not match. Args: @@ -245,7 +325,7 @@ def get_if_changed(self, etag): return True, resp.json(), resp.headers.get('ETag') - def set(self, value): + def set(self, value: _Json) -> None: """Sets the data at this location to the given value. The value must be JSON-serializable and not None. @@ -262,7 +342,11 @@ def set(self, value): raise ValueError('Value must not be None.') self._client.request('put', self._add_suffix(), json=value, params='print=silent') - def set_if_unchanged(self, expected_etag, value): + def set_if_unchanged( + self, + expected_etag: str, + value: _JsonT + ) -> tuple[bool, _JsonT, str]: """Conditonally sets the data at this location to the given value. Sets the data at this location to the given value only if ``expected_etag`` is same as the @@ -290,7 +374,7 @@ def set_if_unchanged(self, expected_etag, value): try: headers = self._client.headers( 'put', self._add_suffix(), json=value, headers={'if-match': expected_etag}) - return True, value, headers.get('ETag') + return True, value, cast(str, headers.get('ETag')) except exceptions.FailedPreconditionError as error: http_response = error.http_response if http_response is not None and 'ETag' in http_response.headers: @@ -300,7 +384,7 @@ def set_if_unchanged(self, expected_etag, value): raise error - def push(self, value=''): + def push(self, value: _Json = '') -> 'Reference': """Creates a new child node. The optional value argument can be used to provide an initial value for the child node. If @@ -320,10 +404,10 @@ def push(self, value=''): if value is None: raise ValueError('Value must not be None.') output = self._client.body('post', self._add_suffix(), json=value) - push_id = output.get('name') + push_id = cast(Optional[str], output.get('name')) return self.child(push_id) - def update(self, value): + def update(self, value: _Json) -> None: """Updates the specified child keys of this Reference to the provided values. Args: @@ -339,7 +423,7 @@ def update(self, value): raise ValueError('Dictionary must not contain None keys.') self._client.request('patch', self._add_suffix(), json=value, params='print=silent') - def delete(self): + def delete(self) -> None: """Deletes this node from the database. Raises: @@ -347,7 +431,7 @@ def delete(self): """ self._client.request('delete', self._add_suffix()) - def listen(self, callback): + def listen(self, callback: Callable[[Event], None]) -> ListenerRegistration: """Registers the ``callback`` function to receive realtime updates. The specified callback function will get invoked with ``db.Event`` objects for each @@ -373,7 +457,7 @@ def listen(self, callback): """ return self._listen_with_session(callback) - def transaction(self, transaction_update): + def transaction(self, transaction_update: Callable[[_Json], _Json]) -> _Json: """Atomically modifies the data at this location. Unlike a normal ``set()``, which just overwrites the data regardless of its previous state, @@ -416,7 +500,7 @@ def transaction(self, transaction_update): raise TransactionAbortedError('Transaction aborted after failed retries.') - def order_by_child(self, path): + def order_by_child(self, path: str) -> 'Query': """Returns a Query that orders data by child values. Returned Query can be used to set additional parameters, and execute complex database @@ -435,7 +519,7 @@ def order_by_child(self, path): raise ValueError(f'Illegal child path: {path}') return Query(order_by=path, client=self._client, pathurl=self._add_suffix()) - def order_by_key(self): + def order_by_key(self) -> 'Query': """Creates a Query that orderes data by key. Returned Query can be used to set additional parameters, and execute complex database @@ -446,7 +530,7 @@ def order_by_key(self): """ return Query(order_by='$key', client=self._client, pathurl=self._add_suffix()) - def order_by_value(self): + def order_by_value(self) -> 'Query': """Creates a Query that orderes data by value. Returned Query can be used to set additional parameters, and execute complex database @@ -457,16 +541,20 @@ def order_by_value(self): """ return Query(order_by='$value', client=self._client, pathurl=self._add_suffix()) - def _add_suffix(self, suffix='.json'): + def _add_suffix(self, suffix: str = '.json') -> str: return self._pathurl + suffix - def _listen_with_session(self, callback, session=None): + def _listen_with_session( + self, + callback: Callable[[Event], None], + session: Optional[requests.Session] = None, + ) -> ListenerRegistration: url = self._client.base_url + self._add_suffix() if not session: session = self._client.create_listener_session() try: - sse = _sseclient.SSEClient(url, session, **{"params": self._client.params}) + sse = _sseclient.SSEClient(url, session, params=self._client.params) return ListenerRegistration(callback, sse) except requests.exceptions.RequestException as error: raise _Client.handle_rtdb_error(error) @@ -485,8 +573,7 @@ class Query: OrderedDict. """ - def __init__(self, **kwargs): - order_by = kwargs.pop('order_by') + def __init__(self, *, client: '_Client', order_by: str, pathurl: str, **kwargs: Any) -> None: if not order_by or not isinstance(order_by, str): raise ValueError('order_by field must be a non-empty string') if order_by not in _RESERVED_FILTERS: @@ -495,14 +582,14 @@ def __init__(self, **kwargs): f'Invalid path argument: "{order_by}". Child path must not start with "/"') segments = _parse_path(order_by) order_by = '/'.join(segments) - self._client = kwargs.pop('client') - self._pathurl = kwargs.pop('pathurl') + self._client = client + self._pathurl = pathurl self._order_by = order_by - self._params = {'orderBy' : json.dumps(order_by)} + self._params: dict[str, Any] = {'orderBy' : json.dumps(order_by)} if kwargs: raise ValueError(f'Unexpected keyword arguments: {kwargs}') - def limit_to_first(self, limit): + def limit_to_first(self, limit: int) -> Self: """Creates a query with limit, and anchors it to the start of the window. Args: @@ -521,7 +608,7 @@ def limit_to_first(self, limit): self._params['limitToFirst'] = limit return self - def limit_to_last(self, limit): + def limit_to_last(self, limit: int) -> Self: """Creates a query with limit, and anchors it to the end of the window. Args: @@ -540,7 +627,7 @@ def limit_to_last(self, limit): self._params['limitToLast'] = limit return self - def start_at(self, start): + def start_at(self, start: _Json) -> Self: """Sets the lower bound for a range query. The Query will only return child nodes with a value greater than or equal to the specified @@ -560,7 +647,7 @@ def start_at(self, start): self._params['startAt'] = json.dumps(start) return self - def end_at(self, end): + def end_at(self, end: _Json) -> Self: """Sets the upper bound for a range query. The Query will only return child nodes with a value less than or equal to the specified @@ -580,7 +667,7 @@ def end_at(self, end): self._params['endAt'] = json.dumps(end) return self - def equal_to(self, value): + def equal_to(self, value: _Json) -> Self: """Sets an equals constraint on the Query. The Query will only return child nodes whose value is equal to the specified value. @@ -600,13 +687,13 @@ def equal_to(self, value): return self @property - def _querystr(self): - params = [] + def _querystr(self) -> str: + params: list[str] = [] for key in sorted(self._params): params.append(f'{key}={self._params[key]}') return '&'.join(params) - def get(self): + def get(self) -> Union[dict[str, _Json], list[_Json]]: """Executes this Query and returns the results. The results will be returned as a sorted list or an OrderedDict. @@ -626,32 +713,40 @@ def get(self): class TransactionAbortedError(exceptions.AbortedError): """A transaction was aborted aftr exceeding the maximum number of retries.""" - def __init__(self, message): - exceptions.AbortedError.__init__(self, message) + def __init__(self, message: str) -> None: + super().__init__(message) -class _Sorter: +class _Sorter(Generic[_K, _V]): """Helper class for sorting query results.""" - def __init__(self, results, order_by): + @overload + def __init__(self, results: dict[_K, _V], order_by: str) -> None: ... + @overload + def __init__( + self: '_Sorter[int, _V]', # pyright: ignore[reportInvalidTypeVarUse] + results: list[_V], + order_by: str, + ) -> None: ... + def __init__(self, results: Union[dict[_K, _V], list[_V]], order_by: str) -> None: if isinstance(results, dict): self.dict_input = True entries = [_SortEntry(k, v, order_by) for k, v in results.items()] elif isinstance(results, list): self.dict_input = False - entries = [_SortEntry(k, v, order_by) for k, v in enumerate(results)] + entries = [_SortEntry(cast(_K, k), v, order_by) for k, v in enumerate(results)] else: raise ValueError(f'Sorting not supported for "{type(results)}" object.') self.sort_entries = sorted(entries) - def get(self): + def get(self) -> Union[collections.OrderedDict[_K, _V], list[_V]]: if self.dict_input: return collections.OrderedDict([(e.key, e.value) for e in self.sort_entries]) return [e.value for e in self.sort_entries] -class _SortEntry: +class _SortEntry(Generic[_K, _V]): """A wrapper that is capable of sorting items in a dictionary.""" _type_none = 0 @@ -661,7 +756,7 @@ class _SortEntry: _type_string = 4 _type_object = 5 - def __init__(self, key, value, order_by): + def __init__(self, key: _K, value: _V, order_by: str) -> None: self._key = key self._value = value if order_by in ('$key', '$priority'): @@ -673,23 +768,23 @@ def __init__(self, key, value, order_by): self._index_type = _SortEntry._get_index_type(self._index) @property - def key(self): + def key(self) -> _K: return self._key @property - def index(self): + def index(self) -> Optional[Any]: return self._index @property - def index_type(self): + def index_type(self) -> int: return self._index_type @property - def value(self): + def value(self) -> _V: return self._value @classmethod - def _get_index_type(cls, index): + def _get_index_type(cls, index: Any) -> int: """Assigns an integer code to the type of the index. The index type determines how differently typed values are sorted. This ordering is based @@ -709,17 +804,18 @@ def _get_index_type(cls, index): return cls._type_object @classmethod - def _extract_child(cls, value, path): + def _extract_child(cls, value: Any, path: str) -> Optional[Any]: segments = path.split('/') current = value for segment in segments: if isinstance(current, dict): + current = cast(dict[str, Any], current) current = current.get(segment) else: return None return current - def _compare(self, other): + def _compare(self, other: '_SortEntry') -> Literal[-1, 0, 1]: """Compares two _SortEntry instances. If the indices have the same numeric or string type, compare them directly. Ties are @@ -734,39 +830,44 @@ def _compare(self, other): else: self_key, other_key = self.key, other.key - if self_key < other_key: + if self_key < other_key: # pyright: ignore[reportOperatorIssue] return -1 - if self_key > other_key: + if self_key > other_key: # pyright: ignore[reportOperatorIssue] return 1 return 0 - def __lt__(self, other): + def __lt__(self, other: '_SortEntry') -> bool: return self._compare(other) < 0 - def __le__(self, other): + def __le__(self, other: '_SortEntry') -> bool: return self._compare(other) <= 0 - def __gt__(self, other): + def __gt__(self, other: '_SortEntry') -> bool: return self._compare(other) > 0 - def __ge__(self, other): + def __ge__(self, other: '_SortEntry') -> bool: return self._compare(other) >= 0 - def __eq__(self, other): + def __eq__(self, other: '_SortEntry') -> bool: # pyright: ignore[reportIncompatibleMethodOverride] return self._compare(other) == 0 +class EmulatorConfig(NamedTuple): + base_url: str + namespace: str + + class _DatabaseService: """Service that maintains a collection of database clients.""" _DEFAULT_AUTH_OVERRIDE = '_admin_' - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: self._credential = app.credential db_url = app.options.get('databaseURL') if db_url: - self._db_url = db_url + self._db_url: Optional[str] = db_url else: self._db_url = None @@ -776,7 +877,7 @@ def __init__(self, app): else: self._auth_override = None self._timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) - self._clients = {} + self._clients: dict[tuple[str, str], _Client] = {} emulator_host = os.environ.get(_EMULATOR_HOST_ENV_VAR) if emulator_host: @@ -788,7 +889,7 @@ def __init__(self, app): else: self._emulator_host = None - def get_client(self, db_url=None): + def get_client(self, db_url: Optional[str] = None) -> '_Client': """Creates a client based on the db_url. Clients may be cached.""" if db_url is None: db_url = self._db_url @@ -813,7 +914,6 @@ def get_client(self, db_url=None): base_url = f'https://{parsed_url.netloc}' params = {} - if self._auth_override: params['auth_variable_override'] = self._auth_override @@ -823,9 +923,8 @@ def get_client(self, db_url=None): self._clients[client_cache_key] = client return self._clients[client_cache_key] - def _get_emulator_config(self, parsed_url): + def _get_emulator_config(self, parsed_url: parse.ParseResult) -> Optional[EmulatorConfig]: """Checks whether the SDK should connect to the RTDB emulator.""" - EmulatorConfig = collections.namedtuple('EmulatorConfig', ['base_url', 'namespace']) if parsed_url.scheme != 'https': # Emulator mode enabled by passing http URL via AppOptions base_url, namespace = _DatabaseService._parse_emulator_url(parsed_url) @@ -839,7 +938,7 @@ def _get_emulator_config(self, parsed_url): return None @classmethod - def _parse_emulator_url(cls, parsed_url): + def _parse_emulator_url(cls, parsed_url: parse.ParseResult) -> tuple[str, str]: """Parses emulator URL like http://localhost:8080/?ns=foo-bar""" query_ns = parse.parse_qs(parsed_url.query).get('ns') if parsed_url.scheme != 'http' or (not query_ns or len(query_ns) != 1 or not query_ns[0]): @@ -852,9 +951,10 @@ def _parse_emulator_url(cls, parsed_url): return base_url, namespace @classmethod - def _get_auth_override(cls, app): + def _get_auth_override(cls, app: firebase_admin.App) -> Optional[Union[dict[str, Any], str]]: """Gets and validates the database auth override to be used.""" - auth_override = app.options.get('databaseAuthVariableOverride', cls._DEFAULT_AUTH_OVERRIDE) + auth_override = cast(Optional[str], app.options.get( + 'databaseAuthVariableOverride', cls._DEFAULT_AUTH_OVERRIDE)) if auth_override == cls._DEFAULT_AUTH_OVERRIDE or auth_override is None: return auth_override if not isinstance(auth_override, dict): @@ -864,7 +964,7 @@ def _get_auth_override(cls, app): return auth_override - def close(self): + def close(self) -> None: for value in self._clients.values(): value.close() self._clients = {} @@ -877,7 +977,13 @@ class _Client(_http_client.JsonHttpClient): marshalling and unmarshalling of JSON data. """ - def __init__(self, credential, base_url, timeout, params=None): + def __init__( + self, + credential: Optional[google.auth.credentials.Credentials], + base_url: str, + timeout: int, + params: Optional[dict[str, Any]] = None, + ) -> None: """Creates a new _Client from the given parameters. This exists primarily to enable testing. For regular use, obtain _Client instances by @@ -897,7 +1003,7 @@ def __init__(self, credential, base_url, timeout, params=None): self.credential = credential self.params = params if params else {} - def request(self, method, url, **kwargs): + def request(self, method: str, url: str, **kwargs: Any) -> requests.Response: """Makes an HTTP call using the Python requests library. Extends the request() method of the parent JsonHttpClient class. Handles default @@ -929,11 +1035,11 @@ def request(self, method, url, **kwargs): except requests.exceptions.RequestException as error: raise _Client.handle_rtdb_error(error) - def create_listener_session(self): + def create_listener_session(self) -> _sseclient.KeepAuthSession: return _sseclient.KeepAuthSession(self.credential) @classmethod - def handle_rtdb_error(cls, error): + def handle_rtdb_error(cls, error: requests.RequestException) -> exceptions.FirebaseError: """Converts an error encountered while calling RTDB into a FirebaseError.""" if error.response is None: return _utils.handle_requests_error(error) @@ -942,7 +1048,7 @@ def handle_rtdb_error(cls, error): return _utils.handle_requests_error(error, message=message) @classmethod - def _extract_error_message(cls, response): + def _extract_error_message(cls, response: requests.Response) -> str: """Extracts an error message from an error response. If the server has sent a JSON response with an 'error' field, which is the typical @@ -953,7 +1059,7 @@ def _extract_error_message(cls, response): message = None try: # RTDB error format: {"error": "text message"} - data = response.json() + data: dict[str, str] = response.json() if isinstance(data, dict): message = data.get('error') except ValueError: diff --git a/firebase_admin/exceptions.py b/firebase_admin/exceptions.py index 947f36806..00143a117 100644 --- a/firebase_admin/exceptions.py +++ b/firebase_admin/exceptions.py @@ -31,6 +31,46 @@ subtype error handlers. """ +from typing import Optional, Union + +import httpx +import requests + +__all__ = ( + 'ABORTED', + 'ALREADY_EXISTS', + 'CANCELLED', + 'CONFLICT', + 'DATA_LOSS', + 'DEADLINE_EXCEEDED', + 'FAILED_PRECONDITION', + 'INTERNAL', + 'INVALID_ARGUMENT', + 'NOT_FOUND', + 'OUT_OF_RANGE', + 'PERMISSION_DENIED', + 'RESOURCE_EXHAUSTED', + 'UNAUTHENTICATED', + 'UNAVAILABLE', + 'UNKNOWN', + 'AbortedError', + 'AlreadyExistsError', + 'CancelledError', + 'ConflictError', + 'DataLossError', + 'DeadlineExceededError', + 'FailedPreconditionError', + 'FirebaseError', + 'InternalError', + 'InvalidArgumentError', + 'NotFoundError', + 'OutOfRangeError', + 'PermissionDeniedError', + 'ResourceExhaustedError', + 'UnauthenticatedError', + 'UnavailableError', + 'UnknownError', +) #: Error code for ``InvalidArgumentError`` type. INVALID_ARGUMENT = 'INVALID_ARGUMENT' @@ -95,52 +135,78 @@ class FirebaseError(Exception): this object. """ - def __init__(self, code, message, cause=None, http_response=None): - Exception.__init__(self, message) + def __init__( + self, + code: str, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(message) self._code = code self._cause = cause self._http_response = http_response @property - def code(self): + def code(self) -> str: return self._code @property - def cause(self): + def cause(self) -> Optional[Exception]: return self._cause @property - def http_response(self): + def http_response(self) -> Optional[Union[httpx.Response, requests.Response]]: return self._http_response class InvalidArgumentError(FirebaseError): """Client specified an invalid argument.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, INVALID_ARGUMENT, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(INVALID_ARGUMENT, message, cause, http_response) class FailedPreconditionError(FirebaseError): """Request can not be executed in the current system state, such as deleting a non-empty directory.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, FAILED_PRECONDITION, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(FAILED_PRECONDITION, message, cause, http_response) class OutOfRangeError(FirebaseError): """Client specified an invalid range.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, OUT_OF_RANGE, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(OUT_OF_RANGE, message, cause, http_response) class UnauthenticatedError(FirebaseError): """Request not authenticated due to missing, invalid, or expired OAuth token.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, UNAUTHENTICATED, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(UNAUTHENTICATED, message, cause, http_response) class PermissionDeniedError(FirebaseError): @@ -150,79 +216,134 @@ class PermissionDeniedError(FirebaseError): have permission, or the API has not been enabled for the client project. """ - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, PERMISSION_DENIED, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(PERMISSION_DENIED, message, cause, http_response) class NotFoundError(FirebaseError): """A specified resource is not found, or the request is rejected by undisclosed reasons, such as whitelisting.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, NOT_FOUND, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(NOT_FOUND, message, cause, http_response) class ConflictError(FirebaseError): """Concurrency conflict, such as read-modify-write conflict.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, CONFLICT, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(CONFLICT, message, cause, http_response) class AbortedError(FirebaseError): """Concurrency conflict, such as read-modify-write conflict.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, ABORTED, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(ABORTED, message, cause, http_response) class AlreadyExistsError(FirebaseError): """The resource that a client tried to create already exists.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, ALREADY_EXISTS, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(ALREADY_EXISTS, message, cause, http_response) class ResourceExhaustedError(FirebaseError): """Either out of resource quota or reaching rate limiting.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, RESOURCE_EXHAUSTED, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(RESOURCE_EXHAUSTED, message, cause, http_response) class CancelledError(FirebaseError): """Request cancelled by the client.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, CANCELLED, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(CANCELLED, message, cause, http_response) class DataLossError(FirebaseError): """Unrecoverable data loss or data corruption.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, DATA_LOSS, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(DATA_LOSS, message, cause, http_response) class UnknownError(FirebaseError): """Unknown server error.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, UNKNOWN, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(UNKNOWN, message, cause, http_response) class InternalError(FirebaseError): """Internal server error.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, INTERNAL, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(INTERNAL, message, cause, http_response) class UnavailableError(FirebaseError): """Service unavailable. Typically the server is down.""" - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, UNAVAILABLE, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(UNAVAILABLE, message, cause, http_response) class DeadlineExceededError(FirebaseError): @@ -233,5 +354,10 @@ class DeadlineExceededError(FirebaseError): request) and the request did not finish within the deadline. """ - def __init__(self, message, cause=None, http_response=None): - FirebaseError.__init__(self, DEADLINE_EXCEEDED, message, cause, http_response) + def __init__( + self, + message: str, + cause: Optional[Exception] = None, + http_response: Optional[Union[httpx.Response, requests.Response]] = None, + ) -> None: + super().__init__(DEADLINE_EXCEEDED, message, cause, http_response) diff --git a/firebase_admin/firestore.py b/firebase_admin/firestore.py index 52ea90671..496afe237 100644 --- a/firebase_admin/firestore.py +++ b/firebase_admin/firestore.py @@ -18,27 +18,28 @@ Firebase apps. This requires the ``google-cloud-firestore`` Python module. """ -from __future__ import annotations -from typing import Optional, Dict -from firebase_admin import App +from typing import Optional + +import firebase_admin from firebase_admin import _utils try: - from google.cloud import firestore + import google.cloud.firestore + # firestore defines __all__ for safe import * + from google.cloud.firestore import * # pyright: ignore[reportWildcardImportFromLibrary] from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE - existing = globals().keys() - for key, value in firestore.__dict__.items(): - if not key.startswith('_') and key not in existing: - globals()[key] = value except ImportError as error: raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure ' 'to install the "google-cloud-firestore" module.') from error +__all__ = ['client'] +__all__.extend(google.cloud.firestore.__all__) # pyright: ignore[reportUnsupportedDunderAll] + _FIRESTORE_ATTRIBUTE = '_firestore' -def client(app: Optional[App] = None, database_id: Optional[str] = None) -> firestore.Client: +def client(app: Optional[firebase_admin.App] = None, database_id: Optional[str] = None) -> Client: """Returns a client that can be used to interact with Google Cloud Firestore. Args: @@ -68,11 +69,11 @@ def client(app: Optional[App] = None, database_id: Optional[str] = None) -> fire class _FirestoreService: """Service that maintains a collection of firestore clients.""" - def __init__(self, app: App) -> None: - self._app: App = app - self._clients: Dict[str, firestore.Client] = {} + def __init__(self, app: firebase_admin.App) -> None: + self._app = app + self._clients: dict[str, Client] = {} - def get_client(self, database_id: Optional[str]) -> firestore.Client: + def get_client(self, database_id: Optional[str]) -> Client: """Creates a client based on the database_id. These clients are cached.""" database_id = database_id or DEFAULT_DATABASE if database_id not in self._clients: @@ -85,7 +86,7 @@ def get_client(self, database_id: Optional[str]) -> firestore.Client: 'or use service account credentials. Alternatively, set the ' 'GOOGLE_CLOUD_PROJECT environment variable.') - fs_client = firestore.Client( + fs_client = Client( credentials=credentials, project=project, database=database_id) self._clients[database_id] = fs_client diff --git a/firebase_admin/firestore_async.py b/firebase_admin/firestore_async.py index 4a197e9df..71694adb9 100644 --- a/firebase_admin/firestore_async.py +++ b/firebase_admin/firestore_async.py @@ -18,27 +18,31 @@ associated with Firebase apps. This requires the ``google-cloud-firestore`` Python module. """ -from __future__ import annotations -from typing import Optional, Dict -from firebase_admin import App +from typing import Optional + +import firebase_admin from firebase_admin import _utils try: - from google.cloud import firestore + import google.cloud.firestore + # firestore defines __all__ for safe import * + from google.cloud.firestore import * # pyright: ignore[reportWildcardImportFromLibrary] from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE - existing = globals().keys() - for key, value in firestore.__dict__.items(): - if not key.startswith('_') and key not in existing: - globals()[key] = value except ImportError as error: raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure ' 'to install the "google-cloud-firestore" module.') from error +__all__ = ['client'] +__all__.extend(google.cloud.firestore.__all__) # pyright: ignore[reportUnsupportedDunderAll] + -_FIRESTORE_ASYNC_ATTRIBUTE: str = '_firestore_async' +_FIRESTORE_ASYNC_ATTRIBUTE = '_firestore_async' -def client(app: Optional[App] = None, database_id: Optional[str] = None) -> firestore.AsyncClient: +def client( + app: Optional[firebase_admin.App] = None, + database_id: Optional[str] = None, +) -> AsyncClient: """Returns an async client that can be used to interact with Google Cloud Firestore. Args: @@ -68,11 +72,11 @@ def client(app: Optional[App] = None, database_id: Optional[str] = None) -> fire class _FirestoreAsyncService: """Service that maintains a collection of firestore async clients.""" - def __init__(self, app: App) -> None: - self._app: App = app - self._clients: Dict[str, firestore.AsyncClient] = {} + def __init__(self, app: firebase_admin.App) -> None: + self._app = app + self._clients: dict[str, AsyncClient] = {} - def get_client(self, database_id: Optional[str]) -> firestore.AsyncClient: + def get_client(self, database_id: Optional[str]) -> AsyncClient: """Creates an async client based on the database_id. These clients are cached.""" database_id = database_id or DEFAULT_DATABASE if database_id not in self._clients: @@ -85,7 +89,7 @@ def get_client(self, database_id: Optional[str]) -> firestore.AsyncClient: 'or use service account credentials. Alternatively, set the ' 'GOOGLE_CLOUD_PROJECT environment variable.') - fs_client = firestore.AsyncClient( + fs_client = AsyncClient( credentials=credentials, project=project, database=database_id) self._clients[database_id] = fs_client diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py index 6db0fbb42..473d6985e 100644 --- a/firebase_admin/functions.py +++ b/firebase_admin/functions.py @@ -14,27 +14,32 @@ """Firebase Functions module.""" -from __future__ import annotations -from datetime import datetime, timedelta, timezone -from urllib import parse -import re +import base64 +import dataclasses +import datetime import json -from base64 import b64encode -from typing import Any, Optional, Dict -from dataclasses import dataclass -from google.auth.compute_engine import Credentials as ComputeEngineCredentials +import re +from urllib import parse +from typing import Any, Optional, cast +from typing_extensions import TypeGuard + import requests +from google.auth.credentials import Credentials as GoogleAuthCredentials +from google.auth.compute_engine import Credentials as ComputeEngineCredentials + import firebase_admin -from firebase_admin import App from firebase_admin import _http_client from firebase_admin import _utils +from firebase_admin import exceptions _FUNCTIONS_ATTRIBUTE = '_functions' __all__ = [ + 'Resource', + 'Task', 'TaskOptions', - + 'TaskQueue', 'task_queue', ] @@ -54,14 +59,14 @@ # Default canonical location ID of the task queue. _DEFAULT_LOCATION = 'us-central1' -def _get_functions_service(app) -> _FunctionsService: +def _get_functions_service(app: Optional[firebase_admin.App]) -> '_FunctionsService': return _utils.get_app_service(app, _FUNCTIONS_ATTRIBUTE, _FunctionsService) def task_queue( - function_name: str, - extension_id: Optional[str] = None, - app: Optional[App] = None - ) -> TaskQueue: + function_name: str, + extension_id: Optional[str] = None, + app: Optional[firebase_admin.App] = None, +) -> 'TaskQueue': """Creates a reference to a TaskQueue for a given function name. The function name can be either: @@ -89,9 +94,10 @@ def task_queue( """ return _get_functions_service(app).task_queue(function_name, extension_id) + class _FunctionsService: """Service class that implements Firebase Functions functionality.""" - def __init__(self, app: App): + def __init__(self, app: firebase_admin.App) -> None: self._project_id = app.project_id if not self._project_id: raise ValueError( @@ -102,28 +108,27 @@ def __init__(self, app: App): self._credential = app.credential.get_credential() self._http_client = _http_client.JsonHttpClient(credential=self._credential) - def task_queue(self, function_name: str, extension_id: Optional[str] = None) -> TaskQueue: + def task_queue(self, function_name: str, extension_id: Optional[str] = None) -> 'TaskQueue': """Creates a TaskQueue instance.""" return TaskQueue( function_name, extension_id, self._project_id, self._credential, self._http_client) @classmethod - def handle_functions_error(cls, error: Any): + def handle_functions_error(cls, error: requests.RequestException) -> exceptions.FirebaseError: """Handles errors received from the Cloud Functions API.""" - return _utils.handle_platform_error_from_requests(error) + class TaskQueue: """TaskQueue class that implements Firebase Cloud Tasks Queues functionality.""" def __init__( - self, - function_name: str, - extension_id: Optional[str], - project_id, - credential, - http_client - ) -> None: - + self, + function_name: str, + extension_id: Optional[str], + project_id: Optional[str], + credential: GoogleAuthCredentials, + http_client: _http_client.HttpClient[dict[str, Any]], + ) -> None: # Validate function_name _Validators.check_non_empty_string('function_name', function_name) @@ -144,8 +149,7 @@ def __init__( _Validators.check_non_empty_string('extension_id', self._extension_id) self._resource.resource_id = f'ext-{self._extension_id}-{self._resource.resource_id}' - - def enqueue(self, task_data: Any, opts: Optional[TaskOptions] = None) -> str: + def enqueue(self, task_data: Any, opts: Optional['TaskOptions'] = None) -> str: """Creates a task and adds it to the queue. Tasks cannot be updated after creation. This action requires `cloudtasks.tasks.create` IAM permission on the service account. @@ -172,7 +176,7 @@ def enqueue(self, task_data: Any, opts: Optional[TaskOptions] = None) -> str: headers=_FUNCTIONS_HEADERS, json={'task': task_payload.__dict__} ) - task_name = resp.get('name', None) + task_name = cast(str, resp['name']) task_resource = \ self._parse_resource_name(task_name, f'queues/{self._resource.resource_id}/tasks') return task_resource.resource_id @@ -203,8 +207,7 @@ def delete(self, task_id: str) -> None: except requests.exceptions.RequestException as error: raise _FunctionsService.handle_functions_error(error) - - def _parse_resource_name(self, resource_name: str, resource_id_key: str) -> Resource: + def _parse_resource_name(self, resource_name: str, resource_id_key: str) -> 'Resource': """Parses a full or partial resource path into a ``Resource``.""" if '/' not in resource_name: return Resource(resource_id=resource_name) @@ -215,7 +218,7 @@ def _parse_resource_name(self, resource_name: str, resource_id_key: str) -> Reso raise ValueError('Invalid resource name format.') return Resource(project_id=match[2], location_id=match[3], resource_id=match[4]) - def _get_url(self, resource: Resource, url_format: str) -> str: + def _get_url(self, resource: 'Resource', url_format: str) -> str: """Generates url path from a ``Resource`` and url format string.""" return url_format.format( project_id=resource.project_id, @@ -223,18 +226,18 @@ def _get_url(self, resource: Resource, url_format: str) -> str: resource_id=resource.resource_id) def _validate_task_options( - self, - data: Any, - resource: Resource, - opts: Optional[TaskOptions] = None - ) -> Task: + self, + data: dict[str, Any], + resource: 'Resource', + opts: Optional['TaskOptions'] = None, + ) -> 'Task': """Validate and create a Task from optional ``TaskOptions``.""" task_http_request = { 'url': '', 'oidc_token': { 'service_account_email': '' }, - 'body': b64encode(json.dumps(data).encode()).decode(), + 'body': base64.b64encode(json.dumps(data).encode()).decode(), 'headers': { 'Content-Type': 'application/json', } @@ -248,15 +251,15 @@ def _validate_task_options( raise ValueError( 'Both sechdule_delay_seconds and schedule_time cannot be set at the same time.') if opts.schedule_time is not None and opts.schedule_delay_seconds is None: - if not isinstance(opts.schedule_time, datetime): + if not isinstance(opts.schedule_time, datetime.datetime): raise ValueError('schedule_time should be UTC datetime.') task.schedule_time = opts.schedule_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') if opts.schedule_delay_seconds is not None and opts.schedule_time is None: if not isinstance(opts.schedule_delay_seconds, int) \ or opts.schedule_delay_seconds < 0: raise ValueError('schedule_delay_seconds should be positive int.') - schedule_time = ( - datetime.now(timezone.utc) + timedelta(seconds=opts.schedule_delay_seconds)) + schedule_time = datetime.datetime.now(datetime.timezone.utc) + \ + datetime.timedelta(seconds=opts.schedule_delay_seconds) task.schedule_time = schedule_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') if opts.dispatch_deadline_seconds is not None: if not isinstance(opts.dispatch_deadline_seconds, int) \ @@ -280,7 +283,12 @@ def _validate_task_options( task.http_request['url'] = opts.uri return task - def _update_task_payload(self, task: Task, resource: Resource, extension_id: str) -> Task: + def _update_task_payload( + self, + task: 'Task', + resource: 'Resource', + extension_id: Optional[str], + ) -> 'Task': """Prepares task to be sent with credentials.""" # Get function url from task or generate from resources if not _Validators.is_non_empty_string(task.http_request['url']): @@ -290,21 +298,22 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str if _Validators.is_non_empty_string(extension_id) and \ isinstance(self._credential, ComputeEngineCredentials): - id_token = self._credential.token + id_token = cast(str, self._credential.token) task.http_request['headers'] = \ {**task.http_request['headers'], 'Authorization': f'Bearer ${id_token}'} # Delete oidc token del task.http_request['oidc_token'] else: + # possible issue: _credential needs more specific annotation task.http_request['oidc_token'] = \ - {'service_account_email': self._credential.service_account_email} + {'service_account_email': self._credential.service_account_email} # pyright: ignore[reportAttributeAccessIssue] return task class _Validators: """A collection of data validation utilities.""" @classmethod - def check_non_empty_string(cls, label: str, value: Any): + def check_non_empty_string(cls, label: str, value: Any) -> None: """Checks if given value is a non-empty string and throws error if not.""" if not isinstance(value, str): raise ValueError(f'{label} "{value}" must be a string.') @@ -312,14 +321,14 @@ def check_non_empty_string(cls, label: str, value: Any): raise ValueError(f'{label} "{value}" must be a non-empty string.') @classmethod - def is_non_empty_string(cls, value: Any): + def is_non_empty_string(cls, value: Any) -> TypeGuard[str]: """Checks if given value is a non-empty string and returns bool.""" if not isinstance(value, str) or value == '': return False return True @classmethod - def is_task_id(cls, task_id: Any): + def is_task_id(cls, task_id: str) -> bool: """Checks if given value is a valid task id.""" reg = '^[A-Za-z0-9_-]+$' if re.match(reg, task_id) is not None and len(task_id) <= 500: @@ -327,7 +336,7 @@ def is_task_id(cls, task_id: Any): return False @classmethod - def is_url(cls, url: Any): + def is_url(cls, url: Any) -> TypeGuard[str]: """Checks if given value is a valid url.""" if not isinstance(url, str): return False @@ -340,7 +349,7 @@ def is_url(cls, url: Any): return False -@dataclass +@dataclasses.dataclass class TaskOptions: """Task Options that can be applied to a Task. @@ -399,13 +408,14 @@ class TaskOptions: http URL. """ schedule_delay_seconds: Optional[int] = None - schedule_time: Optional[datetime] = None + schedule_time: Optional[datetime.datetime] = None dispatch_deadline_seconds: Optional[int] = None task_id: Optional[str] = None - headers: Optional[Dict[str, str]] = None + headers: Optional[dict[str, str]] = None uri: Optional[str] = None -@dataclass + +@dataclasses.dataclass class Task: """Contains the relevant fields for enqueueing tasks that trigger Cloud Functions. @@ -419,13 +429,13 @@ class Task: schedule_time: The time when the task is scheduled to be attempted or retried. dispatch_deadline: The deadline for requests sent to the worker. """ - http_request: Dict[str, Optional[str | dict]] + http_request: dict[str, Any] name: Optional[str] = None schedule_time: Optional[str] = None dispatch_deadline: Optional[str] = None -@dataclass +@dataclasses.dataclass class Resource: """Contains the parsed address of a resource. diff --git a/firebase_admin/instance_id.py b/firebase_admin/instance_id.py index 812daf40b..8c57eb1dc 100644 --- a/firebase_admin/instance_id.py +++ b/firebase_admin/instance_id.py @@ -16,22 +16,25 @@ This module enables deleting instance IDs associated with Firebase projects. """ +from typing import Optional import requests +import firebase_admin from firebase_admin import _http_client from firebase_admin import _utils +__all__ = ('delete_instance_id',) _IID_SERVICE_URL = 'https://console.firebase.google.com/v1/' _IID_ATTRIBUTE = '_iid' -def _get_iid_service(app): +def _get_iid_service(app: Optional[firebase_admin.App]) -> '_InstanceIdService': return _utils.get_app_service(app, _IID_ATTRIBUTE, _InstanceIdService) -def delete_instance_id(instance_id, app=None): +def delete_instance_id(instance_id: str, app: Optional[firebase_admin.App] = None) -> None: """Deletes the specified instance ID and the associated data from Firebase. Note that Google Analytics for Firebase uses its own form of Instance ID to @@ -55,7 +58,7 @@ def delete_instance_id(instance_id, app=None): class _InstanceIdService: """Provides methods for interacting with the remote instance ID service.""" - error_codes = { + error_codes: dict[int, str] = { 400: 'Malformed instance ID argument.', 401: 'Request not authorized.', 403: 'Project does not match instance ID or the client does not have ' @@ -67,7 +70,7 @@ class _InstanceIdService: 503: 'Backend servers are over capacity. Try again later.' } - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -78,7 +81,7 @@ def __init__(self, app): self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), base_url=_IID_SERVICE_URL) - def delete_instance_id(self, instance_id): + def delete_instance_id(self, instance_id: str) -> None: if not isinstance(instance_id, str) or not instance_id: raise ValueError('Instance ID must be a non-empty string.') path = f'project/{self._project_id}/instanceId/{instance_id}' @@ -88,7 +91,7 @@ def delete_instance_id(self, instance_id): msg = self._extract_message(instance_id, error) raise _utils.handle_requests_error(error, msg) - def _extract_message(self, instance_id, error): + def _extract_message(self, instance_id: str, error: requests.RequestException) -> Optional[str]: if error.response is None: return None status = error.response.status_code diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 749044436..4a9c4503c 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -14,14 +14,15 @@ """Firebase Cloud Messaging module.""" -from __future__ import annotations -from typing import Any, Callable, Dict, List, Optional, cast import concurrent.futures import json -import asyncio import logging -import requests +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Optional, Union, cast + +import asyncio import httpx +import requests import firebase_admin from firebase_admin import ( @@ -29,13 +30,11 @@ _messaging_encoder, _messaging_utils, _utils, - exceptions, - App + exceptions ) -logger = logging.getLogger(__name__) - -_MESSAGING_ATTRIBUTE = '_messaging' +if TYPE_CHECKING: + import httplib2 __all__ = [ @@ -75,6 +74,10 @@ 'unsubscribe_from_topic', ] +logger = logging.getLogger(__name__) + +_MESSAGING_ATTRIBUTE = '_messaging' + AndroidConfig = _messaging_utils.AndroidConfig AndroidFCMOptions = _messaging_utils.AndroidFCMOptions @@ -101,10 +104,11 @@ UnregisteredError = _messaging_utils.UnregisteredError -def _get_messaging_service(app: Optional[App]) -> _MessagingService: +def _get_messaging_service(app: Optional[firebase_admin.App]) -> '_MessagingService': return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService) -def send(message: Message, dry_run: bool = False, app: Optional[App] = None) -> str: + +def send(message: Message, dry_run: bool = False, app: Optional[firebase_admin.App] = None) -> str: """Sends the given message via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -124,11 +128,12 @@ def send(message: Message, dry_run: bool = False, app: Optional[App] = None) -> """ return _get_messaging_service(app).send(message, dry_run) + def send_each( - messages: List[Message], - dry_run: bool = False, - app: Optional[App] = None - ) -> BatchResponse: + messages: list[Message], + dry_run: bool = False, + app: Optional[firebase_admin.App] = None, +) -> 'BatchResponse': """Sends each message in the given list via Firebase Cloud Messaging. If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -148,11 +153,12 @@ def send_each( """ return _get_messaging_service(app).send_each(messages, dry_run) + async def send_each_async( - messages: List[Message], - dry_run: bool = False, - app: Optional[App] = None - ) -> BatchResponse: + messages: list[Message], + dry_run: bool = False, + app: Optional[firebase_admin.App] = None, +) -> 'BatchResponse': """Sends each message in the given list asynchronously via Firebase Cloud Messaging. If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -172,11 +178,12 @@ async def send_each_async( """ return await _get_messaging_service(app).send_each_async(messages, dry_run) + async def send_each_for_multicast_async( - multicast_message: MulticastMessage, - dry_run: bool = False, - app: Optional[App] = None - ) -> BatchResponse: + multicast_message: MulticastMessage, + dry_run: bool = False, + app: Optional[firebase_admin.App] = None, +) -> 'BatchResponse': """Sends the given mutlicast message to each token asynchronously via Firebase Cloud Messaging (FCM). @@ -208,7 +215,12 @@ async def send_each_for_multicast_async( ) for token in multicast_message.tokens] return await _get_messaging_service(app).send_each_async(messages, dry_run) -def send_each_for_multicast(multicast_message, dry_run=False, app=None): + +def send_each_for_multicast( + multicast_message: MulticastMessage, + dry_run: bool = False, + app: Optional[firebase_admin.App] = None, +) -> 'BatchResponse': """Sends the given mutlicast message to each token via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -239,7 +251,12 @@ def send_each_for_multicast(multicast_message, dry_run=False, app=None): ) for token in multicast_message.tokens] return _get_messaging_service(app).send_each(messages, dry_run) -def subscribe_to_topic(tokens, topic, app=None): + +def subscribe_to_topic( + tokens: Union[list[str], str], + topic: str, + app: Optional[firebase_admin.App] = None, +) -> 'TopicManagementResponse': """Subscribes a list of registration tokens to an FCM topic. Args: @@ -258,7 +275,12 @@ def subscribe_to_topic(tokens, topic, app=None): return _get_messaging_service(app).make_topic_management_request( tokens, topic, 'iid/v1:batchAdd') -def unsubscribe_from_topic(tokens, topic, app=None): + +def unsubscribe_from_topic( + tokens: Union[list[str], str], + topic: str, + app: Optional[firebase_admin.App] = None, +) -> 'TopicManagementResponse': """Unsubscribes a list of registration tokens from an FCM topic. Args: @@ -281,17 +303,17 @@ def unsubscribe_from_topic(tokens, topic, app=None): class ErrorInfo: """An error encountered when performing a topic management operation.""" - def __init__(self, index, reason): + def __init__(self, index: int, reason: str) -> None: self._index = index self._reason = reason @property - def index(self): + def index(self) -> int: """Index of the registration token to which this error is related to.""" return self._index @property - def reason(self): + def reason(self) -> str: """String describing the nature of the error.""" return self._reason @@ -299,12 +321,12 @@ def reason(self): class TopicManagementResponse: """The response received from a topic management operation.""" - def __init__(self, resp): + def __init__(self, resp: dict[str, Any]) -> None: if not isinstance(resp, dict) or 'results' not in resp: raise ValueError(f'Unexpected topic management response: {resp}.') self._success_count = 0 self._failure_count = 0 - self._errors = [] + self._errors: list[ErrorInfo] = [] for index, result in enumerate(resp['results']): if 'error' in result: self._failure_count += 1 @@ -313,17 +335,17 @@ def __init__(self, resp): self._success_count += 1 @property - def success_count(self): + def success_count(self) -> int: """Number of tokens that were successfully subscribed or unsubscribed.""" return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: """Number of tokens that could not be subscribed or unsubscribed due to errors.""" return self._failure_count @property - def errors(self): + def errors(self) ->list[ErrorInfo]: """A list of ``messaging.ErrorInfo`` objects (possibly empty).""" return self._errors @@ -331,12 +353,12 @@ def errors(self): class BatchResponse: """The response received from a batch request to the FCM API.""" - def __init__(self, responses: List[SendResponse]) -> None: + def __init__(self, responses: list['SendResponse']) -> None: self._responses = responses self._success_count = sum(1 for resp in responses if resp.success) @property - def responses(self) -> List[SendResponse]: + def responses(self) -> list['SendResponse']: """A list of ``messaging.SendResponse`` objects (possibly empty).""" return self._responses @@ -352,27 +374,32 @@ def failure_count(self) -> int: class SendResponse: """The response received from an individual batched request to the FCM API.""" - def __init__(self, resp, exception): + def __init__( + self, + resp: Optional[dict[str, Any]], + exception: Optional[exceptions.FirebaseError], + ) -> None: self._exception = exception - self._message_id = None + self._message_id: Optional[str] = None if resp: self._message_id = resp.get('name', None) @property - def message_id(self): + def message_id(self) -> Optional[str]: """A message ID string that uniquely identifies the message.""" return self._message_id @property - def success(self): + def success(self) -> bool: """A boolean indicating if the request was successful.""" return self._message_id is not None and not self._exception @property - def exception(self): + def exception(self) -> Optional[exceptions.FirebaseError]: """A ``FirebaseError`` if an error occurs while sending the message to the FCM service.""" return self._exception + class _MessagingService: """Service class that implements Firebase Cloud Messaging (FCM) functionality.""" @@ -390,7 +417,7 @@ class _MessagingService: 'UNREGISTERED': UnregisteredError, } - def __init__(self, app: App) -> None: + def __init__(self, app: firebase_admin.App) -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -409,7 +436,7 @@ def __init__(self, app: App) -> None: credential=self._credential, timeout=timeout) @classmethod - def encode_message(cls, message): + def encode_message(cls, message: Message) -> dict[str, Any]: if not isinstance(message, Message): raise ValueError('Message must be an instance of messaging.Message class.') return cls.JSON_ENCODER.default(message) @@ -428,14 +455,14 @@ def send(self, message: Message, dry_run: bool = False) -> str: raise self._handle_fcm_error(error) return cast(str, resp['name']) - def send_each(self, messages: List[Message], dry_run: bool = False) -> BatchResponse: + def send_each(self, messages: list[Message], dry_run: bool = False) -> BatchResponse: """Sends the given messages to FCM via the FCM v1 API.""" if not isinstance(messages, list): raise ValueError('messages must be a list of messaging.Message instances.') if len(messages) > 500: raise ValueError('messages must not contain more than 500 elements.') - def send_data(data): + def send_data(data: dict[str, Any]) -> SendResponse: try: resp = self._client.body( 'post', @@ -456,14 +483,14 @@ def send_data(data): message=f'Unknown error while making remote service calls: {error}', cause=error) - async def send_each_async(self, messages: List[Message], dry_run: bool = True) -> BatchResponse: + async def send_each_async(self, messages: list[Message], dry_run: bool = True) -> BatchResponse: """Sends the given messages to FCM via the FCM v1 API.""" if not isinstance(messages, list): raise ValueError('messages must be a list of messaging.Message instances.') if len(messages) > 500: raise ValueError('messages must not contain more than 500 elements.') - async def send_data(data): + async def send_data(data: dict[str, Any]) -> SendResponse: try: resp = await self._async_client.request( 'post', @@ -486,7 +513,12 @@ async def send_data(data): message=f'Unknown error while making remote service calls: {error}', cause=error) - def make_topic_management_request(self, tokens, topic, operation): + def make_topic_management_request( + self, + tokens: Union[list[str], str], + topic: str, + operation: str, + ) -> TopicManagementResponse: """Invokes the IID service for topic management functionality.""" if isinstance(tokens, str): tokens = [tokens] @@ -516,18 +548,18 @@ def make_topic_management_request(self, tokens, topic, operation): raise self._handle_iid_error(error) return TopicManagementResponse(resp) - def _message_data(self, message, dry_run): - data = {'message': _MessagingService.encode_message(message)} + def _message_data(self, message: Message, dry_run: bool) -> dict[str, Any]: + data: dict[str, Any] = {'message': _MessagingService.encode_message(message)} if dry_run: data['validate_only'] = True return data - def _postproc(self, _, body): + def _postproc(self, _: 'httplib2.Response', body: bytes) -> Any: """Handle response from batch API request.""" # This only gets called for 2xx responses. return json.loads(body.decode()) - def _handle_fcm_error(self, error): + def _handle_fcm_error(self, error: requests.RequestException) -> exceptions.FirebaseError: """Handles errors received from the FCM API.""" return _utils.handle_platform_error_from_requests( error, _MessagingService._build_fcm_error_requests) @@ -537,12 +569,12 @@ def _handle_fcm_httpx_error(self, error: httpx.HTTPError) -> exceptions.Firebase return _utils.handle_platform_error_from_httpx( error, _MessagingService._build_fcm_error_httpx) - def _handle_iid_error(self, error): + def _handle_iid_error(self, error: requests.RequestException) -> exceptions.FirebaseError: """Handles errors received from the Instance ID API.""" if error.response is None: raise _utils.handle_requests_error(error) - data = {} + data: dict[str, Any] = {} try: parsed_body = error.response.json() if isinstance(parsed_body, dict): @@ -567,41 +599,45 @@ def close(self) -> None: asyncio.run(self._async_client.aclose()) @classmethod - def _build_fcm_error_requests(cls, error, message, error_dict): + def _build_fcm_error_requests( + cls, + error: requests.RequestException, + message: str, + error_dict: dict[str, Any], + ) -> Optional[exceptions.FirebaseError]: """Parses an error response from the FCM API and creates a FCM-specific exception if appropriate.""" exc_type = cls._build_fcm_error(error_dict) - # pylint: disable=not-callable - return exc_type(message, cause=error, http_response=error.response) if exc_type else None + return exc_type(message, error, error.response) if exc_type else None @classmethod def _build_fcm_error_httpx( - cls, - error: httpx.HTTPError, - message: str, - error_dict: Optional[Dict[str, Any]] - ) -> Optional[exceptions.FirebaseError]: + cls, + error: httpx.HTTPError, + message: str, + error_dict: Optional[dict[str, Any]], + ) -> Optional[exceptions.FirebaseError]: """Parses a httpx error response from the FCM API and creates a FCM-specific exception if appropriate.""" exc_type = cls._build_fcm_error(error_dict) if isinstance(error, httpx.HTTPStatusError): - # pylint: disable=not-callable - return exc_type( - message, cause=error, http_response=error.response) if exc_type else None - # pylint: disable=not-callable - return exc_type(message, cause=error) if exc_type else None + return exc_type(message, error, error.response) if exc_type else None + return exc_type(message, error, None) if exc_type else None @classmethod def _build_fcm_error( - cls, - error_dict: Optional[Dict[str, Any]] - ) -> Optional[Callable[..., exceptions.FirebaseError]]: + cls, + error_dict: Optional[dict[str, Any]], + ) -> Optional[Callable[ + [str, Optional[Exception], Optional[Union[httpx.Response, requests.Response]]], + exceptions.FirebaseError + ]]: """Parses an error response to determine the appropriate FCM-specific error type.""" if not error_dict: return None fcm_code = None for detail in error_dict.get('details', []): if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': - fcm_code = detail.get('errorCode') + fcm_code = cast(str, detail.get('errorCode')) break return _MessagingService.FCM_ERROR_TYPES.get(fcm_code) if fcm_code else None diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 3a77dd05f..38b2f69af 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -18,12 +18,13 @@ deleting, publishing and unpublishing Firebase ML models. """ - import datetime +import os import re import time -import os -from urllib import parse +import urllib.parse +from collections.abc import Callable, Iterator +from typing import TYPE_CHECKING, Any, Optional, Union, cast import requests @@ -32,19 +33,46 @@ from firebase_admin import _utils from firebase_admin import exceptions -# pylint: disable=import-error,no-member -try: +if TYPE_CHECKING: + import tensorflow as tf + from _typeshed import Incomplete from firebase_admin import storage - _GCS_ENABLED = True -except ImportError: - _GCS_ENABLED = False -# pylint: disable=import-error,no-member -try: - import tensorflow as tf - _TF_ENABLED = True -except ImportError: - _TF_ENABLED = False + _GCS_ENABLED: bool + _TF_ENABLED: bool +else: + Incomplete = Any + + # pylint: disable=import-error,no-member + try: + from firebase_admin import storage + _GCS_ENABLED = True + except ImportError: + _GCS_ENABLED = False + + # pylint: disable=import-error,no-member + try: + import tensorflow as tf + _TF_ENABLED = True + except ImportError: + _TF_ENABLED = False + +__all__ = ( + 'ListModelsPage', + 'Model', + 'ModelFormat', + 'TFLiteFormat', + 'TFLiteGCSModelSource', + 'TFLiteModelSource', + 'create_model', + 'delete_model', + 'get_model', + 'list_models', + 'publish_model', + 'unpublish_model', + 'update_model', +) + _ML_ATTRIBUTE = '_ml' _MAX_PAGE_SIZE = 100 @@ -59,7 +87,7 @@ r'^projects/(?P[a-z0-9-]{6,30})/operations/[^/]+$') -def _get_ml_service(app): +def _get_ml_service(app: Optional[firebase_admin.App]) -> '_MLService': """ Returns an _MLService instance for an App. Args: @@ -74,7 +102,7 @@ def _get_ml_service(app): return _utils.get_app_service(app, _ML_ATTRIBUTE, _MLService) -def create_model(model, app=None): +def create_model(model: 'Model', app: Optional[firebase_admin.App] = None) -> 'Model': """Creates a model in the current Firebase project. Args: @@ -88,7 +116,7 @@ def create_model(model, app=None): return Model.from_dict(ml_service.create_model(model), app=app) -def update_model(model, app=None): +def update_model(model: 'Model', app: Optional[firebase_admin.App] = None) -> 'Model': """Updates a model's metadata or model file. Args: @@ -102,7 +130,7 @@ def update_model(model, app=None): return Model.from_dict(ml_service.update_model(model), app=app) -def publish_model(model_id, app=None): +def publish_model(model_id: str, app: Optional[firebase_admin.App] = None) -> 'Model': """Publishes a Firebase ML model. A published model can be downloaded to client apps. @@ -118,7 +146,7 @@ def publish_model(model_id, app=None): return Model.from_dict(ml_service.set_published(model_id, publish=True), app=app) -def unpublish_model(model_id, app=None): +def unpublish_model(model_id: str, app: Optional[firebase_admin.App] = None) -> 'Model': """Unpublishes a Firebase ML model. Args: @@ -132,7 +160,7 @@ def unpublish_model(model_id, app=None): return Model.from_dict(ml_service.set_published(model_id, publish=False), app=app) -def get_model(model_id, app=None): +def get_model(model_id: str, app: Optional[firebase_admin.App] = None) -> 'Model': """Gets the model specified by the given ID. Args: @@ -146,7 +174,12 @@ def get_model(model_id, app=None): return Model.from_dict(ml_service.get_model(model_id), app=app) -def list_models(list_filter=None, page_size=None, page_token=None, app=None): +def list_models( + list_filter: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + app: Optional[firebase_admin.App] = None, +) -> 'ListModelsPage': """Lists the current project's models. Args: @@ -165,7 +198,7 @@ def list_models(list_filter=None, page_size=None, page_token=None, app=None): ml_service.list_models, list_filter, page_size, page_token, app=app) -def delete_model(model_id, app=None): +def delete_model(model_id: str, app: Optional[firebase_admin.App] = None) -> None: """Deletes a model from the current project. Args: @@ -184,10 +217,15 @@ class Model: tags: Optional list of strings associated with your model. Can be used in list queries. model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. """ - def __init__(self, display_name=None, tags=None, model_format=None): - self._app = None # Only needed for wait_for_unlo - self._data = {} - self._model_format = None + def __init__( + self, + display_name: Optional[str] = None, + tags: Optional[list[str]] = None, + model_format: Optional['ModelFormat'] = None, + ) -> None: + self._app: Optional[firebase_admin.App] = None # Only needed for wait_for_unlo + self._data: dict[str, Any] = {} + self._model_format: Optional[ModelFormat] = None if display_name is not None: self.display_name = display_name @@ -197,7 +235,11 @@ def __init__(self, display_name=None, tags=None, model_format=None): self.model_format = model_format @classmethod - def from_dict(cls, data, app=None): + def from_dict( + cls, + data: dict[str, Any], + app: Optional[firebase_admin.App] = None, + ) -> 'Model': """Create an instance of the object from a dict.""" data_copy = dict(data) tflite_format = None @@ -210,22 +252,22 @@ def from_dict(cls, data, app=None): model._app = app # pylint: disable=protected-access return model - def _update_from_dict(self, data): + def _update_from_dict(self, data: dict[str, Any]) -> None: copy = Model.from_dict(data) self.model_format = copy.model_format self._data = copy._data # pylint: disable=protected-access - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): # pylint: disable=protected-access return self._data == other._data and self._model_format == other._model_format return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @property - def model_id(self): + def model_id(self) -> Optional[str]: """The model's ID, unique to the project.""" if not self._data.get('name'): return None @@ -233,74 +275,72 @@ def model_id(self): return model_id @property - def display_name(self): + def display_name(self) -> Optional[str]: """The model's display name, used to refer to the model in code and in the Firebase console.""" return self._data.get('displayName') @display_name.setter - def display_name(self, display_name): + def display_name(self, display_name: str) -> None: self._data['displayName'] = _validate_display_name(display_name) - return self @staticmethod - def _convert_to_millis(date_string): + def _convert_to_millis(date_string: Optional[str]) -> Optional[int]: if not date_string: return None format_str = '%Y-%m-%dT%H:%M:%S.%fZ' - epoch = datetime.datetime.utcfromtimestamp(0) - datetime_object = datetime.datetime.strptime(date_string, format_str) + epoch = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) + datetime_object = datetime.datetime.strptime(date_string, format_str).replace(tzinfo=datetime.timezone.utc) millis = int((datetime_object - epoch).total_seconds() * 1000) return millis @property - def create_time(self): + def create_time(self) -> Optional[int]: """The time the model was created.""" return Model._convert_to_millis(self._data.get('createTime', None)) @property - def update_time(self): + def update_time(self) -> Optional[int]: """The time the model was last updated.""" return Model._convert_to_millis(self._data.get('updateTime', None)) @property - def validation_error(self): + def validation_error(self) -> Optional[str]: """Validation error message.""" return self._data.get('state', {}).get('validationError', {}).get('message') @property - def published(self): + def published(self) -> bool: """True if the model is published and available for clients to download.""" return bool(self._data.get('state', {}).get('published')) @property - def etag(self): + def etag(self) -> Optional[Incomplete]: """The entity tag (ETag) of the model resource.""" return self._data.get('etag') @property - def model_hash(self): + def model_hash(self) -> Optional[Incomplete]: """SHA256 hash of the model binary.""" return self._data.get('modelHash') @property - def tags(self): + def tags(self) -> Optional[list[str]]: """Tag strings, used for filtering query results.""" return self._data.get('tags') @tags.setter - def tags(self, tags): + def tags(self, tags: list[str]) -> None: self._data['tags'] = _validate_tags(tags) - return self @property - def locked(self): + def locked(self) -> bool: """True if the Model object is locked by an active operation.""" return bool(self._data.get('activeOperations') and - len(self._data.get('activeOperations')) > 0) + len(self._data['activeOperations']) > 0) - def wait_for_unlocked(self, max_time_seconds=None): + def wait_for_unlocked(self, max_time_seconds: Optional[float] = None) -> None: """Waits for the model to be unlocked. (All active operations complete) Args: @@ -313,7 +353,7 @@ def wait_for_unlocked(self, max_time_seconds=None): if not self.locked: return ml_service = _get_ml_service(self._app) - op_name = self._data.get('activeOperations')[0].get('name') + op_name = self._data['activeOperations'][0].get('name') model_dict = ml_service.handle_operation( ml_service.get_operation(op_name), wait_for_operation=True, @@ -321,19 +361,18 @@ def wait_for_unlocked(self, max_time_seconds=None): self._update_from_dict(model_dict) @property - def model_format(self): + def model_format(self) -> Optional['ModelFormat']: """The model's ``ModelFormat`` object, which represents the model's format and storage location.""" return self._model_format @model_format.setter - def model_format(self, model_format): + def model_format(self, model_format: Optional['ModelFormat']) -> None: if model_format is not None: _validate_model_format(model_format) self._model_format = model_format #Can be None - return self - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> dict[str, Any]: """Returns a serializable representation of the object.""" copy = dict(self._data) if self._model_format: @@ -343,7 +382,7 @@ def as_dict(self, for_upload=False): class ModelFormat: """Abstract base class representing a Model Format such as TFLite.""" - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> dict[str, Any]: """Returns a serializable representation of the object.""" raise NotImplementedError @@ -354,32 +393,32 @@ class TFLiteFormat(ModelFormat): Args: model_source: A TFLiteModelSource sub class. Specifies the details of the model source. """ - def __init__(self, model_source=None): - self._data = {} - self._model_source = None + def __init__(self, model_source: Optional['TFLiteModelSource'] = None) -> None: + self._data: dict[str, Any] = {} + self._model_source: Optional[TFLiteModelSource] = None if model_source is not None: self.model_source = model_source @classmethod - def from_dict(cls, data): + def from_dict(cls, data: dict[str, Any]) -> 'TFLiteFormat': """Create an instance of the object from a dict.""" data_copy = dict(data) tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy)) tflite_format._data = data_copy # pylint: disable=protected-access return tflite_format - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): # pylint: disable=protected-access return self._data == other._data and self._model_source == other._model_source return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @staticmethod - def _init_model_source(data): + def _init_model_source(data: dict[str, Any]) -> Optional['TFLiteModelSource']: """Initialize the ML model source.""" gcs_tflite_uri = data.pop('gcsTfliteUri', None) if gcs_tflite_uri: @@ -387,23 +426,23 @@ def _init_model_source(data): return None @property - def model_source(self): + def model_source(self) -> Optional['TFLiteModelSource']: """The TF Lite model's location.""" return self._model_source @model_source.setter - def model_source(self, model_source): + def model_source(self, model_source: Optional['TFLiteModelSource']) -> None: if model_source is not None: if not isinstance(model_source, TFLiteModelSource): raise TypeError('Model source must be a TFLiteModelSource object.') self._model_source = model_source # Can be None @property - def size_bytes(self): + def size_bytes(self) -> Optional[Incomplete]: """The size in bytes of the TF Lite model.""" return self._data.get('sizeBytes') - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> dict[str, Any]: """Returns a serializable representation of the object.""" copy = dict(self._data) if self._model_source: @@ -413,7 +452,7 @@ def as_dict(self, for_upload=False): class TFLiteModelSource: """Abstract base class representing a model source for TFLite format models.""" - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> dict[str, Any]: """Returns a serializable representation of the object.""" raise NotImplementedError @@ -425,13 +464,13 @@ class _CloudStorageClient: BLOB_NAME = 'Firebase/ML/Models/{0}' @staticmethod - def _assert_gcs_enabled(): + def _assert_gcs_enabled() -> None: if not _GCS_ENABLED: raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' 'to install the "google-cloud-storage" module.') @staticmethod - def _parse_gcs_tflite_uri(uri): + def _parse_gcs_tflite_uri(uri: str) -> tuple[str, str]: # GCS Bucket naming rules are complex. The regex is not comprehensive. # See https://cloud.google.com/storage/docs/naming for full details. matcher = _GCS_TFLITE_URI_PATTERN.match(uri) @@ -440,10 +479,13 @@ def _parse_gcs_tflite_uri(uri): return matcher.group('bucket_name'), matcher.group('blob_name') @staticmethod - def upload(bucket_name, model_file_name, app): + def upload( + bucket_name: Optional[str], + model_file_name: Union[str, os.PathLike[str]], + app: Optional[firebase_admin.App], + ) -> str: """Upload a model file to the specified Storage bucket.""" _CloudStorageClient._assert_gcs_enabled() - file_name = os.path.basename(model_file_name) bucket = storage.bucket(bucket_name, app=app) blob_name = _CloudStorageClient.BLOB_NAME.format(file_name) @@ -452,7 +494,7 @@ def upload(bucket_name, model_file_name, app): return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name) @staticmethod - def sign_uri(gcs_tflite_uri, app): + def sign_uri(gcs_tflite_uri: str, app: Optional[firebase_admin.App]) -> str: """Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri.""" _CloudStorageClient._assert_gcs_enabled() bucket_name, blob_name = _CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) @@ -470,20 +512,29 @@ class TFLiteGCSModelSource(TFLiteModelSource): _STORAGE_CLIENT = _CloudStorageClient() - def __init__(self, gcs_tflite_uri, app=None): + def __init__( + self, + gcs_tflite_uri: str, + app: Optional[firebase_admin.App] = None, + ) -> None: self._app = app self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @classmethod - def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): + def from_tflite_model_file( + cls, + model_file_name: Union[str, os.PathLike[str]], + bucket_name: Optional[str] = None, + app: Optional[firebase_admin.App] = None, + ) -> 'TFLiteGCSModelSource': """Uploads the model file to an existing Google Cloud Storage bucket. Args: @@ -502,7 +553,7 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) @staticmethod - def _assert_tf_enabled(): + def _assert_tf_enabled() -> None: if not _TF_ENABLED: raise ImportError('Failed to import the tensorflow library for Python. Make sure ' 'to install the tensorflow module.') @@ -511,13 +562,13 @@ def _assert_tf_enabled(): f'Expected tensorflow version 1.x or 2.x, but found {tf.version.VERSION}') @staticmethod - def _tf_convert_from_saved_model(saved_model_dir): + def _tf_convert_from_saved_model(saved_model_dir: Incomplete) -> Incomplete: # Same for both v1.x and v2.x converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) return converter.convert() @staticmethod - def _tf_convert_from_keras_model(keras_model): + def _tf_convert_from_keras_model(keras_model: Incomplete) -> Incomplete: """Converts the given Keras model into a TF Lite model.""" # Version 1.x conversion function takes a model file. Version 2.x takes the model itself. if tf.version.VERSION.startswith('1.'): @@ -530,8 +581,13 @@ def _tf_convert_from_keras_model(keras_model): return converter.convert() @classmethod - def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', - bucket_name=None, app=None): + def from_saved_model( + cls, + saved_model_dir: Incomplete, + model_file_name: Union[str, os.PathLike[str]] = 'firebase_ml_model.tflite', + bucket_name: Optional[str] = None, + app: Optional[firebase_admin.App] = None, + ) -> 'TFLiteGCSModelSource': """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. Args: @@ -554,8 +610,13 @@ def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tf return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @classmethod - def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite', - bucket_name=None, app=None): + def from_keras_model( + cls, + keras_model: os.PathLike[str], + model_file_name: str = 'firebase_ml_model.tflite', + bucket_name: Optional[str] = None, + app: Optional[firebase_admin.App] = None, + ) -> 'TFLiteGCSModelSource': """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. Args: @@ -578,25 +639,26 @@ def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @property - def gcs_tflite_uri(self): + def gcs_tflite_uri(self) -> str: """URI of the model file in Cloud Storage.""" return self._gcs_tflite_uri @gcs_tflite_uri.setter - def gcs_tflite_uri(self, gcs_tflite_uri): + def gcs_tflite_uri(self, gcs_tflite_uri: str) -> None: self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) - def _get_signed_gcs_tflite_uri(self): + def _get_signed_gcs_tflite_uri(self) -> str: """Signs the GCS uri, so the model file can be uploaded to Firebase ML and verified.""" return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app) - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> dict[str, Any]: """Returns a serializable representation of the object.""" if for_upload: return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()} return {'gcsTfliteUri': self._gcs_tflite_uri} + class ListModelsPage: """Represents a page of models in a Firebase project. @@ -605,7 +667,17 @@ class ListModelsPage: ``iterate_all()`` can be used to iterate through all the models in the Firebase project starting from this page. """ - def __init__(self, list_models_func, list_filter, page_size, page_token, app): + def __init__( + self, + list_models_func: Callable[ + [Optional[str], Optional[int], Optional[str]], + dict[str, Any] + ], + list_filter: Optional[str], + page_size: Optional[int], + page_token: Optional[str], + app: Optional[firebase_admin.App], + ) -> None: self._list_models_func = list_models_func self._list_filter = list_filter self._page_size = page_size @@ -614,28 +686,32 @@ def __init__(self, list_models_func, list_filter, page_size, page_token, app): self._list_response = list_models_func(list_filter, page_size, page_token) @property - def models(self): + def models(self) -> list[Model]: """A list of Models from this page.""" return [ - Model.from_dict(model, app=self._app) for model in self._list_response.get('models', []) + Model.from_dict(model, app=self._app) + for model in cast( + list[dict[str, Any]], + self._list_response.get('models', []), + ) ] @property - def list_filter(self): + def list_filter(self) -> Optional[str]: """The filter string used to filter the models.""" return self._list_filter @property - def next_page_token(self): + def next_page_token(self) -> str: """Token identifying the next page of results.""" - return self._list_response.get('nextPageToken', '') + return cast(str, self._list_response.get('nextPageToken', '')) @property - def has_next_page(self): + def has_next_page(self) -> bool: """True if more pages are available.""" return bool(self.next_page_token) - def get_next_page(self): + def get_next_page(self) -> Optional['ListModelsPage']: """Retrieves the next page of models if available. Returns: @@ -650,7 +726,7 @@ def get_next_page(self): self._app) return None - def iterate_all(self): + def iterate_all(self) -> '_ModelIterator': """Retrieves an iterator for Models. Returned iterator will iterate through all the models in the Firebase @@ -670,16 +746,16 @@ class _ModelIterator: When the whole page has been traversed, it loads another page. This class never keeps more than one page of entries in memory. """ - def __init__(self, current_page): + def __init__(self, current_page: ListModelsPage) -> None: if not isinstance(current_page, ListModelsPage): raise TypeError('Current page must be a ListModelsPage') self._current_page = current_page - self._index = 0 + self._index: int = 0 - def __next__(self): + def __next__(self) -> Model: if self._index == len(self._current_page.models): if self._current_page.has_next_page: - self._current_page = self._current_page.get_next_page() + self._current_page = cast(ListModelsPage, self._current_page.get_next_page()) self._index = 0 if self._index < len(self._current_page.models): result = self._current_page.models[self._index] @@ -687,11 +763,11 @@ def __next__(self): return result raise StopIteration - def __iter__(self): + def __iter__(self) -> Iterator[Model]: return self -def _validate_and_parse_name(name): +def _validate_and_parse_name(name: Any) -> tuple[str, str]: # The resource name is added automatically from API call responses. # The only way it could be invalid is if someone tries to # create a model from a dictionary manually and does it incorrectly. @@ -701,40 +777,41 @@ def _validate_and_parse_name(name): return matcher.group('project_id'), matcher.group('model_id') -def _validate_model(model, update_mask=None): +def _validate_model(model: Model, update_mask: Optional[str] = None) -> None: if not isinstance(model, Model): raise TypeError('Model must be an ml.Model.') if update_mask is None and not model.display_name: raise ValueError('Model must have a display name.') -def _validate_model_id(model_id): +def _validate_model_id(model_id: str) -> None: if not _MODEL_ID_PATTERN.match(model_id): raise ValueError('Model ID format is invalid.') -def _validate_operation_name(op_name): +def _validate_operation_name(op_name: Any) -> str: if not _OPERATION_NAME_PATTERN.match(op_name): raise ValueError('Operation name format is invalid.') return op_name -def _validate_display_name(display_name): +def _validate_display_name(display_name: Any) -> str: if not _DISPLAY_NAME_PATTERN.match(display_name): raise ValueError('Display name format is invalid.') return display_name -def _validate_tags(tags): +def _validate_tags(tags: Any) -> list[str]: if not isinstance(tags, list) or not \ all(isinstance(tag, str) for tag in tags): raise TypeError('Tags must be a list of strings.') + tags = cast(list[str], tags) if not all(_TAG_PATTERN.match(tag) for tag in tags): raise ValueError('Tag format is invalid.') return tags -def _validate_gcs_tflite_uri(uri): +def _validate_gcs_tflite_uri(uri: str) -> str: # GCS Bucket naming rules are complex. The regex is not comprehensive. # See https://cloud.google.com/storage/docs/naming for full details. if not _GCS_TFLITE_URI_PATTERN.match(uri): @@ -742,19 +819,19 @@ def _validate_gcs_tflite_uri(uri): return uri -def _validate_model_format(model_format): +def _validate_model_format(model_format: Any) -> ModelFormat: if not isinstance(model_format, ModelFormat): raise TypeError('Model format must be a ModelFormat object.') return model_format -def _validate_list_filter(list_filter): +def _validate_list_filter(list_filter: Optional[str]) -> None: if list_filter is not None: if not isinstance(list_filter, str): raise TypeError('List filter must be a string or None.') -def _validate_page_size(page_size): +def _validate_page_size(page_size: Optional[int]) -> None: if page_size is not None: if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck # Specifically type() to disallow boolean which is a subtype of int @@ -764,7 +841,7 @@ def _validate_page_size(page_size): f'Page size must be a positive integer between 1 and {_MAX_PAGE_SIZE}') -def _validate_page_token(page_token): +def _validate_page_token(page_token: Optional[str]) -> None: if page_token is not None: if not isinstance(page_token, str): raise TypeError('Page token must be a string or None.') @@ -778,7 +855,7 @@ class _MLService: POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5 POLL_BASE_WAIT_TIME_SECONDS = 3 - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: self._project_id = app.project_id if not self._project_id: raise ValueError( @@ -797,14 +874,14 @@ def __init__(self, app): headers=ml_headers, base_url=_MLService.OPERATION_URL) - def get_operation(self, op_name): + def get_operation(self, op_name: str) -> dict[str, Any]: _validate_operation_name(op_name) try: return self._operation_client.body('get', url=op_name) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def _exponential_backoff(self, current_attempt, stop_time): + def _exponential_backoff(self, current_attempt: int, stop_time: Optional[datetime.datetime]) -> None: """Sleeps for the appropriate amount of time. Or throws deadline exceeded.""" delay_factor = pow(_MLService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) wait_time_seconds = delay_factor * _MLService.POLL_BASE_WAIT_TIME_SECONDS @@ -816,7 +893,12 @@ def _exponential_backoff(self, current_attempt, stop_time): wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) time.sleep(wait_time_seconds) - def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None): + def handle_operation( + self, + operation: dict[str, Any], + wait_for_operation: bool = False, + max_time_seconds: Optional[float] = None, + ) -> dict[str, Any]: """Handles long running operations. Args: @@ -841,13 +923,14 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds if operation.get('done'): # Operations which are immediately done don't have an operation name if operation.get('response'): - return operation.get('response') + return cast(dict[str, Any], operation['response']) if operation.get('error'): - raise _utils.handle_operation_error(operation.get('error')) + error = cast(dict[str, Any], operation['error']) + raise _utils.handle_operation_error(error) raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') op_name = _validate_operation_name(operation.get('name')) - metadata = operation.get('metadata', {}) + metadata = cast(dict[str, Any], operation.get('metadata', {})) metadata_type = metadata.get('@type', '') if not metadata_type.endswith('ModelOperationMetadata'): raise TypeError('Unknown type of operation metadata.') @@ -865,15 +948,16 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds if operation.get('done'): if operation.get('response'): - return operation.get('response') + return cast(dict[str, Any], operation['response']) if operation.get('error'): - raise _utils.handle_operation_error(operation.get('error')) + error = cast(dict[str, Any], operation['error']) + raise _utils.handle_operation_error(error) # If the operation is not complete or timed out, return a (locked) model instead return get_model(model_id).as_dict() - def create_model(self, model): + def create_model(self, model: Model) -> dict[str, Any]: _validate_model(model) try: return self.handle_operation( @@ -881,7 +965,7 @@ def create_model(self, model): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def update_model(self, model, update_mask=None): + def update_model(self, model: Model, update_mask: Optional[str] = None) -> dict[str, Any]: _validate_model(model, update_mask) path = f'models/{model.model_id}' if update_mask is not None: @@ -892,7 +976,7 @@ def update_model(self, model, update_mask=None): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def set_published(self, model_id, publish): + def set_published(self, model_id: str, publish: bool) -> dict[str, Any]: _validate_model_id(model_id) model_name = f'projects/{self._project_id}/models/{model_id}' model = Model.from_dict({ @@ -903,19 +987,24 @@ def set_published(self, model_id, publish): }) return self.update_model(model, update_mask='state.published') - def get_model(self, model_id): + def get_model(self, model_id: str) -> dict[str, Any]: _validate_model_id(model_id) try: return self._client.body('get', url=f'models/{model_id}') except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def list_models(self, list_filter, page_size, page_token): + def list_models( + self, + list_filter: Optional[str], + page_size: Optional[int], + page_token: Optional[str], + ) -> dict[str, Any]: """ lists Firebase ML models.""" _validate_list_filter(list_filter) _validate_page_size(page_size) _validate_page_token(page_token) - params = {} + params: dict[str, Any] = {} if list_filter: params['filter'] = list_filter if page_size: @@ -924,14 +1013,14 @@ def list_models(self, list_filter, page_size, page_token): params['page_token'] = page_token path = 'models' if params: - param_str = parse.urlencode(sorted(params.items()), True) + param_str = urllib.parse.urlencode(sorted(params.items()), True) path = path + '?' + param_str try: return self._client.body('get', url=path) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def delete_model(self, model_id): + def delete_model(self, model_id: str) -> None: _validate_model_id(model_id) try: self._client.body('delete', url=f'models/{model_id}') diff --git a/firebase_admin/project_management.py b/firebase_admin/project_management.py index 73c100d3a..e952390ae 100644 --- a/firebase_admin/project_management.py +++ b/firebase_admin/project_management.py @@ -20,6 +20,15 @@ import base64 import re import time +from collections.abc import Callable +from typing import ( + Any, + Optional, + NoReturn, + TypeVar, + cast, + overload, +) import requests @@ -28,15 +37,31 @@ from firebase_admin import _http_client from firebase_admin import _utils +__all__ = ( + 'AndroidApp', + 'AndroidAppMetadata', + 'IOSApp', + 'IOSAppMetadata', + 'SHACertificate', + 'android_app', + 'create_android_app', + 'create_ios_app', + 'ios_app', + 'list_android_apps', + 'list_ios_apps', +) + +_T = TypeVar('_T') +_AppMetadataT = TypeVar('_AppMetadataT', bound='_AppMetadata') _PROJECT_MANAGEMENT_ATTRIBUTE = '_project_management' -def _get_project_management_service(app): +def _get_project_management_service(app: Optional[firebase_admin.App]) -> '_ProjectManagementService': return _utils.get_app_service(app, _PROJECT_MANAGEMENT_ATTRIBUTE, _ProjectManagementService) -def android_app(app_id, app=None): +def android_app(app_id: str, app: Optional[firebase_admin.App] = None) -> 'AndroidApp': """Obtains a reference to an Android app in the associated Firebase project. Args: @@ -49,7 +74,7 @@ def android_app(app_id, app=None): return AndroidApp(app_id=app_id, service=_get_project_management_service(app)) -def ios_app(app_id, app=None): +def ios_app(app_id: str, app: Optional[firebase_admin.App] = None) -> 'IOSApp': """Obtains a reference to an iOS app in the associated Firebase project. Args: @@ -62,7 +87,7 @@ def ios_app(app_id, app=None): return IOSApp(app_id=app_id, service=_get_project_management_service(app)) -def list_android_apps(app=None): +def list_android_apps(app: Optional[firebase_admin.App] = None) -> list['AndroidApp']: """Lists all Android apps in the associated Firebase project. Args: @@ -75,7 +100,7 @@ def list_android_apps(app=None): return _get_project_management_service(app).list_android_apps() -def list_ios_apps(app=None): +def list_ios_apps(app: Optional[firebase_admin.App] = None) -> list['IOSApp']: """Lists all iOS apps in the associated Firebase project. Args: @@ -87,7 +112,11 @@ def list_ios_apps(app=None): return _get_project_management_service(app).list_ios_apps() -def create_android_app(package_name, display_name=None, app=None): +def create_android_app( + package_name: str, + display_name: Optional[str] = None, + app: Optional[firebase_admin.App] = None, +) -> 'AndroidApp': """Creates a new Android app in the associated Firebase project. Args: @@ -101,7 +130,11 @@ def create_android_app(package_name, display_name=None, app=None): return _get_project_management_service(app).create_android_app(package_name, display_name) -def create_ios_app(bundle_id, display_name=None, app=None): +def create_ios_app( + bundle_id: str, + display_name: Optional[str] = None, + app: Optional[firebase_admin.App] = None, +) -> 'IOSApp': """Creates a new iOS app in the associated Firebase project. Args: @@ -115,25 +148,29 @@ def create_ios_app(bundle_id, display_name=None, app=None): return _get_project_management_service(app).create_ios_app(bundle_id, display_name) -def _check_is_string_or_none(obj, field_name): +def _check_is_string_or_none(obj: Any, field_name: str) -> Optional[str]: if obj is None or isinstance(obj, str): return obj raise ValueError(f'{field_name} must be a string.') -def _check_is_nonempty_string(obj, field_name): +def _check_is_nonempty_string(obj: Any, field_name: str) -> str: if isinstance(obj, str) and obj: return obj raise ValueError(f'{field_name} must be a non-empty string.') -def _check_is_nonempty_string_or_none(obj, field_name): +def _check_is_nonempty_string_or_none(obj: Any, field_name: str) -> Optional[str]: if obj is None: return None return _check_is_nonempty_string(obj, field_name) -def _check_not_none(obj, field_name): +@overload +def _check_not_none(obj: None, field_name: str) -> NoReturn: ... +@overload +def _check_not_none(obj: _T, field_name: str) -> _T: ... +def _check_not_none(obj: Optional[_T], field_name: str) -> _T: if obj is None: raise ValueError(f'{field_name} cannot be None.') return obj @@ -148,12 +185,12 @@ class AndroidApp: instead of instantiating it directly. """ - def __init__(self, app_id, service): + def __init__(self, app_id: str, service: '_ProjectManagementService') -> None: self._app_id = app_id self._service = service @property - def app_id(self): + def app_id(self) -> str: """Returns the app ID of the Android app to which this instance refers. Note: This method does not make an RPC. @@ -163,7 +200,7 @@ def app_id(self): """ return self._app_id - def get_metadata(self): + def get_metadata(self) -> 'AndroidAppMetadata': """Retrieves detailed information about this Android app. Returns: @@ -175,7 +212,7 @@ def get_metadata(self): """ return self._service.get_android_app_metadata(self._app_id) - def set_display_name(self, new_display_name): + def set_display_name(self, new_display_name: Optional[str]) -> None: """Updates the display name attribute of this Android app to the one given. Args: @@ -188,13 +225,13 @@ def set_display_name(self, new_display_name): FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ - return self._service.set_android_app_display_name(self._app_id, new_display_name) + self._service.set_android_app_display_name(self._app_id, new_display_name) - def get_config(self): + def get_config(self) -> str: """Retrieves the configuration artifact associated with this Android app.""" return self._service.get_android_app_config(self._app_id) - def get_sha_certificates(self): + def get_sha_certificates(self) -> list['SHACertificate']: """Retrieves the entire list of SHA certificates associated with this Android app. Returns: @@ -206,7 +243,7 @@ def get_sha_certificates(self): """ return self._service.get_sha_certificates(self._app_id) - def add_sha_certificate(self, certificate_to_add): + def add_sha_certificate(self, certificate_to_add: 'SHACertificate') -> None: """Adds a SHA certificate to this Android app. Args: @@ -219,9 +256,9 @@ def add_sha_certificate(self, certificate_to_add): FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. (For example, if the certificate_to_add already exists.) """ - return self._service.add_sha_certificate(self._app_id, certificate_to_add) + self._service.add_sha_certificate(self._app_id, certificate_to_add) - def delete_sha_certificate(self, certificate_to_delete): + def delete_sha_certificate(self, certificate_to_delete: 'SHACertificate') -> None: """Removes a SHA certificate from this Android app. Args: @@ -234,7 +271,7 @@ def delete_sha_certificate(self, certificate_to_delete): FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. (For example, if the certificate_to_delete is not found.) """ - return self._service.delete_sha_certificate(certificate_to_delete) + self._service.delete_sha_certificate(certificate_to_delete) class IOSApp: @@ -246,12 +283,12 @@ class IOSApp: instead of instantiating it directly. """ - def __init__(self, app_id, service): + def __init__(self, app_id: str, service: '_ProjectManagementService') -> None: self._app_id = app_id self._service = service @property - def app_id(self): + def app_id(self) -> str: """Returns the app ID of the iOS app to which this instance refers. Note: This method does not make an RPC. @@ -261,7 +298,7 @@ def app_id(self): """ return self._app_id - def get_metadata(self): + def get_metadata(self) -> 'IOSAppMetadata': """Retrieves detailed information about this iOS app. Returns: @@ -273,7 +310,7 @@ def get_metadata(self): """ return self._service.get_ios_app_metadata(self._app_id) - def set_display_name(self, new_display_name): + def set_display_name(self, new_display_name: Optional[str]) -> None: """Updates the display name attribute of this iOS app to the one given. Args: @@ -286,9 +323,9 @@ def set_display_name(self, new_display_name): FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ - return self._service.set_ios_app_display_name(self._app_id, new_display_name) + self._service.set_ios_app_display_name(self._app_id, new_display_name) - def get_config(self): + def get_config(self) -> str: """Retrieves the configuration artifact associated with this iOS app.""" return self._service.get_ios_app_config(self._app_id) @@ -296,7 +333,7 @@ def get_config(self): class _AppMetadata: """Detailed information about a Firebase Android or iOS app.""" - def __init__(self, name, app_id, display_name, project_id): + def __init__(self, name: str, app_id: str, display_name: Optional[str], project_id: str) -> None: # _name is the fully qualified resource name of this Android or iOS app; currently it is not # exposed to client code. self._name = _check_is_nonempty_string(name, 'name') @@ -305,7 +342,7 @@ def __init__(self, name, app_id, display_name, project_id): self._project_id = _check_is_nonempty_string(project_id, 'project_id') @property - def app_id(self): + def app_id(self) -> str: """The globally unique, Firebase-assigned identifier of this Android or iOS app. This ID is unique even across apps of different platforms. @@ -313,18 +350,18 @@ def app_id(self): return self._app_id @property - def display_name(self): + def display_name(self) -> Optional[str]: """The user-assigned display name of this Android or iOS app. Note that the display name can be None if it has never been set by the user.""" return self._display_name @property - def project_id(self): + def project_id(self) -> str: """The permanent, globally unique, user-assigned ID of the parent Firebase project.""" return self._project_id - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return False # pylint: disable=protected-access @@ -336,24 +373,31 @@ def __eq__(self, other): class AndroidAppMetadata(_AppMetadata): """Android-specific information about an Android Firebase app.""" - def __init__(self, package_name, name, app_id, display_name, project_id): + def __init__( + self, + package_name: str, + name: str, + app_id: str, + display_name: Optional[str], + project_id: str, + ) -> None: """Clients should not instantiate this class directly.""" super().__init__(name, app_id, display_name, project_id) self._package_name = _check_is_nonempty_string(package_name, 'package_name') @property - def package_name(self): + def package_name(self) -> str: """The canonical package name of this Android app as it would appear in the Play Store.""" return self._package_name - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return (super().__eq__(other) and self.package_name == other.package_name) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash( (self._name, self.app_id, self.display_name, self.project_id, self.package_name)) @@ -361,23 +405,30 @@ def __hash__(self): class IOSAppMetadata(_AppMetadata): """iOS-specific information about an iOS Firebase app.""" - def __init__(self, bundle_id, name, app_id, display_name, project_id): + def __init__( + self, + bundle_id: str, + name: str, + app_id: str, + display_name: Optional[str], + project_id: str, + ) -> None: """Clients should not instantiate this class directly.""" super().__init__(name, app_id, display_name, project_id) self._bundle_id = _check_is_nonempty_string(bundle_id, 'bundle_id') @property - def bundle_id(self): + def bundle_id(self) -> str: """The canonical bundle ID of this iOS app as it would appear in the iOS AppStore.""" return self._bundle_id - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return super().__eq__(other) and self.bundle_id == other.bundle_id - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash((self._name, self.app_id, self.display_name, self.project_id, self.bundle_id)) @@ -390,7 +441,7 @@ class SHACertificate: _SHA_1_RE = re.compile('^[0-9A-Fa-f]{40}$') _SHA_256_RE = re.compile('^[0-9A-Fa-f]{64}$') - def __init__(self, sha_hash, name=None): + def __init__(self, sha_hash: str, name: Optional[str] = None) -> None: """Creates a new SHACertificate instance. Args: @@ -415,7 +466,7 @@ def __init__(self, sha_hash, name=None): 'The supplied certificate hash is neither a valid SHA-1 nor SHA_256 hash.') @property - def name(self): + def name(self) -> Optional[str]: """Returns the fully qualified resource name of this certificate, if known. Returns: @@ -425,7 +476,7 @@ def name(self): return self._name @property - def sha_hash(self): + def sha_hash(self) -> str: """Returns the certificate hash. Returns: @@ -434,7 +485,7 @@ def sha_hash(self): return self._sha_hash @property - def cert_type(self): + def cert_type(self) -> str: """Returns the type of the SHA certificate encoded in the hash. Returns: @@ -442,16 +493,16 @@ def cert_type(self): """ return self._cert_type - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, SHACertificate): return False return (self.name == other.name and self.sha_hash == other.sha_hash and self.cert_type == other.cert_type) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash((self.name, self.sha_hash, self.cert_type)) @@ -469,7 +520,7 @@ class _ProjectManagementService: IOS_APPS_RESOURCE_NAME = 'iosApps' IOS_APP_IDENTIFIER_NAME = 'bundleId' - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -485,73 +536,83 @@ def __init__(self, app): headers={'X-Client-Version': version_header}, timeout=timeout) - def get_android_app_metadata(self, app_id): + def get_android_app_metadata(self, app_id: str) -> AndroidAppMetadata: return self._get_app_metadata( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.ANDROID_APP_IDENTIFIER_NAME, metadata_class=AndroidAppMetadata, app_id=app_id) - def get_ios_app_metadata(self, app_id): + def get_ios_app_metadata(self, app_id: str) -> IOSAppMetadata: return self._get_app_metadata( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.IOS_APP_IDENTIFIER_NAME, metadata_class=IOSAppMetadata, app_id=app_id) - def _get_app_metadata(self, platform_resource_name, identifier_name, metadata_class, app_id): + def _get_app_metadata( + self, + platform_resource_name: str, + identifier_name: str, + metadata_class: Callable[[str, str, str, Optional[str], str], _AppMetadataT], + app_id: str, + ) -> _AppMetadataT: """Retrieves detailed information about an Android or iOS app.""" _check_is_nonempty_string(app_id, 'app_id') path = f'/v1beta1/projects/-/{platform_resource_name}/{app_id}' response = self._make_request('get', path) return metadata_class( response[identifier_name], - name=response['name'], - app_id=response['appId'], - display_name=response.get('displayName') or None, - project_id=response['projectId']) + response['name'], + response['appId'], + response.get('displayName') or None, + response['projectId']) - def set_android_app_display_name(self, app_id, new_display_name): + def set_android_app_display_name(self, app_id: str, new_display_name: Optional[str]) -> None: self._set_display_name( app_id=app_id, new_display_name=new_display_name, platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME) - def set_ios_app_display_name(self, app_id, new_display_name): + def set_ios_app_display_name(self, app_id: str, new_display_name: Optional[str]) -> None: self._set_display_name( app_id=app_id, new_display_name=new_display_name, platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME) - def _set_display_name(self, app_id, new_display_name, platform_resource_name): + def _set_display_name(self, app_id: str, new_display_name: Optional[str], platform_resource_name: str) -> None: """Sets the display name of an Android or iOS app.""" path = f'/v1beta1/projects/-/{platform_resource_name}/{app_id}?updateMask=displayName' request_body = {'displayName': new_display_name} self._make_request('patch', path, json=request_body) - def list_android_apps(self): + def list_android_apps(self) -> list[AndroidApp]: return self._list_apps( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, app_class=AndroidApp) - def list_ios_apps(self): + def list_ios_apps(self) -> list[IOSApp]: return self._list_apps( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, app_class=IOSApp) - def _list_apps(self, platform_resource_name, app_class): + def _list_apps( + self, + platform_resource_name: str, + app_class: Callable[[str, '_ProjectManagementService'], _T], + ) -> list[_T]: """Lists all the Android or iOS apps within the Firebase project.""" path = ( f'/v1beta1/projects/{self._project_id}/{platform_resource_name}?pageSize=' f'{_ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE}' ) response = self._make_request('get', path) - apps_list = [] + apps_list: list[_T] = [] while True: - apps = response.get('apps') + apps = cast(list[dict[str, Any]], response.get('apps', [])) if not apps: break - apps_list.extend(app_class(app_id=app['appId'], service=self) for app in apps) + apps_list.extend(app_class(app['appId'], self) for app in apps) next_page_token = response.get('nextPageToken') if not next_page_token: break @@ -564,7 +625,7 @@ def _list_apps(self, platform_resource_name, app_class): response = self._make_request('get', path) return apps_list - def create_android_app(self, package_name, display_name=None): + def create_android_app(self, package_name: str, display_name: Optional[str] = None) -> AndroidApp: return self._create_app( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.ANDROID_APP_IDENTIFIER_NAME, @@ -572,7 +633,7 @@ def create_android_app(self, package_name, display_name=None): display_name=display_name, app_class=AndroidApp) - def create_ios_app(self, bundle_id, display_name=None): + def create_ios_app(self, bundle_id: str, display_name: Optional[str] = None) -> IOSApp: return self._create_app( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.IOS_APP_IDENTIFIER_NAME, @@ -581,12 +642,13 @@ def create_ios_app(self, bundle_id, display_name=None): app_class=IOSApp) def _create_app( - self, - platform_resource_name, - identifier_name, - identifier, - display_name, - app_class): + self, + platform_resource_name: str, + identifier_name: str, + identifier: str, + display_name: Optional[str], + app_class: Callable[[str, '_ProjectManagementService'], _T], + ) -> _T: """Creates an Android or iOS app.""" _check_is_string_or_none(display_name, 'display_name') path = f'/v1beta1/projects/{self._project_id}/{platform_resource_name}' @@ -596,9 +658,9 @@ def _create_app( response = self._make_request('post', path, json=request_body) operation_name = response['name'] poll_response = self._poll_app_creation(operation_name) - return app_class(app_id=poll_response['appId'], service=self) + return app_class(poll_response['appId'], self) - def _poll_app_creation(self, operation_name): + def _poll_app_creation(self, operation_name: object) -> dict[str, Any]: """Polls the Long-Running Operation repeatedly until it is done with exponential backoff.""" for current_attempt in range(_ProjectManagementService.MAXIMUM_POLLING_ATTEMPTS): delay_factor = pow( @@ -609,7 +671,7 @@ def _poll_app_creation(self, operation_name): poll_response, http_response = self._body_and_response('get', path) done = poll_response.get('done') if done: - response = poll_response.get('response') + response: Optional[dict[str, Any]] = poll_response.get('response') if response: return response @@ -618,45 +680,55 @@ def _poll_app_creation(self, operation_name): http_response=http_response) raise exceptions.DeadlineExceededError('Polling deadline exceeded.') - def get_android_app_config(self, app_id): + def get_android_app_config(self, app_id: str) -> str: return self._get_app_config( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, app_id=app_id) - def get_ios_app_config(self, app_id): + def get_ios_app_config(self, app_id: str) -> str: return self._get_app_config( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, app_id=app_id) - def _get_app_config(self, platform_resource_name, app_id): + def _get_app_config(self, platform_resource_name: str, app_id: str) -> str: path = f'/v1beta1/projects/-/{platform_resource_name}/{app_id}/config' response = self._make_request('get', path) # In Python 2.7, the base64 module works with strings, while in Python 3, it works with # bytes objects. This line works in both versions. return base64.standard_b64decode(response['configFileContents']).decode(encoding='utf-8') - def get_sha_certificates(self, app_id): + def get_sha_certificates(self, app_id: str) -> list[SHACertificate]: path = f'/v1beta1/projects/-/androidApps/{app_id}/sha' - response = self._make_request('get', path) - cert_list = response.get('certificates') or [] + response: dict[str, Any] = self._make_request('get', path) + cert_list: list[dict[str, Any]] = response.get('certificates') or [] return [SHACertificate(sha_hash=cert['shaHash'], name=cert['name']) for cert in cert_list] - def add_sha_certificate(self, app_id, certificate_to_add): + def add_sha_certificate(self, app_id: str, certificate_to_add: SHACertificate) -> None: path = f'/v1beta1/projects/-/androidApps/{app_id}/sha' sha_hash = _check_not_none(certificate_to_add, 'certificate_to_add').sha_hash cert_type = certificate_to_add.cert_type request_body = {'shaHash': sha_hash, 'certType': cert_type} self._make_request('post', path, json=request_body) - def delete_sha_certificate(self, certificate_to_delete): + def delete_sha_certificate(self, certificate_to_delete: SHACertificate) -> None: name = _check_not_none(certificate_to_delete, 'certificate_to_delete').name path = f'/v1beta1/{name}' self._make_request('delete', path) - def _make_request(self, method, url, json=None): + def _make_request( + self, + method: str, + url: str, + json: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: body, _ = self._body_and_response(method, url, json) return body - def _body_and_response(self, method, url, json=None): + def _body_and_response( + self, + method: str, + url: str, + json: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], requests.Response]: try: return self._client.body_and_response(method=method, url=url, json=json) except requests.exceptions.RequestException as error: diff --git a/firebase_admin/remote_config.py b/firebase_admin/remote_config.py index 880804d3d..b6b4955e8 100644 --- a/firebase_admin/remote_config.py +++ b/firebase_admin/remote_config.py @@ -20,13 +20,32 @@ import json import logging import threading -from typing import Dict, Optional, Literal, Union, Any -from enum import Enum +import enum import re import hashlib +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + import requests -from firebase_admin import App, _http_client, _utils + import firebase_admin +from firebase_admin import _http_client +from firebase_admin import _utils +from firebase_admin import exceptions + +if TYPE_CHECKING: + from _typeshed import ConvertibleToFloat + +__all__ = ( + 'MAX_CONDITION_RECURSION_DEPTH', + 'CustomSignalOperator', + 'PercentConditionOperator', + 'ServerConfig', + 'ServerTemplate', + 'ValueSource', + 'get_server_template', + 'init_server_template', +) # Set up logging (you can customize the level and output) logging.basicConfig(level=logging.INFO) @@ -36,7 +55,8 @@ MAX_CONDITION_RECURSION_DEPTH = 10 ValueSource = Literal['default', 'remote', 'static'] # Define the ValueSource type -class PercentConditionOperator(Enum): + +class PercentConditionOperator(enum.Enum): """Enum representing the available operators for percent conditions. """ LESS_OR_EQUAL = "LESS_OR_EQUAL" @@ -44,7 +64,8 @@ class PercentConditionOperator(Enum): BETWEEN = "BETWEEN" UNKNOWN = "UNKNOWN" -class CustomSignalOperator(Enum): + +class CustomSignalOperator(enum.Enum): """Enum representing the available operators for custom signal conditions. """ STRING_CONTAINS = "STRING_CONTAINS" @@ -65,9 +86,10 @@ class CustomSignalOperator(Enum): SEMANTIC_VERSION_GREATER_EQUAL = "SEMANTIC_VERSION_GREATER_EQUAL" UNKNOWN = "UNKNOWN" + class _ServerTemplateData: """Parses, validates and encapsulates template data and metadata.""" - def __init__(self, template_data): + def __init__(self, template_data: dict[str, Any]) -> None: """Initializes a new ServerTemplateData instance. Args: @@ -82,7 +104,7 @@ def __init__(self, template_data): else: raise ValueError('Remote Config parameters must be a non-null object') else: - self._parameters = {} + self._parameters: dict[str, dict[str, Any]] = {} if 'conditions' in template_data: if template_data['conditions'] is not None: @@ -90,28 +112,28 @@ def __init__(self, template_data): else: raise ValueError('Remote Config conditions must be a non-null object') else: - self._conditions = [] + self._conditions: list[dict[str, Any]] = [] - self._version = '' + self._version: str = '' if 'version' in template_data: self._version = template_data['version'] - self._etag = '' + self._etag: str = '' if 'etag' in template_data and isinstance(template_data['etag'], str): self._etag = template_data['etag'] self._template_data_json = json.dumps(template_data) @property - def parameters(self): + def parameters(self) -> dict[str, dict[str, Any]]: return self._parameters @property - def etag(self): + def etag(self) -> str: return self._etag @property - def version(self): + def version(self) -> str: return self._version @property @@ -119,13 +141,17 @@ def conditions(self): return self._conditions @property - def template_data_json(self): + def template_data_json(self) -> str: return self._template_data_json class ServerTemplate: """Represents a Server Template with implementations for loading and evaluating the template.""" - def __init__(self, app: App = None, default_config: Optional[Dict[str, str]] = None): + def __init__( + self, + app: Optional[firebase_admin.App] = None, + default_config: Optional[dict[str, str]] = None, + ) -> None: """Initializes a ServerTemplate instance. Args: @@ -137,8 +163,8 @@ def __init__(self, app: App = None, default_config: Optional[Dict[str, str]] = N _REMOTE_CONFIG_ATTRIBUTE, _RemoteConfigService) # This gets set when the template is # fetched from RC servers via the load API, or via the set API. - self._cache = None - self._stringified_default_config: Dict[str, str] = {} + self._cache: Optional[_ServerTemplateData] = None + self._stringified_default_config: dict[str, str] = {} self._lock = threading.RLock() # RC stores all remote values as string, but it's more intuitive @@ -148,13 +174,13 @@ def __init__(self, app: App = None, default_config: Optional[Dict[str, str]] = N for key in default_config: self._stringified_default_config[key] = str(default_config[key]) - async def load(self): + async def load(self) -> None: """Fetches the server template and caches the data.""" rc_server_template = await self._rc_service.get_server_template() with self._lock: self._cache = rc_server_template - def evaluate(self, context: Optional[Dict[str, Union[str, int]]] = None) -> 'ServerConfig': + def evaluate(self, context: Optional[dict[str, Union[str, int]]] = None) -> 'ServerConfig': """Evaluates the cached server template to produce a ServerConfig. Args: @@ -170,14 +196,14 @@ def evaluate(self, context: Optional[Dict[str, Union[str, int]]] = None) -> 'Ser raise ValueError("""No Remote Config Server template in cache. Call load() before calling evaluate().""") context = context or {} - config_values = {} + config_values: dict[str, _Value] = {} with self._lock: template_conditions = self._cache.conditions template_parameters = self._cache.parameters # Initializes config Value objects with default values. - if self._stringified_default_config is not None: + if self._stringified_default_config: for key, value in self._stringified_default_config.items(): config_values[key] = _Value('default', value) self._evaluator = _ConditionEvaluator(template_conditions, @@ -185,7 +211,7 @@ def evaluate(self, context: Optional[Dict[str, Union[str, int]]] = None) -> 'Ser config_values) return ServerConfig(config_values=self._evaluator.evaluate()) - def set(self, template_data_json: str): + def set(self, template_data_json: str) -> None: """Updates the cache to store the given template is of type ServerTemplateData. Args: @@ -197,7 +223,7 @@ def set(self, template_data_json: str): with self._lock: self._cache = template_data - def to_json(self): + def to_json(self) -> str: """Provides the server template in a JSON format to be used for initialization later.""" if not self._cache: raise ValueError("""No Remote Config Server template in cache. @@ -209,30 +235,30 @@ def to_json(self): class ServerConfig: """Represents a Remote Config Server Side Config.""" - def __init__(self, config_values): + def __init__(self, config_values: dict[str, '_Value']): self._config_values = config_values # dictionary of param key to values - def get_boolean(self, key): + def get_boolean(self, key: str) -> bool: """Returns the value as a boolean.""" return self._get_value(key).as_boolean() - def get_string(self, key): + def get_string(self, key: str) -> str: """Returns the value as a string.""" return self._get_value(key).as_string() - def get_int(self, key): + def get_int(self, key: str) -> int: """Returns the value as an integer.""" return self._get_value(key).as_int() - def get_float(self, key): + def get_float(self, key: str) -> float: """Returns the value as a float.""" return self._get_value(key).as_float() - def get_value_source(self, key): + def get_value_source(self, key: str) -> ValueSource: """Returns the source of the value.""" return self._get_value(key).get_source() - def _get_value(self, key): + def _get_value(self, key: str) -> "_Value": return self._config_values.get(key, _Value('static')) @@ -240,7 +266,7 @@ class _RemoteConfigService: """Internal class that facilitates sending requests to the Firebase Remote Config backend API. """ - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: """Initialize a JsonHttpClient with necessary inputs. Args: @@ -258,7 +284,7 @@ def __init__(self, app): base_url=remote_config_base_url, headers=rc_headers, timeout=timeout) - async def get_server_template(self): + async def get_server_template(self) -> _ServerTemplateData: """Requests for a server template and converts the response to an instance of ServerTemplateData for storing the template parameters and conditions.""" try: @@ -271,12 +297,12 @@ async def get_server_template(self): template_data['etag'] = headers.get('etag') return _ServerTemplateData(template_data) - def _get_url(self): + def _get_url(self) -> str: """Returns project prefix for url, in the format of /v1/projects/${projectId}""" return f"/v1/projects/{self._project_id}/namespaces/firebase-server/serverRemoteConfig" @classmethod - def _handle_remote_config_error(cls, error: Any): + def _handle_remote_config_error(cls, error: requests.RequestException) -> exceptions.FirebaseError: """Handles errors received from the Cloud Functions API.""" return _utils.handle_platform_error_from_requests(error) @@ -284,13 +310,19 @@ def _handle_remote_config_error(cls, error: Any): class _ConditionEvaluator: """Internal class that facilitates sending requests to the Firebase Remote Config backend API.""" - def __init__(self, conditions, parameters, context, config_values): + def __init__( + self, + conditions: list[dict[str, Any]], + parameters: dict[str, dict[str, Any]], + context: dict[str, Any], + config_values: dict[str, '_Value'], + ) -> None: self._context = context self._conditions = conditions self._parameters = parameters self._config_values = config_values - def evaluate(self): + def evaluate(self) -> dict[str, '_Value']: """Internal function that evaluates the cached server template to produce a ServerConfig""" evaluated_conditions = self.evaluate_conditions(self._conditions, self._context) @@ -298,9 +330,9 @@ def evaluate(self): # Overlays config Value objects derived by evaluating the template. if self._parameters: for key, parameter in self._parameters.items(): - conditional_values = parameter.get('conditionalValues', {}) - default_value = parameter.get('defaultValue', {}) - parameter_value_wrapper = None + conditional_values: dict[str, Any] = parameter.get('conditionalValues', {}) + default_value: dict[str, Any] = parameter.get('defaultValue', {}) + parameter_value_wrapper: Optional[dict[str, Any]] = None # Iterates in order over condition list. If there is a value associated # with a condition, this checks if the condition is true. if evaluated_conditions: @@ -314,6 +346,7 @@ def evaluate(self): continue if parameter_value_wrapper: + # possible issue: Is `None` a valid value for `_Value`? parameter_value = parameter_value_wrapper.get('value') self._config_values[key] = _Value('remote', parameter_value) continue @@ -328,7 +361,11 @@ def evaluate(self): self._config_values[key] = _Value('remote', default_value.get('value')) return self._config_values - def evaluate_conditions(self, conditions, context)-> Dict[str, bool]: + def evaluate_conditions( + self, + conditions: list[dict[str, Any]], + context: dict[str, Any], + )-> dict[str, bool]: """Evaluates a list of conditions and returns a dictionary of results. Args: @@ -338,15 +375,20 @@ def evaluate_conditions(self, conditions, context)-> Dict[str, bool]: Returns: A dictionary that maps condition names to boolean evaluation results. """ - evaluated_conditions = {} + evaluated_conditions: dict[Any, Any] = {} for condition in conditions: + # possible issue: does condition always have `name`? evaluated_conditions[condition.get('name')] = self.evaluate_condition( - condition.get('condition'), context + condition['condition'], context ) return evaluated_conditions - def evaluate_condition(self, condition, context, - nesting_level: int = 0) -> bool: + def evaluate_condition( + self, + condition: dict[str, Any], + context: dict[str, Any], + nesting_level: int = 0, + ) -> bool: """Recursively evaluates a condition. Args: @@ -361,25 +403,28 @@ def evaluate_condition(self, condition, context, logger.warning("Maximum condition recursion depth exceeded.") return False if condition.get('orCondition') is not None: - return self.evaluate_or_condition(condition.get('orCondition'), + return self.evaluate_or_condition(condition['orCondition'], context, nesting_level + 1) if condition.get('andCondition') is not None: - return self.evaluate_and_condition(condition.get('andCondition'), + return self.evaluate_and_condition(condition['andCondition'], context, nesting_level + 1) if condition.get('true') is not None: return True if condition.get('false') is not None: return False if condition.get('percent') is not None: - return self.evaluate_percent_condition(condition.get('percent'), context) + return self.evaluate_percent_condition(condition['percent'], context) if condition.get('customSignal') is not None: - return self.evaluate_custom_signal_condition(condition.get('customSignal'), context) + return self.evaluate_custom_signal_condition(condition['customSignal'], context) logger.warning("Unknown condition type encountered.") return False - def evaluate_or_condition(self, or_condition, - context, - nesting_level: int = 0) -> bool: + def evaluate_or_condition( + self, + or_condition: dict[str, Any], + context: dict[str, Any], + nesting_level: int = 0, + ) -> bool: """Evaluates an OR condition. Args: @@ -390,16 +435,19 @@ def evaluate_or_condition(self, or_condition, Returns: True if any of the subconditions are true, False otherwise. """ - sub_conditions = or_condition.get('conditions') or [] + sub_conditions: list[dict[str, Any]] = or_condition.get('conditions') or [] for sub_condition in sub_conditions: result = self.evaluate_condition(sub_condition, context, nesting_level + 1) if result: return True return False - def evaluate_and_condition(self, and_condition, - context, - nesting_level: int = 0) -> bool: + def evaluate_and_condition( + self, + and_condition: dict[str, Any], + context: dict[str, Any], + nesting_level: int = 0, + ) -> bool: """Evaluates an AND condition. Args: @@ -410,15 +458,18 @@ def evaluate_and_condition(self, and_condition, Returns: True if all of the subconditions are met; False otherwise. """ - sub_conditions = and_condition.get('conditions') or [] + sub_conditions: list[dict[str, Any]] = and_condition.get('conditions') or [] for sub_condition in sub_conditions: result = self.evaluate_condition(sub_condition, context, nesting_level + 1) if not result: return False return True - def evaluate_percent_condition(self, percent_condition, - context) -> bool: + def evaluate_percent_condition( + self, + percent_condition: dict[str, Any], + context: dict[str, Any], + ) -> bool: """Evaluates a percent condition. Args: @@ -462,6 +513,7 @@ def evaluate_percent_condition(self, percent_condition, return norm_percent_lower_bound < instance_micro_percentile <= norm_percent_upper_bound logger.warning("Unknown percent operator: %s", percent_operator) return False + def hash_seeded_randomization_id(self, seeded_randomization_id: str) -> int: """Hashes a seeded randomization ID. @@ -476,8 +528,11 @@ def hash_seeded_randomization_id(self, seeded_randomization_id: str) -> int: hash64 = hash_object.hexdigest() return abs(int(hash64, 16)) - def evaluate_custom_signal_condition(self, custom_signal_condition, - context) -> bool: + def evaluate_custom_signal_condition( + self, + custom_signal_condition: dict[str, Any], + context: dict[str, Any], + ) -> bool: """Evaluates a custom signal condition. Args: @@ -487,18 +542,16 @@ def evaluate_custom_signal_condition(self, custom_signal_condition, Returns: True if the condition is met, False otherwise. """ - custom_signal_operator = custom_signal_condition.get('customSignalOperator') or {} - custom_signal_key = custom_signal_condition.get('customSignalKey') or {} - target_custom_signal_values = ( - custom_signal_condition.get('targetCustomSignalValues') or {}) + custom_signal_operator: Optional[str] = custom_signal_condition.get('customSignalOperator') + custom_signal_key: Optional[str] = custom_signal_condition.get('customSignalKey') + target_custom_signal_values: Optional[list[Any]] = ( + custom_signal_condition.get('targetCustomSignalValues')) - if not all([custom_signal_operator, custom_signal_key, target_custom_signal_values]): + if not (custom_signal_operator and custom_signal_key and target_custom_signal_values): logger.warning("Missing operator, key, or target values for custom signal condition.") return False - if not target_custom_signal_values: - return False - actual_custom_signal_value = context.get(custom_signal_key) or {} + actual_custom_signal_value: Optional[Any] = context.get(custom_signal_key) if not actual_custom_signal_value: logger.debug("Custom signal value not found in context: %s", custom_signal_key) @@ -519,7 +572,7 @@ def evaluate_custom_signal_condition(self, custom_signal_condition, if custom_signal_operator == CustomSignalOperator.STRING_CONTAINS_REGEX.value: return self._compare_strings(target_custom_signal_values, actual_custom_signal_value, - re.search) + lambda pattern, string: bool(re.search(pattern, string))) # For numeric operators only one target value is allowed. if custom_signal_operator == CustomSignalOperator.NUMERIC_LESS_THAN.value: @@ -587,7 +640,12 @@ def evaluate_custom_signal_condition(self, custom_signal_condition, logger.warning("Unknown custom signal operator: %s", custom_signal_operator) return False - def _compare_strings(self, target_values, actual_value, predicate_fn) -> bool: + def _compare_strings( + self, + target_values: list[str], + actual_value: str, + predicate_fn: Callable[[str, str], bool], + ) -> bool: """Compares the actual string value of a signal against a list of target values. Args: @@ -607,7 +665,13 @@ def _compare_strings(self, target_values, actual_value, predicate_fn) -> bool: return True return False - def _compare_numbers(self, custom_signal_key, target_value, actual_value, predicate_fn) -> bool: + def _compare_numbers( + self, + custom_signal_key: str, + target_value: 'ConvertibleToFloat', + actual_value: 'ConvertibleToFloat', + predicate_fn: Callable[[float], bool], + ) -> bool: try: target = float(target_value) actual = float(actual_value) @@ -618,8 +682,13 @@ def _compare_numbers(self, custom_signal_key, target_value, actual_value, predic custom_signal_key) return False - def _compare_semantic_versions(self, custom_signal_key, - target_value, actual_value, predicate_fn) -> bool: + def _compare_semantic_versions( + self, + custom_signal_key: str, + target_value: str, + actual_value: str, + predicate_fn: Callable[[Literal[-1, 0, 1]], bool], + ) -> bool: """Compares the actual semantic version value of a signal against a target value. Calls the predicate function with -1, 0, 1 if actual is less than, equal to, or greater than target. @@ -637,8 +706,13 @@ def _compare_semantic_versions(self, custom_signal_key, return self._compare_versions(custom_signal_key, str(actual_value), str(target_value), predicate_fn) - def _compare_versions(self, custom_signal_key, - sem_version_1, sem_version_2, predicate_fn) -> bool: + def _compare_versions( + self, + custom_signal_key: str, + sem_version_1: str, + sem_version_2: str, + predicate_fn: Callable[[Literal[-1, 0, 1]], bool], + ) -> bool: """Compares two semantic version strings. Args: @@ -671,7 +745,11 @@ def _compare_versions(self, custom_signal_key, custom_signal_key) return False -async def get_server_template(app: App = None, default_config: Optional[Dict[str, str]] = None): + +async def get_server_template( + app: Optional[firebase_admin.App] = None, + default_config: Optional[dict[str, str]] = None, +) -> ServerTemplate: """Initializes a new ServerTemplate instance and fetches the server template. Args: @@ -686,8 +764,12 @@ async def get_server_template(app: App = None, default_config: Optional[Dict[str await template.load() return template -def init_server_template(app: App = None, default_config: Optional[Dict[str, str]] = None, - template_data_json: Optional[str] = None): + +def init_server_template( + app: Optional[firebase_admin.App] = None, + default_config: Optional[dict[str, str]] = None, + template_data_json: Optional[str] = None, +) -> ServerTemplate: """Initializes a new ServerTemplate instance. Args: @@ -705,6 +787,7 @@ def init_server_template(app: App = None, default_config: Optional[Dict[str, str template.set(template_data_json) return template + class _Value: """Represents a value fetched from Remote Config. """ @@ -714,7 +797,7 @@ class _Value: DEFAULT_VALUE_FOR_FLOAT_NUMBER = 0.0 BOOLEAN_TRUTHY_VALUES = ['1', 'true', 't', 'yes', 'y', 'on'] - def __init__(self, source: ValueSource, value: str = DEFAULT_VALUE_FOR_STRING): + def __init__(self, source: ValueSource, value: Any = DEFAULT_VALUE_FOR_STRING) -> None: """Initializes a Value instance. Args: @@ -724,7 +807,7 @@ def __init__(self, source: ValueSource, value: str = DEFAULT_VALUE_FOR_STRING): "remote" indicates the value was defined by config produced by evaluating a template. value: The string value. """ - self.source = source + self.source: ValueSource = source self.value = value def as_string(self) -> str: @@ -739,7 +822,7 @@ def as_boolean(self) -> bool: return self.DEFAULT_VALUE_FOR_BOOLEAN return str(self.value).lower() in self.BOOLEAN_TRUTHY_VALUES - def as_int(self) -> float: + def as_int(self) -> int: """Returns the value as a number.""" if self.source == 'static': return self.DEFAULT_VALUE_FOR_INTEGER diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index d2f004be6..1a3d33c31 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -18,6 +18,8 @@ Firebase apps. This requires the ``google-cloud-storage`` Python module. """ +from typing import Optional + # pylint: disable=import-error,no-name-in-module try: from google.cloud import storage @@ -25,12 +27,16 @@ raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' 'to install the "google-cloud-storage" module.') from exception +from google.auth import credentials + +import firebase_admin from firebase_admin import _utils +__all__ = ('bucket',) _STORAGE_ATTRIBUTE = '_storage' -def bucket(name=None, app=None) -> storage.Bucket: +def bucket(name: Optional[str] = None, app: Optional[firebase_admin.App] = None) -> storage.Bucket: """Returns a handle to a Google Cloud Storage bucket. If the name argument is not provided, uses the 'storageBucket' option specified when @@ -59,20 +65,25 @@ class _StorageClient: 'x-goog-api-client': _utils.get_metrics_header(), } - def __init__(self, credentials, project, default_bucket): + def __init__( + self, + credentials: credentials.Credentials, + project: Optional[str], + default_bucket: Optional[str], + ) -> None: self._client = storage.Client( credentials=credentials, project=project, extra_headers=self.STORAGE_HEADERS) self._default_bucket = default_bucket @classmethod - def from_app(cls, app): + def from_app(cls, app: firebase_admin.App) -> '_StorageClient': credentials = app.credential.get_credential() default_bucket = app.options.get('storageBucket') # Specifying project ID is not required, but providing it when available # significantly speeds up the initialization of the storage client. return _StorageClient(credentials, app.project_id, default_bucket) - def bucket(self, name=None): + def bucket(self, name: Optional[str] = None) -> storage.Bucket: """Returns a handle to the specified Cloud Storage Bucket.""" bucket_name = name if name is not None else self._default_bucket if bucket_name is None: diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py index 9e713d988..e1dc0d8b1 100644 --- a/firebase_admin/tenant_mgt.py +++ b/firebase_admin/tenant_mgt.py @@ -20,6 +20,8 @@ import re import threading +from collections.abc import Callable, Iterator +from typing import Any, Optional, cast import requests @@ -29,12 +31,6 @@ from firebase_admin import _http_client from firebase_admin import _utils - -_TENANT_MGT_ATTRIBUTE = '_tenant_mgt' -_MAX_LIST_TENANTS_RESULTS = 100 -_DISPLAY_NAME_PATTERN = re.compile('^[a-zA-Z][a-zA-Z0-9-]{3,19}$') - - __all__ = [ 'ListTenantsPage', 'Tenant', @@ -49,12 +45,15 @@ 'update_tenant', ] +_TENANT_MGT_ATTRIBUTE = '_tenant_mgt' +_MAX_LIST_TENANTS_RESULTS = 100 +_DISPLAY_NAME_PATTERN = re.compile('^[a-zA-Z][a-zA-Z0-9-]{3,19}$') TenantIdMismatchError = _auth_utils.TenantIdMismatchError TenantNotFoundError = _auth_utils.TenantNotFoundError -def auth_for_tenant(tenant_id, app=None): +def auth_for_tenant(tenant_id: str, app: Optional[firebase_admin.App] = None) -> auth.Client: """Gets an Auth Client instance scoped to the given tenant ID. Args: @@ -71,7 +70,7 @@ def auth_for_tenant(tenant_id, app=None): return tenant_mgt_service.auth_for_tenant(tenant_id) -def get_tenant(tenant_id, app=None): +def get_tenant(tenant_id: str, app: Optional[firebase_admin.App] = None) -> 'Tenant': """Gets the tenant corresponding to the given ``tenant_id``. Args: @@ -91,7 +90,11 @@ def get_tenant(tenant_id, app=None): def create_tenant( - display_name, allow_password_sign_up=None, enable_email_link_sign_in=None, app=None): + display_name: str, + allow_password_sign_up: Optional[bool] = None, + enable_email_link_sign_in: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, +) -> 'Tenant': """Creates a new tenant from the given options. Args: @@ -117,8 +120,12 @@ def create_tenant( def update_tenant( - tenant_id, display_name=None, allow_password_sign_up=None, enable_email_link_sign_in=None, - app=None): + tenant_id: str, + display_name: Optional[str] = None, + allow_password_sign_up: Optional[bool] = None, + enable_email_link_sign_in: Optional[bool] = None, + app: Optional[firebase_admin.App] = None, +) -> 'Tenant': """Updates an existing tenant with the given options. Args: @@ -144,7 +151,7 @@ def update_tenant( enable_email_link_sign_in=enable_email_link_sign_in) -def delete_tenant(tenant_id, app=None): +def delete_tenant(tenant_id: str, app: Optional[firebase_admin.App] = None) -> None: """Deletes the tenant corresponding to the given ``tenant_id``. Args: @@ -160,7 +167,11 @@ def delete_tenant(tenant_id, app=None): tenant_mgt_service.delete_tenant(tenant_id) -def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=None): +def list_tenants( + page_token: Optional[str] = None, + max_results: int = _MAX_LIST_TENANTS_RESULTS, + app: Optional[firebase_admin.App] = None, +) -> 'ListTenantsPage': """Retrieves a page of tenants from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -183,12 +194,12 @@ def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=Non FirebaseError: If an error occurs while retrieving the user accounts. """ tenant_mgt_service = _get_tenant_mgt_service(app) - def download(page_token, max_results): + def download(page_token: Optional[str], max_results: int) -> dict[str, Any]: return tenant_mgt_service.list_tenants(page_token, max_results) return ListTenantsPage(download, page_token, max_results) -def _get_tenant_mgt_service(app): +def _get_tenant_mgt_service(app: Optional[firebase_admin.App]) -> '_TenantManagementService': return _utils.get_app_service(app, _TENANT_MGT_ATTRIBUTE, _TenantManagementService) @@ -203,7 +214,7 @@ class Tenant: such as the display name, tenant identifier and email authentication configuration. """ - def __init__(self, data): + def __init__(self, data: dict[str, Any]) -> None: if not isinstance(data, dict): raise ValueError(f'Invalid data argument in Tenant constructor: {data}') if not 'name' in data: @@ -212,20 +223,20 @@ def __init__(self, data): self._data = data @property - def tenant_id(self): + def tenant_id(self) -> str: name = self._data['name'] return name.split('/')[-1] @property - def display_name(self): + def display_name(self) -> Optional[str]: return self._data.get('displayName') @property - def allow_password_sign_up(self): + def allow_password_sign_up(self) -> bool: return self._data.get('allowPasswordSignup', False) @property - def enable_email_link_sign_in(self): + def enable_email_link_sign_in(self) -> bool: return self._data.get('enableEmailLinkSignin', False) @@ -234,17 +245,17 @@ class _TenantManagementService: TENANT_MGT_URL = 'https://identitytoolkit.googleapis.com/v2' - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: credential = app.credential.get_credential() version_header = f'Python/Admin/{firebase_admin.__version__}' base_url = f'{self.TENANT_MGT_URL}/projects/{app.project_id}' self.app = app self.client = _http_client.JsonHttpClient( credential=credential, base_url=base_url, headers={'X-Client-Version': version_header}) - self.tenant_clients = {} + self.tenant_clients: dict[str, auth.Client] = {} self.lock = threading.RLock() - def auth_for_tenant(self, tenant_id): + def auth_for_tenant(self, tenant_id: str) -> auth.Client: """Gets an Auth Client instance scoped to the given tenant ID.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( @@ -256,9 +267,9 @@ def auth_for_tenant(self, tenant_id): client = auth.Client(self.app, tenant_id=tenant_id) self.tenant_clients[tenant_id] = client - return client + return client - def get_tenant(self, tenant_id): + def get_tenant(self, tenant_id: str) -> Tenant: """Gets the tenant corresponding to the given ``tenant_id``.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( @@ -271,10 +282,14 @@ def get_tenant(self, tenant_id): return Tenant(body) def create_tenant( - self, display_name, allow_password_sign_up=None, enable_email_link_sign_in=None): + self, + display_name: str, + allow_password_sign_up: Optional[bool] = None, + enable_email_link_sign_in: Optional[bool] = None, + ) -> Tenant: """Creates a new tenant from the given parameters.""" - payload = {'displayName': _validate_display_name(display_name)} + payload: dict[str, Any] = {'displayName': _validate_display_name(display_name)} if allow_password_sign_up is not None: payload['allowPasswordSignup'] = _auth_utils.validate_boolean( allow_password_sign_up, 'allowPasswordSignup') @@ -289,13 +304,17 @@ def create_tenant( return Tenant(body) def update_tenant( - self, tenant_id, display_name=None, allow_password_sign_up=None, - enable_email_link_sign_in=None): + self, + tenant_id: str, + display_name: Optional[str] = None, + allow_password_sign_up: Optional[bool] = None, + enable_email_link_sign_in: Optional[bool] = None + ) -> Tenant: """Updates the specified tenant with the given parameters.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError('Tenant ID must be a non-empty string.') - payload = {} + payload: dict[str, Any] = {} if display_name is not None: payload['displayName'] = _validate_display_name(display_name) if allow_password_sign_up is not None: @@ -317,18 +336,22 @@ def update_tenant( raise _auth_utils.handle_auth_backend_error(error) return Tenant(body) - def delete_tenant(self, tenant_id): + def delete_tenant(self, tenant_id: str) -> None: """Deletes the tenant corresponding to the given ``tenant_id``.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( - f'Invalid tenant ID: {tenant_id}. Tenant ID must be a non-empty string.') + 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) try: self.client.request('delete', f'/tenants/{tenant_id}') except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - def list_tenants(self, page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS): + def list_tenants( + self, + page_token: Optional[str] = None, + max_results: int = _MAX_LIST_TENANTS_RESULTS, + ) -> dict[str, Any]: """Retrieves a batch of tenants.""" if page_token is not None: if not isinstance(page_token, str) or not page_token: @@ -340,7 +363,7 @@ def list_tenants(self, page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS): 'Max results must be a positive integer less than or equal to ' f'{_MAX_LIST_TENANTS_RESULTS}.') - payload = {'pageSize': max_results} + payload: dict[str, Any] = {'pageSize': max_results} if page_token: payload['pageToken'] = page_token try: @@ -357,27 +380,32 @@ class ListTenantsPage: through all tenants in the Firebase project starting from this page. """ - def __init__(self, download, page_token, max_results): + def __init__( + self, + download: Callable[[Optional[str], int], dict[str, Any]], + page_token: Optional[str], + max_results: int, + ) -> None: self._download = download self._max_results = max_results self._current = download(page_token, max_results) @property - def tenants(self): + def tenants(self) -> list[Tenant]: """A list of ``ExportedUserRecord`` instances available in this page.""" return [Tenant(data) for data in self._current.get('tenants', [])] @property - def next_page_token(self): + def next_page_token(self) -> str: """Page token string for the next page (empty string indicates no more pages).""" return self._current.get('nextPageToken', '') @property - def has_next_page(self): + def has_next_page(self) -> bool: """A boolean indicating whether more pages are available.""" return bool(self.next_page_token) - def get_next_page(self): + def get_next_page(self) -> Optional['ListTenantsPage']: """Retrieves the next page of tenants, if available. Returns: @@ -408,16 +436,16 @@ class _TenantIterator: of entries in memory. """ - def __init__(self, current_page): + def __init__(self, current_page: ListTenantsPage) -> None: if not current_page: raise ValueError('Current page must not be None.') self._current_page = current_page self._index = 0 - def __next__(self): + def __next__(self) -> Tenant: if self._index == len(self._current_page.tenants): if self._current_page.has_next_page: - self._current_page = self._current_page.get_next_page() + self._current_page = cast(ListTenantsPage, self._current_page.get_next_page()) self._index = 0 if self._index < len(self._current_page.tenants): result = self._current_page.tenants[self._index] @@ -425,11 +453,11 @@ def __next__(self): return result raise StopIteration - def __iter__(self): + def __iter__(self) -> Iterator[Tenant]: return self -def _validate_display_name(display_name): +def _validate_display_name(display_name: Any) -> str: if not isinstance(display_name, str): raise ValueError('Invalid type for displayName') if not _DISPLAY_NAME_PATTERN.search(display_name): diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 000000000..772975727 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,33 @@ +{ + "pythonVersion": "3.9", + "typeCheckingMode": "strict", + + "include": ["firebase_admin"], + + "ignore": [ + "integration", + "snippets", + "tests", + "setup.py", + ], + + // Suppress import cycle errors (using forward references as needed) + "reportImportCycles": "none", + + // Allow dependencies without type annotations or stubs + "reportIncompleteStub": "none", + "reportMissingTypeStubs": "none", + + // Permit usage of private members across modules + "reportPrivateUsage": "none", + + // Allow `isinstance` for type assertions and runtime checks + "reportUnnecessaryIsInstance": "none", + + // Warn when a previously ignored type check is no longer needed + "reportUnnecessaryTypeIgnoreComment": "warning", + "reportMissingParameterType": "warning", + "reportUnknownArgumentType": "warning", + "reportUnknownMemberType": "warning", + "reportUnknownVariableType": "warning" +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ff15072a6..6315281f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,10 +6,14 @@ pytest-localserver >= 0.4.1 pytest-asyncio >= 0.26.0 pytest-mock >= 3.6.1 respx == 0.22.0 +pyright >= 1.1.402 cachecontrol >= 0.14.3 google-api-core[grpc] >= 2.25.1, < 3.0.0dev; platform.python_implementation != 'PyPy' google-cloud-firestore >= 2.21.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 3.1.1 pyjwt[crypto] >= 2.10.1 -httpx[http2] == 0.28.1 \ No newline at end of file +httpx[http2] == 0.28.1 +typing-extensions >= 4.12.0 +types-requests +types-httplib2 \ No newline at end of file diff --git a/setup.py b/setup.py index 21e29332e..1876070c6 100644 --- a/setup.py +++ b/setup.py @@ -21,8 +21,7 @@ from setuptools import setup -(major, minor) = (sys.version_info.major, sys.version_info.minor) -if major != 3 or minor < 9: +if sys.version_info < (3, 9): print('firebase_admin requires python >= 3.9', file=sys.stderr) sys.exit(1) @@ -43,6 +42,9 @@ 'google-cloud-storage>=3.1.1', 'pyjwt[crypto] >= 2.10.1', 'httpx[http2] == 0.28.1', + 'typing-extensions >= 4.12.0' + 'types-requests' + 'types-httplib2' ] setup( @@ -72,5 +74,6 @@ 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: 3.13', 'License :: OSI Approved :: Apache Software License', + 'Typing :: Typed', ], ) From 9774306574378544080d722af3d84dbb91d9c2cb Mon Sep 17 00:00:00 2001 From: ViktorSky Date: Tue, 1 Jul 2025 14:28:48 -0300 Subject: [PATCH 10/13] - remove redundant overriding methods - fix some pylint issues --- firebase_admin/_auth_utils.py | 117 -------------------------------- firebase_admin/_sseclient.py | 2 +- firebase_admin/db.py | 8 ++- firebase_admin/ml.py | 9 ++- firebase_admin/remote_config.py | 5 +- firebase_admin/tenant_mgt.py | 2 +- 6 files changed, 19 insertions(+), 124 deletions(-) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index e702ff8f2..9cb7d5774 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -24,13 +24,11 @@ Literal, Optional, Protocol, - Union, cast, overload, ) from urllib import parse -import httpx import requests from typing_extensions import Self, TypeVar @@ -476,28 +474,12 @@ class UidAlreadyExistsError(exceptions.AlreadyExistsError): default_message = 'The user with the provided uid already exists' - def __init__( - self, - message: str, - cause: Optional[Exception], - http_response: Optional[Union[httpx.Response, requests.Response]], - ) -> None: - super().__init__(message, cause, http_response) - class EmailAlreadyExistsError(exceptions.AlreadyExistsError): """The user with the provided email already exists.""" default_message = 'The user with the provided email already exists' - def __init__( - self, - message: str, - cause: Optional[Exception], - http_response: Optional[Union[httpx.Response, requests.Response]], - ) -> None: - super().__init__(message, cause, http_response) - class InsufficientPermissionError(exceptions.PermissionDeniedError): """The credential used to initialize the SDK lacks required permissions.""" @@ -507,169 +489,70 @@ class InsufficientPermissionError(exceptions.PermissionDeniedError): 'https://firebase.google.com/docs/admin/setup for details ' 'on how to initialize the Admin SDK with appropriate permissions') - def __init__( - self, - message: str, - cause: Optional[Exception], - http_response: Optional[Union[httpx.Response, requests.Response]], - ) -> None: - super().__init__(message, cause, http_response) - class InvalidDynamicLinkDomainError(exceptions.InvalidArgumentError): """Dynamic link domain in ActionCodeSettings is not authorized.""" default_message = 'Dynamic link domain specified in ActionCodeSettings is not authorized' - def __init__( - self, - message: str, - cause: Optional[Exception], - http_response: Optional[Union[httpx.Response, requests.Response]], - ) -> None: - super().__init__(message, cause, http_response) - class InvalidIdTokenError(exceptions.InvalidArgumentError): """The provided ID token is not a valid Firebase ID token.""" default_message = 'The provided ID token is invalid' - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class PhoneNumberAlreadyExistsError(exceptions.AlreadyExistsError): """The user with the provided phone number already exists.""" default_message = 'The user with the provided phone number already exists' - def __init__( - self, - message: str, - cause: Optional[Exception], - http_response: Optional[Union[httpx.Response, requests.Response]], - ) -> None: - super().__init__(message, cause, http_response) - class UnexpectedResponseError(exceptions.UnknownError): """Backend service responded with an unexpected or malformed response.""" - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class UserNotFoundError(exceptions.NotFoundError): """No user record found for the specified identifier.""" default_message = 'No user record found for the given identifier' - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class EmailNotFoundError(exceptions.NotFoundError): """No user record found for the specified email.""" default_message = 'No user record found for the given email' - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class TenantNotFoundError(exceptions.NotFoundError): """No tenant found for the specified identifier.""" default_message = 'No tenant found for the given identifier' - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class TenantIdMismatchError(exceptions.InvalidArgumentError): """Missing or invalid tenant ID field in the given JWT.""" - def __init__(self, message: str) -> None: - super().__init__(message) - class ConfigurationNotFoundError(exceptions.NotFoundError): """No auth provider found for the specified identifier.""" default_message = 'No auth provider found for the given identifier' - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class UserDisabledError(exceptions.InvalidArgumentError): """An operation failed due to a user record being disabled.""" default_message = 'The user record is disabled' - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class TooManyAttemptsTryLaterError(exceptions.ResourceExhaustedError): """Rate limited because of too many attempts.""" - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - class ResetPasswordExceedLimitError(exceptions.ResourceExhaustedError): """Reset password emails exceeded their limits.""" - def __init__( - self, - message: str, - cause: Optional[Exception] = None, - http_response: Optional[Union[httpx.Response, requests.Response]] = None, - ) -> None: - super().__init__(message, cause, http_response) - _CODE_TO_EXC_TYPE = { 'CONFIGURATION_NOT_FOUND': ConfigurationNotFoundError, diff --git a/firebase_admin/_sseclient.py b/firebase_admin/_sseclient.py index ea0d5ac23..796b65d82 100644 --- a/firebase_admin/_sseclient.py +++ b/firebase_admin/_sseclient.py @@ -163,7 +163,7 @@ def __next__(self) -> Optional['Event']: return event def next(self) -> Optional['Event']: - return self.__next__() + return next(self) class Event: diff --git a/firebase_admin/db.py b/firebase_admin/db.py index de9cb520b..8545e914d 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -36,11 +36,12 @@ cast, overload, ) -from typing_extensions import Self, TypeVar from urllib import parse import google.auth.credentials import requests +from typing_extensions import Self, TypeVar + import firebase_admin from firebase_admin import exceptions @@ -849,7 +850,10 @@ def __gt__(self, other: '_SortEntry') -> bool: def __ge__(self, other: '_SortEntry') -> bool: return self._compare(other) >= 0 - def __eq__(self, other: '_SortEntry') -> bool: # pyright: ignore[reportIncompatibleMethodOverride] + def __eq__( # pyright: ignore[reportIncompatibleMethodOverride] + self, + other: '_SortEntry', + ) -> bool: return self._compare(other) == 0 diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 38b2f69af..3e906cd84 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -290,7 +290,8 @@ def _convert_to_millis(date_string: Optional[str]) -> Optional[int]: return None format_str = '%Y-%m-%dT%H:%M:%S.%fZ' epoch = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) - datetime_object = datetime.datetime.strptime(date_string, format_str).replace(tzinfo=datetime.timezone.utc) + datetime_object = datetime.datetime.strptime( + date_string, format_str).replace(tzinfo=datetime.timezone.utc) millis = int((datetime_object - epoch).total_seconds() * 1000) return millis @@ -881,7 +882,11 @@ def get_operation(self, op_name: str) -> dict[str, Any]: except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def _exponential_backoff(self, current_attempt: int, stop_time: Optional[datetime.datetime]) -> None: + def _exponential_backoff( + self, + current_attempt: int, + stop_time: Optional[datetime.datetime], + ) -> None: """Sleeps for the appropriate amount of time. Or throws deadline exceeded.""" delay_factor = pow(_MLService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) wait_time_seconds = delay_factor * _MLService.POLL_BASE_WAIT_TIME_SECONDS diff --git a/firebase_admin/remote_config.py b/firebase_admin/remote_config.py index b6b4955e8..c9af51883 100644 --- a/firebase_admin/remote_config.py +++ b/firebase_admin/remote_config.py @@ -302,7 +302,10 @@ def _get_url(self) -> str: return f"/v1/projects/{self._project_id}/namespaces/firebase-server/serverRemoteConfig" @classmethod - def _handle_remote_config_error(cls, error: requests.RequestException) -> exceptions.FirebaseError: + def _handle_remote_config_error( + cls, + error: requests.RequestException, + ) -> exceptions.FirebaseError: """Handles errors received from the Cloud Functions API.""" return _utils.handle_platform_error_from_requests(error) diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py index e1dc0d8b1..41c21548c 100644 --- a/firebase_admin/tenant_mgt.py +++ b/firebase_admin/tenant_mgt.py @@ -340,7 +340,7 @@ def delete_tenant(self, tenant_id: str) -> None: """Deletes the tenant corresponding to the given ``tenant_id``.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( - 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + f'Invalid tenant ID: {tenant_id}. Tenant ID must be a non-empty string.') try: self.client.request('delete', f'/tenants/{tenant_id}') From 1c40ceba47c98ea5bf8dd166702f0fab9cb656af Mon Sep 17 00:00:00 2001 From: ViktorSky Date: Fri, 25 Jul 2025 16:25:20 -0300 Subject: [PATCH 11/13] rollback unintentional changes --- snippets/messaging/cloud_messaging.py | 22 +++++++++++ tests/test_functions.py | 56 +++++++++------------------ 2 files changed, 41 insertions(+), 37 deletions(-) diff --git a/snippets/messaging/cloud_messaging.py b/snippets/messaging/cloud_messaging.py index 6fb525231..3efd223ea 100644 --- a/snippets/messaging/cloud_messaging.py +++ b/snippets/messaging/cloud_messaging.py @@ -244,6 +244,28 @@ def send_each(): print(f'{response.success_count} messages were sent successfully') # [END send_each] + +def send_each_for_multicast(): + # [START send_each_for_multicast] + # Create a list containing up to 500 registration tokens. + # These registration tokens come from the client FCM SDKs. + registration_tokens = [ + 'YOUR_REGISTRATION_TOKEN_1', + # ... + 'YOUR_REGISTRATION_TOKEN_N', + ] + + message = messaging.MulticastMessage( + data={'score': '850', 'time': '2:45'}, + tokens=registration_tokens, + ) + response = messaging.send_each_for_multicast(message) + # See the BatchResponse reference documentation + # for the contents of response. + print(f'{response.success_count} messages were sent successfully') + # [END send_each_for_multicast] + + def send_each_for_multicast_and_handle_errors(): # [START send_each_for_multicast_error] # These registration tokens come from the client FCM SDKs. diff --git a/tests/test_functions.py b/tests/test_functions.py index 52e92c1b2..fdafbc4cc 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -14,7 +14,7 @@ """Test cases for the firebase_admin.functions module.""" -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta import json import time import pytest @@ -33,6 +33,8 @@ _CLOUD_TASKS_URL + 'projects/test-project/locations/us-central1/queues/test-function-name/tasks' _DEFAULT_TASK_URL = _CLOUD_TASKS_URL + _DEFAULT_TASK_PATH _DEFAULT_RESPONSE = json.dumps({'name': _DEFAULT_TASK_PATH}) +_ENQUEUE_TIME = datetime.utcnow() +_SCHEDULE_TIME = _ENQUEUE_TIME + timedelta(seconds=100) class TestTaskQueue: @classmethod @@ -183,46 +185,26 @@ def _instrument_functions_service(self, app=None, status=200, payload=_DEFAULT_R testutils.MockAdapter(payload, status, recorder)) return functions_service, recorder - def test_task_options_delay_seconds(self): - _, recorder = self._instrument_functions_service() - enqueue_time = datetime.now(timezone.utc) - expected_schedule_time = enqueue_time + timedelta(seconds=100) - task_opts_params = { + @pytest.mark.parametrize('task_opts_params', [ + { 'schedule_delay_seconds': 100, 'schedule_time': None, 'dispatch_deadline_seconds': 200, 'task_id': 'test-task-id', 'headers': {'x-test-header': 'test-header-value'}, 'uri': 'https://google.com' - } - queue = functions.task_queue('test-function-name') - task_opts = functions.TaskOptions(**task_opts_params) - queue.enqueue(_DEFAULT_DATA, task_opts) - - assert len(recorder) == 1 - task = json.loads(recorder[0].body.decode())['task'] - - task_schedule_time = datetime.fromisoformat(task['schedule_time'].replace('Z', '+00:00')) - delta = abs(task_schedule_time - expected_schedule_time) - assert delta <= timedelta(seconds=1) - - assert task['dispatch_deadline'] == '200s' - assert task['http_request']['headers']['x-test-header'] == 'test-header-value' - assert task['http_request']['url'] in ['http://google.com', 'https://google.com'] - assert task['name'] == _DEFAULT_TASK_PATH - - def test_task_options_utc_time(self): - _, recorder = self._instrument_functions_service() - enqueue_time = datetime.now(timezone.utc) - expected_schedule_time = enqueue_time + timedelta(seconds=100) - task_opts_params = { + }, + { 'schedule_delay_seconds': None, - 'schedule_time': expected_schedule_time, + 'schedule_time': _SCHEDULE_TIME, 'dispatch_deadline_seconds': 200, 'task_id': 'test-task-id', 'headers': {'x-test-header': 'test-header-value'}, 'uri': 'http://google.com' - } + }, + ]) + def test_task_options(self, task_opts_params): + _, recorder = self._instrument_functions_service() queue = functions.task_queue('test-function-name') task_opts = functions.TaskOptions(**task_opts_params) queue.enqueue(_DEFAULT_DATA, task_opts) @@ -230,8 +212,9 @@ def test_task_options_utc_time(self): assert len(recorder) == 1 task = json.loads(recorder[0].body.decode())['task'] - task_schedule_time = datetime.fromisoformat(task['schedule_time'].replace('Z', '+00:00')) - assert task_schedule_time == expected_schedule_time + schedule_time = datetime.fromisoformat(task['schedule_time'][:-1]) + delta = abs(schedule_time - _SCHEDULE_TIME) + assert delta <= timedelta(seconds=15) assert task['dispatch_deadline'] == '200s' assert task['http_request']['headers']['x-test-header'] == 'test-header-value' @@ -240,8 +223,7 @@ def test_task_options_utc_time(self): def test_schedule_set_twice_error(self): _, recorder = self._instrument_functions_service() - opts = functions.TaskOptions( - schedule_delay_seconds=100, schedule_time=datetime.now(timezone.utc)) + opts = functions.TaskOptions(schedule_delay_seconds=100, schedule_time=datetime.utcnow()) queue = functions.task_queue('test-function-name') with pytest.raises(ValueError) as excinfo: queue.enqueue(_DEFAULT_DATA, opts) @@ -252,9 +234,9 @@ def test_schedule_set_twice_error(self): @pytest.mark.parametrize('schedule_time', [ time.time(), - str(datetime.now(timezone.utc)), - datetime.now(timezone.utc).isoformat(), - datetime.now(timezone.utc).isoformat() + 'Z', + str(datetime.utcnow()), + datetime.utcnow().isoformat(), + datetime.utcnow().isoformat() + 'Z', '', ' ' ]) def test_invalid_schedule_time_error(self, schedule_time): From 2e5047adba0adf6f7bc3c85d01cf84ebe6b37877 Mon Sep 17 00:00:00 2001 From: ViktorSky Date: Sat, 26 Jul 2025 13:09:41 -0300 Subject: [PATCH 12/13] add missing commas --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 1876070c6..1670cf8e3 100644 --- a/setup.py +++ b/setup.py @@ -42,9 +42,9 @@ 'google-cloud-storage>=3.1.1', 'pyjwt[crypto] >= 2.10.1', 'httpx[http2] == 0.28.1', - 'typing-extensions >= 4.12.0' - 'types-requests' - 'types-httplib2' + 'typing-extensions >= 4.12.0', + 'types-requests', + 'types-httplib2', ] setup( From 5c096a8813559e7eac5fd30d98ace9148854670a Mon Sep 17 00:00:00 2001 From: ViktorSky Date: Sun, 27 Jul 2025 00:08:50 -0300 Subject: [PATCH 13/13] add py.typed mark --- firebase_admin/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 firebase_admin/py.typed diff --git a/firebase_admin/py.typed b/firebase_admin/py.typed new file mode 100644 index 000000000..e69de29bb